From 05cac355846f4c17e259d1e57fd69f20d0b36cd6 Mon Sep 17 00:00:00 2001 From: Allen Cheng Date: Thu, 11 Jun 2026 18:47:55 -0700 Subject: [PATCH] feat: add hybrid retrieval API --- README.md | 10 + crates/lance-context-api/src/lib.rs | 50 ++++ crates/lance-context-client/src/lib.rs | 23 ++ crates/lance-context-core/src/api_impl.rs | 51 +++- crates/lance-context-core/src/lib.rs | 2 +- crates/lance-context-core/src/record.rs | 116 ++++++-- crates/lance-context-core/src/store.rs | 268 +++++++++++++++++- crates/lance-context-server/src/routes/mod.rs | 1 + .../lance-context-server/src/routes/search.rs | 62 +++- crates/lance-context/src/lib.rs | 6 +- crates/lance-context/src/unified.rs | 7 +- python/python/lance_context/api.py | 67 +++++ python/src/lib.rs | 150 +++++----- python/tests/test_async.py | 19 ++ python/tests/test_persistence.py | 60 ++++ python/tests/test_search.py | 154 +++++++++- 16 files changed, 927 insertions(+), 119 deletions(-) diff --git a/README.md b/README.md index a594d45..663675f 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,16 @@ hits = ctx.search( ) service_context = ctx.related("service://service-a", relation="describes") +# Hybrid retrieval combines lexical recall, vector recall, and existing filters +# over the same context records. +hybrid_hits = ctx.retrieve( + text="service-a runbook owner", + vector=runbook_embedding, + limit=5, + filters={"tenant": "example-org", "scope": "team"}, +) +print(hybrid_hits[0]["matched_channels"], hybrid_hits[0]["score"]) + from PIL import Image image = Image.new("RGB", (2, 2), color="teal") ctx.add("assistant", image) diff --git a/crates/lance-context-api/src/lib.rs b/crates/lance-context-api/src/lib.rs index fc69e6b..89ae788 100644 --- a/crates/lance-context-api/src/lib.rs +++ b/crates/lance-context-api/src/lib.rs @@ -49,6 +49,11 @@ pub trait ContextStoreApi { include_relationships: bool, ) -> impl Future>> + Send; + fn retrieve( + &self, + request: &RetrieveRequest, + ) -> impl Future>> + Send; + fn version(&self) -> u64; fn checkout(&mut self, version: u64) -> impl Future> + Send; @@ -244,6 +249,47 @@ pub struct SearchResponse { pub results: Vec, } +// --------------------------------------------------------------------------- +// Hybrid retrieval +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize, Deserialize)] +pub struct RetrieveRequest { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub vector: Option>, + #[serde(default = "default_search_limit")] + pub limit: usize, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub filters: Option, + #[serde(default)] + pub include_expired: bool, + #[serde(default)] + pub include_retired: bool, + #[serde(default)] + pub include_relationships: bool, + #[serde(default = "default_retrieve_fusion")] + pub fusion: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RetrieveResultDto { + pub record: RecordDto, + pub score: f32, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub vector_distance: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub text_score: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub matched_channels: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RetrieveResponse { + pub results: Vec, +} + // --------------------------------------------------------------------------- // Versioning // --------------------------------------------------------------------------- @@ -320,6 +366,10 @@ fn default_search_limit() -> usize { 10 } +fn default_retrieve_fusion() -> String { + "rrf".to_string() +} + fn serialize_base64_opt(data: &Option>, serializer: S) -> Result where S: serde::Serializer, diff --git a/crates/lance-context-client/src/lib.rs b/crates/lance-context-client/src/lib.rs index 8f7409d..ce8ab59 100644 --- a/crates/lance-context-client/src/lib.rs +++ b/crates/lance-context-client/src/lib.rs @@ -99,6 +99,15 @@ impl ContextStoreApi for RemoteContextStore { Ok(resp.results) } + async fn retrieve(&self, request: &RetrieveRequest) -> ContextResult> { + let resp = self + .client + .retrieve(&self.context_name, request) + .await + .map_err(to_ctx_err)?; + Ok(resp.results) + } + fn version(&self) -> u64 { self.cached_version } @@ -273,6 +282,20 @@ impl ContextClient { Self::handle_response(resp).await } + pub async fn retrieve( + &self, + name: &str, + req: &RetrieveRequest, + ) -> Result { + let resp = self + .http + .post(self.url(&format!("/contexts/{}/retrieve", name))) + .json(req) + .send() + .await?; + Self::handle_response(resp).await + } + pub async fn get_version(&self, name: &str) -> Result { let resp = self .http diff --git a/crates/lance-context-core/src/api_impl.rs b/crates/lance-context-core/src/api_impl.rs index 600b3aa..2875494 100644 --- a/crates/lance-context-core/src/api_impl.rs +++ b/crates/lance-context-core/src/api_impl.rs @@ -3,11 +3,14 @@ use uuid::Uuid; use lance_context_api::{ AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse, - ContextError, ContextResult, ContextStoreApi, RecordDto, RelationshipDto, SearchResultDto, - StateMetadataDto, + ContextError, ContextResult, ContextStoreApi, RecordDto, RelationshipDto, RetrieveRequest, + RetrieveResultDto, SearchResultDto, StateMetadataDto, }; -use crate::record::{ContextRecord, Relationship, StateMetadata, LIFECYCLE_ACTIVE}; +use crate::record::{ + ContextRecord, LifecycleQueryOptions, RecordFilters, Relationship, StateMetadata, + LIFECYCLE_ACTIVE, +}; use crate::store::{CompactionConfig, ContextStore}; impl ContextStoreApi for ContextStore { @@ -102,6 +105,48 @@ impl ContextStoreApi for ContextStore { .collect()) } + async fn retrieve(&self, request: &RetrieveRequest) -> ContextResult> { + if request.fusion != "rrf" { + return Err(ContextError::InvalidRequest( + "retrieve fusion currently supports only 'rrf'".to_string(), + )); + } + + let filters = request + .filters + .clone() + .map(RecordFilters::from_json_value) + .transpose() + .map_err(ContextError::InvalidRequest)?; + let options = LifecycleQueryOptions::new(request.include_expired, request.include_retired); + let results = self + .retrieve_filtered_with_options( + request.text.as_deref(), + request.vector.as_deref(), + Some(request.limit), + filters.as_ref(), + options, + ) + .await + .map_err(to_ctx_err)?; + + Ok(results + .into_iter() + .map(|mut result| { + if !request.include_relationships { + result.record.relationships.clear(); + } + RetrieveResultDto { + record: record_to_dto(result.record), + score: result.score, + vector_distance: result.vector_distance, + text_score: result.text_score, + matched_channels: result.matched_channels, + } + }) + .collect()) + } + fn version(&self) -> u64 { ContextStore::version(self) } diff --git a/crates/lance-context-core/src/lib.rs b/crates/lance-context-core/src/lib.rs index a3f6dbf..a20bf1b 100644 --- a/crates/lance-context-core/src/lib.rs +++ b/crates/lance-context-core/src/lib.rs @@ -10,7 +10,7 @@ mod store; pub use context::{Context, ContextEntry, Snapshot}; pub use record::{ ContextRecord, LifecycleQueryOptions, MetadataFilter, RecordFilters, Relationship, - SearchResult, StateMetadata, LIFECYCLE_ACTIVE, LIFECYCLE_CONTRADICTED, + RetrieveResult, SearchResult, StateMetadata, LIFECYCLE_ACTIVE, LIFECYCLE_CONTRADICTED, }; pub use store::{ CompactionConfig, CompactionStats, ContextStore, ContextStoreOptions, IdIndexType, diff --git a/crates/lance-context-core/src/record.rs b/crates/lance-context-core/src/record.rs index 1be7717..9eec2bc 100644 --- a/crates/lance-context-core/src/record.rs +++ b/crates/lance-context-core/src/record.rs @@ -129,6 +129,16 @@ pub struct SearchResult { pub distance: f32, } +/// Result returned from hybrid retrieval over context records. +#[derive(Debug, Clone)] +pub struct RetrieveResult { + pub record: ContextRecord, + pub score: f32, + pub vector_distance: Option, + pub text_score: Option, + pub matched_channels: Vec, +} + /// Metadata matching operation for filtered retrieval. #[derive(Debug, Clone, PartialEq)] pub enum MetadataFilter { @@ -149,6 +159,42 @@ pub struct RecordFilters { } impl RecordFilters { + pub fn from_json_value(value: Value) -> Result { + let Value::Object(object) = value else { + return Err("filters must be a JSON object".to_string()); + }; + + let mut filters = RecordFilters::default(); + for (key, value) in object { + match key.as_str() { + "bot_id" => filters.bot_id = filter_string(key.as_str(), value)?, + "session_id" => filters.session_id = filter_string(key.as_str(), value)?, + "role" => filters.role = filter_string(key.as_str(), value)?, + "content_type" => filters.content_type = filter_string(key.as_str(), value)?, + "created_at" => apply_created_at_filter(&mut filters, value)?, + "created_at_start" | "created_after" | "created_at_gte" => { + filters.created_at_start = Some(parse_filter_datetime(&key, &value)?); + } + "created_at_end" | "created_before" | "created_at_lte" => { + filters.created_at_end = Some(parse_filter_datetime(&key, &value)?); + } + _ => { + let filter = match value { + Value::Object(mut object) + if object.len() == 1 && object.contains_key("contains") => + { + MetadataFilter::Contains(object.remove("contains").unwrap()) + } + value => MetadataFilter::Equals(value), + }; + filters.metadata.insert(key, filter); + } + } + } + + Ok(filters) + } + #[must_use] pub fn is_empty(&self) -> bool { self.bot_id.is_none() @@ -218,6 +264,47 @@ impl RecordFilters { } } +fn filter_string(name: &str, value: Value) -> Result, String> { + match value { + Value::Null => Ok(None), + Value::String(value) => Ok(Some(value)), + _ => Err(format!("filter '{name}' must be a string or null")), + } +} + +fn apply_created_at_filter(filters: &mut RecordFilters, value: Value) -> Result<(), String> { + let Value::Object(object) = value else { + return Err("filter 'created_at' must be an object with gte/lte bounds".to_string()); + }; + + for (key, value) in object { + match key.as_str() { + "gte" | "start" | "after" => { + filters.created_at_start = Some(parse_filter_datetime(&key, &value)?); + } + "lte" | "end" | "before" => { + filters.created_at_end = Some(parse_filter_datetime(&key, &value)?); + } + other => { + return Err(format!("unsupported created_at filter operator '{other}'")); + } + } + } + + Ok(()) +} + +fn parse_filter_datetime(name: &str, value: &Value) -> Result, String> { + let Some(value) = value.as_str() else { + return Err(format!( + "filter '{name}' must be an ISO-8601 timestamp string" + )); + }; + DateTime::parse_from_rfc3339(value) + .map(|value| value.with_timezone(&Utc)) + .map_err(|err| err.to_string()) +} + fn metadata_contains(value: &Value, expected: &Value) -> bool { match (value, expected) { (Value::Array(items), expected) => items.iter().any(|item| item == expected), @@ -264,22 +351,19 @@ mod tests { #[test] fn filters_match_builtin_fields_timestamps_and_metadata() { - let mut filters = RecordFilters { - bot_id: Some("support-bot".to_string()), - session_id: Some("incident-1".to_string()), - role: Some("assistant".to_string()), - content_type: Some("text/plain".to_string()), - created_at_start: Some(Utc.with_ymd_and_hms(2026, 6, 9, 2, 0, 0).unwrap()), - created_at_end: Some(Utc.with_ymd_and_hms(2026, 6, 9, 4, 0, 0).unwrap()), - metadata: HashMap::new(), - }; - filters - .metadata - .insert("scope".to_string(), MetadataFilter::Equals(json!("team"))); - filters.metadata.insert( - "tags".to_string(), - MetadataFilter::Contains(json!("runbook")), - ); + let mut filters = RecordFilters::from_json_value(json!({ + "bot_id": "support-bot", + "session_id": "incident-1", + "role": "assistant", + "content_type": "text/plain", + "created_at": { + "gte": "2026-06-09T02:00:00Z", + "lte": "2026-06-09T04:00:00Z" + }, + "scope": "team", + "tags": {"contains": "runbook"} + })) + .unwrap(); assert!(filters.matches(&record())); diff --git a/crates/lance-context-core/src/store.rs b/crates/lance-context-core/src/store.rs index 734afab..527d8ba 100644 --- a/crates/lance-context-core/src/store.rs +++ b/crates/lance-context-core/src/store.rs @@ -1,3 +1,4 @@ +use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::Duration; @@ -34,8 +35,8 @@ use tracing::{error, info, warn}; use uuid::Uuid; use crate::record::{ - ContextRecord, LifecycleQueryOptions, RecordFilters, Relationship, SearchResult, StateMetadata, - LIFECYCLE_ACTIVE, + ContextRecord, LifecycleQueryOptions, RecordFilters, Relationship, RetrieveResult, + SearchResult, StateMetadata, LIFECYCLE_ACTIVE, }; use crate::serde::CONTENT_TYPE_TOMBSTONE; @@ -43,6 +44,7 @@ use crate::serde::CONTENT_TYPE_TOMBSTONE; const DEFAULT_EMBEDDING_DIM: i32 = 1536; const DEFAULT_SEARCH_LIMIT: usize = 10; const DEFAULT_MANIFEST_SCAN_BATCH_SIZE: usize = 16; +const RRF_K: f32 = 60.0; const ID_INDEX_NAME: &str = "id_idx"; const RELATIONSHIPS_COLUMN: &str = "relationships"; @@ -646,14 +648,7 @@ impl ContextStore { filters: Option<&RecordFilters>, options: LifecycleQueryOptions, ) -> LanceResult> { - if query.len() != DEFAULT_EMBEDDING_DIM as usize { - return Err(ArrowError::InvalidArgumentError(format!( - "query length {} does not match embedding dimension {}", - query.len(), - DEFAULT_EMBEDDING_DIM - )) - .into()); - } + validate_query_dimension(query)?; let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT); if top_k == 0 { @@ -674,6 +669,100 @@ impl ContextStore { Ok(results) } + /// Retrieve records using optional text and vector channels, after filters and lifecycle visibility. + pub async fn retrieve_filtered_with_options( + &self, + text: Option<&str>, + vector: Option<&[f32]>, + limit: Option, + filters: Option<&RecordFilters>, + options: LifecycleQueryOptions, + ) -> LanceResult> { + let text_terms = text.map(unique_query_terms).unwrap_or_default(); + let has_text = !text_terms.is_empty(); + + if !has_text && vector.is_none() { + return Err(ArrowError::InvalidArgumentError( + "retrieve requires text or vector".to_string(), + ) + .into()); + } + + if let Some(query) = vector { + validate_query_dimension(query)?; + } + + let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT); + if top_k == 0 { + return Ok(Vec::new()); + } + + let records = self + .list_filtered_with_options(None, None, filters, options) + .await?; + let mut candidates: HashMap = HashMap::new(); + + if let Some(query) = vector { + let mut vector_hits: Vec<(usize, f32)> = records + .iter() + .enumerate() + .filter_map(|(index, record)| { + let distance = l2_distance(query, record.embedding.as_ref()?); + Some((index, distance)) + }) + .collect(); + vector_hits.sort_by(|left, right| { + left.1 + .total_cmp(&right.1) + .then_with(|| records[left.0].id.cmp(&records[right.0].id)) + }); + + for (rank, (index, distance)) in vector_hits.into_iter().enumerate() { + add_retrieve_channel( + &mut candidates, + &records[index], + rank + 1, + "vector", + Some(distance), + None, + ); + } + } + + if has_text { + let mut text_hits: Vec<(usize, f32)> = records + .iter() + .enumerate() + .filter_map(|(index, record)| { + lexical_score(&text_terms, record.text_payload.as_deref()) + .map(|score| (index, score)) + }) + .collect(); + text_hits.sort_by(|left, right| { + right + .1 + .total_cmp(&left.1) + .then_with(|| records[left.0].id.cmp(&records[right.0].id)) + }); + + for (rank, (index, score)) in text_hits.into_iter().enumerate() { + add_retrieve_channel( + &mut candidates, + &records[index], + rank + 1, + "text", + None, + Some(score), + ); + } + } + + let mut results: Vec = candidates.into_values().collect(); + results.sort_by(compare_retrieve_results); + results.truncate(top_k); + Ok(results) + } + async fn lsm_scanner(&self) -> LanceResult { let object_store = self.dataset.object_store(None).await?; let branch_location = self.dataset.branch_location(); @@ -1755,6 +1844,127 @@ fn l2_distance(left: &[f32], right: &[f32]) -> f32 { .sqrt() } +fn validate_query_dimension(query: &[f32]) -> LanceResult<()> { + if query.len() != DEFAULT_EMBEDDING_DIM as usize { + return Err(ArrowError::InvalidArgumentError(format!( + "query length {} does not match embedding dimension {}", + query.len(), + DEFAULT_EMBEDDING_DIM + )) + .into()); + } + Ok(()) +} + +fn unique_query_terms(text: &str) -> Vec { + let mut seen = HashSet::new(); + tokenize_for_retrieval(text) + .into_iter() + .filter(|term| seen.insert(term.clone())) + .collect() +} + +fn tokenize_for_retrieval(text: &str) -> Vec { + let mut terms = Vec::new(); + let mut current = String::new(); + + for character in text.chars() { + if character.is_alphanumeric() { + current.extend(character.to_lowercase()); + } else if !current.is_empty() { + terms.push(std::mem::take(&mut current)); + } + } + + if !current.is_empty() { + terms.push(current); + } + + terms +} + +fn lexical_score(query_terms: &[String], text: Option<&str>) -> Option { + let text = text?; + if query_terms.is_empty() { + return None; + } + + let payload_terms: HashSet = tokenize_for_retrieval(text).into_iter().collect(); + if payload_terms.is_empty() { + return None; + } + + let matched_terms = query_terms + .iter() + .filter(|term| payload_terms.contains(*term)) + .count(); + if matched_terms == 0 { + return None; + } + + Some(matched_terms as f32 / query_terms.len() as f32) +} + +fn add_retrieve_channel( + candidates: &mut HashMap, + record: &ContextRecord, + rank: usize, + channel: &str, + vector_distance: Option, + text_score: Option, +) { + let candidate = candidates + .entry(record.id.clone()) + .or_insert_with(|| RetrieveResult { + record: record.clone(), + score: 0.0, + vector_distance: None, + text_score: None, + matched_channels: Vec::new(), + }); + candidate.score += 1.0 / (RRF_K + rank as f32); + if let Some(distance) = vector_distance { + candidate.vector_distance = Some(distance); + } + if let Some(score) = text_score { + candidate.text_score = Some(score); + } + if !candidate + .matched_channels + .iter() + .any(|existing| existing == channel) + { + candidate.matched_channels.push(channel.to_string()); + } +} + +fn compare_retrieve_results(left: &RetrieveResult, right: &RetrieveResult) -> Ordering { + right + .score + .total_cmp(&left.score) + .then_with(|| compare_optional_distance(left.vector_distance, right.vector_distance)) + .then_with(|| compare_optional_score(left.text_score, right.text_score)) + .then_with(|| left.record.id.cmp(&right.record.id)) +} + +fn compare_optional_distance(left: Option, right: Option) -> Ordering { + match (left, right) { + (Some(left), Some(right)) => left.total_cmp(&right), + (Some(_), None) => Ordering::Less, + (None, Some(_)) => Ordering::Greater, + (None, None) => Ordering::Equal, + } +} + +fn compare_optional_score(left: Option, right: Option) -> Ordering { + match (left, right) { + (Some(left), Some(right)) => right.total_cmp(&left), + (Some(_), None) => Ordering::Less, + (None, Some(_)) => Ordering::Greater, + (None, None) => Ordering::Equal, + } +} + fn column_as<'a, A>(batch: &'a RecordBatch, name: &str) -> LanceResult<&'a A> where A: Array + 'static, @@ -1866,6 +2076,44 @@ mod tests { }); } + #[test] + fn retrieve_fuses_text_and_vector_channels() { + let dir = TempDir::new().unwrap(); + let uri = dir.path().to_string_lossy().to_string(); + let runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let mut store = ContextStore::open(&uri).await.unwrap(); + let mut semantic_near = text_record("semantic-near", 0.0); + semantic_near.text_payload = Some("general rollout risk guidance".to_string()); + let mut exact_policy = text_record("exact-policy", 1.0); + exact_policy.text_payload = Some("POLICY-123 blocks service-a rollouts".to_string()); + + store + .add(&[semantic_near.clone(), exact_policy.clone()]) + .await + .unwrap(); + + let query = make_embedding(0.0); + let results = store + .retrieve_filtered_with_options( + Some("POLICY-123 service-a"), + Some(&query), + Some(2), + None, + LifecycleQueryOptions::default(), + ) + .await + .unwrap(); + + assert_eq!(results.len(), 2); + assert_eq!(results[0].record.id, exact_policy.id); + assert!(results[0].score > results[1].score); + assert!(results[0].vector_distance.is_some()); + assert_eq!(results[0].text_score, Some(1.0)); + assert_eq!(results[0].matched_channels, ["vector", "text"]); + }); + } + #[test] fn list_hides_expired_and_retired_records_by_default() { let dir = TempDir::new().unwrap(); diff --git a/crates/lance-context-server/src/routes/mod.rs b/crates/lance-context-server/src/routes/mod.rs index 2337ac8..8e353fe 100644 --- a/crates/lance-context-server/src/routes/mod.rs +++ b/crates/lance-context-server/src/routes/mod.rs @@ -32,6 +32,7 @@ pub fn router() -> Router> { get(records::get_record), ) .route("/api/v1/contexts/{name}/search", post(search::search)) + .route("/api/v1/contexts/{name}/retrieve", post(search::retrieve)) .route( "/api/v1/contexts/{name}/version", get(versions::get_version), diff --git a/crates/lance-context-server/src/routes/search.rs b/crates/lance-context-server/src/routes/search.rs index feb8c04..6cae04f 100644 --- a/crates/lance-context-server/src/routes/search.rs +++ b/crates/lance-context-server/src/routes/search.rs @@ -2,7 +2,11 @@ use std::sync::Arc; use axum::extract::{Path, State}; use axum::Json; -use lance_context_api::{SearchRequest, SearchResponse, SearchResultDto}; +use lance_context_api::{ + RetrieveRequest, RetrieveResponse, RetrieveResultDto, SearchRequest, SearchResponse, + SearchResultDto, +}; +use lance_context_core::{LifecycleQueryOptions, RecordFilters}; use crate::error::AppError; use crate::routes::records::record_to_dto; @@ -41,3 +45,59 @@ pub async fn search( Ok(Json(SearchResponse { results: dtos })) } + +pub async fn retrieve( + State(state): State>, + Path(name): Path, + Json(req): Json, +) -> Result, AppError> { + if req.fusion != "rrf" { + return Err(AppError::InvalidRequest( + "retrieve fusion currently supports only 'rrf'".to_string(), + )); + } + + let filters = req + .filters + .clone() + .map(RecordFilters::from_json_value) + .transpose() + .map_err(AppError::InvalidRequest)?; + + let stores = state.stores.read().await; + let store_lock = stores + .get(&name) + .ok_or_else(|| AppError::NotFound(format!("Context '{}' does not exist", name)))? + .clone(); + drop(stores); + + let store = store_lock.read().await; + let results = store + .retrieve_filtered_with_options( + req.text.as_deref(), + req.vector.as_deref(), + Some(req.limit), + filters.as_ref(), + LifecycleQueryOptions::new(req.include_expired, req.include_retired), + ) + .await + .map_err(AppError::from_lance)?; + + let dtos: Vec = results + .into_iter() + .map(|mut result| { + if !req.include_relationships { + result.record.relationships.clear(); + } + RetrieveResultDto { + record: record_to_dto(result.record), + score: result.score, + vector_distance: result.vector_distance, + text_score: result.text_score, + matched_channels: result.matched_channels, + } + }) + .collect(); + + Ok(Json(RetrieveResponse { results: dtos })) +} diff --git a/crates/lance-context/src/lib.rs b/crates/lance-context/src/lib.rs index 8d757e8..44f9ed4 100644 --- a/crates/lance-context/src/lib.rs +++ b/crates/lance-context/src/lib.rs @@ -5,12 +5,14 @@ pub use lance_context_core::serde; pub use lance_context_core::{ CompactionConfig, CompactionMetrics, CompactionStats, Context, ContextEntry, ContextRecord, ContextStoreOptions, IdIndexType, LifecycleQueryOptions, MetadataFilter, RecordFilters, - Relationship, SearchResult, Snapshot, StateMetadata, LIFECYCLE_ACTIVE, LIFECYCLE_CONTRADICTED, + Relationship, RetrieveResult, SearchResult, Snapshot, StateMetadata, LIFECYCLE_ACTIVE, + LIFECYCLE_CONTRADICTED, }; pub use lance_context_api::{ AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse, - ContextError, ContextResult, ContextStoreApi, RecordDto, RelationshipDto, SearchResultDto, + ContextError, ContextResult, ContextStoreApi, RecordDto, RelationshipDto, RetrieveRequest, + RetrieveResponse, RetrieveResultDto, SearchResultDto, }; #[cfg(feature = "remote")] diff --git a/crates/lance-context/src/unified.rs b/crates/lance-context/src/unified.rs index c664c89..95269bf 100644 --- a/crates/lance-context/src/unified.rs +++ b/crates/lance-context/src/unified.rs @@ -2,7 +2,8 @@ use std::collections::HashSet; use lance_context_api::{ AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse, - ContextError, ContextResult, ContextStoreApi, RecordDto, SearchResultDto, + ContextError, ContextResult, ContextStoreApi, RecordDto, RetrieveRequest, RetrieveResultDto, + SearchResultDto, }; use lance_context_core::{ContextStore as LocalStore, ContextStoreOptions, IdIndexType}; @@ -130,6 +131,10 @@ impl ContextStoreApi for ContextStore { dispatch_ref!(self, search, query, limit, include_relationships) } + async fn retrieve(&self, request: &RetrieveRequest) -> ContextResult> { + dispatch_ref!(self, retrieve, request) + } + fn version(&self) -> u64 { dispatch_sync!(self, version) } diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index 1a6ff0f..0995eef 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -163,6 +163,16 @@ def _normalize_search_hit(raw: dict[str, Any]) -> dict[str, Any]: return result +def _normalize_retrieve_hit(raw: dict[str, Any]) -> dict[str, Any]: + """Normalize a retrieve hit with hybrid ranking diagnostics.""" + result = _normalize_record(raw) + result["score"] = raw.get("score") + result["vector_distance"] = raw.get("vector_distance") + result["text_score"] = raw.get("text_score") + result["matched_channels"] = list(raw.get("matched_channels") or []) + return result + + _AWS_KWARG_MAP: dict[str, str] = { "aws_access_key_id": "aws_access_key_id", "aws_secret_access_key": "aws_secret_access_key", @@ -512,6 +522,36 @@ def search( ) return [_normalize_search_hit(item) for item in results] + def retrieve( + self, + *, + text: str | None = None, + vector: Any | None = None, + limit: int | None = None, + filters: dict[str, Any] | None = None, + include_expired: bool = False, + include_retired: bool = False, + include_relationships: bool = False, + fusion: str = "rrf", + ) -> list[dict[str, Any]]: + if text is None and vector is None: + raise ValueError("retrieve requires text or vector") + if fusion != "rrf": + raise ValueError("retrieve fusion currently supports only 'rrf'") + + coerced_vector = _coerce_vector(vector) if vector is not None else None + results = self._inner.retrieve( + text, + coerced_vector, + limit, + _json_dumps(filters, "filters"), + include_expired, + include_retired, + include_relationships, + fusion, + ) + return [_normalize_retrieve_hit(item) for item in results] + def list( self, limit: int | None = None, @@ -780,6 +820,33 @@ async def search( ), ) + async def retrieve( + self, + *, + text: str | None = None, + vector: Any | None = None, + limit: int | None = None, + filters: dict[str, Any] | None = None, + include_expired: bool = False, + include_retired: bool = False, + include_relationships: bool = False, + fusion: str = "rrf", + ) -> list[dict[str, Any]]: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + lambda: self._sync.retrieve( + text=text, + vector=vector, + limit=limit, + filters=filters, + include_expired=include_expired, + include_retired=include_retired, + include_relationships=include_relationships, + fusion=fusion, + ), + ) + async def related( self, target_id: str, diff --git a/python/src/lib.rs b/python/src/lib.rs index b58ad44..134e2d8 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -14,8 +14,8 @@ use tokio::runtime::Runtime; use lance_context_core::serde::CONTENT_TYPE_TEXT; use lance_context_core::{ CompactionConfig, CompactionMetrics, CompactionStats, Context as RustContext, ContextRecord, - ContextStore, ContextStoreOptions, IdIndexType, LifecycleQueryOptions, MetadataFilter, - RecordFilters, Relationship, SearchResult, LIFECYCLE_ACTIVE, + ContextStore, ContextStoreOptions, IdIndexType, LifecycleQueryOptions, RecordFilters, + Relationship, RetrieveResult, SearchResult, LIFECYCLE_ACTIVE, }; const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream"; @@ -159,86 +159,9 @@ fn filters_from_json(filters_json: Option) -> PyResult filters.bot_id = filter_string(key.as_str(), value)?, - "session_id" => filters.session_id = filter_string(key.as_str(), value)?, - "role" => filters.role = filter_string(key.as_str(), value)?, - "content_type" => filters.content_type = filter_string(key.as_str(), value)?, - "created_at" => apply_created_at_filter(&mut filters, value)?, - "created_at_start" | "created_after" | "created_at_gte" => { - filters.created_at_start = Some(parse_filter_datetime(&key, &value)?); - } - "created_at_end" | "created_before" | "created_at_lte" => { - filters.created_at_end = Some(parse_filter_datetime(&key, &value)?); - } - _ => { - let filter = match value { - Value::Object(mut object) - if object.len() == 1 && object.contains_key("contains") => - { - MetadataFilter::Contains(object.remove("contains").unwrap()) - } - value => MetadataFilter::Equals(value), - }; - filters.metadata.insert(key, filter); - } - } - } - - Ok(Some(filters)) -} - -fn filter_string(name: &str, value: Value) -> PyResult> { - match value { - Value::Null => Ok(None), - Value::String(value) => Ok(Some(value)), - _ => Err(PyRuntimeError::new_err(format!( - "filter '{name}' must be a string or null" - ))), - } -} - -fn apply_created_at_filter(filters: &mut RecordFilters, value: Value) -> PyResult<()> { - let Value::Object(object) = value else { - return Err(PyRuntimeError::new_err( - "filter 'created_at' must be an object with gte/lte bounds", - )); - }; - - for (key, value) in object { - match key.as_str() { - "gte" | "start" | "after" => { - filters.created_at_start = Some(parse_filter_datetime(&key, &value)?); - } - "lte" | "end" | "before" => { - filters.created_at_end = Some(parse_filter_datetime(&key, &value)?); - } - other => { - return Err(PyRuntimeError::new_err(format!( - "unsupported created_at filter operator '{other}'" - ))); - } - } - } - - Ok(()) -} - -fn parse_filter_datetime(name: &str, value: &Value) -> PyResult> { - let Some(value) = value.as_str() else { - return Err(PyRuntimeError::new_err(format!( - "filter '{name}' must be an ISO-8601 timestamp string" - ))); - }; - DateTime::parse_from_rfc3339(value) - .map(|value| value.with_timezone(&Utc)) - .map_err(to_py_err) + RecordFilters::from_json_value(value) + .map(Some) + .map_err(PyRuntimeError::new_err) } #[pymethods] @@ -442,6 +365,44 @@ impl Context { .collect() } + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (text = None, vector = None, limit = None, filters_json = None, include_expired = false, include_retired = false, include_relationships = false, fusion = None))] + fn retrieve( + &self, + py: Python<'_>, + text: Option, + vector: Option>, + limit: Option, + filters_json: Option, + include_expired: bool, + include_retired: bool, + include_relationships: bool, + fusion: Option, + ) -> PyResult> { + if fusion.as_deref().is_some_and(|value| value != "rrf") { + return Err(PyRuntimeError::new_err( + "retrieve fusion currently supports only 'rrf'", + )); + } + + let filters = filters_from_json(filters_json)?; + let options = LifecycleQueryOptions::new(include_expired, include_retired); + let hits_res = py.allow_threads(|| { + self.runtime + .block_on(self.store.retrieve_filtered_with_options( + text.as_deref(), + vector.as_deref(), + limit, + filters.as_ref(), + options, + )) + }); + let hits = hits_res.map_err(to_py_err)?; + hits.into_iter() + .map(|hit| retrieve_hit_to_py(py, hit, include_relationships)) + .collect() + } + #[pyo3(signature = (limit = None, offset = None, filters_json = None, include_expired = false, include_retired = false))] fn list( &self, @@ -823,6 +784,31 @@ fn search_hit_to_py( Ok(dict) } +fn retrieve_hit_to_py( + py: Python<'_>, + hit: RetrieveResult, + include_relationships: bool, +) -> PyResult { + let RetrieveResult { + record, + score, + vector_distance, + text_score, + matched_channels, + } = hit; + let mut record = record; + if !include_relationships { + record.relationships.clear(); + } + let dict = record_to_py(py, record)?; + let dict_ref = dict.downcast_bound::(py)?; + dict_ref.set_item("score", score)?; + dict_ref.set_item("vector_distance", vector_distance)?; + dict_ref.set_item("text_score", text_score)?; + dict_ref.set_item("matched_channels", matched_channels)?; + Ok(dict) +} + fn record_to_py(py: Python<'_>, record: ContextRecord) -> PyResult { let ContextRecord { id, diff --git a/python/tests/test_async.py b/python/tests/test_async.py index 4c24e54..359589e 100644 --- a/python/tests/test_async.py +++ b/python/tests/test_async.py @@ -118,6 +118,25 @@ async def test_search(tmp_path: Path) -> None: assert results[0]["text"] == "hello" +@pytest.mark.asyncio +async def test_retrieve(tmp_path: Path) -> None: + uri = str(tmp_path / "ctx.lance") + ctx = await AsyncContext.create(uri) + + dim = 1536 + near = [0.0] * dim + far = [0.0] * dim + far[0] = 1.0 + + await ctx.add("assistant", "general rollout guidance", embedding=near) + await ctx.add("assistant", "POLICY-123 blocks service-a", embedding=far) + + results = await ctx.retrieve(text="POLICY-123 service-a", vector=near, limit=1) + assert len(results) == 1 + assert results[0]["text"] == "POLICY-123 blocks service-a" + assert results[0]["matched_channels"] == ["vector", "text"] + + @pytest.mark.asyncio async def test_metadata_filters(tmp_path: Path) -> None: uri = str(tmp_path / "ctx.lance") diff --git a/python/tests/test_persistence.py b/python/tests/test_persistence.py index af3605b..988c1bf 100644 --- a/python/tests/test_persistence.py +++ b/python/tests/test_persistence.py @@ -289,6 +289,66 @@ def test_search_applies_filters_before_limit(tmp_path: Path) -> None: assert hits[0]["metadata"] == {"scope": "team", "tags": ["runbook"]} +def test_retrieve_fuses_text_vector_and_filters(tmp_path: Path) -> None: + uri = tmp_path / "context.lance" + ctx = Context.create(str(uri)) + near = [0.0] * 1536 + far = [0.0] * 1536 + far[0] = 1.0 + + ctx.add( + "assistant", + "general rollout risk guidance", + embedding=near, + metadata={"scope": "team", "tags": ["runbook"]}, + ) + ctx.add( + "assistant", + "POLICY-123 blocks service-a rollouts", + embedding=far, + metadata={"scope": "team", "tags": ["policy"]}, + ) + ctx.add( + "assistant", + "POLICY-123 personal note for service-a", + embedding=far, + metadata={"scope": "personal", "tags": ["policy"]}, + ) + + hits = ctx.retrieve( + text="POLICY-123 service-a", + vector=near, + limit=2, + filters={"scope": "team"}, + ) + + assert [hit["text"] for hit in hits] == [ + "POLICY-123 blocks service-a rollouts", + "general rollout risk guidance", + ] + assert hits[0]["matched_channels"] == ["vector", "text"] + assert hits[0]["score"] > hits[1]["score"] + assert hits[0]["vector_distance"] is not None + assert hits[0]["text_score"] == 1.0 + assert hits[1]["matched_channels"] == ["vector"] + + +def test_retrieve_supports_text_only(tmp_path: Path) -> None: + uri = tmp_path / "context.lance" + ctx = Context.create(str(uri)) + + ctx.add("assistant", "The rollout owner is service-a.") + ctx.add("assistant", "The unrelated deployment note mentions service-b.") + + hits = ctx.retrieve(text="service-a rollout", limit=1) + + assert len(hits) == 1 + assert hits[0]["text"] == "The rollout owner is service-a." + assert hits[0]["matched_channels"] == ["text"] + assert hits[0]["vector_distance"] is None + assert hits[0]["text_score"] == 1.0 + + def test_lifecycle_fields_round_trip_and_default_filtering(tmp_path: Path) -> None: uri = tmp_path / "context.lance" ctx = Context.create(str(uri)) diff --git a/python/tests/test_search.py b/python/tests/test_search.py index 3a4ef1d..a527849 100644 --- a/python/tests/test_search.py +++ b/python/tests/test_search.py @@ -7,6 +7,7 @@ Context, _coerce_vector, _normalize_record, + _normalize_retrieve_hit, _normalize_search_hit, ) @@ -16,6 +17,18 @@ def __init__(self) -> None: self.search_calls: list[tuple[list[float], int | None, str | None]] = [] self.search_lifecycle_calls: list[tuple[bool, bool]] = [] self.search_relationship_calls: list[bool] = [] + self.retrieve_calls: list[ + tuple[ + str | None, + list[float] | None, + int | None, + str | None, + bool, + bool, + bool, + str, + ] + ] = [] self.list_calls: list[tuple[int | None, int | None, str | None]] = [] self.list_lifecycle_calls: list[tuple[bool, bool]] = [] self.related_calls: list[tuple[str, str | None, int | None, bool, bool]] = [] @@ -132,6 +145,51 @@ def search( ] return [hit] + def retrieve( + self, + text: str | None, + vector: list[float] | None, + limit: int | None, + filters_json: str | None, + include_expired: bool = False, + include_retired: bool = False, + include_relationships: bool = False, + fusion: str = "rrf", + ): + self.retrieve_calls.append( + ( + text, + vector, + limit, + filters_json, + include_expired, + include_retired, + include_relationships, + fusion, + ) + ) + hit = self.search( + vector or [0.0, 0.0], + limit, + filters_json, + include_expired, + include_retired, + include_relationships, + )[0] + hit.pop("distance", None) + hit["score"] = 0.032 + hit["vector_distance"] = 0.12 if vector is not None else None + hit["text_score"] = 1.0 if text else None + hit["matched_channels"] = [ + channel + for channel, enabled in ( + ("vector", vector is not None), + ("text", text is not None), + ) + if enabled + ] + return [hit] + def add_many(self, records: list[dict[str, Any]]): self.add_many_calls.append(records) @@ -294,6 +352,72 @@ def test_context_search_can_include_relationships(): ] +def test_context_retrieve_forwards_hybrid_arguments(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + hits = ctx.retrieve( + text="POLICY-123 service-a", + vector=[0.5, 0.4], + limit=3, + filters={"bot_id": "support_bot", "scope": "team"}, + include_expired=True, + include_retired=True, + include_relationships=True, + ) + + assert dummy.retrieve_calls == [ + ( + "POLICY-123 service-a", + [0.5, 0.4], + 3, + '{"bot_id":"support_bot","scope":"team"}', + True, + True, + True, + "rrf", + ) + ] + assert hits[0]["score"] == 0.032 + assert hits[0]["vector_distance"] == 0.12 + assert hits[0]["text_score"] == 1.0 + assert hits[0]["matched_channels"] == ["vector", "text"] + assert hits[0]["relationships"] == [ + {"target_id": "doc-1#chunk-1", "relation": "cites", "weight": 0.75} + ] + + +def test_context_retrieve_accepts_text_only(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + hits = ctx.retrieve(text="runbook") + + assert dummy.retrieve_calls[0][0] == "runbook" + assert dummy.retrieve_calls[0][1] is None + assert hits[0]["matched_channels"] == ["text"] + + +def test_context_retrieve_requires_text_or_vector(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + with pytest.raises(ValueError, match="requires text or vector"): + ctx.retrieve() + + +def test_context_retrieve_rejects_unknown_fusion(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + with pytest.raises(ValueError, match="supports only 'rrf'"): + ctx.retrieve(text="runbook", fusion="weighted") + + def test_normalize_record_without_distance(): result = _normalize_record( { @@ -339,6 +463,32 @@ def test_normalize_record_with_relationships(): ] +def test_normalize_retrieve_hit_with_scores(): + result = _normalize_retrieve_hit( + { + "id": "rec-1", + "external_id": None, + "created_at": "2024-01-01T00:00:00Z", + "content_type": "text/plain", + "text_payload": "hello", + "binary_payload": None, + "embedding": None, + "run_id": "run-1", + "role": "user", + "state_metadata": None, + "score": 0.032, + "vector_distance": 0.12, + "text_score": 1.0, + "matched_channels": ["vector", "text"], + } + ) + + assert result["score"] == 0.032 + assert result["vector_distance"] == 0.12 + assert result["text_score"] == 1.0 + assert result["matched_channels"] == ["vector", "text"] + + def test_context_list_returns_entries(): ctx = Context.__new__(Context) dummy = DummyInner() @@ -876,9 +1026,7 @@ def test_context_add_many_forwards_relationships(): { "role": "user", "content": "hello", - "relationships": [ - {"target_id": "doc-1#chunk-1", "relation": "cites"} - ], + "relationships": [{"target_id": "doc-1#chunk-1", "relation": "cites"}], } ] )