Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions crates/lance-context-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ pub trait ContextStoreApi {
records: &[AddRecordRequest],
) -> impl Future<Output = ContextResult<AddRecordsResponse>> + Send;

fn upsert(
&mut self,
request: &UpsertRecordRequest,
) -> impl Future<Output = ContextResult<UpsertRecordResponse>> + Send;

fn update(
&mut self,
request: &UpdateRecordRequest,
) -> impl Future<Output = ContextResult<UpdateRecordResponse>> + Send;

fn get(&self, id: &str) -> impl Future<Output = ContextResult<Option<RecordDto>>> + Send;

fn get_by_external_id(
Expand Down Expand Up @@ -194,6 +204,82 @@ pub struct AddRecordsResponse {
pub count: usize,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct UpsertRecordRequest {
pub record: AddRecordRequest,
#[serde(default = "default_upsert_key")]
pub key: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct UpsertRecordResponse {
pub version: u64,
pub inserted: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub replaced_id: Option<String>,
pub record: RecordDto,
}

#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RecordPatchDto {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bot_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub state_metadata: Option<StateMetadataDto>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub relationships: Option<Vec<RelationshipDto>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub expires_at: Option<DateTime<Utc>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub retention_policy: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub lifecycle_status: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub retired_at: Option<DateTime<Utc>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub retired_reason: Option<String>,
}

impl RecordPatchDto {
#[must_use]
pub fn is_empty(&self) -> bool {
self.bot_id.is_none()
&& self.session_id.is_none()
&& self.state_metadata.is_none()
&& self.metadata.is_none()
&& self.relationships.is_none()
&& self.expires_at.is_none()
&& self.retention_policy.is_none()
&& self.lifecycle_status.is_none()
&& self.retired_at.is_none()
&& self.retired_reason.is_none()
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct UpdateRecordRequest {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub external_id: Option<String>,
#[serde(default)]
pub patch: RecordPatchDto,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct UpdateRecordResponse {
pub version: u64,
pub updated: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub replaced_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub record: Option<RecordDto>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecordDto {
pub id: String,
Expand Down Expand Up @@ -396,6 +482,10 @@ fn default_role() -> String {
"user".to_string()
}

fn default_upsert_key() -> String {
"external_id".to_string()
}

fn default_search_limit() -> usize {
10
}
Expand Down
54 changes: 54 additions & 0 deletions crates/lance-context-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,32 @@ impl ContextStoreApi for RemoteContextStore {
Ok(resp)
}

async fn upsert(
&mut self,
request: &UpsertRecordRequest,
) -> ContextResult<UpsertRecordResponse> {
let resp = self
.client
.upsert_record(&self.context_name, request)
.await
.map_err(to_ctx_err)?;
self.cached_version = resp.version;
Ok(resp)
}

async fn update(
&mut self,
request: &UpdateRecordRequest,
) -> ContextResult<UpdateRecordResponse> {
let resp = self
.client
.update_record(&self.context_name, request)
.await
.map_err(to_ctx_err)?;
self.cached_version = resp.version;
Ok(resp)
}

async fn get(&self, id: &str) -> ContextResult<Option<RecordDto>> {
let resp = self
.client
Expand Down Expand Up @@ -292,6 +318,34 @@ impl ContextClient {
Self::handle_response(resp).await
}

pub async fn upsert_record(
&self,
name: &str,
req: &UpsertRecordRequest,
) -> Result<UpsertRecordResponse, ClientError> {
let resp = self
.http
.put(self.url(&format!("/contexts/{}/records", name)))
.json(req)
.send()
.await?;
Self::handle_response(resp).await
}

pub async fn update_record(
&self,
name: &str,
req: &UpdateRecordRequest,
) -> Result<UpdateRecordResponse, ClientError> {
let resp = self
.http
.patch(self.url(&format!("/contexts/{}/records", name)))
.json(req)
.send()
.await?;
Self::handle_response(resp).await
}

pub async fn get_record(&self, name: &str, id: &str) -> Result<GetRecordResponse, ClientError> {
let resp = self
.http
Expand Down
185 changes: 149 additions & 36 deletions crates/lance-context-core/src/api_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ use uuid::Uuid;

use lance_context_api::{
AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse,
ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RelationshipDto,
RetrieveRequest, RetrieveResultDto, SearchResultDto, StateMetadataDto,
ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RecordPatchDto,
RelationshipDto, RetrieveRequest, RetrieveResultDto, SearchResultDto, StateMetadataDto,
UpdateRecordRequest, UpdateRecordResponse, UpsertRecordRequest, UpsertRecordResponse,
};

use crate::record::{
ContextRecord, LifecycleQueryOptions, RecordFilters, Relationship, StateMetadata,
ContextRecord, LifecycleQueryOptions, RecordFilters, RecordPatch, Relationship, StateMetadata,
LIFECYCLE_ACTIVE,
};
use crate::store::{CompactionConfig, ContextStore};
Expand All @@ -22,39 +23,7 @@ impl ContextStoreApi for ContextStore {
for r in records {
let id = Uuid::new_v4().to_string();
ids.push(id.clone());
core_records.push(ContextRecord {
id,
external_id: r.external_id.clone(),
run_id: run_id.clone(),
bot_id: r.bot_id.clone(),
session_id: r.session_id.clone(),
created_at: Utc::now(),
role: r.role.clone(),
state_metadata: r.state_metadata.as_ref().map(|sm| StateMetadata {
step: sm.step,
active_plan_id: sm.active_plan_id.clone(),
tokens_used: sm.tokens_used,
custom: sm.custom.clone(),
}),
metadata: r.metadata.clone(),
relationships: r
.relationships
.iter()
.cloned()
.map(dto_to_relationship)
.collect(),
expires_at: r.expires_at,
retention_policy: r.retention_policy.clone(),
lifecycle_status: LIFECYCLE_ACTIVE.to_string(),
retired_at: None,
retired_reason: None,
supersedes_id: r.supersedes_id.clone(),
superseded_by_id: None,
content_type: r.content_type.clone(),
text_payload: r.text_payload.clone(),
binary_payload: r.binary_payload.clone(),
embedding: r.embedding.clone(),
});
core_records.push(record_from_add_request(r, id, run_id.clone()));
}

let count = core_records.len();
Expand All @@ -66,6 +35,88 @@ impl ContextStoreApi for ContextStore {
})
}

async fn upsert(
&mut self,
request: &UpsertRecordRequest,
) -> ContextResult<UpsertRecordResponse> {
if request.key != "external_id" {
return Err(ContextError::InvalidRequest(format!(
"upsert key '{}' is not supported; use 'external_id'",
request.key
)));
}
if request
.record
.external_id
.as_deref()
.is_none_or(str::is_empty)
{
return Err(ContextError::InvalidRequest(
"upsert requires record.external_id".to_string(),
));
}

let record = record_from_add_request(
&request.record,
Uuid::new_v4().to_string(),
Uuid::new_v4().to_string(),
);
let result = ContextStore::upsert_by_external_id(self, record)
.await
.map_err(to_ctx_err)?;
Ok(UpsertRecordResponse {
version: result.version,
inserted: result.inserted,
replaced_id: result.replaced_id,
record: record_to_dto(result.record),
})
}

async fn update(
&mut self,
request: &UpdateRecordRequest,
) -> ContextResult<UpdateRecordResponse> {
if request.patch.is_empty() {
return Err(ContextError::InvalidRequest(
"update requires at least one patch field".to_string(),
));
}

let patch = patch_from_dto(&request.patch);
let result = match (&request.id, &request.external_id) {
(Some(id), None) => ContextStore::update_by_id(self, id, patch).await,
(None, Some(external_id)) => {
ContextStore::update_by_external_id(self, external_id, patch).await
}
(None, None) => {
return Err(ContextError::InvalidRequest(
"update requires either id or external_id".to_string(),
));
}
(Some(_), Some(_)) => {
return Err(ContextError::InvalidRequest(
"update accepts only one of id or external_id".to_string(),
));
}
}
.map_err(to_ctx_err)?;

Ok(match result {
Some(result) => UpdateRecordResponse {
version: result.version,
updated: true,
replaced_id: Some(result.replaced_id),
record: Some(record_to_dto(result.record)),
},
None => UpdateRecordResponse {
version: ContextStore::version(self),
updated: false,
replaced_id: None,
record: None,
},
})
}

async fn get(&self, id: &str) -> ContextResult<Option<RecordDto>> {
let record = ContextStore::get(self, id).await.map_err(to_ctx_err)?;
Ok(record.map(record_to_dto))
Expand Down Expand Up @@ -256,6 +307,68 @@ fn relationship_to_dto(r: Relationship) -> RelationshipDto {
}
}

fn patch_from_dto(patch: &RecordPatchDto) -> RecordPatch {
RecordPatch {
bot_id: patch.bot_id.clone(),
session_id: patch.session_id.clone(),
state_metadata: patch.state_metadata.as_ref().map(|sm| StateMetadata {
step: sm.step,
active_plan_id: sm.active_plan_id.clone(),
tokens_used: sm.tokens_used,
custom: sm.custom.clone(),
}),
metadata: patch.metadata.clone(),
relationships: patch.relationships.as_ref().map(|relationships| {
relationships
.iter()
.cloned()
.map(dto_to_relationship)
.collect()
}),
expires_at: patch.expires_at,
retention_policy: patch.retention_policy.clone(),
lifecycle_status: patch.lifecycle_status.clone(),
retired_at: patch.retired_at,
retired_reason: patch.retired_reason.clone(),
}
}

fn record_from_add_request(r: &AddRecordRequest, id: String, run_id: String) -> ContextRecord {
ContextRecord {
id,
external_id: r.external_id.clone(),
run_id,
bot_id: r.bot_id.clone(),
session_id: r.session_id.clone(),
created_at: Utc::now(),
role: r.role.clone(),
state_metadata: r.state_metadata.as_ref().map(|sm| StateMetadata {
step: sm.step,
active_plan_id: sm.active_plan_id.clone(),
tokens_used: sm.tokens_used,
custom: sm.custom.clone(),
}),
metadata: r.metadata.clone(),
relationships: r
.relationships
.iter()
.cloned()
.map(dto_to_relationship)
.collect(),
expires_at: r.expires_at,
retention_policy: r.retention_policy.clone(),
lifecycle_status: LIFECYCLE_ACTIVE.to_string(),
retired_at: None,
retired_reason: None,
supersedes_id: r.supersedes_id.clone(),
superseded_by_id: None,
content_type: r.content_type.clone(),
text_payload: r.text_payload.clone(),
binary_payload: r.binary_payload.clone(),
embedding: r.embedding.clone(),
}
}

fn record_to_dto(r: ContextRecord) -> RecordDto {
RecordDto {
id: r.id,
Expand Down
Loading
Loading