diff --git a/Cargo.lock b/Cargo.lock index 8966e47..1e7718c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1119,6 +1119,58 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31b698c5f9a010f6573133b09e0de5408834d0c82f8d7475a89fc1867a71cd90" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http 1.4.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.10.1", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http 1.4.1", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backon" version = "1.6.0" @@ -4539,6 +4591,7 @@ dependencies = [ "http 1.4.1", "http-body 1.0.1", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -5318,9 +5371,33 @@ dependencies = [ name = "lance-context" version = "0.3.0" dependencies = [ + "lance-context-api", + "lance-context-client", "lance-context-core", ] +[[package]] +name = "lance-context-api" +version = "0.2.4" +dependencies = [ + "base64", + "chrono", + "serde", + "serde_json", + "thiserror 2.0.18", +] + +[[package]] +name = "lance-context-client" +version = "0.2.4" +dependencies = [ + "lance-context-api", + "reqwest 0.12.28", + "serde", + "serde_json", + "thiserror 2.0.18", +] + [[package]] name = "lance-context-core" version = "0.3.0" @@ -5331,6 +5408,7 @@ dependencies = [ "chrono", "futures", "lance 7.0.0", + "lance-context-api", "lance-graph", "lance-index 7.0.0", "lance-namespace 7.0.0", @@ -5348,12 +5426,30 @@ name = "lance-context-python" version = "0.3.0" dependencies = [ "chrono", - "lance-context", + "lance-context-core", "pyo3", "serde_json", "tokio", ] +[[package]] +name = "lance-context-server" +version = "0.2.4" +dependencies = [ + "axum", + "chrono", + "clap", + "lance-context-api", + "lance-context-core", + "serde", + "serde_json", + "tokio", + "tower-http", + "tracing", + "tracing-subscriber", + "uuid", +] + [[package]] name = "lance-core" version = "1.0.4" @@ -6529,6 +6625,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "matrixmultiply" version = "0.3.10" @@ -9109,6 +9211,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_repr" version = "0.1.20" @@ -10147,6 +10260,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -10169,6 +10283,7 @@ dependencies = [ "tower", "tower-layer", "tower-service", + "tracing", "url", ] diff --git a/Cargo.toml b/Cargo.toml index 6c5fe2e..8a816d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,9 @@ members = [ "crates/lance-context-core", "crates/lance-context", + "crates/lance-context-api", + "crates/lance-context-server", + "crates/lance-context-client", "python", ] resolver = "2" diff --git a/crates/lance-context-api/Cargo.toml b/crates/lance-context-api/Cargo.toml new file mode 100644 index 0000000..e5bc5c6 --- /dev/null +++ b/crates/lance-context-api/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "lance-context-api" +version = "0.2.4" +edition = "2021" +license = "Apache-2.0" +authors = ["Lance Devs "] +repository = "https://github.com/lancedb/lance-context" +description = "Shared request/response types for the lance-context REST API" +keywords = ["context", "lance", "api"] + +[dependencies] +base64 = "0.22" +chrono = { version = "0.4", default-features = false, features = ["clock", "serde"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +thiserror = "2" diff --git a/crates/lance-context-api/src/lib.rs b/crates/lance-context-api/src/lib.rs new file mode 100644 index 0000000..7177b52 --- /dev/null +++ b/crates/lance-context-api/src/lib.rs @@ -0,0 +1,330 @@ +use base64::{engine::general_purpose::STANDARD as BASE64, Engine}; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::future::Future; + +// --------------------------------------------------------------------------- +// Unified error +// --------------------------------------------------------------------------- + +#[derive(Debug, thiserror::Error)] +pub enum ContextError { + #[error("{0}")] + NotFound(String), + #[error("{0}")] + AlreadyExists(String), + #[error("{0}")] + InvalidRequest(String), + #[error("{0}")] + Internal(String), + #[error("Compaction already in progress")] + CompactionInProgress, +} + +pub type ContextResult = Result; + +// --------------------------------------------------------------------------- +// Unified trait +// --------------------------------------------------------------------------- + +pub trait ContextStoreApi { + fn add( + &mut self, + records: &[AddRecordRequest], + ) -> impl Future> + Send; + + fn get(&self, id: &str) -> impl Future>> + Send; + + fn list( + &self, + limit: Option, + offset: Option, + ) -> impl Future>> + Send; + + fn search( + &self, + query: &[f32], + limit: Option, + ) -> impl Future>> + Send; + + fn version(&self) -> u64; + + fn checkout(&mut self, version: u64) -> impl Future> + Send; + + fn compact( + &mut self, + options: Option, + ) -> impl Future> + Send; + + fn compaction_stats(&self) -> impl Future> + Send; +} + +// --------------------------------------------------------------------------- +// Context lifecycle +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize, Deserialize)] +pub struct CreateContextRequest { + pub name: String, + #[serde(default)] + pub storage_options: Option>, + #[serde(default)] + pub id_index_type: Option, + #[serde(default)] + pub blob_columns: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ContextInfo { + pub name: String, + pub uri: String, + pub version: u64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ListContextsResponse { + pub contexts: Vec, +} + +// --------------------------------------------------------------------------- +// Records +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StateMetadataDto { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub step: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub active_plan_id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tokens_used: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub custom: Option, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct AddRecordRequest { + #[serde(default = "default_role")] + pub role: String, + #[serde(default = "default_content_type")] + pub content_type: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub text_payload: Option, + #[serde( + default, + skip_serializing_if = "Option::is_none", + serialize_with = "serialize_base64_opt", + deserialize_with = "deserialize_base64_opt" + )] + pub binary_payload: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub embedding: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub bot_id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub session_id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub external_id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub state_metadata: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub expires_at: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub retention_policy: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub supersedes_id: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AddRecordsRequest { + pub records: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AddRecordsResponse { + pub version: u64, + pub ids: Vec, + pub count: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RecordDto { + pub id: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub external_id: Option, + pub run_id: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub bot_id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub session_id: Option, + pub created_at: DateTime, + pub role: String, + pub content_type: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub text_payload: Option, + #[serde( + default, + skip_serializing_if = "Option::is_none", + serialize_with = "serialize_base64_opt", + deserialize_with = "deserialize_base64_opt" + )] + pub binary_payload: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub embedding: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub state_metadata: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub expires_at: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub retention_policy: Option, + pub lifecycle_status: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub retired_at: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub retired_reason: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub supersedes_id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub superseded_by_id: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ListRecordsResponse { + pub records: Vec, +} + +// --------------------------------------------------------------------------- +// Single record lookup +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize, Deserialize)] +pub struct GetRecordResponse { + pub record: Option, +} + +// --------------------------------------------------------------------------- +// Search +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize, Deserialize)] +pub struct SearchRequest { + pub query: Vec, + #[serde(default = "default_search_limit")] + pub limit: usize, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SearchResultDto { + pub record: RecordDto, + pub distance: f32, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SearchResponse { + pub results: Vec, +} + +// --------------------------------------------------------------------------- +// Versioning +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize, Deserialize)] +pub struct VersionResponse { + pub version: u64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CheckoutRequest { + pub version: u64, +} + +// --------------------------------------------------------------------------- +// Compaction +// --------------------------------------------------------------------------- + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct CompactRequest { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub target_rows_per_fragment: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub materialize_deletions: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CompactResponse { + pub fragments_removed: usize, + pub fragments_added: usize, + pub files_removed: usize, + pub files_added: usize, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CompactStatsResponse { + pub total_fragments: usize, + pub is_compacting: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub last_compaction: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub last_error: Option, + pub total_compactions: u64, +} + +// --------------------------------------------------------------------------- +// Error +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize, Deserialize)] +pub struct ErrorBody { + pub code: String, + pub message: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ErrorResponse { + pub error: ErrorBody, +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn default_content_type() -> String { + "text/plain".to_string() +} + +fn default_role() -> String { + "user".to_string() +} + +fn default_search_limit() -> usize { + 10 +} + +fn serialize_base64_opt(data: &Option>, serializer: S) -> Result +where + S: serde::Serializer, +{ + match data { + Some(bytes) => serializer.serialize_some(&BASE64.encode(bytes)), + None => serializer.serialize_none(), + } +} + +fn deserialize_base64_opt<'de, D>(deserializer: D) -> Result>, D::Error> +where + D: serde::Deserializer<'de>, +{ + let opt: Option = Option::deserialize(deserializer)?; + match opt { + Some(s) => BASE64 + .decode(&s) + .map(Some) + .map_err(serde::de::Error::custom), + None => Ok(None), + } +} diff --git a/crates/lance-context-client/Cargo.toml b/crates/lance-context-client/Cargo.toml new file mode 100644 index 0000000..f23a827 --- /dev/null +++ b/crates/lance-context-client/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "lance-context-client" +version = "0.2.4" +edition = "2021" +license = "Apache-2.0" +authors = ["Lance Devs "] +repository = "https://github.com/lancedb/lance-context" +description = "Rust client for the lance-context REST API" +keywords = ["context", "lance", "client", "api"] + +[dependencies] +lance-context-api = { path = "../lance-context-api" } +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +thiserror = "2" diff --git a/crates/lance-context-client/src/error.rs b/crates/lance-context-client/src/error.rs new file mode 100644 index 0000000..0ea36b1 --- /dev/null +++ b/crates/lance-context-client/src/error.rs @@ -0,0 +1,12 @@ +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + #[error("HTTP error: {0}")] + Http(#[from] reqwest::Error), + + #[error("API error ({status}): [{code}] {message}")] + Api { + status: u16, + code: String, + message: String, + }, +} diff --git a/crates/lance-context-client/src/lib.rs b/crates/lance-context-client/src/lib.rs new file mode 100644 index 0000000..b4c57cb --- /dev/null +++ b/crates/lance-context-client/src/lib.rs @@ -0,0 +1,345 @@ +use lance_context_api::*; +use reqwest::Client; + +mod error; +pub use error::ClientError; + +pub struct ContextClient { + base_url: String, + http: Client, +} + +pub struct RemoteContextStore { + client: ContextClient, + context_name: String, + cached_version: u64, +} + +impl RemoteContextStore { + pub async fn connect(base_url: &str, context_name: &str) -> Result { + let client = ContextClient::new(base_url); + let info = client.get_context(context_name).await?; + Ok(Self { + client, + context_name: context_name.to_string(), + cached_version: info.version, + }) + } + + pub async fn connect_or_create( + base_url: &str, + req: &CreateContextRequest, + ) -> Result { + let client = ContextClient::new(base_url); + let info = match client.get_context(&req.name).await { + Ok(info) => info, + Err(ClientError::Api { status: 404, .. }) => client.create_context(req).await?, + Err(e) => return Err(e), + }; + Ok(Self { + client, + context_name: req.name.clone(), + cached_version: info.version, + }) + } +} + +impl ContextStoreApi for RemoteContextStore { + async fn add(&mut self, records: &[AddRecordRequest]) -> ContextResult { + let req = AddRecordsRequest { + records: records.to_vec(), + }; + let resp = self + .client + .add_records(&self.context_name, &req) + .await + .map_err(to_ctx_err)?; + self.cached_version = resp.version; + Ok(resp) + } + + async fn get(&self, id: &str) -> ContextResult> { + let resp = self + .client + .get_record(&self.context_name, id) + .await + .map_err(to_ctx_err)?; + Ok(resp.record) + } + + async fn list( + &self, + limit: Option, + offset: Option, + ) -> ContextResult> { + let resp = self + .client + .list_records(&self.context_name, limit, offset) + .await + .map_err(to_ctx_err)?; + Ok(resp.records) + } + + async fn search( + &self, + query: &[f32], + limit: Option, + ) -> ContextResult> { + let req = SearchRequest { + query: query.to_vec(), + limit: limit.unwrap_or(10), + }; + let resp = self + .client + .search(&self.context_name, &req) + .await + .map_err(to_ctx_err)?; + Ok(resp.results) + } + + fn version(&self) -> u64 { + self.cached_version + } + + async fn checkout(&mut self, version: u64) -> ContextResult<()> { + let req = CheckoutRequest { version }; + let resp = self + .client + .checkout(&self.context_name, &req) + .await + .map_err(to_ctx_err)?; + self.cached_version = resp.version; + Ok(()) + } + + async fn compact(&mut self, options: Option) -> ContextResult { + let req = options.unwrap_or_default(); + let resp = self + .client + .compact(&self.context_name, &req) + .await + .map_err(to_ctx_err)?; + Ok(resp) + } + + async fn compaction_stats(&self) -> ContextResult { + self.client + .compact_stats(&self.context_name) + .await + .map_err(to_ctx_err) + } +} + +fn to_ctx_err(err: ClientError) -> ContextError { + match err { + ClientError::Api { + status: 404, + message, + .. + } => ContextError::NotFound(message), + ClientError::Api { + status: 409, + code, + message, + } => { + if code == "COMPACTION_IN_PROGRESS" { + ContextError::CompactionInProgress + } else { + ContextError::AlreadyExists(message) + } + } + ClientError::Api { + status: 400, + message, + .. + } => ContextError::InvalidRequest(message), + ClientError::Api { message, .. } => ContextError::Internal(message), + ClientError::Http(e) => ContextError::Internal(e.to_string()), + } +} + +// --- Low-level client (still available for context lifecycle management) --- + +impl ContextClient { + pub fn new(base_url: &str) -> Self { + Self { + base_url: base_url.trim_end_matches('/').to_string(), + http: Client::new(), + } + } + + fn url(&self, path: &str) -> String { + format!("{}/api/v1{}", self.base_url, path) + } + + pub async fn create_context( + &self, + req: &CreateContextRequest, + ) -> Result { + let resp = self + .http + .post(self.url("/contexts")) + .json(req) + .send() + .await?; + Self::handle_response(resp).await + } + + pub async fn list_contexts(&self) -> Result { + let resp = self.http.get(self.url("/contexts")).send().await?; + Self::handle_response(resp).await + } + + pub async fn get_context(&self, name: &str) -> Result { + let resp = self + .http + .get(self.url(&format!("/contexts/{}", name))) + .send() + .await?; + Self::handle_response(resp).await + } + + pub async fn delete_context(&self, name: &str) -> Result<(), ClientError> { + let resp = self + .http + .delete(self.url(&format!("/contexts/{}", name))) + .send() + .await?; + if resp.status().is_success() { + Ok(()) + } else { + Err(Self::extract_error(resp).await) + } + } + + pub async fn add_records( + &self, + name: &str, + req: &AddRecordsRequest, + ) -> Result { + let resp = self + .http + .post(self.url(&format!("/contexts/{}/records", name))) + .json(req) + .send() + .await?; + Self::handle_response(resp).await + } + + pub async fn get_record(&self, name: &str, id: &str) -> Result { + let resp = self + .http + .get(self.url(&format!("/contexts/{}/records/{}", name, id))) + .send() + .await?; + Self::handle_response(resp).await + } + + pub async fn list_records( + &self, + name: &str, + limit: Option, + offset: Option, + ) -> Result { + let mut url = self.url(&format!("/contexts/{}/records", name)); + let mut params = Vec::new(); + if let Some(l) = limit { + params.push(format!("limit={}", l)); + } + if let Some(o) = offset { + params.push(format!("offset={}", o)); + } + if !params.is_empty() { + url = format!("{}?{}", url, params.join("&")); + } + + let resp = self.http.get(&url).send().await?; + Self::handle_response(resp).await + } + + pub async fn search( + &self, + name: &str, + req: &SearchRequest, + ) -> Result { + let resp = self + .http + .post(self.url(&format!("/contexts/{}/search", name))) + .json(req) + .send() + .await?; + Self::handle_response(resp).await + } + + pub async fn get_version(&self, name: &str) -> Result { + let resp = self + .http + .get(self.url(&format!("/contexts/{}/version", name))) + .send() + .await?; + Self::handle_response(resp).await + } + + pub async fn checkout( + &self, + name: &str, + req: &CheckoutRequest, + ) -> Result { + let resp = self + .http + .post(self.url(&format!("/contexts/{}/checkout", name))) + .json(req) + .send() + .await?; + Self::handle_response(resp).await + } + + pub async fn compact( + &self, + name: &str, + req: &CompactRequest, + ) -> Result { + let resp = self + .http + .post(self.url(&format!("/contexts/{}/compact", name))) + .json(req) + .send() + .await?; + Self::handle_response(resp).await + } + + pub async fn compact_stats(&self, name: &str) -> Result { + let resp = self + .http + .get(self.url(&format!("/contexts/{}/compact/stats", name))) + .send() + .await?; + Self::handle_response(resp).await + } + + async fn handle_response( + resp: reqwest::Response, + ) -> Result { + if resp.status().is_success() { + Ok(resp.json::().await?) + } else { + Err(Self::extract_error(resp).await) + } + } + + async fn extract_error(resp: reqwest::Response) -> ClientError { + let status = resp.status().as_u16(); + match resp.json::().await { + Ok(err_resp) => ClientError::Api { + status, + code: err_resp.error.code, + message: err_resp.error.message, + }, + Err(_) => ClientError::Api { + status, + code: "UNKNOWN".to_string(), + message: "Failed to parse error response".to_string(), + }, + } + } +} diff --git a/crates/lance-context-core/Cargo.toml b/crates/lance-context-core/Cargo.toml index f2ba763..e0de20a 100644 --- a/crates/lance-context-core/Cargo.toml +++ b/crates/lance-context-core/Cargo.toml @@ -16,6 +16,7 @@ arrow-ipc = "58" arrow-schema = "58" chrono = { version = "0.4", default-features = false, features = ["clock"] } lance = "7.0.0" +lance-context-api = { path = "../lance-context-api" } lance-index = "7.0.0" lance-namespace = "7.0.0" lancedb = "0.30.0" diff --git a/crates/lance-context-core/src/api_impl.rs b/crates/lance-context-core/src/api_impl.rs new file mode 100644 index 0000000..1e59f2f --- /dev/null +++ b/crates/lance-context-core/src/api_impl.rs @@ -0,0 +1,180 @@ +use chrono::Utc; +use uuid::Uuid; + +use lance_context_api::{ + AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse, + ContextError, ContextResult, ContextStoreApi, RecordDto, SearchResultDto, StateMetadataDto, +}; + +use crate::record::{ContextRecord, StateMetadata, LIFECYCLE_ACTIVE}; +use crate::store::{CompactionConfig, ContextStore}; + +impl ContextStoreApi for ContextStore { + async fn add(&mut self, records: &[AddRecordRequest]) -> ContextResult { + let run_id = Uuid::new_v4().to_string(); + let mut ids = Vec::with_capacity(records.len()); + let mut core_records = Vec::with_capacity(records.len()); + + for r in records { + let id = Uuid::new_v4().to_string(); + ids.push(id.clone()); + core_records.push(ContextRecord { + id, + external_id: r.external_id.clone(), + run_id: run_id.clone(), + bot_id: r.bot_id.clone(), + session_id: r.session_id.clone(), + created_at: Utc::now(), + role: r.role.clone(), + state_metadata: r.state_metadata.as_ref().map(|sm| StateMetadata { + step: sm.step, + active_plan_id: sm.active_plan_id.clone(), + tokens_used: sm.tokens_used, + custom: sm.custom.clone(), + }), + metadata: r.metadata.clone(), + expires_at: r.expires_at, + retention_policy: r.retention_policy.clone(), + lifecycle_status: LIFECYCLE_ACTIVE.to_string(), + retired_at: None, + retired_reason: None, + supersedes_id: r.supersedes_id.clone(), + superseded_by_id: None, + content_type: r.content_type.clone(), + text_payload: r.text_payload.clone(), + binary_payload: r.binary_payload.clone(), + embedding: r.embedding.clone(), + }); + } + + let count = core_records.len(); + let version = self.add(&core_records).await.map_err(to_ctx_err)?; + Ok(AddRecordsResponse { + version, + ids, + count, + }) + } + + async fn get(&self, id: &str) -> ContextResult> { + let record = ContextStore::get(self, id).await.map_err(to_ctx_err)?; + Ok(record.map(record_to_dto)) + } + + async fn list( + &self, + limit: Option, + offset: Option, + ) -> ContextResult> { + let records = ContextStore::list(self, limit, offset) + .await + .map_err(to_ctx_err)?; + Ok(records.into_iter().map(record_to_dto).collect()) + } + + async fn search( + &self, + query: &[f32], + limit: Option, + ) -> ContextResult> { + let results = ContextStore::search(self, query, limit) + .await + .map_err(to_ctx_err)?; + Ok(results + .into_iter() + .map(|sr| SearchResultDto { + record: record_to_dto(sr.record), + distance: sr.distance, + }) + .collect()) + } + + fn version(&self) -> u64 { + ContextStore::version(self) + } + + async fn checkout(&mut self, version: u64) -> ContextResult<()> { + ContextStore::checkout(self, version) + .await + .map_err(to_ctx_err) + } + + async fn compact(&mut self, options: Option) -> ContextResult { + let config = options.map(|req| { + let mut c = CompactionConfig::default(); + if let Some(v) = req.target_rows_per_fragment { + c.target_rows_per_fragment = v; + } + if let Some(v) = req.materialize_deletions { + c.materialize_deletions = v; + } + c + }); + + let metrics = ContextStore::compact(self, config) + .await + .map_err(to_ctx_err)?; + Ok(CompactResponse { + fragments_removed: metrics.fragments_removed, + fragments_added: metrics.fragments_added, + files_removed: metrics.files_removed, + files_added: metrics.files_added, + }) + } + + async fn compaction_stats(&self) -> ContextResult { + let stats = ContextStore::compaction_stats(self) + .await + .map_err(to_ctx_err)?; + Ok(CompactStatsResponse { + total_fragments: stats.total_fragments, + is_compacting: stats.is_compacting, + last_compaction: stats.last_compaction, + last_error: stats.last_error, + total_compactions: stats.total_compactions, + }) + } +} + +fn record_to_dto(r: ContextRecord) -> RecordDto { + RecordDto { + id: r.id, + external_id: r.external_id, + run_id: r.run_id, + bot_id: r.bot_id, + session_id: r.session_id, + created_at: r.created_at, + role: r.role, + content_type: r.content_type, + text_payload: r.text_payload, + binary_payload: r.binary_payload, + embedding: r.embedding, + state_metadata: r.state_metadata.map(|sm| StateMetadataDto { + step: sm.step, + active_plan_id: sm.active_plan_id, + tokens_used: sm.tokens_used, + custom: sm.custom, + }), + metadata: r.metadata, + expires_at: r.expires_at, + retention_policy: r.retention_policy, + lifecycle_status: r.lifecycle_status, + retired_at: r.retired_at, + retired_reason: r.retired_reason, + supersedes_id: r.supersedes_id, + superseded_by_id: r.superseded_by_id, + } +} + +fn to_ctx_err(err: lance::Error) -> ContextError { + let msg = err.to_string(); + if msg.contains("already in progress") { + ContextError::CompactionInProgress + } else if msg.contains("not found") || msg.contains("DatasetNotFound") { + ContextError::NotFound(msg) + } else if msg.contains("Invalid") { + ContextError::InvalidRequest(msg) + } else { + ContextError::Internal(msg) + } +} diff --git a/crates/lance-context-core/src/lib.rs b/crates/lance-context-core/src/lib.rs index cf3b98f..6f9d628 100644 --- a/crates/lance-context-core/src/lib.rs +++ b/crates/lance-context-core/src/lib.rs @@ -1,6 +1,7 @@ //! Core types for the lance-context storage layer. #![recursion_limit = "256"] +mod api_impl; mod context; mod record; pub mod serde; diff --git a/crates/lance-context-core/src/store.rs b/crates/lance-context-core/src/store.rs index 9d7b973..88c1630 100644 --- a/crates/lance-context-core/src/store.rs +++ b/crates/lance-context-core/src/store.rs @@ -391,6 +391,21 @@ impl ContextStore { Ok(()) } + /// Retrieve a single record by its unique ID. + pub async fn get(&self, id: &str) -> LanceResult> { + let escaped_id = id.replace('\'', "''"); + let mut scanner = self.dataset.scan(); + scanner.filter(&format!("id = '{}'", escaped_id))?; + scanner.limit(Some(1), None)?; + + let mut stream = scanner.try_into_stream().await?; + if let Some(batch) = stream.try_next().await? { + let records = batch_to_records(&batch)?; + return Ok(records.into_iter().next()); + } + Ok(None) + } + /// List all records in the dataset. pub async fn list( &self, diff --git a/crates/lance-context-server/Cargo.toml b/crates/lance-context-server/Cargo.toml new file mode 100644 index 0000000..d54c7f5 --- /dev/null +++ b/crates/lance-context-server/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "lance-context-server" +version = "0.2.4" +edition = "2021" +license = "Apache-2.0" +authors = ["Lance Devs "] +repository = "https://github.com/lancedb/lance-context" +description = "REST API server for lance-context" +keywords = ["context", "lance", "server", "api"] + +[[bin]] +name = "lance-context-server" +path = "src/main.rs" + +[dependencies] +lance-context-core = { path = "../lance-context-core" } +lance-context-api = { path = "../lance-context-api" } +axum = { version = "0.8", features = ["json"] } +chrono = { version = "0.4", default-features = false, features = ["clock"] } +clap = { version = "4", features = ["derive"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +tokio = { version = "1", features = ["rt-multi-thread", "macros", "signal"] } +tower-http = { version = "0.6", features = ["cors", "trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +uuid = { version = "1", features = ["v4"] } diff --git a/crates/lance-context-server/src/config.rs b/crates/lance-context-server/src/config.rs new file mode 100644 index 0000000..e645a5d --- /dev/null +++ b/crates/lance-context-server/src/config.rs @@ -0,0 +1,15 @@ +use clap::Parser; + +#[derive(Debug, Clone, Parser)] +#[command(name = "lance-context-server")] +#[command(about = "REST API server for lance-context")] +pub struct ServerConfig { + #[arg(long, default_value = "0.0.0.0")] + pub host: String, + + #[arg(long, default_value = "3000")] + pub port: u16, + + #[arg(long, default_value = "./lance-data")] + pub data_dir: String, +} diff --git a/crates/lance-context-server/src/error.rs b/crates/lance-context-server/src/error.rs new file mode 100644 index 0000000..eac91b0 --- /dev/null +++ b/crates/lance-context-server/src/error.rs @@ -0,0 +1,53 @@ +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::Json; +use lance_context_api::{ErrorBody, ErrorResponse}; + +#[derive(Debug)] +pub enum AppError { + NotFound(String), + AlreadyExists(String), + InvalidRequest(String), + Internal(String), + CompactionInProgress, +} + +impl AppError { + pub fn from_lance(err: impl std::fmt::Display) -> Self { + let msg = err.to_string(); + if msg.contains("already in progress") { + AppError::CompactionInProgress + } else if msg.contains("not found") || msg.contains("DatasetNotFound") { + AppError::NotFound(msg) + } else if msg.contains("Invalid") { + AppError::InvalidRequest(msg) + } else { + AppError::Internal(msg) + } + } +} + +impl IntoResponse for AppError { + fn into_response(self) -> Response { + let (status, code, message) = match self { + AppError::NotFound(msg) => (StatusCode::NOT_FOUND, "NOT_FOUND", msg), + AppError::AlreadyExists(msg) => (StatusCode::CONFLICT, "ALREADY_EXISTS", msg), + AppError::InvalidRequest(msg) => (StatusCode::BAD_REQUEST, "INVALID_REQUEST", msg), + AppError::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, "INTERNAL", msg), + AppError::CompactionInProgress => ( + StatusCode::CONFLICT, + "COMPACTION_IN_PROGRESS", + "Compaction already in progress".to_string(), + ), + }; + + let body = ErrorResponse { + error: ErrorBody { + code: code.to_string(), + message, + }, + }; + + (status, Json(body)).into_response() + } +} diff --git a/crates/lance-context-server/src/main.rs b/crates/lance-context-server/src/main.rs new file mode 100644 index 0000000..286a26e --- /dev/null +++ b/crates/lance-context-server/src/main.rs @@ -0,0 +1,56 @@ +mod config; +mod error; +mod routes; +mod state; + +use std::sync::Arc; + +use clap::Parser; +use tokio::net::TcpListener; +use tower_http::cors::CorsLayer; +use tower_http::trace::TraceLayer; +use tracing_subscriber::EnvFilter; + +use crate::config::ServerConfig; +use crate::state::AppState; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env().add_directive("info".parse().unwrap())) + .init(); + + let config = ServerConfig::parse(); + let addr = format!("{}:{}", config.host, config.port); + + if let Err(e) = std::fs::create_dir_all(&config.data_dir) { + tracing::error!( + "Failed to create data directory '{}': {}", + config.data_dir, + e + ); + std::process::exit(1); + } + + let state = Arc::new(AppState::new(config)); + + let app = routes::router() + .with_state(state) + .layer(TraceLayer::new_for_http()) + .layer(CorsLayer::permissive()); + + tracing::info!("Starting lance-context-server on {}", addr); + + let listener = TcpListener::bind(&addr).await.unwrap(); + axum::serve(listener, app) + .with_graceful_shutdown(shutdown_signal()) + .await + .unwrap(); +} + +async fn shutdown_signal() { + tokio::signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + tracing::info!("Shutting down"); +} diff --git a/crates/lance-context-server/src/routes/compact.rs b/crates/lance-context-server/src/routes/compact.rs new file mode 100644 index 0000000..cfd7588 --- /dev/null +++ b/crates/lance-context-server/src/routes/compact.rs @@ -0,0 +1,71 @@ +use std::sync::Arc; + +use axum::extract::{Path, State}; +use axum::Json; +use lance_context_api::{CompactRequest, CompactResponse, CompactStatsResponse}; +use lance_context_core::CompactionConfig; + +use crate::error::AppError; +use crate::state::AppState; + +pub async fn compact( + State(state): State>, + Path(name): Path, + Json(req): Json, +) -> Result, AppError> { + let stores = state.stores.read().await; + let store_lock = stores + .get(&name) + .ok_or_else(|| AppError::NotFound(format!("Context '{}' does not exist", name)))? + .clone(); + drop(stores); + + let config = if req.target_rows_per_fragment.is_some() || req.materialize_deletions.is_some() { + let mut c = CompactionConfig::default(); + if let Some(v) = req.target_rows_per_fragment { + c.target_rows_per_fragment = v; + } + if let Some(v) = req.materialize_deletions { + c.materialize_deletions = v; + } + Some(c) + } else { + None + }; + + let mut store = store_lock.write().await; + let metrics = store.compact(config).await.map_err(AppError::from_lance)?; + + Ok(Json(CompactResponse { + fragments_removed: metrics.fragments_removed, + fragments_added: metrics.fragments_added, + files_removed: metrics.files_removed, + files_added: metrics.files_added, + })) +} + +pub async fn compact_stats( + State(state): State>, + Path(name): Path, +) -> Result, AppError> { + let stores = state.stores.read().await; + let store_lock = stores + .get(&name) + .ok_or_else(|| AppError::NotFound(format!("Context '{}' does not exist", name)))? + .clone(); + drop(stores); + + let store = store_lock.read().await; + let stats = store + .compaction_stats() + .await + .map_err(AppError::from_lance)?; + + Ok(Json(CompactStatsResponse { + total_fragments: stats.total_fragments, + is_compacting: stats.is_compacting, + last_compaction: stats.last_compaction, + last_error: stats.last_error, + total_compactions: stats.total_compactions, + })) +} diff --git a/crates/lance-context-server/src/routes/contexts.rs b/crates/lance-context-server/src/routes/contexts.rs new file mode 100644 index 0000000..bc42fe1 --- /dev/null +++ b/crates/lance-context-server/src/routes/contexts.rs @@ -0,0 +1,118 @@ +use std::collections::HashSet; +use std::sync::Arc; + +use axum::extract::{Path, State}; +use axum::Json; +use lance_context_api::{ContextInfo, CreateContextRequest, ListContextsResponse}; +use lance_context_core::{ContextStore, ContextStoreOptions, IdIndexType}; +use tokio::sync::RwLock; + +use crate::error::AppError; +use crate::state::AppState; + +pub async fn create_context( + State(state): State>, + Json(req): Json, +) -> Result<(axum::http::StatusCode, Json), AppError> { + let stores = state.stores.read().await; + if stores.contains_key(&req.name) { + return Err(AppError::AlreadyExists(format!( + "Context '{}' already exists", + req.name + ))); + } + drop(stores); + + let id_index_type = match req.id_index_type.as_deref() { + Some("btree") => IdIndexType::BTree, + Some("zonemap") => IdIndexType::ZoneMap, + Some("none") | None => IdIndexType::None, + Some(other) => { + return Err(AppError::InvalidRequest(format!( + "Invalid id_index_type: '{}'. Must be 'none', 'zonemap', or 'btree'", + other + ))); + } + }; + + let blob_columns: HashSet = req.blob_columns.unwrap_or_default().into_iter().collect(); + + let uri = state.context_uri(&req.name); + let options = ContextStoreOptions { + storage_options: req.storage_options, + blob_columns, + id_index_type, + ..Default::default() + }; + + let store = ContextStore::open_with_options(&uri, options) + .await + .map_err(AppError::from_lance)?; + + let version = store.version(); + + let mut stores = state.stores.write().await; + stores.insert(req.name.clone(), Arc::new(RwLock::new(store))); + + Ok(( + axum::http::StatusCode::CREATED, + Json(ContextInfo { + name: req.name, + uri, + version, + }), + )) +} + +pub async fn list_contexts(State(state): State>) -> Json { + let stores = state.stores.read().await; + let mut contexts = Vec::with_capacity(stores.len()); + + for (name, store_lock) in stores.iter() { + let store = store_lock.read().await; + contexts.push(ContextInfo { + name: name.clone(), + uri: state.context_uri(name), + version: store.version(), + }); + } + + Json(ListContextsResponse { contexts }) +} + +pub async fn get_context( + State(state): State>, + Path(name): Path, +) -> Result, AppError> { + let stores = state.stores.read().await; + let store_lock = stores + .get(&name) + .ok_or_else(|| AppError::NotFound(format!("Context '{}' does not exist", name)))?; + let store = store_lock.read().await; + + Ok(Json(ContextInfo { + name: name.clone(), + uri: state.context_uri(&name), + version: store.version(), + })) +} + +pub async fn delete_context( + State(state): State>, + Path(name): Path, +) -> Result { + let mut stores = state.stores.write().await; + if stores.remove(&name).is_none() { + return Err(AppError::NotFound(format!( + "Context '{}' does not exist", + name + ))); + } + + let uri = state.context_uri(&name); + if let Err(e) = tokio::fs::remove_dir_all(&uri).await { + tracing::warn!("Failed to remove context data at {}: {}", uri, e); + } + + Ok(axum::http::StatusCode::NO_CONTENT) +} diff --git a/crates/lance-context-server/src/routes/health.rs b/crates/lance-context-server/src/routes/health.rs new file mode 100644 index 0000000..c3c0281 --- /dev/null +++ b/crates/lance-context-server/src/routes/health.rs @@ -0,0 +1,10 @@ +use axum::Json; + +#[derive(serde::Serialize)] +pub struct HealthResponse { + pub status: &'static str, +} + +pub async fn health_check() -> Json { + Json(HealthResponse { status: "ok" }) +} diff --git a/crates/lance-context-server/src/routes/mod.rs b/crates/lance-context-server/src/routes/mod.rs new file mode 100644 index 0000000..2337ac8 --- /dev/null +++ b/crates/lance-context-server/src/routes/mod.rs @@ -0,0 +1,45 @@ +pub mod compact; +pub mod contexts; +pub mod health; +pub mod records; +pub mod search; +pub mod versions; + +use std::sync::Arc; + +use axum::routing::{delete, get, post}; +use axum::Router; + +use crate::state::AppState; + +pub fn router() -> Router> { + Router::new() + .route("/api/v1/health", get(health::health_check)) + .route("/api/v1/contexts", post(contexts::create_context)) + .route("/api/v1/contexts", get(contexts::list_contexts)) + .route("/api/v1/contexts/{name}", get(contexts::get_context)) + .route("/api/v1/contexts/{name}", delete(contexts::delete_context)) + .route( + "/api/v1/contexts/{name}/records", + post(records::add_records), + ) + .route( + "/api/v1/contexts/{name}/records", + get(records::list_records), + ) + .route( + "/api/v1/contexts/{name}/records/{id}", + get(records::get_record), + ) + .route("/api/v1/contexts/{name}/search", post(search::search)) + .route( + "/api/v1/contexts/{name}/version", + get(versions::get_version), + ) + .route("/api/v1/contexts/{name}/checkout", post(versions::checkout)) + .route("/api/v1/contexts/{name}/compact", post(compact::compact)) + .route( + "/api/v1/contexts/{name}/compact/stats", + get(compact::compact_stats), + ) +} diff --git a/crates/lance-context-server/src/routes/records.rs b/crates/lance-context-server/src/routes/records.rs new file mode 100644 index 0000000..f565883 --- /dev/null +++ b/crates/lance-context-server/src/routes/records.rs @@ -0,0 +1,163 @@ +use std::sync::Arc; + +use axum::extract::{Path, Query, State}; +use axum::Json; +use chrono::Utc; +use lance_context_api::{ + AddRecordsRequest, AddRecordsResponse, GetRecordResponse, ListRecordsResponse, RecordDto, + StateMetadataDto, +}; +use lance_context_core::{ContextRecord, StateMetadata, LIFECYCLE_ACTIVE}; +use uuid::Uuid; + +use crate::error::AppError; +use crate::state::AppState; + +pub async fn add_records( + State(state): State>, + Path(name): Path, + Json(req): Json, +) -> Result<(axum::http::StatusCode, Json), AppError> { + if req.records.is_empty() { + return Err(AppError::InvalidRequest( + "records array must not be empty".to_string(), + )); + } + + let stores = state.stores.read().await; + let store_lock = stores + .get(&name) + .ok_or_else(|| AppError::NotFound(format!("Context '{}' does not exist", name)))? + .clone(); + drop(stores); + + let run_id = Uuid::new_v4().to_string(); + let mut ids = Vec::with_capacity(req.records.len()); + let mut core_records = Vec::with_capacity(req.records.len()); + + for r in &req.records { + let id = Uuid::new_v4().to_string(); + ids.push(id.clone()); + core_records.push(ContextRecord { + id, + external_id: r.external_id.clone(), + run_id: run_id.clone(), + bot_id: r.bot_id.clone(), + session_id: r.session_id.clone(), + created_at: Utc::now(), + role: r.role.clone(), + state_metadata: r.state_metadata.as_ref().map(|sm| StateMetadata { + step: sm.step, + active_plan_id: sm.active_plan_id.clone(), + tokens_used: sm.tokens_used, + custom: sm.custom.clone(), + }), + metadata: r.metadata.clone(), + expires_at: r.expires_at, + retention_policy: r.retention_policy.clone(), + lifecycle_status: LIFECYCLE_ACTIVE.to_string(), + retired_at: None, + retired_reason: None, + supersedes_id: r.supersedes_id.clone(), + superseded_by_id: None, + content_type: r.content_type.clone(), + text_payload: r.text_payload.clone(), + binary_payload: r.binary_payload.clone(), + embedding: r.embedding.clone(), + }); + } + + let count = core_records.len(); + let mut store = store_lock.write().await; + let version = store + .add(&core_records) + .await + .map_err(AppError::from_lance)?; + + Ok(( + axum::http::StatusCode::CREATED, + Json(AddRecordsResponse { + version, + ids, + count, + }), + )) +} + +pub async fn get_record( + State(state): State>, + Path((name, id)): Path<(String, String)>, +) -> Result, AppError> { + let stores = state.stores.read().await; + let store_lock = stores + .get(&name) + .ok_or_else(|| AppError::NotFound(format!("Context '{}' does not exist", name)))? + .clone(); + drop(stores); + + let store = store_lock.read().await; + let record = store.get(&id).await.map_err(AppError::from_lance)?; + + Ok(Json(GetRecordResponse { + record: record.map(record_to_dto), + })) +} + +#[derive(serde::Deserialize)] +pub struct ListParams { + pub limit: Option, + pub offset: Option, +} + +pub async fn list_records( + State(state): State>, + Path(name): Path, + Query(params): Query, +) -> Result, AppError> { + let stores = state.stores.read().await; + let store_lock = stores + .get(&name) + .ok_or_else(|| AppError::NotFound(format!("Context '{}' does not exist", name)))? + .clone(); + drop(stores); + + let store = store_lock.read().await; + let records = store + .list(params.limit, params.offset) + .await + .map_err(AppError::from_lance)?; + + let dtos: Vec = records.into_iter().map(record_to_dto).collect(); + + Ok(Json(ListRecordsResponse { records: dtos })) +} + +pub fn record_to_dto(r: ContextRecord) -> RecordDto { + RecordDto { + id: r.id, + external_id: r.external_id, + run_id: r.run_id, + bot_id: r.bot_id, + session_id: r.session_id, + created_at: r.created_at, + role: r.role, + content_type: r.content_type, + text_payload: r.text_payload, + binary_payload: r.binary_payload, + embedding: r.embedding, + state_metadata: r.state_metadata.map(|sm| StateMetadataDto { + step: sm.step, + active_plan_id: sm.active_plan_id, + tokens_used: sm.tokens_used, + custom: sm.custom, + }), + metadata: r.metadata, + expires_at: r.expires_at, + retention_policy: r.retention_policy, + lifecycle_status: r.lifecycle_status, + retired_at: r.retired_at, + retired_reason: r.retired_reason, + supersedes_id: r.supersedes_id, + superseded_by_id: r.superseded_by_id, + } +} diff --git a/crates/lance-context-server/src/routes/search.rs b/crates/lance-context-server/src/routes/search.rs new file mode 100644 index 0000000..455f2b4 --- /dev/null +++ b/crates/lance-context-server/src/routes/search.rs @@ -0,0 +1,38 @@ +use std::sync::Arc; + +use axum::extract::{Path, State}; +use axum::Json; +use lance_context_api::{SearchRequest, SearchResponse, SearchResultDto}; + +use crate::error::AppError; +use crate::routes::records::record_to_dto; +use crate::state::AppState; + +pub async fn search( + State(state): State>, + Path(name): Path, + Json(req): Json, +) -> Result, AppError> { + let stores = state.stores.read().await; + let store_lock = stores + .get(&name) + .ok_or_else(|| AppError::NotFound(format!("Context '{}' does not exist", name)))? + .clone(); + drop(stores); + + let store = store_lock.read().await; + let results = store + .search(&req.query, Some(req.limit)) + .await + .map_err(AppError::from_lance)?; + + let dtos: Vec = results + .into_iter() + .map(|sr| SearchResultDto { + record: record_to_dto(sr.record), + distance: sr.distance, + }) + .collect(); + + Ok(Json(SearchResponse { results: dtos })) +} diff --git a/crates/lance-context-server/src/routes/versions.rs b/crates/lance-context-server/src/routes/versions.rs new file mode 100644 index 0000000..76af8aa --- /dev/null +++ b/crates/lance-context-server/src/routes/versions.rs @@ -0,0 +1,48 @@ +use std::sync::Arc; + +use axum::extract::{Path, State}; +use axum::Json; +use lance_context_api::{CheckoutRequest, VersionResponse}; + +use crate::error::AppError; +use crate::state::AppState; + +pub async fn get_version( + State(state): State>, + Path(name): Path, +) -> Result, AppError> { + let stores = state.stores.read().await; + let store_lock = stores + .get(&name) + .ok_or_else(|| AppError::NotFound(format!("Context '{}' does not exist", name)))? + .clone(); + drop(stores); + + let store = store_lock.read().await; + Ok(Json(VersionResponse { + version: store.version(), + })) +} + +pub async fn checkout( + State(state): State>, + Path(name): Path, + Json(req): Json, +) -> Result, AppError> { + let stores = state.stores.read().await; + let store_lock = stores + .get(&name) + .ok_or_else(|| AppError::NotFound(format!("Context '{}' does not exist", name)))? + .clone(); + drop(stores); + + let mut store = store_lock.write().await; + store + .checkout(req.version) + .await + .map_err(AppError::from_lance)?; + + Ok(Json(VersionResponse { + version: store.version(), + })) +} diff --git a/crates/lance-context-server/src/state.rs b/crates/lance-context-server/src/state.rs new file mode 100644 index 0000000..853fd3e --- /dev/null +++ b/crates/lance-context-server/src/state.rs @@ -0,0 +1,29 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; + +use lance_context_core::ContextStore; +use tokio::sync::RwLock; + +use crate::config::ServerConfig; + +pub struct AppState { + pub stores: RwLock>>>, + pub base_path: PathBuf, +} + +impl AppState { + pub fn new(config: ServerConfig) -> Self { + Self { + stores: RwLock::new(HashMap::new()), + base_path: PathBuf::from(&config.data_dir), + } + } + + pub fn context_uri(&self, name: &str) -> String { + self.base_path + .join(format!("{}.lance", name)) + .to_string_lossy() + .to_string() + } +} diff --git a/crates/lance-context/Cargo.toml b/crates/lance-context/Cargo.toml index 612cc65..c6be7b5 100644 --- a/crates/lance-context/Cargo.toml +++ b/crates/lance-context/Cargo.toml @@ -5,8 +5,14 @@ edition = "2021" license = "Apache-2.0" authors = ["Lance Devs "] repository = "https://github.com/lancedb/lance-context" -description = "Public re-export crate for lance-context bindings" +description = "Multimodal, versioned context storage for agentic workflows" readme = "README.md" +[features] +default = [] +remote = ["lance-context-client"] + [dependencies] lance-context-core = { version = "0.3.0", path = "../lance-context-core" } +lance-context-api = { version = "0.2.4", path = "../lance-context-api" } +lance-context-client = { version = "0.2.4", path = "../lance-context-client", optional = true } diff --git a/crates/lance-context/src/lib.rs b/crates/lance-context/src/lib.rs index 1264030..87eef8b 100644 --- a/crates/lance-context/src/lib.rs +++ b/crates/lance-context/src/lib.rs @@ -1 +1,20 @@ -pub use lance_context_core::*; +#![recursion_limit = "256"] + +// Explicit re-exports from core (no glob to avoid recursion depth overflow) +pub use lance_context_core::serde; +pub use lance_context_core::{ + CompactionConfig, CompactionMetrics, CompactionStats, Context, ContextEntry, ContextRecord, + ContextStoreOptions, IdIndexType, LifecycleQueryOptions, MetadataFilter, RecordFilters, + SearchResult, Snapshot, StateMetadata, LIFECYCLE_ACTIVE, LIFECYCLE_CONTRADICTED, +}; + +pub use lance_context_api::{ + AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse, + ContextError, ContextResult, ContextStoreApi, RecordDto, SearchResultDto, +}; + +#[cfg(feature = "remote")] +pub use lance_context_client::{ClientError, RemoteContextStore}; + +mod unified; +pub use unified::ContextStore; diff --git a/crates/lance-context/src/unified.rs b/crates/lance-context/src/unified.rs new file mode 100644 index 0000000..0680d2c --- /dev/null +++ b/crates/lance-context/src/unified.rs @@ -0,0 +1,147 @@ +use std::collections::HashSet; + +use lance_context_api::{ + AddRecordRequest, AddRecordsResponse, CompactRequest, CompactResponse, CompactStatsResponse, + ContextError, ContextResult, ContextStoreApi, RecordDto, SearchResultDto, +}; +use lance_context_core::{ContextStore as LocalStore, ContextStoreOptions, IdIndexType}; + +#[cfg(feature = "remote")] +use lance_context_client::RemoteContextStore; + +pub enum ContextStore { + Local(Box), + #[cfg(feature = "remote")] + Remote(RemoteContextStore), +} + +impl ContextStore { + pub async fn open(uri: &str) -> Result { + let store = LocalStore::open(uri) + .await + .map_err(|e| ContextError::Internal(e.to_string()))?; + Ok(Self::Local(Box::new(store))) + } + + pub async fn open_with_options( + uri: &str, + storage_options: Option>, + id_index_type: Option<&str>, + blob_columns: Option>, + ) -> Result { + let id_idx = match id_index_type { + Some("btree") => IdIndexType::BTree, + Some("zonemap") => IdIndexType::ZoneMap, + Some("none") | None => IdIndexType::None, + Some(other) => { + return Err(ContextError::InvalidRequest(format!( + "Invalid id_index_type: '{other}'" + ))); + } + }; + let options = ContextStoreOptions { + storage_options, + blob_columns: blob_columns + .unwrap_or_default() + .into_iter() + .collect::>(), + id_index_type: id_idx, + ..Default::default() + }; + let store = LocalStore::open_with_options(uri, options) + .await + .map_err(|e| ContextError::Internal(e.to_string()))?; + Ok(Self::Local(Box::new(store))) + } + + #[cfg(feature = "remote")] + pub async fn connect(base_url: &str, context_name: &str) -> Result { + let store = RemoteContextStore::connect(base_url, context_name) + .await + .map_err(|e| ContextError::Internal(e.to_string()))?; + Ok(Self::Remote(store)) + } + + #[cfg(feature = "remote")] + pub async fn connect_or_create( + base_url: &str, + req: &lance_context_api::CreateContextRequest, + ) -> Result { + let store = RemoteContextStore::connect_or_create(base_url, req) + .await + .map_err(|e| ContextError::Internal(e.to_string()))?; + Ok(Self::Remote(store)) + } +} + +macro_rules! dispatch_mut { + ($self:expr, $method:ident $(, $arg:expr)*) => { + match $self { + ContextStore::Local(s) => ContextStoreApi::$method(s.as_mut() $(, $arg)*).await, + #[cfg(feature = "remote")] + ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*).await, + } + }; +} + +macro_rules! dispatch_ref { + ($self:expr, $method:ident $(, $arg:expr)*) => { + match $self { + ContextStore::Local(s) => ContextStoreApi::$method(s.as_ref() $(, $arg)*).await, + #[cfg(feature = "remote")] + ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*).await, + } + }; +} + +macro_rules! dispatch_sync { + ($self:expr, $method:ident $(, $arg:expr)*) => { + match $self { + ContextStore::Local(s) => ContextStoreApi::$method(s.as_ref() $(, $arg)*), + #[cfg(feature = "remote")] + ContextStore::Remote(s) => ContextStoreApi::$method(s $(, $arg)*), + } + }; +} + +impl ContextStoreApi for ContextStore { + async fn add(&mut self, records: &[AddRecordRequest]) -> ContextResult { + dispatch_mut!(self, add, records) + } + + async fn get(&self, id: &str) -> ContextResult> { + dispatch_ref!(self, get, id) + } + + async fn list( + &self, + limit: Option, + offset: Option, + ) -> ContextResult> { + dispatch_ref!(self, list, limit, offset) + } + + async fn search( + &self, + query: &[f32], + limit: Option, + ) -> ContextResult> { + dispatch_ref!(self, search, query, limit) + } + + fn version(&self) -> u64 { + dispatch_sync!(self, version) + } + + async fn checkout(&mut self, version: u64) -> ContextResult<()> { + dispatch_mut!(self, checkout, version) + } + + async fn compact(&mut self, options: Option) -> ContextResult { + dispatch_mut!(self, compact, options) + } + + async fn compaction_stats(&self) -> ContextResult { + dispatch_ref!(self, compaction_stats) + } +} diff --git a/python/Cargo.toml b/python/Cargo.toml index 1115b66..2c21006 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -11,7 +11,7 @@ crate-type = ["cdylib"] [dependencies] chrono = { version = "0.4", default-features = false, features = ["clock"] } -lance-context = { path = "../crates/lance-context" } +lance-context-core = { path = "../crates/lance-context-core" } pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39", "py-clone"] } serde_json = "1" tokio = { version = "1", features = ["rt-multi-thread"] } diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index 069675c..ce32883 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -744,6 +744,15 @@ async def search( ), ) + async def get( + self, *, id: str | None = None, external_id: str | None = None + ) -> dict[str, Any] | None: + """Asynchronously retrieve a single context record by id or external_id.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, lambda: self._sync.get(id=id, external_id=external_id) + ) + async def list( self, limit: int | None = None, diff --git a/python/src/lib.rs b/python/src/lib.rs index ebdf5cb..71295f9 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -11,8 +11,8 @@ use pyo3::IntoPyObject; use serde_json::Value; use tokio::runtime::Runtime; -use lance_context::serde::CONTENT_TYPE_TEXT; -use lance_context::{ +use lance_context_core::serde::CONTENT_TYPE_TEXT; +use lance_context_core::{ CompactionConfig, CompactionMetrics, CompactionStats, Context as RustContext, ContextRecord, ContextStore, ContextStoreOptions, IdIndexType, LifecycleQueryOptions, MetadataFilter, RecordFilters, SearchResult, LIFECYCLE_ACTIVE,