From e377e63849166451ccc736c0f05e70234d418133 Mon Sep 17 00:00:00 2001 From: dcfocus Date: Tue, 9 Jun 2026 06:38:47 +0000 Subject: [PATCH 1/3] feat: add metadata filters for context retrieval --- Cargo.lock | 2 + README.md | 37 +++++ crates/lance-context-core/Cargo.toml | 1 + crates/lance-context-core/src/lib.rs | 3 +- crates/lance-context-core/src/record.rs | 154 ++++++++++++++++++++ crates/lance-context-core/src/store.rs | 81 ++++++++++- python/Cargo.toml | 1 + python/python/lance_context/api.py | 34 ++++- python/src/lib.rs | 133 ++++++++++++++++- python/tests/test_persistence.py | 94 ++++++++++++ python/tests/test_search.py | 185 ++++++++++++++++++++---- 11 files changed, 675 insertions(+), 50 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 262eae2..5088dbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5336,6 +5336,7 @@ dependencies = [ "lance-namespace 7.0.0", "lancedb", "serde", + "serde_json", "tempfile", "tokio", "tracing", @@ -5349,6 +5350,7 @@ dependencies = [ "chrono", "lance-context", "pyo3", + "serde_json", "tokio", ] diff --git a/README.md b/README.md index d6261d6..dc34411 100644 --- a/README.md +++ b/README.md @@ -58,9 +58,45 @@ ctx.add( "user", "Where should I travel in spring?", external_id="conversation-2026-03-01#turn-1", + metadata={ + "tenant": "example-org", + "scope": "travel-planning", + "source_uri": "chat://conversation-2026-03-01", + "tags": ["travel", "preference"], + }, ) print(ctx.get(external_id="conversation-2026-03-01#turn-1")) +# Scoped recall and provenance-oriented metadata +runbook_embedding = [0.0] * 1536 +ctx.add( + "assistant", + "The runbook owner is the platform team.", + embedding=runbook_embedding, + bot_id="support-bot", + session_id="incident-123", + metadata={ + "tenant": "example-org", + "scope": "team", + "source_uri": "docs://runbooks/service-a", + "tags": ["runbook", "ownership"], + "confidence": 0.92, + }, +) +records = ctx.list( + filters={ + "bot_id": "support-bot", + "session_id": "incident-123", + "scope": "team", + "tags": {"contains": "runbook"}, + } +) +hits = ctx.search( + runbook_embedding, + limit=10, + filters={"tenant": "example-org", "content_type": "text/plain"}, +) + from PIL import Image image = Image.new("RGB", (2, 2), color="teal") ctx.add("assistant", image) @@ -153,6 +189,7 @@ let record = ContextRecord { tokens_used: None, custom: None, }), + metadata: None, content_type: "text/plain".into(), text_payload: Some("hello world".into()), binary_payload: None, diff --git a/crates/lance-context-core/Cargo.toml b/crates/lance-context-core/Cargo.toml index 158c3c3..096c04b 100644 --- a/crates/lance-context-core/Cargo.toml +++ b/crates/lance-context-core/Cargo.toml @@ -21,6 +21,7 @@ lance-namespace = "7.0.0" lancedb = "0.30.0" lance-graph = "0.5.4" serde = { version = "1", features = ["derive"] } +serde_json = "1" futures = "0.3" tokio = { version = "1", features = ["sync", "time"] } tracing = "0.1" diff --git a/crates/lance-context-core/src/lib.rs b/crates/lance-context-core/src/lib.rs index 08b1edf..fd24f82 100644 --- a/crates/lance-context-core/src/lib.rs +++ b/crates/lance-context-core/src/lib.rs @@ -1,4 +1,5 @@ //! Core types for the lance-context storage layer. +#![recursion_limit = "256"] mod context; mod record; @@ -6,7 +7,7 @@ pub mod serde; mod store; pub use context::{Context, ContextEntry, Snapshot}; -pub use record::{ContextRecord, SearchResult, StateMetadata}; +pub use record::{ContextRecord, MetadataFilter, RecordFilters, SearchResult, StateMetadata}; pub use store::{ CompactionConfig, CompactionStats, ContextStore, ContextStoreOptions, IdIndexType, }; diff --git a/crates/lance-context-core/src/record.rs b/crates/lance-context-core/src/record.rs index 95cf27b..3609e9c 100644 --- a/crates/lance-context-core/src/record.rs +++ b/crates/lance-context-core/src/record.rs @@ -1,4 +1,6 @@ use chrono::{DateTime, Utc}; +use serde_json::Value; +use std::collections::HashMap; /// Structured metadata captured alongside each context entry. #[derive(Debug, Clone, Default)] @@ -20,6 +22,7 @@ pub struct ContextRecord { pub created_at: DateTime, pub role: String, pub state_metadata: Option, + pub metadata: Option, pub content_type: String, pub text_payload: Option, pub binary_payload: Option>, @@ -32,3 +35,154 @@ pub struct SearchResult { pub record: ContextRecord, pub distance: f32, } + +/// Metadata matching operation for filtered retrieval. +#[derive(Debug, Clone, PartialEq)] +pub enum MetadataFilter { + Equals(Value), + Contains(Value), +} + +/// Filters applied to records before list pagination or search ranking. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct RecordFilters { + pub bot_id: Option, + pub session_id: Option, + pub role: Option, + pub content_type: Option, + pub created_at_start: Option>, + pub created_at_end: Option>, + pub metadata: HashMap, +} + +impl RecordFilters { + #[must_use] + pub fn is_empty(&self) -> bool { + self.bot_id.is_none() + && self.session_id.is_none() + && self.role.is_none() + && self.content_type.is_none() + && self.created_at_start.is_none() + && self.created_at_end.is_none() + && self.metadata.is_empty() + } + + #[must_use] + pub fn matches(&self, record: &ContextRecord) -> bool { + if self + .bot_id + .as_deref() + .is_some_and(|value| record.bot_id.as_deref() != Some(value)) + { + return false; + } + if self + .session_id + .as_deref() + .is_some_and(|value| record.session_id.as_deref() != Some(value)) + { + return false; + } + if self + .role + .as_deref() + .is_some_and(|value| record.role != value) + { + return false; + } + if self + .content_type + .as_deref() + .is_some_and(|value| record.content_type != value) + { + return false; + } + if self + .created_at_start + .is_some_and(|start| record.created_at < start) + { + return false; + } + if self + .created_at_end + .is_some_and(|end| record.created_at > end) + { + return false; + } + + self.metadata.iter().all(|(key, filter)| { + let Some(Value::Object(metadata)) = &record.metadata else { + return false; + }; + let Some(value) = metadata.get(key) else { + return false; + }; + match filter { + MetadataFilter::Equals(expected) => value == expected, + MetadataFilter::Contains(expected) => metadata_contains(value, expected), + } + }) + } +} + +fn metadata_contains(value: &Value, expected: &Value) -> bool { + match (value, expected) { + (Value::Array(items), expected) => items.iter().any(|item| item == expected), + (Value::String(value), Value::String(expected)) => value.contains(expected), + _ => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::TimeZone; + use serde_json::json; + + fn record() -> ContextRecord { + ContextRecord { + id: "rec-1".to_string(), + external_id: None, + run_id: "run-1".to_string(), + bot_id: Some("support-bot".to_string()), + session_id: Some("incident-1".to_string()), + created_at: Utc.with_ymd_and_hms(2026, 6, 9, 3, 0, 0).unwrap(), + role: "assistant".to_string(), + state_metadata: None, + metadata: Some(json!({ + "scope": "team", + "tags": ["runbook", "ownership"], + "confidence": 0.92 + })), + content_type: "text/plain".to_string(), + text_payload: Some("hello".to_string()), + binary_payload: None, + embedding: None, + } + } + + #[test] + fn filters_match_builtin_fields_timestamps_and_metadata() { + let mut filters = RecordFilters { + bot_id: Some("support-bot".to_string()), + session_id: Some("incident-1".to_string()), + role: Some("assistant".to_string()), + content_type: Some("text/plain".to_string()), + created_at_start: Some(Utc.with_ymd_and_hms(2026, 6, 9, 2, 0, 0).unwrap()), + created_at_end: Some(Utc.with_ymd_and_hms(2026, 6, 9, 4, 0, 0).unwrap()), + metadata: HashMap::new(), + }; + filters + .metadata + .insert("scope".to_string(), MetadataFilter::Equals(json!("team"))); + filters.metadata.insert( + "tags".to_string(), + MetadataFilter::Contains(json!("runbook")), + ); + + assert!(filters.matches(&record())); + + filters.session_id = Some("other".to_string()); + assert!(!filters.matches(&record())); + } +} diff --git a/crates/lance-context-core/src/store.rs b/crates/lance-context-core/src/store.rs index 5f1f6a1..794a9ff 100644 --- a/crates/lance-context-core/src/store.rs +++ b/crates/lance-context-core/src/store.rs @@ -31,7 +31,7 @@ use tokio::task::JoinHandle; use tracing::{error, info, warn}; use uuid::Uuid; -use crate::record::{ContextRecord, SearchResult, StateMetadata}; +use crate::record::{ContextRecord, RecordFilters, SearchResult, StateMetadata}; /// Embedding length used for the semantic index column. const DEFAULT_EMBEDDING_DIM: i32 = 1536; @@ -328,6 +328,16 @@ impl ContextStore { &self, limit: Option, offset: Option, + ) -> LanceResult> { + self.list_filtered(limit, offset, None).await + } + + /// List records matching filters. + pub async fn list_filtered( + &self, + limit: Option, + offset: Option, + filters: Option<&RecordFilters>, ) -> LanceResult> { let scanner = self.lsm_scanner().await?; let mut stream = scanner.try_into_stream().await?; @@ -336,6 +346,10 @@ impl ContextStore { results.extend(batch_to_records(&batch)?); } + if let Some(filters) = filters.filter(|filters| !filters.is_empty()) { + results.retain(|record| filters.matches(record)); + } + if let Some(offset) = offset { results = results.into_iter().skip(offset).collect(); } @@ -371,6 +385,16 @@ impl ContextStore { &self, query: &[f32], limit: Option, + ) -> LanceResult> { + self.search_filtered(query, limit, None).await + } + + /// Perform a nearest-neighbor search over stored embeddings matching filters. + pub async fn search_filtered( + &self, + query: &[f32], + limit: Option, + filters: Option<&RecordFilters>, ) -> LanceResult> { if query.len() != DEFAULT_EMBEDDING_DIM as usize { return Err(ArrowError::InvalidArgumentError(format!( @@ -387,7 +411,7 @@ impl ContextStore { } let mut results: Vec = self - .list(None, None) + .list_filtered(None, None, filters) .await? .into_iter() .filter_map(|record| { @@ -650,10 +674,14 @@ impl ContextStore { /// Lance V1 blob encoding (out-of-line binary buffers). For `text_payload`, /// this also changes the Arrow type from `LargeUtf8` to `LargeBinary`. pub fn schema(blob_columns: &HashSet) -> Schema { - Self::schema_with_options(blob_columns, true) + Self::schema_with_options(blob_columns, true, true) } - fn schema_with_options(blob_columns: &HashSet, include_external_id: bool) -> Schema { + fn schema_with_options( + blob_columns: &HashSet, + include_external_id: bool, + include_metadata: bool, + ) -> Schema { let mut id_metadata = HashMap::new(); id_metadata.insert( "lance-schema:unenforced-primary-key".to_string(), @@ -707,6 +735,11 @@ impl ContextStore { ), true, ), + ]); + if include_metadata { + fields.push(Field::new("metadata", DataType::LargeUtf8, true)); + } + fields.extend([ Field::new("content_type", DataType::Utf8, false), text_field, binary_field, @@ -781,6 +814,18 @@ impl ContextStore { ) .into()); } + let include_metadata = self + .dataset + .schema() + .field_paths() + .iter() + .any(|path| path == "metadata"); + if !include_metadata && entries.iter().any(|entry| entry.metadata.is_some()) { + return Err(ArrowError::InvalidArgumentError( + "metadata requires a context dataset created with metadata support".to_string(), + ) + .into()); + } let mut id_builder = StringBuilder::new(); let mut external_id_builder = StringBuilder::new(); @@ -789,6 +834,7 @@ impl ContextStore { let mut session_id_builder = StringBuilder::new(); let mut created_at_builder = TimestampMicrosecondBuilder::with_capacity(entries.len()); let mut role_builder = StringDictionaryBuilder::::new(); + let mut metadata_builder = LargeStringBuilder::new(); let mut content_type_builder = StringBuilder::new(); let mut binary_builder = LargeBinaryBuilder::new(); @@ -831,6 +877,10 @@ impl ContextStore { session_id_builder.append_option(entry.session_id.as_deref()); created_at_builder.append_value(entry.created_at.timestamp_micros()); role_builder.append(&entry.role)?; + match &entry.metadata { + Some(metadata) => metadata_builder.append_value(metadata.to_string()), + None => metadata_builder.append_null(), + } content_type_builder.append_value(&entry.content_type); if text_is_blob { @@ -924,6 +974,7 @@ impl ContextStore { let session_id_array: ArrayRef = Arc::new(session_id_builder.finish()); let created_at_array: ArrayRef = Arc::new(created_at_builder.finish()); let role_array: ArrayRef = Arc::new(role_builder.finish()); + let metadata_array: ArrayRef = Arc::new(metadata_builder.finish()); let content_type_array: ArrayRef = Arc::new(content_type_builder.finish()); let text_array: ArrayRef = if text_is_blob { Arc::new(text_binary_builder.unwrap().finish()) @@ -937,6 +988,7 @@ impl ContextStore { let schema = Arc::new(Self::schema_with_options( &self.blob_columns, include_external_id, + include_metadata, )); let mut arrays = vec![id_array]; if include_external_id { @@ -949,6 +1001,11 @@ impl ContextStore { created_at_array, role_array, state_array, + ]); + if include_metadata { + arrays.push(metadata_array); + } + arrays.extend([ content_type_array, text_array, binary_array, @@ -981,6 +1038,7 @@ fn batch_to_records(batch: &RecordBatch) -> LanceResult> { let created_at_array = column_as::(batch, "created_at")?; let role_array = column_as::>(batch, "role")?; let state_array = column_as::(batch, "state_metadata")?; + let metadata_array = column_as_optional::(batch, "metadata"); let content_type_array = column_as::(batch, "content_type")?; let binary_array = column_as::(batch, "binary_payload")?; let embedding_array = column_as::(batch, "embedding")?; @@ -1142,6 +1200,19 @@ fn batch_to_records(batch: &RecordBatch) -> LanceResult> { } }); + let metadata = match metadata_array { + Some(arr) if !arr.is_null(row) => { + Some(serde_json::from_str(arr.value(row)).map_err(|err| { + LanceError::from(ArrowError::InvalidArgumentError(format!( + "invalid metadata JSON for record {}: {}", + id_array.value(row), + err + ))) + })?) + } + _ => None, + }; + results.push(ContextRecord { id: id_array.value(row).to_string(), external_id: external_id_array.and_then(|arr| { @@ -1157,6 +1228,7 @@ fn batch_to_records(batch: &RecordBatch) -> LanceResult> { created_at, role, state_metadata, + metadata, content_type: content_type_array.value(row).to_string(), text_payload, binary_payload, @@ -1251,6 +1323,7 @@ mod tests { tokens_used: Some(10), custom: None, }), + metadata: None, content_type: CONTENT_TYPE_TEXT.to_string(), text_payload: Some(format!("payload-{id}")), binary_payload: None, diff --git a/python/Cargo.toml b/python/Cargo.toml index ea3a524..90ca7f5 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -13,4 +13,5 @@ crate-type = ["cdylib"] chrono = { version = "0.4", default-features = false, features = ["clock"] } lance-context = { path = "../crates/lance-context" } pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39", "py-clone"] } +serde_json = "1" tokio = { version = "1", features = ["rt-multi-thread"] } diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index d431def..0bffd7e 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import warnings from datetime import datetime from io import BytesIO @@ -126,6 +127,7 @@ def _normalize_record(raw: dict[str, Any]) -> dict[str, Any]: "embedding": raw.get("embedding"), "created_at": created_at, "state_metadata": raw.get("state_metadata"), + "metadata": raw.get("metadata"), } @@ -145,6 +147,15 @@ def _normalize_search_hit(raw: dict[str, Any]) -> dict[str, Any]: } +def _json_dumps(value: dict[str, Any] | None, name: str) -> str | None: + if value is None: + return None + try: + return json.dumps(value, sort_keys=True, separators=(",", ":")) + except (TypeError, ValueError) as exc: + raise TypeError(f"{name} must be JSON-serializable") from exc + + def _merge_storage_options( storage_options: dict[str, Any] | None, *, @@ -352,6 +363,7 @@ def add( bot_id: str | None = None, session_id: str | None = None, external_id: str | None = None, + metadata: dict[str, Any] | None = None, ) -> None: if content_type is not None and data_type is not None: raise ValueError("Specify only one of content_type or data_type") @@ -366,6 +378,7 @@ def add( bot_id, session_id, external_id, + _json_dumps(metadata, "metadata"), ) def snapshot(self, label: str | None = None) -> str: @@ -378,25 +391,36 @@ def fork(self, branch_name: str) -> Context: def checkout(self, version_id: int | str) -> None: self._inner.checkout(int(version_id)) - def search(self, query: Any, limit: int | None = None) -> list[dict[str, Any]]: + def search( + self, + query: Any, + limit: int | None = None, + filters: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: vector = _coerce_vector(query) - results = self._inner.search(vector, limit) + results = self._inner.search(vector, limit, _json_dumps(filters, "filters")) return [_normalize_search_hit(item) for item in results] def list( - self, limit: int | None = None, offset: int | None = None + self, + limit: int | None = None, + offset: int | None = None, + filters: dict[str, Any] | None = None, ) -> list[dict[str, Any]]: """Return stored entries. Args: limit: Maximum number of entries to return. If None, returns all. offset: Number of entries to skip before returning results. + filters: Optional equality filters for built-in fields + (bot_id, session_id, role, content_type), created_at range + filters, or metadata fields. Returns: List of entry dicts with keys: id, run_id, role, content_type, - text, binary, embedding, created_at, state_metadata. + text, binary, embedding, created_at, metadata, state_metadata. """ - results = self._inner.list(limit, offset) + results = self._inner.list(limit, offset, _json_dumps(filters, "filters")) return [_normalize_record(item) for item in results] def get( diff --git a/python/src/lib.rs b/python/src/lib.rs index f40a460..e22636e 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,17 +1,20 @@ +#![recursion_limit = "256"] + use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use chrono::{SecondsFormat, Utc}; +use chrono::{DateTime, SecondsFormat, Utc}; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyDict, PyType}; +use pyo3::types::{PyBytes, PyDict, PyModule, PyType}; use pyo3::IntoPyObject; +use serde_json::Value; use tokio::runtime::Runtime; use lance_context::serde::CONTENT_TYPE_TEXT; use lance_context::{ CompactionConfig, CompactionMetrics, CompactionStats, Context as RustContext, ContextRecord, - ContextStore, ContextStoreOptions, IdIndexType, SearchResult, + ContextStore, ContextStoreOptions, IdIndexType, MetadataFilter, RecordFilters, SearchResult, }; const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream"; @@ -107,6 +110,99 @@ fn compaction_config_from_dict<'py>( Ok(config) } +fn metadata_from_json(metadata_json: Option) -> PyResult> { + metadata_json + .map(|value| serde_json::from_str(&value).map_err(to_py_err)) + .transpose() +} + +fn filters_from_json(filters_json: Option) -> PyResult> { + let Some(filters_json) = filters_json else { + return Ok(None); + }; + let value: Value = serde_json::from_str(&filters_json).map_err(to_py_err)?; + let Value::Object(object) = value else { + return Err(PyRuntimeError::new_err("filters must be a JSON object")); + }; + + let mut filters = RecordFilters::default(); + for (key, value) in object { + match key.as_str() { + "bot_id" => filters.bot_id = filter_string(key.as_str(), value)?, + "session_id" => filters.session_id = filter_string(key.as_str(), value)?, + "role" => filters.role = filter_string(key.as_str(), value)?, + "content_type" => filters.content_type = filter_string(key.as_str(), value)?, + "created_at" => apply_created_at_filter(&mut filters, value)?, + "created_at_start" | "created_after" | "created_at_gte" => { + filters.created_at_start = Some(parse_filter_datetime(&key, &value)?); + } + "created_at_end" | "created_before" | "created_at_lte" => { + filters.created_at_end = Some(parse_filter_datetime(&key, &value)?); + } + _ => { + let filter = match value { + Value::Object(mut object) + if object.len() == 1 && object.contains_key("contains") => + { + MetadataFilter::Contains(object.remove("contains").unwrap()) + } + value => MetadataFilter::Equals(value), + }; + filters.metadata.insert(key, filter); + } + } + } + + Ok(Some(filters)) +} + +fn filter_string(name: &str, value: Value) -> PyResult> { + match value { + Value::Null => Ok(None), + Value::String(value) => Ok(Some(value)), + _ => Err(PyRuntimeError::new_err(format!( + "filter '{name}' must be a string or null" + ))), + } +} + +fn apply_created_at_filter(filters: &mut RecordFilters, value: Value) -> PyResult<()> { + let Value::Object(object) = value else { + return Err(PyRuntimeError::new_err( + "filter 'created_at' must be an object with gte/lte bounds", + )); + }; + + for (key, value) in object { + match key.as_str() { + "gte" | "start" | "after" => { + filters.created_at_start = Some(parse_filter_datetime(&key, &value)?); + } + "lte" | "end" | "before" => { + filters.created_at_end = Some(parse_filter_datetime(&key, &value)?); + } + other => { + return Err(PyRuntimeError::new_err(format!( + "unsupported created_at filter operator '{other}'" + ))); + } + } + } + + Ok(()) +} + +fn parse_filter_datetime(name: &str, value: &Value) -> PyResult> { + let Some(value) = value.as_str() else { + return Err(PyRuntimeError::new_err(format!( + "filter '{name}' must be an ISO-8601 timestamp string" + ))); + }; + DateTime::parse_from_rfc3339(value) + .map(|value| value.with_timezone(&Utc)) + .map_err(to_py_err) +} + #[pymethods] impl Context { #[classmethod] @@ -172,7 +268,7 @@ impl Context { } #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (role, content, data_type = None, embedding = None, bot_id = None, session_id = None, external_id = None))] + #[pyo3(signature = (role, content, data_type = None, embedding = None, bot_id = None, session_id = None, external_id = None, metadata_json = None))] fn add( &mut self, py: Python<'_>, @@ -183,6 +279,7 @@ impl Context { bot_id: Option, session_id: Option, external_id: Option, + metadata_json: Option, ) -> PyResult<()> { let (content_type, text_payload, binary_payload, inner_content) = match content.extract::<&[u8]>() { @@ -204,6 +301,7 @@ impl Context { }; let record_id = format!("{}-{}", self.run_id, self.inner.entries() + 1); + let metadata = metadata_from_json(metadata_json)?; let record = ContextRecord { id: record_id, external_id, @@ -213,6 +311,7 @@ impl Context { created_at: Utc::now(), role: role.to_string(), state_metadata: None, + metadata, content_type, text_payload, binary_payload, @@ -249,31 +348,38 @@ impl Context { Ok(()) } - #[pyo3(signature = (query, limit = None))] + #[pyo3(signature = (query, limit = None, filters_json = None))] fn search( &self, py: Python<'_>, query: Vec, limit: Option, + filters_json: Option, ) -> PyResult> { - let hits_res = py.allow_threads(|| self.runtime.block_on(self.store.search(&query, limit))); + let filters = filters_from_json(filters_json)?; + let hits_res = py.allow_threads(|| { + self.runtime + .block_on(self.store.search_filtered(&query, limit, filters.as_ref())) + }); let hits = hits_res.map_err(to_py_err)?; hits.into_iter() .map(|hit| search_hit_to_py(py, hit)) .collect() } - #[pyo3(signature = (limit = None, offset = None))] + #[pyo3(signature = (limit = None, offset = None, filters_json = None))] fn list( &self, py: Python<'_>, limit: Option, offset: Option, + filters_json: Option, ) -> PyResult> { + let filters = filters_from_json(filters_json)?; // Release GIL during data retrieval let records = py.allow_threads(|| { self.runtime - .block_on(self.store.list(limit, offset)) + .block_on(self.store.list_filtered(limit, offset, filters.as_ref())) .map_err(to_py_err) })?; @@ -409,6 +515,7 @@ fn record_to_py(py: Python<'_>, record: ContextRecord) -> PyResult { created_at, role, state_metadata, + metadata, content_type, text_payload, binary_payload, @@ -439,6 +546,11 @@ fn record_to_py(py: Python<'_>, record: ContextRecord) -> PyResult { None => py.None().into_pyobject(py)?.unbind(), }; dict.set_item("state_metadata", state_obj)?; + let metadata_obj: PyObject = match metadata { + Some(metadata) => json_value_to_py(py, &metadata)?, + None => py.None().into_pyobject(py)?.unbind(), + }; + dict.set_item("metadata", metadata_obj)?; dict.set_item("content_type", content_type)?; dict.set_item("text_payload", text_payload)?; match binary_payload { @@ -449,6 +561,11 @@ fn record_to_py(py: Python<'_>, record: ContextRecord) -> PyResult { Ok(dict.into_pyobject(py)?.unbind().into()) } +fn json_value_to_py(py: Python<'_>, value: &Value) -> PyResult { + let json = PyModule::import(py, "json")?; + Ok(json.call_method1("loads", (value.to_string(),))?.unbind()) +} + fn to_py_err(err: E) -> PyErr { PyRuntimeError::new_err(err.to_string()) } diff --git a/python/tests/test_persistence.py b/python/tests/test_persistence.py index e94bda0..c9b4db5 100644 --- a/python/tests/test_persistence.py +++ b/python/tests/test_persistence.py @@ -5,6 +5,7 @@ import sys import time import uuid +from datetime import datetime from io import BytesIO from pathlib import Path from typing import Any @@ -156,6 +157,99 @@ def test_text_round_trip(tmp_path: Path) -> None: assert record["content_type"] == "text/plain" +def test_metadata_and_filters_round_trip(tmp_path: Path) -> None: + uri = tmp_path / "context.lance" + ctx = Context.create(str(uri)) + ctx.add( + "assistant", + "The runbook owner is the platform team.", + bot_id="support-bot", + session_id="incident-1", + metadata={ + "tenant": "example-org", + "scope": "team", + "source_uri": "docs://runbooks/service-a", + "tags": ["runbook", "ownership"], + "confidence": 0.92, + }, + ) + ctx.add( + "user", + "What is the owner?", + bot_id="support-bot", + session_id="incident-2", + metadata={"tenant": "example-org", "scope": "personal"}, + ) + + scoped = ctx.list( + filters={ + "bot_id": "support-bot", + "session_id": "incident-1", + "role": "assistant", + "content_type": "text/plain", + "scope": "team", + "tags": {"contains": "runbook"}, + } + ) + + assert len(scoped) == 1 + assert scoped[0]["text"] == "The runbook owner is the platform team." + assert scoped[0]["metadata"] == { + "tenant": "example-org", + "scope": "team", + "source_uri": "docs://runbooks/service-a", + "tags": ["runbook", "ownership"], + "confidence": 0.92, + } + + created_at = scoped[0]["created_at"] + assert isinstance(created_at, datetime) + timestamp_scoped = ctx.list( + filters={ + "created_at": { + "gte": created_at.isoformat(), + "lte": created_at.isoformat(), + } + } + ) + assert [record["id"] for record in timestamp_scoped] == [scoped[0]["id"]] + + +def test_search_applies_filters_before_limit(tmp_path: Path) -> None: + uri = tmp_path / "context.lance" + ctx = Context.create(str(uri)) + near = [0.0] * 1536 + far = [0.0] * 1536 + far[0] = 10.0 + + ctx.add( + "assistant", + "global nearest", + embedding=near, + bot_id="support-bot", + session_id="other", + metadata={"scope": "personal"}, + ) + ctx.add( + "assistant", + "scoped farther", + embedding=far, + bot_id="support-bot", + session_id="incident-1", + metadata={"scope": "team", "tags": ["runbook"]}, + ) + + hits = ctx.search( + near, + limit=1, + filters={"session_id": "incident-1", "tags": {"contains": "runbook"}}, + ) + + assert len(hits) == 1 + assert hits[0]["text"] == "scoped farther" + assert hits[0]["metadata"] == {"scope": "team", "tags": ["runbook"]} + + def test_image_round_trip(tmp_path: Path) -> None: Image = pytest.importorskip("PIL.Image") uri = tmp_path / "context.lance" diff --git a/python/tests/test_search.py b/python/tests/test_search.py index 51dbe1e..9260cd3 100644 --- a/python/tests/test_search.py +++ b/python/tests/test_search.py @@ -1,3 +1,4 @@ +import json from datetime import datetime from typing import Any @@ -12,8 +13,8 @@ class DummyInner: def __init__(self) -> None: - self.search_calls: list[tuple[list[float], int | None]] = [] - self.list_calls: list[tuple[int | None, int | None]] = [] + self.search_calls: list[tuple[list[float], int | None, str | None]] = [] + self.list_calls: list[tuple[int | None, int | None, str | None]] = [] self.get_calls: list[tuple[str | None, str | None]] = [] self.add_calls: list[ tuple[ @@ -24,6 +25,7 @@ def __init__(self) -> None: str | None, str | None, str | None, + str | None, ] ] = [] @@ -36,19 +38,29 @@ def add( bot_id: str | None, session_id: str | None, external_id: str | None, + metadata_json: str | None, ): self.add_calls.append( - (role, content, data_type, embedding, bot_id, session_id, external_id) + ( + role, + content, + data_type, + embedding, + bot_id, + session_id, + external_id, + metadata_json, + ) ) def get(self, id: str | None, external_id: str | None): self.get_calls.append((id, external_id)) if id == "rec-1" or external_id == "source-1": - return self.list(None, None)[0] + return self.list(None, None, None)[0] return None - def search(self, vector: list[float], limit: int | None): - self.search_calls.append((vector, limit)) + def search(self, vector: list[float], limit: int | None, filters_json: str | None): + self.search_calls.append((vector, limit, filters_json)) return [ { "id": "rec-1", @@ -64,11 +76,12 @@ def search(self, vector: list[float], limit: int | None): "distance": 0.12, "created_at": "2024-01-01T12:00:00Z", "state_metadata": {"step": 1}, + "metadata": {"scope": "team", "tags": ["runbook"]}, } ] - def list(self, limit: int | None, offset: int | None): - self.list_calls.append((limit, offset)) + def list(self, limit: int | None, offset: int | None, filters_json: str | None): + self.list_calls.append((limit, offset, filters_json)) return [ { "id": "rec-1", @@ -83,6 +96,7 @@ def list(self, limit: int | None, offset: int | None): "embedding": [0.1, 0.2], "created_at": "2024-01-01T12:00:00Z", "state_metadata": {"step": 1}, + "metadata": {"scope": "team", "tags": ["runbook"]}, }, { "id": "rec-2", @@ -97,6 +111,7 @@ def list(self, limit: int | None, offset: int | None): "embedding": None, "created_at": "2024-01-02T12:00:00Z", "state_metadata": None, + "metadata": None, }, ] @@ -141,13 +156,26 @@ def test_context_search_formats_results(): hits = ctx.search([0.5, 0.4], limit=3) - assert dummy.search_calls == [([0.5, 0.4], 3)] + assert dummy.search_calls == [([0.5, 0.4], 3, None)] assert hits[0]["id"] == "rec-1" assert hits[0]["text"] == "hello" assert hits[0]["binary"] is None + assert hits[0]["metadata"] == {"scope": "team", "tags": ["runbook"]} assert isinstance(hits[0]["created_at"], datetime) +def test_context_search_forwards_filters(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + ctx.search([0.5, 0.4], filters={"bot_id": "support_bot", "scope": "team"}) + + filters_json = dummy.search_calls[0][2] + assert filters_json is not None + assert json.loads(filters_json) == {"bot_id": "support_bot", "scope": "team"} + + def test_normalize_record_without_distance(): result = _normalize_record( { @@ -175,11 +203,12 @@ def test_context_list_returns_entries(): entries = ctx.list(limit=10, offset=5) - assert dummy.list_calls == [(10, 5)] + assert dummy.list_calls == [(10, 5, None)] assert len(entries) == 2 assert entries[0]["id"] == "rec-1" assert entries[0]["text"] == "hello" assert entries[0]["role"] == "user" + assert entries[0]["metadata"] == {"scope": "team", "tags": ["runbook"]} assert "distance" not in entries[0] assert entries[1]["id"] == "rec-2" assert entries[1]["text"] == "world" @@ -237,7 +266,22 @@ def test_context_list_default_args(): ctx.list() - assert dummy.list_calls == [(None, None)] + assert dummy.list_calls == [(None, None, None)] + + +def test_context_list_forwards_filters(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + ctx.list(filters={"role": "user", "tags": {"contains": "runbook"}}) + + filters_json = dummy.list_calls[0][2] + assert filters_json is not None + assert json.loads(filters_json) == { + "role": "user", + "tags": {"contains": "runbook"}, + } def test_context_add_with_embedding(): @@ -248,9 +292,16 @@ def test_context_add_with_embedding(): embedding = [0.1, 0.2, 0.3] ctx.add("user", "hello", embedding=embedding) - role, content, data_type, passed_embedding, bot_id, session_id, external_id = ( - _only_add_call(dummy) - ) + ( + role, + content, + data_type, + passed_embedding, + bot_id, + session_id, + external_id, + metadata_json, + ) = _only_add_call(dummy) assert role == "user" assert content == "hello" assert data_type is None @@ -258,6 +309,7 @@ def test_context_add_with_embedding(): assert bot_id is None assert session_id is None assert external_id is None + assert metadata_json is None def test_context_add_without_embedding(): @@ -267,15 +319,23 @@ def test_context_add_without_embedding(): ctx.add("assistant", "world") - role, content, data_type, passed_embedding, bot_id, session_id, external_id = ( - _only_add_call(dummy) - ) + ( + role, + content, + data_type, + passed_embedding, + bot_id, + session_id, + external_id, + metadata_json, + ) = _only_add_call(dummy) assert role == "assistant" assert content == "world" assert passed_embedding is None assert bot_id is None assert session_id is None assert external_id is None + assert metadata_json is None def test_context_add_with_content_type_and_embedding(): @@ -286,15 +346,23 @@ def test_context_add_with_content_type_and_embedding(): embedding = [0.5, 0.6] ctx.add("system", "prompt", content_type="text/markdown", embedding=embedding) - role, content, data_type, passed_embedding, bot_id, session_id, external_id = ( - _only_add_call(dummy) - ) + ( + role, + content, + data_type, + passed_embedding, + bot_id, + session_id, + external_id, + metadata_json, + ) = _only_add_call(dummy) assert role == "system" assert data_type == "text/markdown" assert passed_embedding == [0.5, 0.6] assert bot_id is None assert session_id is None assert external_id is None + assert metadata_json is None def test_context_add_with_bot_id(): @@ -304,14 +372,22 @@ def test_context_add_with_bot_id(): ctx.add("user", "hello", bot_id="support_bot") - role, content, data_type, passed_embedding, bot_id, session_id, external_id = ( - _only_add_call(dummy) - ) + ( + role, + content, + data_type, + passed_embedding, + bot_id, + session_id, + external_id, + metadata_json, + ) = _only_add_call(dummy) assert role == "user" assert content == "hello" assert bot_id == "support_bot" assert session_id is None assert external_id is None + assert metadata_json is None def test_context_add_with_session_id(): @@ -321,14 +397,22 @@ def test_context_add_with_session_id(): ctx.add("user", "hello", session_id="user_123") - role, content, data_type, passed_embedding, bot_id, session_id, external_id = ( - _only_add_call(dummy) - ) + ( + role, + content, + data_type, + passed_embedding, + bot_id, + session_id, + external_id, + metadata_json, + ) = _only_add_call(dummy) assert role == "user" assert content == "hello" assert bot_id is None assert session_id == "user_123" assert external_id is None + assert metadata_json is None def test_context_add_with_agent_and_session_id(): @@ -338,13 +422,21 @@ def test_context_add_with_agent_and_session_id(): ctx.add("user", "hello", bot_id="sales_bot", session_id="conv_456") - role, content, data_type, passed_embedding, bot_id, session_id, external_id = ( - _only_add_call(dummy) - ) + ( + role, + content, + data_type, + passed_embedding, + bot_id, + session_id, + external_id, + metadata_json, + ) = _only_add_call(dummy) assert role == "user" assert bot_id == "sales_bot" assert session_id == "conv_456" assert external_id is None + assert metadata_json is None def test_context_add_with_all_options(): @@ -360,16 +452,45 @@ def test_context_add_with_all_options(): bot_id="bot", session_id="sess", external_id="doc-1#chunk-1", + metadata={ + "tenant": "example-org", + "scope": "team", + "tags": ["runbook", "ownership"], + "confidence": 0.92, + }, ) - role, content, data_type, passed_embedding, bot_id, session_id, external_id = ( - _only_add_call(dummy) - ) + ( + role, + content, + data_type, + passed_embedding, + bot_id, + session_id, + external_id, + metadata_json, + ) = _only_add_call(dummy) assert role == "user" assert passed_embedding == [0.1, 0.2] assert bot_id == "bot" assert session_id == "sess" assert external_id == "doc-1#chunk-1" + assert metadata_json is not None + assert json.loads(metadata_json) == { + "tenant": "example-org", + "scope": "team", + "tags": ["runbook", "ownership"], + "confidence": 0.92, + } + + +def test_context_add_rejects_non_json_metadata(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + with pytest.raises(TypeError, match="metadata must be JSON-serializable"): + ctx.add("user", "hello", metadata={"bad": object()}) def test_normalize_record_with_agent_and_session_id(): From 5c4411b5e93cb94ed445bd841a4beae76fe9f0e5 Mon Sep 17 00:00:00 2001 From: dcfocus Date: Tue, 9 Jun 2026 17:02:27 +0000 Subject: [PATCH 2/3] style: format python rust bindings --- python/src/lib.rs | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/python/src/lib.rs b/python/src/lib.rs index 8b9777b..dcb6a42 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -318,9 +318,9 @@ impl Context { let mut prepared = Vec::new(); for (index, item) in records.try_iter()?.enumerate() { let item = item?; - let dict = item.downcast::().map_err(|_| { - PyTypeError::new_err(format!("records[{index}] must be a dict")) - })?; + let dict = item + .downcast::() + .map_err(|_| PyTypeError::new_err(format!("records[{index}] must be a dict")))?; prepared.push(self.prepare_record_from_dict(dict, index)?); } @@ -330,8 +330,7 @@ impl Context { let context_records: Vec = prepared.iter().map(|item| item.record.clone()).collect(); - let add_res = - py.allow_threads(|| self.runtime.block_on(self.store.add(&context_records))); + let add_res = py.allow_threads(|| self.runtime.block_on(self.store.add(&context_records))); add_res.map_err(to_py_err)?; for item in prepared { @@ -487,13 +486,10 @@ impl Context { ) -> PyResult { let role = required_item(dict, "role", index)?.extract::()?; let content = required_item(dict, "content", index)?; - let data_type = - optional_item(dict, "data_type")?.map(|value| value.extract::()); - let embedding = - optional_item(dict, "embedding")?.map(|value| value.extract::>()); + let data_type = optional_item(dict, "data_type")?.map(|value| value.extract::()); + let embedding = optional_item(dict, "embedding")?.map(|value| value.extract::>()); let bot_id = optional_item(dict, "bot_id")?.map(|value| value.extract::()); - let session_id = - optional_item(dict, "session_id")?.map(|value| value.extract::()); + let session_id = optional_item(dict, "session_id")?.map(|value| value.extract::()); let external_id = optional_item(dict, "external_id")?.map(|value| value.extract::()); let metadata_json = @@ -582,10 +578,7 @@ fn required_item<'py>( }) } -fn optional_item<'py>( - dict: &Bound<'py, PyDict>, - key: &str, -) -> PyResult>> { +fn optional_item<'py>(dict: &Bound<'py, PyDict>, key: &str) -> PyResult>> { Ok(dict.get_item(key)?.filter(|value| !value.is_none())) } From 32e32e598f6cf1252667551535e74687e4eb899f Mon Sep 17 00:00:00 2001 From: dcfocus Date: Tue, 9 Jun 2026 17:12:04 +0000 Subject: [PATCH 3/3] refactor: reduce record preparation arguments --- python/src/lib.rs | 60 ++++++++++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/python/src/lib.rs b/python/src/lib.rs index dcb6a42..98fcb87 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -27,6 +27,16 @@ struct PreparedRecord { data_type: Option, } +struct RecordInput { + role: String, + data_type: Option, + embedding: Option>, + bot_id: Option, + session_id: Option, + external_id: Option, + metadata_json: Option, +} + #[pyfunction] fn version() -> &'static str { env!("CARGO_PKG_VERSION") @@ -289,14 +299,16 @@ impl Context { metadata_json: Option, ) -> PyResult<()> { let prepared = self.prepare_record( - role.to_string(), content, - data_type.map(str::to_string), - embedding, - bot_id, - session_id, - external_id, - metadata_json, + RecordInput { + role: role.to_string(), + data_type: data_type.map(str::to_string), + embedding, + bot_id, + session_id, + external_id, + metadata_json, + }, 1, )?; @@ -496,30 +508,36 @@ impl Context { optional_item(dict, "metadata_json")?.map(|value| value.extract::()); self.prepare_record( - role, &content, - data_type.transpose()?, - embedding.transpose()?, - bot_id.transpose()?, - session_id.transpose()?, - external_id.transpose()?, - metadata_json.transpose()?, + RecordInput { + role, + data_type: data_type.transpose()?, + embedding: embedding.transpose()?, + bot_id: bot_id.transpose()?, + session_id: session_id.transpose()?, + external_id: external_id.transpose()?, + metadata_json: metadata_json.transpose()?, + }, index as u64 + 1, ) } fn prepare_record( &self, - role: String, content: &Bound<'_, PyAny>, - data_type: Option, - embedding: Option>, - bot_id: Option, - session_id: Option, - external_id: Option, - metadata_json: Option, + input: RecordInput, offset: u64, ) -> PyResult { + let RecordInput { + role, + data_type, + embedding, + bot_id, + session_id, + external_id, + metadata_json, + } = input; + let (content_type, text_payload, binary_payload, inner_content) = match content.extract::<&[u8]>() { Ok(bytes) => (