diff --git a/crates/lance-context-api/src/lib.rs b/crates/lance-context-api/src/lib.rs index ebc595b..bac85e7 100644 --- a/crates/lance-context-api/src/lib.rs +++ b/crates/lance-context-api/src/lib.rs @@ -34,6 +34,16 @@ pub trait ContextStoreApi { records: &[AddRecordRequest], ) -> impl Future> + Send; + fn upsert( + &mut self, + request: &UpsertRecordRequest, + ) -> impl Future> + Send; + + fn update( + &mut self, + request: &UpdateRecordRequest, + ) -> impl Future> + Send; + fn get(&self, id: &str) -> impl Future>> + Send; fn get_by_external_id( @@ -194,6 +204,82 @@ pub struct AddRecordsResponse { pub count: usize, } +#[derive(Debug, Serialize, Deserialize)] +pub struct UpsertRecordRequest { + pub record: AddRecordRequest, + #[serde(default = "default_upsert_key")] + pub key: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct UpsertRecordResponse { + pub version: u64, + pub inserted: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub replaced_id: Option, + pub record: RecordDto, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct RecordPatchDto { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub bot_id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub session_id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub state_metadata: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub relationships: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub expires_at: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub retention_policy: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub lifecycle_status: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub retired_at: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub retired_reason: Option, +} + +impl RecordPatchDto { + #[must_use] + pub fn is_empty(&self) -> bool { + self.bot_id.is_none() + && self.session_id.is_none() + && self.state_metadata.is_none() + && self.metadata.is_none() + && self.relationships.is_none() + && self.expires_at.is_none() + && self.retention_policy.is_none() + && self.lifecycle_status.is_none() + && self.retired_at.is_none() + && self.retired_reason.is_none() + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct UpdateRecordRequest { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub external_id: Option, + #[serde(default)] + pub patch: RecordPatchDto, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct UpdateRecordResponse { + pub version: u64, + pub updated: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub replaced_id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub record: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RecordDto { pub id: String, @@ -396,6 +482,10 @@ fn default_role() -> String { "user".to_string() } +fn default_upsert_key() -> String { + "external_id".to_string() +} + fn default_search_limit() -> usize { 10 } diff --git a/crates/lance-context-client/src/lib.rs b/crates/lance-context-client/src/lib.rs index 3c52600..f88f4d5 100644 --- a/crates/lance-context-client/src/lib.rs +++ b/crates/lance-context-client/src/lib.rs @@ -58,6 +58,32 @@ impl ContextStoreApi for RemoteContextStore { Ok(resp) } + async fn upsert( + &mut self, + request: &UpsertRecordRequest, + ) -> ContextResult { + let resp = self + .client + .upsert_record(&self.context_name, request) + .await + .map_err(to_ctx_err)?; + self.cached_version = resp.version; + Ok(resp) + } + + async fn update( + &mut self, + request: &UpdateRecordRequest, + ) -> ContextResult { + let resp = self + .client + .update_record(&self.context_name, request) + .await + .map_err(to_ctx_err)?; + self.cached_version = resp.version; + Ok(resp) + } + async fn get(&self, id: &str) -> ContextResult> { let resp = self .client @@ -292,6 +318,34 @@ impl ContextClient { Self::handle_response(resp).await } + pub async fn upsert_record( + &self, + name: &str, + req: &UpsertRecordRequest, + ) -> Result { + let resp = self + .http + .put(self.url(&format!("/contexts/{}/records", name))) + .json(req) + .send() + .await?; + Self::handle_response(resp).await + } + + pub async fn update_record( + &self, + name: &str, + req: &UpdateRecordRequest, + ) -> Result { + let resp = self + .http + .patch(self.url(&format!("/contexts/{}/records", name))) + .json(req) + .send() + .await?; + Self::handle_response(resp).await + } + pub async fn get_record(&self, name: &str, id: &str) -> Result { let resp = self .http diff --git a/crates/lance-context-core/src/api_impl.rs b/crates/lance-context-core/src/api_impl.rs index a3896ce..a05ed14 100644 --- a/crates/lance-context-core/src/api_impl.rs +++ b/crates/lance-context-core/src/api_impl.rs @@ -3,12 +3,13 @@ use uuid::Uuid; use lance_context_api::{ AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse, - ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RelationshipDto, - RetrieveRequest, RetrieveResultDto, SearchResultDto, StateMetadataDto, + ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RecordPatchDto, + RelationshipDto, RetrieveRequest, RetrieveResultDto, SearchResultDto, StateMetadataDto, + UpdateRecordRequest, UpdateRecordResponse, UpsertRecordRequest, UpsertRecordResponse, }; use crate::record::{ - ContextRecord, LifecycleQueryOptions, RecordFilters, Relationship, StateMetadata, + ContextRecord, LifecycleQueryOptions, RecordFilters, RecordPatch, Relationship, StateMetadata, LIFECYCLE_ACTIVE, }; use crate::store::{CompactionConfig, ContextStore}; @@ -22,39 +23,7 @@ impl ContextStoreApi for ContextStore { for r in records { let id = Uuid::new_v4().to_string(); ids.push(id.clone()); - core_records.push(ContextRecord { - id, - external_id: r.external_id.clone(), - run_id: run_id.clone(), - bot_id: r.bot_id.clone(), - session_id: r.session_id.clone(), - created_at: Utc::now(), - role: r.role.clone(), - state_metadata: r.state_metadata.as_ref().map(|sm| StateMetadata { - step: sm.step, - active_plan_id: sm.active_plan_id.clone(), - tokens_used: sm.tokens_used, - custom: sm.custom.clone(), - }), - metadata: r.metadata.clone(), - relationships: r - .relationships - .iter() - .cloned() - .map(dto_to_relationship) - .collect(), - expires_at: r.expires_at, - retention_policy: r.retention_policy.clone(), - lifecycle_status: LIFECYCLE_ACTIVE.to_string(), - retired_at: None, - retired_reason: None, - supersedes_id: r.supersedes_id.clone(), - superseded_by_id: None, - content_type: r.content_type.clone(), - text_payload: r.text_payload.clone(), - binary_payload: r.binary_payload.clone(), - embedding: r.embedding.clone(), - }); + core_records.push(record_from_add_request(r, id, run_id.clone())); } let count = core_records.len(); @@ -66,6 +35,88 @@ impl ContextStoreApi for ContextStore { }) } + async fn upsert( + &mut self, + request: &UpsertRecordRequest, + ) -> ContextResult { + if request.key != "external_id" { + return Err(ContextError::InvalidRequest(format!( + "upsert key '{}' is not supported; use 'external_id'", + request.key + ))); + } + if request + .record + .external_id + .as_deref() + .is_none_or(str::is_empty) + { + return Err(ContextError::InvalidRequest( + "upsert requires record.external_id".to_string(), + )); + } + + let record = record_from_add_request( + &request.record, + Uuid::new_v4().to_string(), + Uuid::new_v4().to_string(), + ); + let result = ContextStore::upsert_by_external_id(self, record) + .await + .map_err(to_ctx_err)?; + Ok(UpsertRecordResponse { + version: result.version, + inserted: result.inserted, + replaced_id: result.replaced_id, + record: record_to_dto(result.record), + }) + } + + async fn update( + &mut self, + request: &UpdateRecordRequest, + ) -> ContextResult { + if request.patch.is_empty() { + return Err(ContextError::InvalidRequest( + "update requires at least one patch field".to_string(), + )); + } + + let patch = patch_from_dto(&request.patch); + let result = match (&request.id, &request.external_id) { + (Some(id), None) => ContextStore::update_by_id(self, id, patch).await, + (None, Some(external_id)) => { + ContextStore::update_by_external_id(self, external_id, patch).await + } + (None, None) => { + return Err(ContextError::InvalidRequest( + "update requires either id or external_id".to_string(), + )); + } + (Some(_), Some(_)) => { + return Err(ContextError::InvalidRequest( + "update accepts only one of id or external_id".to_string(), + )); + } + } + .map_err(to_ctx_err)?; + + Ok(match result { + Some(result) => UpdateRecordResponse { + version: result.version, + updated: true, + replaced_id: Some(result.replaced_id), + record: Some(record_to_dto(result.record)), + }, + None => UpdateRecordResponse { + version: ContextStore::version(self), + updated: false, + replaced_id: None, + record: None, + }, + }) + } + async fn get(&self, id: &str) -> ContextResult> { let record = ContextStore::get(self, id).await.map_err(to_ctx_err)?; Ok(record.map(record_to_dto)) @@ -256,6 +307,68 @@ fn relationship_to_dto(r: Relationship) -> RelationshipDto { } } +fn patch_from_dto(patch: &RecordPatchDto) -> RecordPatch { + RecordPatch { + bot_id: patch.bot_id.clone(), + session_id: patch.session_id.clone(), + state_metadata: patch.state_metadata.as_ref().map(|sm| StateMetadata { + step: sm.step, + active_plan_id: sm.active_plan_id.clone(), + tokens_used: sm.tokens_used, + custom: sm.custom.clone(), + }), + metadata: patch.metadata.clone(), + relationships: patch.relationships.as_ref().map(|relationships| { + relationships + .iter() + .cloned() + .map(dto_to_relationship) + .collect() + }), + expires_at: patch.expires_at, + retention_policy: patch.retention_policy.clone(), + lifecycle_status: patch.lifecycle_status.clone(), + retired_at: patch.retired_at, + retired_reason: patch.retired_reason.clone(), + } +} + +fn record_from_add_request(r: &AddRecordRequest, id: String, run_id: String) -> ContextRecord { + ContextRecord { + id, + external_id: r.external_id.clone(), + run_id, + bot_id: r.bot_id.clone(), + session_id: r.session_id.clone(), + created_at: Utc::now(), + role: r.role.clone(), + state_metadata: r.state_metadata.as_ref().map(|sm| StateMetadata { + step: sm.step, + active_plan_id: sm.active_plan_id.clone(), + tokens_used: sm.tokens_used, + custom: sm.custom.clone(), + }), + metadata: r.metadata.clone(), + relationships: r + .relationships + .iter() + .cloned() + .map(dto_to_relationship) + .collect(), + expires_at: r.expires_at, + retention_policy: r.retention_policy.clone(), + lifecycle_status: LIFECYCLE_ACTIVE.to_string(), + retired_at: None, + retired_reason: None, + supersedes_id: r.supersedes_id.clone(), + superseded_by_id: None, + content_type: r.content_type.clone(), + text_payload: r.text_payload.clone(), + binary_payload: r.binary_payload.clone(), + embedding: r.embedding.clone(), + } +} + fn record_to_dto(r: ContextRecord) -> RecordDto { RecordDto { id: r.id, diff --git a/crates/lance-context-core/src/lib.rs b/crates/lance-context-core/src/lib.rs index 42b2438..160eac3 100644 --- a/crates/lance-context-core/src/lib.rs +++ b/crates/lance-context-core/src/lib.rs @@ -9,8 +9,9 @@ mod store; pub use context::{Context, ContextEntry, Snapshot}; pub use record::{ - ContextRecord, LifecycleQueryOptions, MetadataFilter, RecordFilters, Relationship, - RetrieveResult, SearchResult, StateMetadata, LIFECYCLE_ACTIVE, LIFECYCLE_CONTRADICTED, + ContextRecord, LifecycleQueryOptions, MetadataFilter, RecordFilters, RecordPatch, Relationship, + RetrieveResult, SearchResult, StateMetadata, UpdateResult, UpsertResult, LIFECYCLE_ACTIVE, + LIFECYCLE_CONTRADICTED, }; pub use store::{ CompactionConfig, CompactionStats, ContextStore, ContextStoreOptions, DistanceMetric, diff --git a/crates/lance-context-core/src/record.rs b/crates/lance-context-core/src/record.rs index d31bd22..abfa3c2 100644 --- a/crates/lance-context-core/src/record.rs +++ b/crates/lance-context-core/src/record.rs @@ -143,6 +143,54 @@ pub struct RetrieveResult { pub matched_channels: Vec, } +/// Result returned from insert-or-replace operations. +#[derive(Debug, Clone)] +pub struct UpsertResult { + pub record: ContextRecord, + pub inserted: bool, + pub replaced_id: Option, + pub version: u64, +} + +/// Mutable fields that can be patched without resupplying the payload. +#[derive(Debug, Clone, Default)] +pub struct RecordPatch { + pub bot_id: Option, + pub session_id: Option, + pub state_metadata: Option, + pub metadata: Option, + pub relationships: Option>, + pub expires_at: Option>, + pub retention_policy: Option, + pub lifecycle_status: Option, + pub retired_at: Option>, + pub retired_reason: Option, +} + +impl RecordPatch { + #[must_use] + pub fn is_empty(&self) -> bool { + self.bot_id.is_none() + && self.session_id.is_none() + && self.state_metadata.is_none() + && self.metadata.is_none() + && self.relationships.is_none() + && self.expires_at.is_none() + && self.retention_policy.is_none() + && self.lifecycle_status.is_none() + && self.retired_at.is_none() + && self.retired_reason.is_none() + } +} + +/// Result returned from partial record update operations. +#[derive(Debug, Clone)] +pub struct UpdateResult { + pub record: ContextRecord, + pub replaced_id: String, + pub version: u64, +} + /// Metadata matching operation for filtered retrieval. #[derive(Debug, Clone, PartialEq)] pub enum MetadataFilter { diff --git a/crates/lance-context-core/src/store.rs b/crates/lance-context-core/src/store.rs index abe6cc6..de97f5b 100644 --- a/crates/lance-context-core/src/store.rs +++ b/crates/lance-context-core/src/store.rs @@ -35,8 +35,8 @@ use tracing::{error, info, warn}; use uuid::Uuid; use crate::record::{ - ContextRecord, LifecycleQueryOptions, RecordFilters, Relationship, RetrieveResult, - SearchResult, StateMetadata, LIFECYCLE_ACTIVE, + ContextRecord, LifecycleQueryOptions, RecordFilters, RecordPatch, Relationship, RetrieveResult, + SearchResult, StateMetadata, UpdateResult, UpsertResult, LIFECYCLE_ACTIVE, }; use crate::serde::CONTENT_TYPE_TOMBSTONE; @@ -411,6 +411,192 @@ impl ContextStore { Ok(true) } + /// Insert a record or replace the currently-visible record with the same external id. + /// + /// Replacement is append-only: the new record keeps the same `external_id` + /// and gets `supersedes_id` set to the old record id. Default reads hide + /// the superseded record while `include_retired` reads can still inspect + /// both versions. Caller-supplied supersession fields are ignored because + /// this method manages replacement by `external_id`. + pub async fn upsert_by_external_id( + &mut self, + mut record: ContextRecord, + ) -> LanceResult { + let Some(external_id) = record.external_id.clone() else { + return Err(ArrowError::InvalidArgumentError( + "upsert_by_external_id requires external_id".to_string(), + ) + .into()); + }; + if external_id.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "upsert_by_external_id requires a non-empty external_id".to_string(), + ) + .into()); + } + if record.is_tombstone() { + return Err(ArrowError::InvalidArgumentError(format!( + "content_type '{}' is reserved for internal tombstones", + CONTENT_TYPE_TOMBSTONE + )) + .into()); + } + record.supersedes_id = None; + record.superseded_by_id = None; + self.validate_new_record_id(&record).await?; + + let matches: Vec = self + .list(None, None) + .await? + .into_iter() + .filter(|existing| existing.external_id.as_deref() == Some(external_id.as_str())) + .collect(); + + match matches.as_slice() { + [] => { + let version = self.add(std::slice::from_ref(&record)).await?; + Ok(UpsertResult { + record, + inserted: true, + replaced_id: None, + version, + }) + } + [existing] => { + record.supersedes_id = Some(existing.id.clone()); + let version = self.write_entries(std::slice::from_ref(&record)).await?; + Ok(UpsertResult { + record, + inserted: false, + replaced_id: Some(existing.id.clone()), + version, + }) + } + _ => Err(ArrowError::InvalidArgumentError(format!( + "external_id '{}' matches multiple visible records", + external_id + )) + .into()), + } + } + + /// Partially update mutable fields on a visible record by internal id. + /// + /// The update is append-only: it writes a replacement record that + /// supersedes the current visible record, preserving the original payload + /// and embedding while changing only the requested patch fields. + pub async fn update_by_id( + &mut self, + id: &str, + patch: RecordPatch, + ) -> LanceResult> { + if id.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "update_by_id requires a non-empty id".to_string(), + ) + .into()); + } + let Some(existing) = self.get_by_id(id).await? else { + return Ok(None); + }; + self.update_visible_record(existing, patch).await.map(Some) + } + + /// Partially update mutable fields on a visible record by external id. + /// + /// Returns `Ok(None)` when no visible record currently has the external id. + pub async fn update_by_external_id( + &mut self, + external_id: &str, + patch: RecordPatch, + ) -> LanceResult> { + if external_id.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "update_by_external_id requires a non-empty external_id".to_string(), + ) + .into()); + } + + let matches: Vec = self + .list(None, None) + .await? + .into_iter() + .filter(|existing| existing.external_id.as_deref() == Some(external_id)) + .collect(); + + match matches.as_slice() { + [] => Ok(None), + [existing] => self + .update_visible_record(existing.clone(), patch) + .await + .map(Some), + _ => Err(ArrowError::InvalidArgumentError(format!( + "external_id '{}' matches multiple visible records", + external_id + )) + .into()), + } + } + + async fn update_visible_record( + &mut self, + existing: ContextRecord, + patch: RecordPatch, + ) -> LanceResult { + if patch.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "update requires at least one patch field".to_string(), + ) + .into()); + } + + let mut record = existing.clone(); + record.id = Uuid::new_v4().to_string(); + record.run_id = Uuid::new_v4().to_string(); + record.created_at = Utc::now(); + record.supersedes_id = Some(existing.id.clone()); + record.superseded_by_id = None; + + if let Some(bot_id) = patch.bot_id { + record.bot_id = Some(bot_id); + } + if let Some(session_id) = patch.session_id { + record.session_id = Some(session_id); + } + if let Some(state_metadata) = patch.state_metadata { + record.state_metadata = Some(state_metadata); + } + if let Some(metadata) = patch.metadata { + record.metadata = Some(metadata); + } + if let Some(relationships) = patch.relationships { + record.relationships = relationships; + } + if let Some(expires_at) = patch.expires_at { + record.expires_at = Some(expires_at); + } + if let Some(retention_policy) = patch.retention_policy { + record.retention_policy = Some(retention_policy); + } + if let Some(lifecycle_status) = patch.lifecycle_status { + record.lifecycle_status = lifecycle_status; + } + if let Some(retired_at) = patch.retired_at { + record.retired_at = Some(retired_at); + } + if let Some(retired_reason) = patch.retired_reason { + record.retired_reason = Some(retired_reason); + } + + self.validate_new_record_id(&record).await?; + let version = self.write_entries(std::slice::from_ref(&record)).await?; + Ok(UpdateResult { + record, + replaced_id: existing.id, + version, + }) + } + async fn write_tombstone_for(&mut self, record: ContextRecord) -> LanceResult { let tombstone = ContextRecord { id: record.id, @@ -492,6 +678,22 @@ impl ContextStore { Ok(()) } + async fn validate_new_record_id(&self, entry: &ContextRecord) -> LanceResult<()> { + for record in self + .list_with_options(None, None, LifecycleQueryOptions::new(true, true)) + .await? + { + if record.id == entry.id { + return Err(ArrowError::InvalidArgumentError(format!( + "id '{}' already exists", + entry.id + )) + .into()); + } + } + Ok(()) + } + fn derive_region_id(bot_id: &Option, session_id: &Option) -> Uuid { let mut input = String::new(); @@ -2598,6 +2800,119 @@ mod tests { }); } + #[test] + fn upsert_by_external_id_inserts_then_replaces_visible_record() { + let dir = TempDir::new().unwrap(); + let uri = dir.path().to_string_lossy().to_string(); + let runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let mut store = ContextStore::open(&uri).await.unwrap(); + + let mut first = text_record("first", 0.0); + first.external_id = Some("doc-123#chunk-1".to_string()); + let inserted = store.upsert_by_external_id(first.clone()).await.unwrap(); + assert!(inserted.inserted); + assert_eq!(inserted.replaced_id, None); + assert_eq!(inserted.record.id, first.id); + + let mut replacement = text_record("replacement", 1.0); + replacement.external_id = first.external_id.clone(); + let replaced = store + .upsert_by_external_id(replacement.clone()) + .await + .unwrap(); + assert!(!replaced.inserted); + assert_eq!(replaced.replaced_id.as_deref(), Some(first.id.as_str())); + assert_eq!( + replaced.record.supersedes_id.as_deref(), + Some(first.id.as_str()) + ); + + let visible = store.list(None, None).await.unwrap(); + assert_eq!(visible.len(), 1); + assert_eq!(visible[0].id, replacement.id); + + let by_external_id = store + .get_by_external_id("doc-123#chunk-1") + .await + .unwrap() + .unwrap(); + assert_eq!(by_external_id.id, replacement.id); + + let history = store + .list_with_options(None, None, LifecycleQueryOptions::new(false, true)) + .await + .unwrap(); + assert_eq!(history.len(), 2); + assert!(history.iter().any(|record| record.id == first.id)); + assert!(history.iter().any(|record| record.id == replacement.id)); + }); + } + + #[test] + fn update_by_external_id_patches_mutable_fields_and_preserves_payload() { + let dir = TempDir::new().unwrap(); + let uri = dir.path().to_string_lossy().to_string(); + let runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let mut store = ContextStore::open(&uri).await.unwrap(); + + let mut record = text_record("stable", 0.0); + record.external_id = Some("doc-123#chunk-1".to_string()); + record.metadata = Some(serde_json::json!({"revision": 1})); + store.add(std::slice::from_ref(&record)).await.unwrap(); + + let patch = RecordPatch { + bot_id: Some("bot-a".to_string()), + session_id: Some("session-a".to_string()), + metadata: Some(serde_json::json!({"revision": 2, "confidence": 0.9})), + relationships: Some(vec![Relationship { + target_id: "doc-123".to_string(), + relation: "derived_from".to_string(), + weight: None, + }]), + ..Default::default() + }; + let updated = store + .update_by_external_id("doc-123#chunk-1", patch) + .await + .unwrap() + .unwrap(); + + assert_eq!(updated.replaced_id, record.id); + assert_ne!(updated.record.id, record.id); + assert_eq!(updated.record.external_id, record.external_id); + assert_eq!(updated.record.text_payload, record.text_payload); + assert_eq!(updated.record.embedding, record.embedding); + assert_eq!(updated.record.bot_id.as_deref(), Some("bot-a")); + assert_eq!(updated.record.session_id.as_deref(), Some("session-a")); + assert_eq!( + updated.record.metadata, + Some(serde_json::json!({"revision": 2, "confidence": 0.9})) + ); + assert_eq!(updated.record.relationships.len(), 1); + assert_eq!( + updated.record.supersedes_id.as_deref(), + Some(record.id.as_str()) + ); + + let visible = store + .get_by_external_id("doc-123#chunk-1") + .await + .unwrap() + .unwrap(); + assert_eq!(visible.id, updated.record.id); + + let history = store + .list_with_options(None, None, LifecycleQueryOptions::new(false, true)) + .await + .unwrap(); + assert_eq!(history.len(), 2); + assert!(history.iter().any(|item| item.id == record.id)); + assert!(history.iter().any(|item| item.id == updated.record.id)); + }); + } + #[test] fn relationships_roundtrip_and_support_related_lookup() { let dir = TempDir::new().unwrap(); diff --git a/crates/lance-context-server/src/routes/mod.rs b/crates/lance-context-server/src/routes/mod.rs index a5d6596..f7182c6 100644 --- a/crates/lance-context-server/src/routes/mod.rs +++ b/crates/lance-context-server/src/routes/mod.rs @@ -7,7 +7,7 @@ pub mod versions; use std::sync::Arc; -use axum::routing::{delete, get, post}; +use axum::routing::{delete, get, patch, post, put}; use axum::Router; use crate::state::AppState; @@ -23,6 +23,14 @@ pub fn router() -> Router> { "/api/v1/contexts/{name}/records", post(records::add_records), ) + .route( + "/api/v1/contexts/{name}/records", + put(records::upsert_record), + ) + .route( + "/api/v1/contexts/{name}/records", + patch(records::update_record), + ) .route( "/api/v1/contexts/{name}/records", get(records::list_records), diff --git a/crates/lance-context-server/src/routes/records.rs b/crates/lance-context-server/src/routes/records.rs index a37b07a..fcd8344 100644 --- a/crates/lance-context-server/src/routes/records.rs +++ b/crates/lance-context-server/src/routes/records.rs @@ -5,10 +5,12 @@ use axum::Json; use chrono::Utc; use lance_context_api::{ AddRecordsRequest, AddRecordsResponse, DeleteRecordResponse, GetRecordResponse, - ListRecordsResponse, RecordDto, RelationshipDto, StateMetadataDto, + ListRecordsResponse, RecordDto, RecordPatchDto, RelationshipDto, StateMetadataDto, + UpdateRecordRequest, UpdateRecordResponse, UpsertRecordRequest, UpsertRecordResponse, }; use lance_context_core::{ - ContextRecord, LifecycleQueryOptions, Relationship, StateMetadata, LIFECYCLE_ACTIVE, + ContextRecord, LifecycleQueryOptions, RecordPatch, Relationship, StateMetadata, + LIFECYCLE_ACTIVE, }; use uuid::Uuid; @@ -40,39 +42,7 @@ pub async fn add_records( for r in &req.records { let id = Uuid::new_v4().to_string(); ids.push(id.clone()); - core_records.push(ContextRecord { - id, - external_id: r.external_id.clone(), - run_id: run_id.clone(), - bot_id: r.bot_id.clone(), - session_id: r.session_id.clone(), - created_at: Utc::now(), - role: r.role.clone(), - state_metadata: r.state_metadata.as_ref().map(|sm| StateMetadata { - step: sm.step, - active_plan_id: sm.active_plan_id.clone(), - tokens_used: sm.tokens_used, - custom: sm.custom.clone(), - }), - metadata: r.metadata.clone(), - relationships: r - .relationships - .iter() - .cloned() - .map(dto_to_relationship) - .collect(), - expires_at: r.expires_at, - retention_policy: r.retention_policy.clone(), - lifecycle_status: LIFECYCLE_ACTIVE.to_string(), - retired_at: None, - retired_reason: None, - supersedes_id: r.supersedes_id.clone(), - superseded_by_id: None, - content_type: r.content_type.clone(), - text_payload: r.text_payload.clone(), - binary_payload: r.binary_payload.clone(), - embedding: r.embedding.clone(), - }); + core_records.push(record_from_add_request(r, id, run_id.clone())); } let count = core_records.len(); @@ -92,6 +62,109 @@ pub async fn add_records( )) } +pub async fn upsert_record( + State(state): State>, + Path(name): Path, + Json(req): Json, +) -> Result<(axum::http::StatusCode, Json), AppError> { + if req.key != "external_id" { + return Err(AppError::InvalidRequest(format!( + "upsert key '{}' is not supported; use 'external_id'", + req.key + ))); + } + if req.record.external_id.as_deref().is_none_or(str::is_empty) { + return Err(AppError::InvalidRequest( + "upsert requires record.external_id".to_string(), + )); + } + + let stores = state.stores.read().await; + let store_lock = stores + .get(&name) + .ok_or_else(|| AppError::NotFound(format!("Context '{}' does not exist", name)))? + .clone(); + drop(stores); + + let record = record_from_add_request( + &req.record, + Uuid::new_v4().to_string(), + Uuid::new_v4().to_string(), + ); + let mut store = store_lock.write().await; + let result = store + .upsert_by_external_id(record) + .await + .map_err(AppError::from_lance)?; + let status = if result.inserted { + axum::http::StatusCode::CREATED + } else { + axum::http::StatusCode::OK + }; + + Ok(( + status, + Json(UpsertRecordResponse { + version: result.version, + inserted: result.inserted, + replaced_id: result.replaced_id, + record: record_to_dto(result.record), + }), + )) +} + +pub async fn update_record( + State(state): State>, + Path(name): Path, + Json(req): Json, +) -> Result, AppError> { + if req.patch.is_empty() { + return Err(AppError::InvalidRequest( + "update requires at least one patch field".to_string(), + )); + } + + let stores = state.stores.read().await; + let store_lock = stores + .get(&name) + .ok_or_else(|| AppError::NotFound(format!("Context '{}' does not exist", name)))? + .clone(); + drop(stores); + + let patch = patch_from_dto(&req.patch); + let mut store = store_lock.write().await; + let result = match (&req.id, &req.external_id) { + (Some(id), None) => store.update_by_id(id, patch).await, + (None, Some(external_id)) => store.update_by_external_id(external_id, patch).await, + (None, None) => { + return Err(AppError::InvalidRequest( + "update requires either id or external_id".to_string(), + )); + } + (Some(_), Some(_)) => { + return Err(AppError::InvalidRequest( + "update accepts only one of id or external_id".to_string(), + )); + } + } + .map_err(AppError::from_lance)?; + + Ok(Json(match result { + Some(result) => UpdateRecordResponse { + version: result.version, + updated: true, + replaced_id: Some(result.replaced_id), + record: Some(record_to_dto(result.record)), + }, + None => UpdateRecordResponse { + version: store.version(), + updated: false, + replaced_id: None, + record: None, + }, + })) +} + pub async fn get_record( State(state): State>, Path((name, id)): Path<(String, String)>, @@ -301,6 +374,72 @@ fn relationship_to_dto(r: Relationship) -> RelationshipDto { } } +fn patch_from_dto(patch: &RecordPatchDto) -> RecordPatch { + RecordPatch { + bot_id: patch.bot_id.clone(), + session_id: patch.session_id.clone(), + state_metadata: patch.state_metadata.as_ref().map(|sm| StateMetadata { + step: sm.step, + active_plan_id: sm.active_plan_id.clone(), + tokens_used: sm.tokens_used, + custom: sm.custom.clone(), + }), + metadata: patch.metadata.clone(), + relationships: patch.relationships.as_ref().map(|relationships| { + relationships + .iter() + .cloned() + .map(dto_to_relationship) + .collect() + }), + expires_at: patch.expires_at, + retention_policy: patch.retention_policy.clone(), + lifecycle_status: patch.lifecycle_status.clone(), + retired_at: patch.retired_at, + retired_reason: patch.retired_reason.clone(), + } +} + +fn record_from_add_request( + r: &lance_context_api::AddRecordRequest, + id: String, + run_id: String, +) -> ContextRecord { + ContextRecord { + id, + external_id: r.external_id.clone(), + run_id, + bot_id: r.bot_id.clone(), + session_id: r.session_id.clone(), + created_at: Utc::now(), + role: r.role.clone(), + state_metadata: r.state_metadata.as_ref().map(|sm| StateMetadata { + step: sm.step, + active_plan_id: sm.active_plan_id.clone(), + tokens_used: sm.tokens_used, + custom: sm.custom.clone(), + }), + metadata: r.metadata.clone(), + relationships: r + .relationships + .iter() + .cloned() + .map(dto_to_relationship) + .collect(), + expires_at: r.expires_at, + retention_policy: r.retention_policy.clone(), + lifecycle_status: LIFECYCLE_ACTIVE.to_string(), + retired_at: None, + retired_reason: None, + supersedes_id: r.supersedes_id.clone(), + superseded_by_id: None, + content_type: r.content_type.clone(), + text_payload: r.text_payload.clone(), + binary_payload: r.binary_payload.clone(), + embedding: r.embedding.clone(), + } +} + #[cfg(test)] mod tests { use std::collections::HashMap; @@ -308,7 +447,10 @@ mod tests { use axum::extract::{Path, Query, State}; use axum::Json; - use lance_context_api::{AddRecordRequest, AddRecordsRequest}; + use lance_context_api::{ + AddRecordRequest, AddRecordsRequest, RecordPatchDto, UpdateRecordRequest, + UpsertRecordRequest, + }; use lance_context_core::ContextStore; use tempfile::TempDir; use tokio::sync::RwLock; @@ -429,6 +571,131 @@ mod tests { assert!(!second_response.deleted); } + #[tokio::test] + async fn upsert_by_external_id_inserts_then_replaces_visible_record() { + let context_name = "ctx"; + let (state, _dir) = test_state(context_name).await; + let external_id = "doc-123#chunk-1"; + + let mut first = text_record("old value"); + first.external_id = Some(external_id.to_string()); + let (insert_status, Json(inserted)) = upsert_record( + State(state.clone()), + Path(context_name.to_string()), + Json(UpsertRecordRequest { + record: first, + key: "external_id".to_string(), + }), + ) + .await + .unwrap(); + assert_eq!(insert_status, axum::http::StatusCode::CREATED); + assert!(inserted.inserted); + assert!(inserted.replaced_id.is_none()); + + let mut replacement = text_record("new value"); + replacement.external_id = Some(external_id.to_string()); + let (replace_status, Json(replaced)) = upsert_record( + State(state.clone()), + Path(context_name.to_string()), + Json(UpsertRecordRequest { + record: replacement, + key: "external_id".to_string(), + }), + ) + .await + .unwrap(); + assert_eq!(replace_status, axum::http::StatusCode::OK); + assert!(!replaced.inserted); + assert_eq!( + replaced.replaced_id.as_deref(), + Some(inserted.record.id.as_str()) + ); + assert_eq!( + replaced.record.supersedes_id.as_deref(), + Some(inserted.record.id.as_str()) + ); + + let Json(response) = list_records( + State(state), + Path(context_name.to_string()), + Query(ListParams { + limit: None, + offset: None, + }), + ) + .await + .unwrap(); + assert_eq!(response.records.len(), 1); + assert_eq!( + response.records[0].text_payload.as_deref(), + Some("new value") + ); + } + + #[tokio::test] + async fn update_by_external_id_patches_visible_record() { + let context_name = "ctx"; + let (state, _dir) = test_state(context_name).await; + let external_id = "doc-123#chunk-1"; + + let mut record = text_record("stable value"); + record.external_id = Some(external_id.to_string()); + let (_, Json(add_response)) = add_records( + State(state.clone()), + Path(context_name.to_string()), + Json(AddRecordsRequest { + records: vec![record], + }), + ) + .await + .unwrap(); + let old_id = add_response.ids[0].clone(); + + let Json(updated) = update_record( + State(state.clone()), + Path(context_name.to_string()), + Json(UpdateRecordRequest { + id: None, + external_id: Some(external_id.to_string()), + patch: RecordPatchDto { + metadata: Some(serde_json::json!({"revision": 2})), + relationships: Some(vec![RelationshipDto { + target_id: "doc-123".to_string(), + relation: "derived_from".to_string(), + weight: None, + }]), + ..Default::default() + }, + }), + ) + .await + .unwrap(); + + assert!(updated.updated); + assert_eq!(updated.replaced_id.as_deref(), Some(old_id.as_str())); + let record = updated.record.unwrap(); + assert_ne!(record.id, old_id); + assert_eq!(record.external_id.as_deref(), Some(external_id)); + assert_eq!(record.text_payload.as_deref(), Some("stable value")); + assert_eq!(record.metadata, Some(serde_json::json!({"revision": 2}))); + assert_eq!(record.relationships.len(), 1); + assert_eq!(record.supersedes_id.as_deref(), Some(old_id.as_str())); + + let Json(response) = list_records( + State(state), + Path(context_name.to_string()), + Query(ListParams { + limit: None, + offset: None, + }), + ) + .await + .unwrap(); + assert_eq!(response.records.len(), 1); + assert_eq!(response.records[0].id, record.id); + } + #[tokio::test] async fn related_records_filters_by_target_and_relation() { let context_name = "ctx"; diff --git a/crates/lance-context/src/lib.rs b/crates/lance-context/src/lib.rs index 00fc48f..f3007f6 100644 --- a/crates/lance-context/src/lib.rs +++ b/crates/lance-context/src/lib.rs @@ -12,7 +12,8 @@ pub use lance_context_core::{ pub use lance_context_api::{ AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse, ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RelationshipDto, - RetrieveRequest, RetrieveResponse, RetrieveResultDto, SearchResultDto, + RetrieveRequest, RetrieveResponse, RetrieveResultDto, SearchResultDto, UpsertRecordRequest, + UpsertRecordResponse, }; #[cfg(feature = "remote")] diff --git a/crates/lance-context/src/unified.rs b/crates/lance-context/src/unified.rs index 5b92374..d109331 100644 --- a/crates/lance-context/src/unified.rs +++ b/crates/lance-context/src/unified.rs @@ -3,7 +3,8 @@ use std::collections::HashSet; use lance_context_api::{ AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse, ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RetrieveRequest, - RetrieveResultDto, SearchResultDto, + RetrieveResultDto, SearchResultDto, UpdateRecordRequest, UpdateRecordResponse, + UpsertRecordRequest, UpsertRecordResponse, }; use lance_context_core::{ ContextStore as LocalStore, ContextStoreOptions, DistanceMetric, IdIndexType, @@ -119,6 +120,20 @@ impl ContextStoreApi for ContextStore { dispatch_mut!(self, add, records) } + async fn upsert( + &mut self, + request: &UpsertRecordRequest, + ) -> ContextResult { + dispatch_mut!(self, upsert, request) + } + + async fn update( + &mut self, + request: &UpdateRecordRequest, + ) -> ContextResult { + dispatch_mut!(self, update, request) + } + async fn get(&self, id: &str) -> ContextResult> { dispatch_ref!(self, get, id) } diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index 3db9f81..957d6a5 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -446,6 +446,113 @@ def add( _json_dumps(relationships, "relationships"), ) + def upsert( + self, + role: str, + content: Any, + content_type: str | None = None, + data_type: str | None = None, + embedding: list[float] | None = None, + bot_id: str | None = None, + session_id: str | None = None, + external_id: str | None = None, + metadata: dict[str, Any] | None = None, + relationships: list[dict[str, Any]] | None = None, + expires_at: datetime | str | None = None, + retention_policy: str | None = None, + lifecycle_status: str | None = None, + retired_at: datetime | str | None = None, + retired_reason: str | None = None, + *, + key: str = "external_id", + ) -> dict[str, Any]: + """Insert a record, or replace the visible record with the same external_id.""" + if content_type is not None and data_type is not None: + raise ValueError("Specify only one of content_type or data_type") + if key != "external_id": + raise ValueError("Only key='external_id' is currently supported") + if not external_id: + raise ValueError("upsert requires external_id") + if content_type is None: + content_type = data_type + + payload, resolved_type = _normalize_content(content, content_type) + result = self._inner.upsert( + role, + payload, + resolved_type, + embedding, + bot_id, + session_id, + external_id, + _json_dumps(metadata, "metadata"), + _coerce_timestamp(expires_at, field_name="expires_at"), + retention_policy, + lifecycle_status, + _coerce_timestamp(retired_at, field_name="retired_at"), + retired_reason, + _json_dumps(relationships, "relationships"), + key, + ) + return { + "inserted": bool(result["inserted"]), + "replaced_id": result.get("replaced_id"), + "version": result["version"], + "record": _normalize_record(result["record"]), + } + + def update( + self, + *, + id: str | None = None, + external_id: str | None = None, + bot_id: str | None = None, + session_id: str | None = None, + metadata: dict[str, Any] | None = None, + relationships: list[dict[str, Any]] | None = None, + expires_at: datetime | str | None = None, + retention_policy: str | None = None, + lifecycle_status: str | None = None, + retired_at: datetime | str | None = None, + retired_reason: str | None = None, + ) -> dict[str, Any]: + """Patch mutable fields on a visible record by id or external_id.""" + if (id is None) == (external_id is None): + raise ValueError("Specify exactly one of id or external_id") + if ( + bot_id is None + and session_id is None + and metadata is None + and relationships is None + and expires_at is None + and retention_policy is None + and lifecycle_status is None + and retired_at is None + and retired_reason is None + ): + raise ValueError("update requires at least one patch field") + + result = self._inner.update( + id, + external_id, + bot_id, + session_id, + _json_dumps(metadata, "metadata"), + _json_dumps(relationships, "relationships"), + _coerce_timestamp(expires_at, field_name="expires_at"), + retention_policy, + lifecycle_status, + _coerce_timestamp(retired_at, field_name="retired_at"), + retired_reason, + ) + record = result.get("record") + return { + "updated": bool(result["updated"]), + "replaced_id": result.get("replaced_id"), + "version": result["version"], + "record": _normalize_record(record) if record is not None else None, + } + def add_many(self, records: Iterable[Mapping[str, Any]]) -> None: """Append multiple records in one storage operation. @@ -800,6 +907,82 @@ async def add( ), ) + async def upsert( + self, + role: str, + content: Any, + content_type: str | None = None, + data_type: str | None = None, + embedding: list[float] | None = None, + bot_id: str | None = None, + session_id: str | None = None, + external_id: str | None = None, + metadata: dict[str, Any] | None = None, + relationships: list[dict[str, Any]] | None = None, + expires_at: datetime | str | None = None, + retention_policy: str | None = None, + lifecycle_status: str | None = None, + retired_at: datetime | str | None = None, + retired_reason: str | None = None, + *, + key: str = "external_id", + ) -> dict[str, Any]: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + lambda: self._sync.upsert( + role, + content, + content_type=content_type, + data_type=data_type, + embedding=embedding, + bot_id=bot_id, + session_id=session_id, + external_id=external_id, + metadata=metadata, + relationships=relationships, + expires_at=expires_at, + retention_policy=retention_policy, + lifecycle_status=lifecycle_status, + retired_at=retired_at, + retired_reason=retired_reason, + key=key, + ), + ) + + async def update( + self, + *, + id: str | None = None, + external_id: str | None = None, + bot_id: str | None = None, + session_id: str | None = None, + metadata: dict[str, Any] | None = None, + relationships: list[dict[str, Any]] | None = None, + expires_at: datetime | str | None = None, + retention_policy: str | None = None, + lifecycle_status: str | None = None, + retired_at: datetime | str | None = None, + retired_reason: str | None = None, + ) -> dict[str, Any]: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + lambda: self._sync.update( + id=id, + external_id=external_id, + bot_id=bot_id, + session_id=session_id, + metadata=metadata, + relationships=relationships, + expires_at=expires_at, + retention_policy=retention_policy, + lifecycle_status=lifecycle_status, + retired_at=retired_at, + retired_reason=retired_reason, + ), + ) + def snapshot(self, label: str | None = None) -> str: return self._sync.snapshot(label) diff --git a/python/src/lib.rs b/python/src/lib.rs index 771ef8a..1627e58 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -15,7 +15,7 @@ use lance_context_core::serde::CONTENT_TYPE_TEXT; use lance_context_core::{ CompactionConfig, CompactionMetrics, CompactionStats, Context as RustContext, ContextRecord, ContextStore, ContextStoreOptions, DistanceMetric, IdIndexType, LifecycleQueryOptions, - RecordFilters, Relationship, RetrieveResult, SearchResult, LIFECYCLE_ACTIVE, + RecordFilters, RecordPatch, Relationship, RetrieveResult, SearchResult, LIFECYCLE_ACTIVE, }; const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream"; @@ -298,6 +298,154 @@ impl Context { Ok(()) } + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (role, content, data_type = None, embedding = None, bot_id = None, session_id = None, external_id = None, metadata_json = None, expires_at = None, retention_policy = None, lifecycle_status = None, retired_at = None, retired_reason = None, relationships_json = None, key = "external_id"))] + fn upsert( + &mut self, + py: Python<'_>, + role: &str, + content: &Bound<'_, PyAny>, + data_type: Option<&str>, + embedding: Option>, + bot_id: Option, + session_id: Option, + external_id: Option, + metadata_json: Option, + expires_at: Option, + retention_policy: Option, + lifecycle_status: Option, + retired_at: Option, + retired_reason: Option, + relationships_json: Option, + key: &str, + ) -> PyResult { + if key != "external_id" { + return Err(PyRuntimeError::new_err(format!( + "upsert key '{key}' is not supported; use 'external_id'" + ))); + } + if external_id.as_deref().is_none_or(str::is_empty) { + return Err(PyRuntimeError::new_err( + "upsert requires external_id".to_string(), + )); + } + + let lifecycle = LifecycleFields { + expires_at: parse_optional_datetime(expires_at, "expires_at")?, + retention_policy, + lifecycle_status, + retired_at: parse_optional_datetime(retired_at, "retired_at")?, + retired_reason, + supersedes_id: None, + superseded_by_id: None, + }; + let prepared = self.prepare_record( + content, + RecordInput { + role: role.to_string(), + data_type: data_type.map(str::to_string), + embedding, + bot_id, + session_id, + external_id, + metadata_json, + relationships: relationships_from_json(relationships_json)?, + lifecycle, + }, + 1, + )?; + + let result = py.allow_threads(|| { + self.runtime + .block_on(self.store.upsert_by_external_id(prepared.record.clone())) + }); + let result = result.map_err(to_py_err)?; + self.inner.add( + &prepared.role, + &prepared.inner_content, + prepared.data_type.as_deref(), + ); + + let dict = PyDict::new(py); + dict.set_item("inserted", result.inserted)?; + dict.set_item("replaced_id", result.replaced_id)?; + dict.set_item("version", result.version)?; + dict.set_item("record", record_to_py(py, result.record)?)?; + Ok(dict.into_pyobject(py)?.unbind().into()) + } + + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (id = None, external_id = None, bot_id = None, session_id = None, metadata_json = None, relationships_json = None, expires_at = None, retention_policy = None, lifecycle_status = None, retired_at = None, retired_reason = None))] + fn update( + &mut self, + py: Python<'_>, + id: Option, + external_id: Option, + bot_id: Option, + session_id: Option, + metadata_json: Option, + relationships_json: Option, + expires_at: Option, + retention_policy: Option, + lifecycle_status: Option, + retired_at: Option, + retired_reason: Option, + ) -> PyResult { + let patch = RecordPatch { + bot_id, + session_id, + state_metadata: None, + metadata: metadata_from_json(metadata_json)?, + relationships: relationships_patch_from_json(relationships_json)?, + expires_at: parse_optional_datetime(expires_at, "expires_at")?, + retention_policy, + lifecycle_status, + retired_at: parse_optional_datetime(retired_at, "retired_at")?, + retired_reason, + }; + if patch.is_empty() { + return Err(PyRuntimeError::new_err( + "update requires at least one patch field", + )); + } + + let result = match (id, external_id) { + (Some(id), None) => py.allow_threads(|| { + self.runtime + .block_on(self.store.update_by_id(&id, patch)) + .map_err(to_py_err) + }), + (None, Some(external_id)) => py.allow_threads(|| { + self.runtime + .block_on(self.store.update_by_external_id(&external_id, patch)) + .map_err(to_py_err) + }), + (None, None) => Err(PyRuntimeError::new_err( + "update() requires either id or external_id", + )), + (Some(_), Some(_)) => Err(PyRuntimeError::new_err( + "update() accepts only one of id or external_id", + )), + }?; + + let dict = PyDict::new(py); + match result { + Some(result) => { + dict.set_item("updated", true)?; + dict.set_item("replaced_id", Some(result.replaced_id))?; + dict.set_item("version", result.version)?; + dict.set_item("record", record_to_py(py, result.record)?)?; + } + None => { + dict.set_item("updated", false)?; + dict.set_item("replaced_id", Option::::None)?; + dict.set_item("version", self.store.version())?; + dict.set_item("record", Option::::None)?; + } + } + Ok(dict.into_pyobject(py)?.unbind().into()) + } + #[pyo3(signature = (records))] fn add_many(&mut self, py: Python<'_>, records: &Bound<'_, PyAny>) -> PyResult<()> { let mut prepared = Vec::new(); @@ -729,6 +877,12 @@ fn optional_item<'py>(dict: &Bound<'py, PyDict>, key: &str) -> PyResult) -> PyResult>> { + value + .map(|value| relationships_from_json(Some(value))) + .transpose() +} + fn parse_optional_datetime( value: Option, field_name: &str, diff --git a/python/tests/test_persistence.py b/python/tests/test_persistence.py index badcdc1..57fb3d2 100644 --- a/python/tests/test_persistence.py +++ b/python/tests/test_persistence.py @@ -472,6 +472,91 @@ def test_supersedes_pointer_hides_old_record_by_default(tmp_path: Path) -> None: assert history[1]["supersedes_id"] == old["id"] +def test_upsert_by_external_id_inserts_then_replaces_visible_record( + tmp_path: Path, +) -> None: + uri = tmp_path / "context.lance" + ctx = Context.create(str(uri)) + external_id = "doc-123#chunk-1" + + inserted = ctx.upsert( + "user", + "old value", + embedding=_embedding(0.0), + external_id=external_id, + ) + assert inserted["inserted"] is True + assert inserted["replaced_id"] is None + old_id = inserted["record"]["id"] + + replaced = ctx.upsert( + "user", + "new value", + embedding=_embedding(1.0), + external_id=external_id, + metadata={"revision": 2}, + ) + assert replaced["inserted"] is False + assert replaced["replaced_id"] == old_id + assert replaced["record"]["external_id"] == external_id + assert replaced["record"]["supersedes_id"] == old_id + + assert ctx.get(external_id=external_id)["text"] == "new value" # type: ignore[index] + assert [record["text"] for record in ctx.list()] == ["new value"] + assert [record["text"] for record in ctx.search(_embedding(0.0), limit=10)] == [ + "new value" + ] + + history = ctx.list(include_retired=True) + assert [record["text"] for record in history] == ["old value", "new value"] + + +def test_update_by_external_id_patches_mutable_fields_and_preserves_payload( + tmp_path: Path, +) -> None: + uri = tmp_path / "context.lance" + ctx = Context.create(str(uri)) + external_id = "doc-123#chunk-1" + + ctx.add( + "user", + "stable content", + embedding=_embedding(0.0), + external_id=external_id, + metadata={"revision": 1}, + ) + original = ctx.get(external_id=external_id) + assert original is not None + + updated = ctx.update( + external_id=external_id, + metadata={"revision": 2, "confidence": 0.9}, + relationships=[{"target_id": "doc-123", "relation": "derived_from"}], + ) + + assert updated["updated"] is True + assert updated["replaced_id"] == original["id"] + assert updated["record"]["id"] != original["id"] + assert updated["record"]["external_id"] == external_id + assert updated["record"]["text"] == "stable content" + assert updated["record"]["metadata"] == {"revision": 2, "confidence": 0.9} + assert updated["record"]["relationships"] == [ + {"target_id": "doc-123", "relation": "derived_from", "weight": None} + ] + assert updated["record"]["supersedes_id"] == original["id"] + + visible = ctx.get(external_id=external_id) + assert visible is not None + assert visible["id"] == updated["record"]["id"] + assert [record["text"] for record in ctx.list()] == ["stable content"] + + history = ctx.list(include_retired=True) + assert {record["id"] for record in history} == { + original["id"], + updated["record"]["id"], + } + + def test_image_round_trip(tmp_path: Path) -> None: Image = pytest.importorskip("PIL.Image") uri = tmp_path / "context.lance" diff --git a/python/tests/test_search.py b/python/tests/test_search.py index a527849..9ce7c90 100644 --- a/python/tests/test_search.py +++ b/python/tests/test_search.py @@ -34,6 +34,8 @@ def __init__(self) -> None: self.related_calls: list[tuple[str, str | None, int | None, bool, bool]] = [] self.get_calls: list[tuple[str | None, str | None]] = [] self.delete_calls: list[tuple[str | None, str | None]] = [] + self.upsert_calls: list[dict[str, Any]] = [] + self.update_calls: list[dict[str, Any]] = [] self.lifecycle_add_calls: list[dict[str, Any]] = [] self.relationship_add_calls: list[str | None] = [] self.add_calls: list[ @@ -94,6 +96,141 @@ def add( ) self.relationship_add_calls.append(relationships_json) + def upsert( + self, + role: str, + content: Any, + data_type: str | None, + embedding: list[float] | None, + bot_id: str | None, + session_id: str | None, + external_id: str | None, + metadata_json: str | None, + expires_at: str | None = None, + retention_policy: str | None = None, + lifecycle_status: str | None = None, + retired_at: str | None = None, + retired_reason: str | None = None, + relationships_json: str | None = None, + key: str = "external_id", + ): + self.upsert_calls.append( + { + "role": role, + "content": content, + "data_type": data_type, + "embedding": embedding, + "bot_id": bot_id, + "session_id": session_id, + "external_id": external_id, + "metadata_json": metadata_json, + "expires_at": expires_at, + "retention_policy": retention_policy, + "lifecycle_status": lifecycle_status, + "retired_at": retired_at, + "retired_reason": retired_reason, + "relationships_json": relationships_json, + "key": key, + } + ) + return { + "inserted": False, + "replaced_id": "old-id", + "version": 7, + "record": { + "id": "new-id", + "external_id": external_id, + "run_id": "run-2", + "bot_id": bot_id, + "session_id": session_id, + "role": role, + "content_type": data_type or "text/plain", + "text_payload": content if isinstance(content, str) else None, + "binary_payload": None, + "embedding": embedding, + "created_at": "2024-01-03T12:00:00Z", + "state_metadata": None, + "metadata": json.loads(metadata_json) if metadata_json else None, + "relationships": ( + json.loads(relationships_json) if relationships_json else [] + ), + "expires_at": expires_at, + "retention_policy": retention_policy, + "lifecycle_status": lifecycle_status or "active", + "retired_at": retired_at, + "retired_reason": retired_reason, + "supersedes_id": "old-id", + "superseded_by_id": None, + }, + } + + def update( + self, + id: str | None, + external_id: str | None, + bot_id: str | None, + session_id: str | None, + metadata_json: str | None, + relationships_json: str | None, + expires_at: str | None, + retention_policy: str | None, + lifecycle_status: str | None, + retired_at: str | None, + retired_reason: str | None, + ): + self.update_calls.append( + { + "id": id, + "external_id": external_id, + "bot_id": bot_id, + "session_id": session_id, + "metadata_json": metadata_json, + "relationships_json": relationships_json, + "expires_at": expires_at, + "retention_policy": retention_policy, + "lifecycle_status": lifecycle_status, + "retired_at": retired_at, + "retired_reason": retired_reason, + } + ) + if id == "missing" or external_id == "missing": + return { + "updated": False, + "replaced_id": None, + "version": 7, + "record": None, + } + return { + "updated": True, + "replaced_id": "old-id", + "version": 8, + "record": { + "id": "new-id", + "external_id": external_id, + "run_id": "run-2", + "bot_id": bot_id, + "session_id": session_id, + "role": "user", + "content_type": "text/plain", + "text_payload": "stable content", + "binary_payload": None, + "embedding": None, + "created_at": "2024-01-03T12:00:00Z", + "state_metadata": None, + "metadata": json.loads(metadata_json) if metadata_json else None, + "relationships": ( + json.loads(relationships_json) if relationships_json else [] + ), + "expires_at": expires_at, + "retention_policy": retention_policy, + "lifecycle_status": lifecycle_status or "active", + "retired_at": retired_at, + "retired_reason": retired_reason, + "supersedes_id": "old-id", + "superseded_by_id": None, + }, + } + def get(self, id: str | None, external_id: str | None): self.get_calls.append((id, external_id)) if id == "rec-1" or external_id == "source-1": @@ -924,6 +1061,133 @@ def test_context_add_rejects_naive_lifecycle_datetime(): ctx.add("user", "hello", expires_at=datetime(2026, 7, 1)) +def test_context_upsert_requires_external_id_and_supported_key(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + with pytest.raises(ValueError, match="external_id"): + ctx.upsert("user", "hello") + with pytest.raises(ValueError, match="Only key='external_id'"): + ctx.upsert("user", "hello", external_id="source-1", key="id") + assert dummy.upsert_calls == [] + + +def test_context_upsert_returns_operation_metadata_and_record(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + result = ctx.upsert( + "user", + "new value", + embedding=[0.1, 0.2], + external_id="source-1", + metadata={"revision": 2}, + relationships=[{"target_id": "doc-1", "relation": "updates"}], + expires_at=datetime(2026, 7, 1, tzinfo=timezone.utc), + ) + + assert dummy.upsert_calls == [ + { + "role": "user", + "content": "new value", + "data_type": None, + "embedding": [0.1, 0.2], + "bot_id": None, + "session_id": None, + "external_id": "source-1", + "metadata_json": '{"revision":2}', + "expires_at": "2026-07-01T00:00:00Z", + "retention_policy": None, + "lifecycle_status": None, + "retired_at": None, + "retired_reason": None, + "relationships_json": '[{"relation":"updates","target_id":"doc-1"}]', + "key": "external_id", + } + ] + assert result["inserted"] is False + assert result["replaced_id"] == "old-id" + assert result["version"] == 7 + assert result["record"]["id"] == "new-id" + assert result["record"]["text"] == "new value" + assert result["record"]["metadata"] == {"revision": 2} + assert result["record"]["supersedes_id"] == "old-id" + assert isinstance(result["record"]["created_at"], datetime) + + +def test_context_update_requires_identifier_and_patch(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + with pytest.raises(ValueError, match="exactly one"): + ctx.update(metadata={"revision": 2}) + with pytest.raises(ValueError, match="exactly one"): + ctx.update(id="rec-1", external_id="source-1", metadata={"revision": 2}) + with pytest.raises(ValueError, match="at least one patch field"): + ctx.update(external_id="source-1") + assert dummy.update_calls == [] + + +def test_context_update_returns_operation_metadata_and_record(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + result = ctx.update( + external_id="source-1", + bot_id="bot", + session_id="sess", + metadata={"revision": 2}, + relationships=[{"target_id": "doc-1", "relation": "updates"}], + expires_at=datetime(2026, 7, 1, tzinfo=timezone.utc), + lifecycle_status="active", + ) + + assert dummy.update_calls == [ + { + "id": None, + "external_id": "source-1", + "bot_id": "bot", + "session_id": "sess", + "metadata_json": '{"revision":2}', + "relationships_json": '[{"relation":"updates","target_id":"doc-1"}]', + "expires_at": "2026-07-01T00:00:00Z", + "retention_policy": None, + "lifecycle_status": "active", + "retired_at": None, + "retired_reason": None, + } + ] + assert result["updated"] is True + assert result["replaced_id"] == "old-id" + assert result["version"] == 8 + assert result["record"]["id"] == "new-id" + assert result["record"]["text"] == "stable content" + assert result["record"]["metadata"] == {"revision": 2} + assert result["record"]["relationships"] == [ + {"target_id": "doc-1", "relation": "updates"} + ] + assert result["record"]["supersedes_id"] == "old-id" + + +def test_context_update_missing_record_returns_not_updated(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + result = ctx.update(external_id="missing", metadata={"revision": 2}) + + assert result == { + "updated": False, + "replaced_id": None, + "version": 7, + "record": None, + } + + def test_context_add_many_normalizes_records(): ctx = Context.__new__(Context) dummy = DummyInner()