From eb2d94a87f07d2d5e73899ae6d725708c995911e Mon Sep 17 00:00:00 2001 From: Will Jones Date: Wed, 17 Jun 2026 17:00:02 -0700 Subject: [PATCH] feat(scalar): implement CacheCodec for FM index The FM index ignored the session cache entirely: every open re-parsed each partition's metadata and the only memoization was the in-memory `Arc`, which a serializable backend cannot persist. This adds two cache entries that round-trip through the stable cache codec, so an opened FM index survives in a node-agnostic, restart-surviving backend: - `FMIndexPartitionState`: one per partition; the skeleton needed to rebuild a `LazyFMIndex` without re-reading metadata (huffman codes, tree topology, c_table, row ids, doc starts, sampled SA, per-node prefix ranks). - `WaveletNodeWords`: one per wavelet node; the node's bitvector words, loaded lazily and shared so neighbouring blocks come from a single read. `load_partition` now fetches the skeleton via `get_or_insert_with_key`, and `LazyRankBitVec` loads whole-node words through the cache. No on-disk format change. The in-memory `Arc` memoization remains the fast in-session layer. Closes #7277 --- rust/lance-index/protos-cache/cache.proto | 24 + rust/lance-index/src/scalar/fmindex.rs | 914 ++++++++++++++++------ 2 files changed, 686 insertions(+), 252 deletions(-) diff --git a/rust/lance-index/protos-cache/cache.proto b/rust/lance-index/protos-cache/cache.proto index b24a27055d7..63237f25bf0 100644 --- a/rust/lance-index/protos-cache/cache.proto +++ b/rust/lance-index/protos-cache/cache.proto @@ -106,6 +106,30 @@ message RangeToFile { string path = 4; } +// Header for a serialized FM-index partition skeleton (`FMIndexPartitionState`). +// +// Followed, in order, by raw blobs: the C table, Huffman codes, wavelet tree +// topology, document row ids, document start positions, sampled suffix array, +// and the flattened per-block prefix ranks (split back into nodes by +// `nodes[i].num_blocks`). The lazily-loaded BWT/SA word blocks themselves are +// *not* part of this entry — they are cached per wavelet node separately. +message FmIndexStateHeader { + uint64 bwt_len = 1; + uint32 alphabet_size = 2; + // Per wavelet node, in node-id order. A node's blocks occupy + // `num_blocks` consecutive rows starting after all earlier nodes' blocks, so + // the block row offset is the running sum of earlier `num_blocks`. + repeated FmIndexNodeMeta nodes = 3; +} + +// Per-wavelet-node metadata within an `FmIndexStateHeader`. +message FmIndexNodeMeta { + // Number of word blocks this node occupies (also its prefix-rank count). + uint32 num_blocks = 1; + // Number of bits in this node's bitvector. + uint64 bit_len = 2; +} + // --------------------------------------------------------------------------- // Vector indices (IVF partitions) // --------------------------------------------------------------------------- diff --git a/rust/lance-index/src/scalar/fmindex.rs b/rust/lance-index/src/scalar/fmindex.rs index aed1136535a..228348257fd 100644 --- a/rust/lance-index/src/scalar/fmindex.rs +++ b/rust/lance-index/src/scalar/fmindex.rs @@ -29,7 +29,9 @@ use arrow_schema::{DataType, Field}; use async_trait::async_trait; use datafusion::execution::SendableRecordBatchStream; use futures::StreamExt; -use lance_core::cache::LanceCache; +use lance_core::cache::{ + CacheCodec, CacheCodecImpl, CacheEntryReader, CacheEntryWriter, CacheKey, LanceCache, +}; use lance_core::deepsize::DeepSizeOf; use lance_core::{Error, ROW_ADDR, Result}; use roaring::RoaringBitmap; @@ -398,11 +400,265 @@ fn build_suffix_array(text: &[u8]) -> Vec { const BLOCK_BITS: usize = BLOCK_WORDS * 64; +// ── Cache entries ──────────────────────────────────────────────────────────── +// +// An opened FM-index is memoized in two granularities, both serializable through +// the stable cache codec so they can live in a node-agnostic, restart-surviving +// backend: +// +// - One `FMIndexPartitionState` per partition: the skeleton needed to rebuild a +// `LazyFMIndex` without re-reading any metadata (Huffman codes, tree +// topology, c_table, row ids, doc starts, sampled SA, and per-node prefix +// ranks). +// - One `WaveletNodeWords` per wavelet node: that node's bitvector words, +// loaded lazily and shared so neighbouring blocks come from a single read. + +/// One wavelet node's bitvector, as the packed `u64` words that back its rank +/// queries. The lazily-loaded bulk of an FM index; cached per node. +#[derive(Debug)] +struct WaveletNodeWords(Vec); + +impl DeepSizeOf for WaveletNodeWords { + fn deep_size_of_children(&self, _ctx: &mut lance_core::deepsize::Context) -> usize { + self.0.len() * 8 + } +} + +impl CacheCodecImpl for WaveletNodeWords { + const TYPE_ID: &'static str = "lance.scalar.FMIndexWaveletNode"; + const CURRENT_VERSION: u32 = 1; + + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { + w.write_raw(&FMIndex::u64_to_bytes(&self.0)) + } + + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + let bytes = r.read_raw()?; + Ok(Self( + bytes + .chunks_exact(8) + .map(|c| u64::from_le_bytes(c.try_into().unwrap())) + .collect(), + )) + } +} + +/// Per-wavelet-node metadata held in an [`FMIndexPartitionState`]. A node's +/// blocks occupy `prefix_ranks.len()` consecutive rows in the partition file. +#[derive(Debug, Clone)] +struct NodeMeta { + /// Prefix rank at each block boundary; one entry per block. + prefix_ranks: Vec, + /// Number of bits in this node's bitvector. + bit_len: usize, +} + +/// Skeleton of an FM-index partition: everything required to rebuild a +/// [`LazyFMIndex`] without re-reading partition metadata. Excludes the wavelet +/// node words (cached per node) and the `IndexReader` (reconstructed on load). +#[derive(Debug)] +struct FMIndexPartitionState { + huffman_codes: [HuffmanCode; 256], + children: Vec<(WaveletChild, WaveletChild)>, + c_table: Vec, + row_ids: Vec, + doc_start_positions: Vec, + sa_samples: Vec, + bwt_len: usize, + alphabet_size: usize, + nodes: Vec, +} + +impl DeepSizeOf for FMIndexPartitionState { + fn deep_size_of_children(&self, _ctx: &mut lance_core::deepsize::Context) -> usize { + let codes: usize = self + .huffman_codes + .iter() + .map(|c| c.node_path.len() * 8) + .sum(); + let nodes: usize = self + .nodes + .iter() + .map(|n| n.prefix_ranks.len() * 8 + std::mem::size_of::()) + .sum(); + codes + + self.children.len() * std::mem::size_of::<(WaveletChild, WaveletChild)>() + + self.c_table.len() * std::mem::size_of::() + + self.row_ids.len() * 8 + + self.doc_start_positions.len() * 8 + + self.sa_samples.len() * 8 + + nodes + } +} + +impl CacheCodecImpl for FMIndexPartitionState { + const TYPE_ID: &'static str = "lance.scalar.FMIndexState"; + const CURRENT_VERSION: u32 = 1; + + fn serialize(&self, w: &mut CacheEntryWriter<'_>) -> Result<()> { + let header = crate::cache_pb::FmIndexStateHeader { + bwt_len: self.bwt_len as u64, + alphabet_size: self.alphabet_size as u32, + nodes: self + .nodes + .iter() + .map(|n| crate::cache_pb::FmIndexNodeMeta { + num_blocks: n.prefix_ranks.len() as u32, + bit_len: n.bit_len as u64, + }) + .collect(), + }; + w.write_header(&header)?; + + w.write_raw(&serialize_c_table(&self.c_table))?; + w.write_raw(&serialize_huffman_codes(&self.huffman_codes))?; + w.write_raw(&serialize_tree_topology(&self.children))?; + w.write_raw(&FMIndex::u64_to_bytes(&self.row_ids))?; + w.write_raw(&FMIndex::u64_to_bytes(&self.doc_start_positions))?; + w.write_raw(&FMIndex::u64_to_bytes(&self.sa_samples))?; + + let prefix_ranks: Vec = self + .nodes + .iter() + .flat_map(|n| n.prefix_ranks.iter().copied()) + .collect(); + w.write_raw(&FMIndex::u64_to_bytes(&prefix_ranks))?; + Ok(()) + } + + fn deserialize(r: &mut CacheEntryReader<'_>) -> Result { + let header: crate::cache_pb::FmIndexStateHeader = r.read_header()?; + + let c_table = FMIndex::deserialize_c_table(&r.read_raw()?); + let huffman_codes = FMIndex::deserialize_huffman_codes(&r.read_raw()?); + let children = FMIndex::deserialize_tree_topology(&r.read_raw()?); + let row_ids = bytes_to_u64(&r.read_raw()?); + let doc_start_positions = bytes_to_u64(&r.read_raw()?); + let sa_samples = bytes_to_u64(&r.read_raw()?); + let mut prefix_ranks = bytes_to_u64(&r.read_raw()?).into_iter(); + + let nodes = header + .nodes + .iter() + .map(|n| NodeMeta { + prefix_ranks: prefix_ranks.by_ref().take(n.num_blocks as usize).collect(), + bit_len: n.bit_len as usize, + }) + .collect(); + + Ok(Self { + huffman_codes, + children, + c_table, + row_ids, + doc_start_positions, + sa_samples, + bwt_len: header.bwt_len as usize, + alphabet_size: header.alphabet_size as usize, + nodes, + }) + } +} + +fn bytes_to_u64(bytes: &[u8]) -> Vec { + bytes + .chunks_exact(8) + .map(|c| u64::from_le_bytes(c.try_into().unwrap())) + .collect() +} + +fn serialize_huffman_codes(codes: &[HuffmanCode; 256]) -> Vec { + let mut buf = Vec::new(); + for code in codes { + buf.extend_from_slice(&code.bits.to_le_bytes()); + buf.push(code.length); + buf.extend_from_slice(&(code.node_path.len() as u16).to_le_bytes()); + for &nid in &code.node_path { + buf.extend_from_slice(&(nid as u32).to_le_bytes()); + } + } + buf +} + +fn serialize_tree_topology(children: &[(WaveletChild, WaveletChild)]) -> Vec { + let mut buf = Vec::new(); + buf.extend_from_slice(&(children.len() as u32).to_le_bytes()); + for (left, right) in children { + for child in [left, right] { + match child { + WaveletChild::Node(id) => { + buf.push(0); + buf.extend_from_slice(&(*id as u32).to_le_bytes()); + } + WaveletChild::Leaf(b) => { + buf.push(1); + buf.extend_from_slice(&(*b as u32).to_le_bytes()); + } + } + } + } + buf +} + +fn serialize_c_table(c_table: &[usize]) -> Vec { + c_table + .iter() + .flat_map(|&v| (v as u64).to_le_bytes()) + .collect() +} + +/// Cache key for a partition's [`FMIndexPartitionState`]. The cache is already +/// per-partition namespaced by the caller, so a constant key suffices. +struct FMIndexStateKey; + +impl CacheKey for FMIndexStateKey { + type ValueType = FMIndexPartitionState; + + fn key(&self) -> std::borrow::Cow<'_, str> { + "state".into() + } + + fn type_name() -> &'static str { + "FMIndexState" + } + + fn codec() -> Option { + Some(CacheCodec::from_impl::()) + } +} + +/// Cache key for a single wavelet node's [`WaveletNodeWords`], within a +/// partition-namespaced cache. +struct WaveletNodeKey { + node_id: u32, +} + +impl CacheKey for WaveletNodeKey { + type ValueType = WaveletNodeWords; + + fn key(&self) -> std::borrow::Cow<'_, str> { + format!("node-{}", self.node_id).into() + } + + fn type_name() -> &'static str { + "FMIndexWaveletNode" + } + + fn codec() -> Option { + Some(CacheCodec::from_impl::()) + } +} + struct LazyRankBitVec { prefix_ranks: Vec, - blocks: Vec>>, + /// The node's full bitvector words, loaded as a single cached unit on first + /// access. Shared with the cache so the words back rank/access zero-copy. + words: OnceLock>, + cache: LanceCache, reader: Arc, + node_id: u32, block_row_offset: usize, + num_blocks: usize, len: usize, } @@ -417,58 +673,77 @@ impl std::fmt::Debug for LazyRankBitVec { impl LazyRankBitVec { fn new( prefix_ranks: Vec, - num_blocks: usize, + cache: LanceCache, reader: Arc, - offset: usize, + node_id: u32, + block_row_offset: usize, + num_blocks: usize, len: usize, ) -> Self { Self { prefix_ranks, - blocks: (0..num_blocks).map(|_| OnceLock::new()).collect(), + words: OnceLock::new(), + cache, reader, - block_row_offset: offset, + node_id, + block_row_offset, + num_blocks, len, } } - /// Pre-load all blocks into memory. Call this before sync rank/access operations - /// to avoid the need for `block_in_place` during queries. - async fn load_all_blocks(&self) -> Result<()> { - for (idx, lock) in self.blocks.iter().enumerate() { - if lock.get().is_none() { - let words = self.load_block(idx).await?; - let _ = lock.set(words); - } + /// Pre-load this node's words into memory. Call before sync rank/access + /// operations to avoid `block_in_place` during queries. + async fn prewarm(&self) -> Result<()> { + if self.words.get().is_none() { + let _ = self.words.set(self.load_words().await?); } Ok(()) } #[inline] - fn ensure_block(&self, idx: usize) -> &[u64] { - self.blocks[idx].get_or_init(|| { - tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(self.load_block(idx)) + fn ensure_words(&self) -> &[u64] { + self.words + .get_or_init(|| { + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(self.load_words()) + }) + .unwrap_or_else(|e| panic!("FM-Index node load failed: {e}")) }) - .unwrap_or_else(|e| panic!("FM-Index block load failed: {e}")) - }) - } - - async fn load_block(&self, idx: usize) -> Result> { - let row = self.block_row_offset + idx; - let batch = self - .reader - .read_range(row..row + 1, Some(&["words"])) - .await?; - let col = batch - .column(0) - .as_any() - .downcast_ref::() - .ok_or_else(|| Error::invalid_input("expected LargeBinary words column"))?; - Ok(col - .value(0) - .chunks_exact(8) - .map(|c| u64::from_le_bytes(c.try_into().unwrap())) - .collect()) + .0 + .as_slice() + } + + /// Load (or recompute on a cache miss) this node's words as a single cache + /// entry, so neighbouring blocks share one read. + async fn load_words(&self) -> Result> { + let reader = self.reader.clone(); + let start = self.block_row_offset; + let end = start + self.num_blocks; + self.cache + .get_or_insert_with_key( + WaveletNodeKey { + node_id: self.node_id, + }, + || async move { + let batch = reader.read_range(start..end, Some(&["words"])).await?; + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::invalid_input("expected LargeBinary words column"))?; + let mut words = Vec::new(); + for i in 0..batch.num_rows() { + words.extend( + col.value(i) + .chunks_exact(8) + .map(|c| u64::from_le_bytes(c.try_into().unwrap())), + ); + } + Ok(WaveletNodeWords(words)) + }, + ) + .await } #[inline] @@ -482,14 +757,15 @@ impl LazyRankBitVec { return self.prefix_ranks[bi] as usize; } let mut count = self.prefix_ranks[bi] as usize; - let block = self.ensure_block(bi); + let words = self.ensure_words(); + let block_start = bi * BLOCK_WORDS; let wi = local / 64; let bit = local % 64; - for w in &block[..wi] { + for w in &words[block_start..block_start + wi] { count += w.count_ones() as usize; } if bit > 0 { - count += (block[wi] & ((1u64 << bit) - 1)).count_ones() as usize; + count += (words[block_start + wi] & ((1u64 << bit) - 1)).count_ones() as usize; } count } @@ -501,19 +777,12 @@ impl LazyRankBitVec { #[inline] fn get(&self, pos: usize) -> bool { - let bi = pos / BLOCK_BITS; - let local = pos % BLOCK_BITS; - let block = self.ensure_block(bi); - (block[local / 64] >> (local % 64)) & 1 != 0 + let words = self.ensure_words(); + (words[pos / 64] >> (pos % 64)) & 1 != 0 } fn deep_size(&self) -> usize { - let loaded: usize = self - .blocks - .iter() - .filter_map(|b| b.get()) - .map(|w| w.len() * 8) - .sum(); + let loaded = self.words.get().map(|w| w.0.len() * 8).unwrap_or(0); self.prefix_ranks.len() * 8 + loaded } } @@ -534,10 +803,10 @@ impl std::fmt::Debug for LazyHuffmanWaveletTree { } impl LazyHuffmanWaveletTree { - /// Pre-load all wavelet tree blocks into memory. + /// Pre-load all wavelet node words into memory. async fn load_all(&self) -> Result<()> { for node in &self.nodes { - node.load_all_blocks().await?; + node.prewarm().await?; } Ok(()) } @@ -797,16 +1066,7 @@ impl FMIndex { } fn serialize_huffman_codes(&self) -> Vec { - let mut buf = Vec::new(); - for code in &self.wavelet.codes { - buf.extend_from_slice(&code.bits.to_le_bytes()); - buf.push(code.length); - buf.extend_from_slice(&(code.node_path.len() as u16).to_le_bytes()); - for &nid in &code.node_path { - buf.extend_from_slice(&(nid as u32).to_le_bytes()); - } - } - buf + serialize_huffman_codes(&self.wavelet.codes) } fn deserialize_huffman_codes(data: &[u8]) -> [HuffmanCode; 256] { @@ -834,23 +1094,7 @@ impl FMIndex { } fn serialize_tree_topology(&self) -> Vec { - let mut buf = Vec::new(); - buf.extend_from_slice(&(self.wavelet.children.len() as u32).to_le_bytes()); - for (left, right) in &self.wavelet.children { - for child in [left, right] { - match child { - WaveletChild::Node(id) => { - buf.push(0); - buf.extend_from_slice(&(*id as u32).to_le_bytes()); - } - WaveletChild::Leaf(b) => { - buf.push(1); - buf.extend_from_slice(&(*b as u32).to_le_bytes()); - } - } - } - } - buf + serialize_tree_topology(&self.wavelet.children) } fn deserialize_tree_topology(data: &[u8]) -> Vec<(WaveletChild, WaveletChild)> { @@ -878,10 +1122,7 @@ impl FMIndex { } fn serialize_c_table(&self) -> Vec { - self.c_table - .iter() - .flat_map(|&v| (v as u64).to_le_bytes()) - .collect() + serialize_c_table(&self.c_table) } fn deserialize_c_table(data: &[u8]) -> Vec { @@ -1042,118 +1283,42 @@ impl LazyFMIndex { result } - #[allow(clippy::too_many_arguments)] - async fn from_reader( + /// Rebuild a lazy index from a cached partition skeleton. The wavelet node + /// words remain lazy: they are loaded (through `cache`) from `reader` on + /// first access. No I/O happens here. + fn from_state( + state: &FMIndexPartitionState, reader: Arc, - num_bwt_nodes: usize, - huffman_codes: [HuffmanCode; 256], - children: Vec<(WaveletChild, WaveletChild)>, - c_table: Vec, - bwt_len: usize, - total_wavelet_rows: usize, - num_sa_blocks: usize, - sa_samples_len: usize, - row_ids: Vec, - doc_start_positions: Vec, - ) -> Result { - use arrow_array::UInt64Array; - - let meta = reader - .read_range( - 0..total_wavelet_rows, - Some(&["node_id", "prefix_rank", "bit_len"]), - ) - .await?; - let nid_col = meta - .column_by_name("node_id") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - let pr_col = meta - .column_by_name("prefix_rank") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - let bl_col = meta - .column_by_name("bit_len") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - - struct NM { - prs: Vec, - offset: usize, - blen: usize, - } - let mut nms: Vec = (0..num_bwt_nodes) - .map(|_| NM { - prs: Vec::new(), - offset: 0, - blen: 0, - }) - .collect(); - for row in 0..meta.num_rows() { - let nid = nid_col.value(row) as usize; - if nid >= num_bwt_nodes { - continue; - } - let nm = &mut nms[nid]; - if nm.prs.is_empty() { - nm.offset = row; - } - nm.prs.push(pr_col.value(row)); - nm.blen = bl_col.value(row) as usize; - } - - let mut bwt_nodes = Vec::with_capacity(num_bwt_nodes); - for nm in &nms { + cache: LanceCache, + ) -> Self { + let mut bwt_nodes = Vec::with_capacity(state.nodes.len()); + let mut block_row_offset = 0usize; + for (node_id, nm) in state.nodes.iter().enumerate() { + let num_blocks = nm.prefix_ranks.len(); bwt_nodes.push(LazyRankBitVec::new( - nm.prs.clone(), - nm.prs.len(), + nm.prefix_ranks.clone(), + cache.clone(), reader.clone(), - nm.offset, - nm.blen, + node_id as u32, + block_row_offset, + num_blocks, + nm.bit_len, )); + block_row_offset += num_blocks; } let wavelet = LazyHuffmanWaveletTree { nodes: bwt_nodes, - codes: huffman_codes, - children, - len: bwt_len, + codes: state.huffman_codes.clone(), + children: state.children.clone(), + len: state.bwt_len, }; - - // Read SA samples from packed binary blocks - let mut sa_samples = Vec::with_capacity(sa_samples_len); - let sa_batch = reader - .read_range( - total_wavelet_rows..total_wavelet_rows + num_sa_blocks, - Some(&["words"]), - ) - .await?; - let words_col = sa_batch - .column_by_name("words") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - for i in 0..sa_batch.num_rows() { - let raw = words_col.value(i); - for chunk in raw.chunks_exact(8) { - sa_samples.push(u64::from_le_bytes(chunk.try_into().unwrap())); - } - } - sa_samples.truncate(sa_samples_len); - - Ok(Self { + Self { wavelet, - row_ids, - sa_samples, - doc_start_positions, - c_table, - }) + row_ids: state.row_ids.clone(), + sa_samples: state.sa_samples.clone(), + doc_start_positions: state.doc_start_positions.clone(), + c_table: state.c_table.clone(), + } } fn deep_size(&self) -> usize { @@ -1190,76 +1355,30 @@ impl FMIndexScalarIndex { store: &dyn IndexStore, filename: &str, pid: u64, + index_cache: &LanceCache, ) -> Result { let reader = store.open_index_file(filename).await?; - let md = &reader.schema().metadata; - - let parse = |key: &str| -> Result { - md.get(key) - .ok_or_else(|| Error::invalid_input(format!("missing {key}")))? - .parse() - .map_err(|e| Error::invalid_input(format!("invalid {key}: {e}"))) + let cache = index_cache.with_key_prefix(&format!("part-{pid}")); + + // The skeleton (everything but the lazily-loaded wavelet node words) is + // cached so a reload skips the per-partition metadata reads. + let state = { + let reader = reader.clone(); + cache + .get_or_insert_with_key(FMIndexStateKey, || async move { + build_partition_state(reader.as_ref()).await + }) + .await? }; - let num_bwt_nodes = parse("num_bwt_nodes")?; - let bwt_len = parse("bwt_len")?; - let num_sa_blocks = parse("num_sa_blocks")?; - let sa_samples_len = parse("sa_samples_len")?; - let total_wavelet_rows = parse("total_wavelet_rows")?; - - let c_table = FMIndex::deserialize_c_table(&hex_decode( - md.get("c_table") - .ok_or_else(|| Error::invalid_input("missing c_table"))?, - )?); - let huffman_codes = FMIndex::deserialize_huffman_codes(&hex_decode( - md.get("huffman_codes") - .ok_or_else(|| Error::invalid_input("missing huffman_codes"))?, - )?); - let children = FMIndex::deserialize_tree_topology(&hex_decode( - md.get("tree_topology") - .ok_or_else(|| Error::invalid_input("missing tree_topology"))?, - )?); - - // row_ids and doc_start_positions stored in metadata (small) - let row_ids_hex = md - .get("row_ids") - .ok_or_else(|| Error::invalid_input("missing row_ids"))?; - let row_ids_bytes = hex_decode(row_ids_hex)?; - let row_ids: Vec = row_ids_bytes - .chunks_exact(8) - .map(|c| u64::from_le_bytes(c.try_into().unwrap())) - .collect(); - - let doc_starts_hex = md - .get("doc_start_positions") - .ok_or_else(|| Error::invalid_input("missing doc_start_positions"))?; - let doc_starts_bytes = hex_decode(doc_starts_hex)?; - let doc_start_positions: Vec = doc_starts_bytes - .chunks_exact(8) - .map(|c| u64::from_le_bytes(c.try_into().unwrap())) - .collect(); - - let fm = Box::pin(LazyFMIndex::from_reader( - reader, - num_bwt_nodes, - huffman_codes, - children, - c_table, - bwt_len, - total_wavelet_rows, - num_sa_blocks, - sa_samples_len, - row_ids, - doc_start_positions, - )) - .await?; + let fm = LazyFMIndex::from_state(&state, reader, cache); Ok(FMIndexPartition { id: pid, fm }) } async fn load( store: Arc, _fri: Option>, - _cache: &LanceCache, + cache: &LanceCache, ) -> Result> { let files = store.list_files_with_sizes().await?; let mut pfiles: Vec<(u64, String)> = Vec::new(); @@ -1280,13 +1399,124 @@ impl FMIndexScalarIndex { let mut parts = Vec::with_capacity(pfiles.len()); for (id, name) in &pfiles { parts.push(Arc::new( - Self::load_partition(store.as_ref(), name, *id).await?, + Self::load_partition(store.as_ref(), name, *id, cache).await?, )); } Ok(Arc::new(Self { partitions: parts })) } } +/// Read a partition file's skeleton: the metadata blobs plus the per-node prefix +/// ranks and sampled suffix array. The wavelet node words are left out — they +/// are loaded lazily per node. +async fn build_partition_state( + reader: &dyn crate::scalar::IndexReader, +) -> Result { + use arrow_array::UInt64Array; + + let md = &reader.schema().metadata; + let parse = |key: &str| -> Result { + md.get(key) + .ok_or_else(|| Error::invalid_input(format!("missing {key}")))? + .parse() + .map_err(|e| Error::invalid_input(format!("invalid {key}: {e}"))) + }; + let num_bwt_nodes = parse("num_bwt_nodes")?; + let bwt_len = parse("bwt_len")?; + let num_sa_blocks = parse("num_sa_blocks")?; + let sa_samples_len = parse("sa_samples_len")?; + let total_wavelet_rows = parse("total_wavelet_rows")?; + let alphabet_size = parse("alphabet_size").unwrap_or(256); + + let get_blob = |key: &str| -> Result> { + hex_decode( + md.get(key) + .ok_or_else(|| Error::invalid_input(format!("missing {key}")))?, + ) + }; + let c_table = FMIndex::deserialize_c_table(&get_blob("c_table")?); + let huffman_codes = FMIndex::deserialize_huffman_codes(&get_blob("huffman_codes")?); + let children = FMIndex::deserialize_tree_topology(&get_blob("tree_topology")?); + let row_ids = bytes_to_u64(&get_blob("row_ids")?); + let doc_start_positions = bytes_to_u64(&get_blob("doc_start_positions")?); + + // Per-node prefix ranks and bit lengths, in node-id order. + let meta = reader + .read_range( + 0..total_wavelet_rows, + Some(&["node_id", "prefix_rank", "bit_len"]), + ) + .await?; + let nid_col = meta + .column_by_name("node_id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let pr_col = meta + .column_by_name("prefix_rank") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let bl_col = meta + .column_by_name("bit_len") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let mut nodes: Vec = (0..num_bwt_nodes) + .map(|_| NodeMeta { + prefix_ranks: Vec::new(), + bit_len: 0, + }) + .collect(); + for row in 0..meta.num_rows() { + let nid = nid_col.value(row) as usize; + if nid >= num_bwt_nodes { + continue; + } + nodes[nid].prefix_ranks.push(pr_col.value(row)); + nodes[nid].bit_len = bl_col.value(row) as usize; + } + + // Sampled suffix array, packed as binary blocks after the wavelet rows. + let mut sa_samples = Vec::with_capacity(sa_samples_len); + let sa_batch = reader + .read_range( + total_wavelet_rows..total_wavelet_rows + num_sa_blocks, + Some(&["words"]), + ) + .await?; + let words_col = sa_batch + .column_by_name("words") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..sa_batch.num_rows() { + sa_samples.extend( + words_col + .value(i) + .chunks_exact(8) + .map(|c| u64::from_le_bytes(c.try_into().unwrap())), + ); + } + sa_samples.truncate(sa_samples_len); + + Ok(FMIndexPartitionState { + huffman_codes, + children, + c_table, + row_ids, + doc_start_positions, + sa_samples, + bwt_len, + alphabet_size, + nodes, + }) +} + #[async_trait] impl Index for FMIndexScalarIndex { fn as_any(&self) -> &dyn std::any::Any { @@ -1978,10 +2208,14 @@ mod tests { .unwrap(); // Load - let part = - FMIndexScalarIndex::load_partition(store.as_ref(), &fmindex_partition_path(0), 0) - .await - .unwrap(); + let part = FMIndexScalarIndex::load_partition( + store.as_ref(), + &fmindex_partition_path(0), + 0, + &LanceCache::no_cache(), + ) + .await + .unwrap(); // Verify search results match let r = part.fm.search(b"hello"); @@ -2213,4 +2447,180 @@ mod tests { assert!(stats["total_bwt_len"].as_u64().unwrap() > 0); }); } + + /// A single wavelet node's words survive a codec round-trip unchanged. + #[test] + fn test_wavelet_node_words_codec_roundtrip() { + use lance_core::cache::CacheDecode; + + let words = vec![0u64, 1, 2, 0xDEAD_BEEF_CAFE_F00D, u64::MAX]; + let codec = CacheCodec::from_impl::(); + let any: Arc = Arc::new(WaveletNodeWords(words.clone())); + let mut buf = Vec::new(); + codec.serialize(&any, &mut buf).unwrap(); + + let decoded = match codec.deserialize(&bytes::Bytes::from(buf)) { + CacheDecode::Hit(any) => any.downcast::().unwrap(), + CacheDecode::Miss(reason) => panic!("unexpected cache miss: {reason:?}"), + }; + assert_eq!(decoded.0, words); + } + + /// An `FMIndexPartitionState` survives a codec round-trip carrying everything + /// needed to rebuild a searchable index. + #[tokio::test(flavor = "multi_thread")] + async fn test_partition_state_codec_roundtrip() { + use lance_core::cache::CacheDecode; + + let texts: Vec<(u64, &[u8])> = vec![ + (10, b"hello world foo"), + (20, b"hello rust bar"), + (30, b"goodbye world baz"), + ]; + let fm = FMIndex::build(&texts).unwrap(); + + let tempdir = tempfile::tempdir().unwrap(); + let index_dir = Path::from_filesystem_path(tempdir.path()).unwrap(); + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + index_dir, + Arc::new(LanceCache::no_cache()), + )); + write_fmindex(&fm, store.as_ref(), &fmindex_partition_path(0)) + .await + .unwrap(); + + let reader = store + .open_index_file(&fmindex_partition_path(0)) + .await + .unwrap(); + let state = build_partition_state(reader.as_ref()).await.unwrap(); + let row_ids = state.row_ids.clone(); + let sa_samples = state.sa_samples.clone(); + let bwt_len = state.bwt_len; + let block_counts: Vec = state.nodes.iter().map(|n| n.prefix_ranks.len()).collect(); + + let codec = CacheCodec::from_impl::(); + let any: Arc = Arc::new(state); + let mut buf = Vec::new(); + codec.serialize(&any, &mut buf).unwrap(); + let decoded = match codec.deserialize(&bytes::Bytes::from(buf)) { + CacheDecode::Hit(any) => any.downcast::().unwrap(), + CacheDecode::Miss(reason) => panic!("unexpected cache miss: {reason:?}"), + }; + + assert_eq!(decoded.row_ids, row_ids); + assert_eq!(decoded.sa_samples, sa_samples); + assert_eq!(decoded.bwt_len, bwt_len); + assert_eq!( + decoded + .nodes + .iter() + .map(|n| n.prefix_ranks.len()) + .collect::>(), + block_counts + ); + + // Rebuild from the decoded skeleton and confirm search still works. + let lazy = LazyFMIndex::from_state(&decoded, reader, LanceCache::no_cache()); + lazy.prewarm().await.unwrap(); + let mut hits = lazy.search_row_addrs(b"hello"); + hits.sort_unstable(); + assert_eq!(hits, vec![10, 20]); + let mut hits = lazy.search_row_addrs(b"world"); + hits.sort_unstable(); + assert_eq!(hits, vec![10, 30]); + assert!(lazy.search_row_addrs(b"missing").is_empty()); + } + + /// Loading through a real cache populates the per-partition skeleton and the + /// per-wavelet-node entries, and a second load reuses them with identical + /// results. + #[tokio::test(flavor = "multi_thread")] + async fn test_load_populates_and_reuses_cache() { + async fn search_count(index: &FMIndexScalarIndex, pattern: &str) -> usize { + match index + .search( + &TextQuery::StringContains(pattern.to_string()), + &crate::metrics::NoOpMetricsCollector, + ) + .await + .unwrap() + { + SearchResult::Exact(set) => set.len().unwrap() as usize, + _ => panic!("expected exact result"), + } + } + + let texts: Vec<(u64, Vec)> = (0..20) + .map(|i| (i, format!("doc {i} hello world").into_bytes())) + .collect(); + + let tempdir = tempfile::tempdir().unwrap(); + let index_dir = Path::from_filesystem_path(tempdir.path()).unwrap(); + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + index_dir, + Arc::new(LanceCache::no_cache()), + )); + write_partitioned_fmindex(&texts, store.as_ref()) + .await + .unwrap(); + + let cache = LanceCache::with_capacity(64 * 1024 * 1024); + + // First load + search populates the skeleton and wavelet-node entries. + let index = FMIndexScalarIndex::load(store.clone(), None, &cache) + .await + .unwrap(); + assert_eq!(search_count(&index, "hello world").await, 20); + + let part_cache = cache.with_key_prefix("part-0"); + assert!( + part_cache.get_with_key(&FMIndexStateKey).await.is_some(), + "skeleton should be cached after load" + ); + assert!( + part_cache + .get_with_key(&WaveletNodeKey { node_id: 0 }) + .await + .is_some(), + "wavelet node words should be cached after search" + ); + + // Second load reuses the cached entries; results are identical. + let index2 = FMIndexScalarIndex::load(store, None, &cache).await.unwrap(); + assert_eq!(search_count(&index2, "hello world").await, 20); + assert_eq!(search_count(&index2, "doc 7 ").await, 1); + } + + /// An empty index (no documents, no wavelet nodes, no SA blocks) round-trips + /// through the skeleton cache without error. + #[tokio::test(flavor = "multi_thread")] + async fn test_load_empty_index_through_cache() { + let tempdir = tempfile::tempdir().unwrap(); + let index_dir = Path::from_filesystem_path(tempdir.path()).unwrap(); + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + index_dir, + Arc::new(LanceCache::no_cache()), + )); + write_partitioned_fmindex(&[], store.as_ref()) + .await + .unwrap(); + + let cache = LanceCache::with_capacity(1024 * 1024); + let index = FMIndexScalarIndex::load(store, None, &cache).await.unwrap(); + match index + .search( + &TextQuery::StringContains("anything".to_string()), + &crate::metrics::NoOpMetricsCollector, + ) + .await + .unwrap() + { + SearchResult::Exact(set) => assert_eq!(set.len(), Some(0)), + _ => panic!("expected exact result"), + } + } }