Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ hits = ctx.search(
)
service_context = ctx.related("service://service-a", relation="describes")

# Hybrid retrieval combines lexical recall, vector recall, and existing filters
# over the same context records.
hybrid_hits = ctx.retrieve(
text="service-a runbook owner",
vector=runbook_embedding,
limit=5,
filters={"tenant": "example-org", "scope": "team"},
)
print(hybrid_hits[0]["matched_channels"], hybrid_hits[0]["score"])

from PIL import Image
image = Image.new("RGB", (2, 2), color="teal")
ctx.add("assistant", image)
Expand Down
50 changes: 50 additions & 0 deletions crates/lance-context-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ pub trait ContextStoreApi {
include_relationships: bool,
) -> impl Future<Output = ContextResult<Vec<SearchResultDto>>> + Send;

fn retrieve(
&self,
request: &RetrieveRequest,
) -> impl Future<Output = ContextResult<Vec<RetrieveResultDto>>> + Send;

fn version(&self) -> u64;

fn checkout(&mut self, version: u64) -> impl Future<Output = ContextResult<()>> + Send;
Expand Down Expand Up @@ -244,6 +249,47 @@ pub struct SearchResponse {
pub results: Vec<SearchResultDto>,
}

// ---------------------------------------------------------------------------
// Hybrid retrieval
// ---------------------------------------------------------------------------

#[derive(Debug, Serialize, Deserialize)]
pub struct RetrieveRequest {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub vector: Option<Vec<f32>>,
#[serde(default = "default_search_limit")]
pub limit: usize,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub filters: Option<Value>,
#[serde(default)]
pub include_expired: bool,
#[serde(default)]
pub include_retired: bool,
#[serde(default)]
pub include_relationships: bool,
#[serde(default = "default_retrieve_fusion")]
pub fusion: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct RetrieveResultDto {
pub record: RecordDto,
pub score: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub vector_distance: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub text_score: Option<f32>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub matched_channels: Vec<String>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct RetrieveResponse {
pub results: Vec<RetrieveResultDto>,
}

// ---------------------------------------------------------------------------
// Versioning
// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -320,6 +366,10 @@ fn default_search_limit() -> usize {
10
}

fn default_retrieve_fusion() -> String {
"rrf".to_string()
}

fn serialize_base64_opt<S>(data: &Option<Vec<u8>>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
Expand Down
23 changes: 23 additions & 0 deletions crates/lance-context-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ impl ContextStoreApi for RemoteContextStore {
Ok(resp.results)
}

async fn retrieve(&self, request: &RetrieveRequest) -> ContextResult<Vec<RetrieveResultDto>> {
let resp = self
.client
.retrieve(&self.context_name, request)
.await
.map_err(to_ctx_err)?;
Ok(resp.results)
}

fn version(&self) -> u64 {
self.cached_version
}
Expand Down Expand Up @@ -273,6 +282,20 @@ impl ContextClient {
Self::handle_response(resp).await
}

pub async fn retrieve(
&self,
name: &str,
req: &RetrieveRequest,
) -> Result<RetrieveResponse, ClientError> {
let resp = self
.http
.post(self.url(&format!("/contexts/{}/retrieve", name)))
.json(req)
.send()
.await?;
Self::handle_response(resp).await
}

pub async fn get_version(&self, name: &str) -> Result<VersionResponse, ClientError> {
let resp = self
.http
Expand Down
51 changes: 48 additions & 3 deletions crates/lance-context-core/src/api_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ use uuid::Uuid;

use lance_context_api::{
AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse,
ContextError, ContextResult, ContextStoreApi, RecordDto, RelationshipDto, SearchResultDto,
StateMetadataDto,
ContextError, ContextResult, ContextStoreApi, RecordDto, RelationshipDto, RetrieveRequest,
RetrieveResultDto, SearchResultDto, StateMetadataDto,
};

use crate::record::{ContextRecord, Relationship, StateMetadata, LIFECYCLE_ACTIVE};
use crate::record::{
ContextRecord, LifecycleQueryOptions, RecordFilters, Relationship, StateMetadata,
LIFECYCLE_ACTIVE,
};
use crate::store::{CompactionConfig, ContextStore};

impl ContextStoreApi for ContextStore {
Expand Down Expand Up @@ -102,6 +105,48 @@ impl ContextStoreApi for ContextStore {
.collect())
}

async fn retrieve(&self, request: &RetrieveRequest) -> ContextResult<Vec<RetrieveResultDto>> {
if request.fusion != "rrf" {
return Err(ContextError::InvalidRequest(
"retrieve fusion currently supports only 'rrf'".to_string(),
));
}

let filters = request
.filters
.clone()
.map(RecordFilters::from_json_value)
.transpose()
.map_err(ContextError::InvalidRequest)?;
let options = LifecycleQueryOptions::new(request.include_expired, request.include_retired);
let results = self
.retrieve_filtered_with_options(
request.text.as_deref(),
request.vector.as_deref(),
Some(request.limit),
filters.as_ref(),
options,
)
.await
.map_err(to_ctx_err)?;

Ok(results
.into_iter()
.map(|mut result| {
if !request.include_relationships {
result.record.relationships.clear();
}
RetrieveResultDto {
record: record_to_dto(result.record),
score: result.score,
vector_distance: result.vector_distance,
text_score: result.text_score,
matched_channels: result.matched_channels,
}
})
.collect())
}

fn version(&self) -> u64 {
ContextStore::version(self)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/lance-context-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mod store;
pub use context::{Context, ContextEntry, Snapshot};
pub use record::{
ContextRecord, LifecycleQueryOptions, MetadataFilter, RecordFilters, Relationship,
SearchResult, StateMetadata, LIFECYCLE_ACTIVE, LIFECYCLE_CONTRADICTED,
RetrieveResult, SearchResult, StateMetadata, LIFECYCLE_ACTIVE, LIFECYCLE_CONTRADICTED,
};
pub use store::{
CompactionConfig, CompactionStats, ContextStore, ContextStoreOptions, IdIndexType,
Expand Down
116 changes: 100 additions & 16 deletions crates/lance-context-core/src/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ pub struct SearchResult {
pub distance: f32,
}

/// Result returned from hybrid retrieval over context records.
#[derive(Debug, Clone)]
pub struct RetrieveResult {
pub record: ContextRecord,
pub score: f32,
pub vector_distance: Option<f32>,
pub text_score: Option<f32>,
pub matched_channels: Vec<String>,
}

/// Metadata matching operation for filtered retrieval.
#[derive(Debug, Clone, PartialEq)]
pub enum MetadataFilter {
Expand All @@ -149,6 +159,42 @@ pub struct RecordFilters {
}

impl RecordFilters {
pub fn from_json_value(value: Value) -> Result<Self, String> {
let Value::Object(object) = value else {
return Err("filters must be a JSON object".to_string());
};

let mut filters = RecordFilters::default();
for (key, value) in object {
match key.as_str() {
"bot_id" => filters.bot_id = filter_string(key.as_str(), value)?,
"session_id" => filters.session_id = filter_string(key.as_str(), value)?,
"role" => filters.role = filter_string(key.as_str(), value)?,
"content_type" => filters.content_type = filter_string(key.as_str(), value)?,
"created_at" => apply_created_at_filter(&mut filters, value)?,
"created_at_start" | "created_after" | "created_at_gte" => {
filters.created_at_start = Some(parse_filter_datetime(&key, &value)?);
}
"created_at_end" | "created_before" | "created_at_lte" => {
filters.created_at_end = Some(parse_filter_datetime(&key, &value)?);
}
_ => {
let filter = match value {
Value::Object(mut object)
if object.len() == 1 && object.contains_key("contains") =>
{
MetadataFilter::Contains(object.remove("contains").unwrap())
}
value => MetadataFilter::Equals(value),
};
filters.metadata.insert(key, filter);
}
}
}

Ok(filters)
}

#[must_use]
pub fn is_empty(&self) -> bool {
self.bot_id.is_none()
Expand Down Expand Up @@ -218,6 +264,47 @@ impl RecordFilters {
}
}

fn filter_string(name: &str, value: Value) -> Result<Option<String>, String> {
match value {
Value::Null => Ok(None),
Value::String(value) => Ok(Some(value)),
_ => Err(format!("filter '{name}' must be a string or null")),
}
}

fn apply_created_at_filter(filters: &mut RecordFilters, value: Value) -> Result<(), String> {
let Value::Object(object) = value else {
return Err("filter 'created_at' must be an object with gte/lte bounds".to_string());
};

for (key, value) in object {
match key.as_str() {
"gte" | "start" | "after" => {
filters.created_at_start = Some(parse_filter_datetime(&key, &value)?);
}
"lte" | "end" | "before" => {
filters.created_at_end = Some(parse_filter_datetime(&key, &value)?);
}
other => {
return Err(format!("unsupported created_at filter operator '{other}'"));
}
}
}

Ok(())
}

fn parse_filter_datetime(name: &str, value: &Value) -> Result<DateTime<Utc>, String> {
let Some(value) = value.as_str() else {
return Err(format!(
"filter '{name}' must be an ISO-8601 timestamp string"
));
};
DateTime::parse_from_rfc3339(value)
.map(|value| value.with_timezone(&Utc))
.map_err(|err| err.to_string())
}

fn metadata_contains(value: &Value, expected: &Value) -> bool {
match (value, expected) {
(Value::Array(items), expected) => items.iter().any(|item| item == expected),
Expand Down Expand Up @@ -264,22 +351,19 @@ mod tests {

#[test]
fn filters_match_builtin_fields_timestamps_and_metadata() {
let mut filters = RecordFilters {
bot_id: Some("support-bot".to_string()),
session_id: Some("incident-1".to_string()),
role: Some("assistant".to_string()),
content_type: Some("text/plain".to_string()),
created_at_start: Some(Utc.with_ymd_and_hms(2026, 6, 9, 2, 0, 0).unwrap()),
created_at_end: Some(Utc.with_ymd_and_hms(2026, 6, 9, 4, 0, 0).unwrap()),
metadata: HashMap::new(),
};
filters
.metadata
.insert("scope".to_string(), MetadataFilter::Equals(json!("team")));
filters.metadata.insert(
"tags".to_string(),
MetadataFilter::Contains(json!("runbook")),
);
let mut filters = RecordFilters::from_json_value(json!({
"bot_id": "support-bot",
"session_id": "incident-1",
"role": "assistant",
"content_type": "text/plain",
"created_at": {
"gte": "2026-06-09T02:00:00Z",
"lte": "2026-06-09T04:00:00Z"
},
"scope": "team",
"tags": {"contains": "runbook"}
}))
.unwrap();

assert!(filters.matches(&record()));

Expand Down
Loading
Loading