diff --git a/rust/lance/Cargo.toml b/rust/lance/Cargo.toml index 440c3fb301a..06211c0724a 100644 --- a/rust/lance/Cargo.toml +++ b/rust/lance/Cargo.toml @@ -187,6 +187,10 @@ harness = false name = "scan" harness = false +[[bench]] +name = "count_pushdown" +harness = false + [[bench]] name = "vector_index" harness = false diff --git a/rust/lance/benches/count_pushdown.rs b/rust/lance/benches/count_pushdown.rs new file mode 100644 index 00000000000..4f633d489dc --- /dev/null +++ b/rust/lance/benches/count_pushdown.rs @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Benchmarks for `COUNT(*)` via the scanner aggregate plan (the path the +//! `count_pushdown` rule rewrites into `CountFromMaskExec`). +//! +//! The dataset uses stable row ids, multiple fragments, and scattered +//! cross-fragment deletions, with a BTree scalar index on the filter column. +//! Run on two revisions to compare (e.g. before/after a change to the rule): +//! +//! ```text +//! cargo bench -p lance --bench count_pushdown +//! ``` + +use std::sync::Arc; + +use arrow_array::types::UInt32Type; +use criterion::{Criterion, criterion_group, criterion_main}; +use lance::Dataset; +use lance::dataset::WriteParams; +use lance::index::DatasetIndexExt; +use lance_core::utils::tempfile::TempStrDir; +use lance_datagen::{BatchCount, RowCount, array, gen_batch}; +use lance_index::IndexType; +use lance_index::scalar::ScalarIndexParams; +#[cfg(target_os = "linux")] +use lance_testing::pprof::{Output, PProfProfiler}; + +const ROWS_PER_FRAGMENT: usize = 100_000; +const NUM_FRAGMENTS: usize = 50; +const TOTAL_ROWS: u32 = (ROWS_PER_FRAGMENT * NUM_FRAGMENTS) as u32; // 5,000,000 + +struct Fixture { + _datadir: TempStrDir, + dataset: Arc, +} + +impl Fixture { + async fn open() -> Self { + let datadir = TempStrDir::default(); + // `value` steps 0..TOTAL_ROWS, so `value < k` selects exactly k rows + // (before deletions) and gives precise control over selectivity. + let reader = gen_batch() + .col("value", array::step::()) + .into_reader_rows( + RowCount::from(ROWS_PER_FRAGMENT as u64), + BatchCount::from(NUM_FRAGMENTS as u32), + ); + let mut dataset = Dataset::write( + reader, + datadir.as_str(), + Some(WriteParams { + max_rows_per_file: ROWS_PER_FRAGMENT, + enable_stable_row_ids: true, + ..Default::default() + }), + ) + .await + .unwrap(); + + // Scatter deletions across every fragment (~1%) to exercise the + // deletion mask in stable-id space. + dataset.delete("value % 100 = 0").await.unwrap(); + + dataset + .create_index( + &["value"], + IndexType::BTree, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + Self { + _datadir: datadir, + dataset: Arc::new(dataset), + } + } +} + +async fn count_unfiltered(dataset: &Dataset) -> u64 { + dataset.scan().count_rows().await.unwrap() +} + +async fn count_filtered(dataset: &Dataset, filter: &str) -> u64 { + let mut scanner = dataset.scan(); + scanner.filter(filter).unwrap(); + scanner.count_rows().await.unwrap() +} + +fn bench_count(c: &mut Criterion) { + let rt = tokio::runtime::Runtime::new().unwrap(); + let fixture = rt.block_on(Fixture::open()); + let ds = &fixture.dataset; + + c.bench_function("count_unfiltered", |b| { + b.iter(|| rt.block_on(count_unfiltered(ds))) + }); + + // ~1% of rows match. + let filter_1pct = format!("value < {}", TOTAL_ROWS / 100); + c.bench_function("count_filtered_1pct", |b| { + b.iter(|| rt.block_on(count_filtered(ds, &filter_1pct))) + }); + + // ~50% of rows match. + let filter_50pct = format!("value < {}", TOTAL_ROWS / 2); + c.bench_function("count_filtered_50pct", |b| { + b.iter(|| rt.block_on(count_filtered(ds, &filter_50pct))) + }); +} + +#[cfg(target_os = "linux")] +criterion_group!( + name = benches; + config = Criterion::default().significance_level(0.1).sample_size(10) + .with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); + targets = bench_count); + +#[cfg(not(target_os = "linux"))] +criterion_group!( + name = benches; + config = Criterion::default().significance_level(0.1).sample_size(10); + targets = bench_count); + +criterion_main!(benches); diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs index d10a3e42769..8d45cda98e2 100644 --- a/rust/lance/src/dataset/tests/dataset_aggregate.rs +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -1301,6 +1301,66 @@ async fn test_scanner_count_rows_with_indexed_filter() { ); } +#[tokio::test] +async fn test_scanner_count_rows_with_indexed_filter_stable_row_ids() { + // Indexed-filter count under stable row ids, with deletions in both + // fragments. The rule fires and the cross-fragment count stays correct. + let tmp = tempdir().unwrap(); + let uri = tmp.path().to_str().unwrap(); + let mut ds = gen_batch() + .col("x", array::step::()) + .col("y", array::step_custom::(0, 2)) + .col("category", array::cycle::(vec![1, 2, 3])) + .into_dataset_with_params( + uri, + FragmentCount::from(2), + FragmentRowCount::from(50), + Some(crate::dataset::WriteParams { + max_rows_per_file: 50, + enable_stable_row_ids: true, + ..Default::default() + }), + ) + .await + .unwrap(); + ds.create_index( + &["x"], + IndexType::BTree, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + // Delete one row from each fragment (x=10 in frag 0, x=70 in frag 1). + ds.delete("x = 10 OR x = 70").await.unwrap(); + + let mut scanner = ds.scan(); + scanner.filter("x < 100").unwrap(); + scanner + .aggregate(AggregateExpr::builder().count_star().build()) + .unwrap(); + let plan = scanner.create_plan().await.unwrap(); + + assert_plan_node_equals( + plan.clone(), + "AggregateExec: mode=Final, gby=[], aggr=[count(Int32(1))] + CountFromMask + ScalarIndexQuery: query=[x < 100]@x_idx(BTree)", + ) + .await + .unwrap(); + + let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 1); + // 100 rows match `x < 100`, minus the two deletions. + assert_eq!( + batches[0].column(0).as_primitive::().value(0), + 98, + ); +} + #[tokio::test] async fn test_scanner_count_rows_with_partial_index_coverage() { // Index covers the first two fragments, then a third fragment is diff --git a/rust/lance/src/io/exec/count_from_mask.rs b/rust/lance/src/io/exec/count_from_mask.rs index df0478ce208..0b7aeb11111 100644 --- a/rust/lance/src/io/exec/count_from_mask.rs +++ b/rust/lance/src/io/exec/count_from_mask.rs @@ -31,7 +31,7 @@ use datafusion::physical_plan::{ }; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; -use futures::{StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt, stream}; use lance_core::{Error, Result}; use lance_select::{RowAddrMask, RowAddrSelection, RowAddrTreeMap}; use lance_table::format::Fragment; @@ -40,6 +40,7 @@ use tracing::instrument; use super::utils::InstrumentedRecordBatchStreamAdapter; use crate::Dataset; +use crate::dataset::rowids::load_row_id_sequences; use crate::index::prefilter::DatasetPreFilter; /// An execution node that computes a `COUNT(*)`-style aggregate from an @@ -234,36 +235,16 @@ impl CountFromMaskExec { Ok(count) } - #[instrument(name = "count_from_mask", skip_all, level = "debug")] - async fn do_execute( - dataset: Arc, - aggregate_funcs_len: usize, - prefilter_input: Option>, - restrict_to_fragments: Option, - context: Arc, - schema: SchemaRef, - ) -> Result { - let prefilter = match prefilter_input { - None => None, - Some(input) => Some(Self::load_prefilter(input, context.clone()).await?), - }; - - // Anchor the deletion mask against either every dataset fragment or - // the caller-supplied restricted subset. - let dataset_fragments: RoaringBitmap = - dataset.fragments().iter().map(|f| f.id as u32).collect(); - let fragments_covered = match restrict_to_fragments { - Some(restrict) => dataset_fragments & restrict, - None => dataset_fragments, - }; - - // Build the fragments allow list as concrete `[0..physical_rows)` - // ranges rather than `Full` markers. `Full` interacts poorly with - // `BlockList` subtraction — `RowAddrTreeMap::Sub` materializes a - // `RoaringBitmap::full()` (2^32 rows) per fragment when a `Full` - // entry gets a partial block subtracted from it, which inflates - // counts and is expensive. Concrete ranges avoid that path entirely - // and keep `len()` exact at every combine step. + /// Row-address-space fragments-allow list: concrete `[0..physical_rows)` + /// ranges per covered fragment. + /// + /// Concrete ranges, not `Full` markers: subtracting a `BlockList` from a + /// `Full` entry materializes a `RoaringBitmap::full()` (2^32) per fragment, + /// which is slow and throws off `len()`. + fn address_fragments_allow( + dataset: &Dataset, + fragments_covered: &RoaringBitmap, + ) -> Result { let frag_map: HashMap = dataset .fragments() .iter() @@ -287,16 +268,118 @@ impl CountFromMaskExec { bitmap.insert_range(0u32..(physical as u32)); fragments_allow.insert_bitmap(frag_id, bitmap); } + Ok(fragments_allow) + } + + /// Live (non-deleted) row count of the covered fragments, from fragment + /// metadata. Used for an unfiltered count: no prefilter to intersect, so no + /// need to build the stable-id universe. + async fn count_live_rows(dataset: &Dataset, fragments_covered: &RoaringBitmap) -> Result { + let frags = dataset + .get_fragments() + .into_iter() + .filter(|f| fragments_covered.contains(f.id() as u32)); + let counts = stream::iter(frags) + .map(|f| async move { f.count_rows(None).await }) + .buffer_unordered(dataset.object_store.as_ref().io_parallelism()) + .try_collect::>() + .await?; + Ok(counts.iter().sum::() as i64) + } + + /// Count universe in stable-id space: live stable row ids whose current home + /// is in `fragments_covered`. Staying in stable-id space lets it intersect + /// the index prefilter directly; deletions are already folded in, so the + /// caller passes no separate deletion mask. + async fn stable_id_universe( + dataset: &Arc, + fragments_covered: RoaringBitmap, + ) -> Result { + // create_restricted_deletion_mask gives a live-id allow list restricted + // to `fragments_covered`. It returns None only with no deletions and full + // coverage — then the universe is every stable id, loaded below. + if let Some(fut) = DatasetPreFilter::create_restricted_deletion_mask( + dataset.clone(), + fragments_covered.clone(), + ) { + let mask = fut.await?; + return mask.allow_list().cloned().ok_or_else(|| { + Error::internal( + "CountFromMaskExec: stable-row-id deletion mask must be an AllowList" + .to_string(), + ) + }); + } + Self::load_stable_id_universe(dataset, &fragments_covered).await + } + + /// Every stable row id in the covered fragments, from their row-id sequences + /// (metadata, not column data). Only used with no deletions and full coverage. + async fn load_stable_id_universe( + dataset: &Dataset, + fragments_covered: &RoaringBitmap, + ) -> Result { + let frags: Vec = dataset + .fragments() + .iter() + .filter(|f| fragments_covered.contains(f.id as u32)) + .cloned() + .collect(); + let mut sequences = load_row_id_sequences(dataset, &frags); + let mut universe = RowAddrTreeMap::new(); + while let Some((_frag_id, sequence)) = sequences.try_next().await? { + universe |= RowAddrTreeMap::from(sequence.as_ref()); + } + Ok(universe) + } + + #[instrument(name = "count_from_mask", skip_all, level = "debug")] + async fn do_execute( + dataset: Arc, + aggregate_funcs_len: usize, + prefilter_input: Option>, + restrict_to_fragments: Option, + context: Arc, + schema: SchemaRef, + ) -> Result { + let prefilter = match prefilter_input { + None => None, + Some(input) => Some(Self::load_prefilter(input, context.clone()).await?), + }; - // Load the deletion mask for the covered fragments. - let deletion_mask = - match DatasetPreFilter::create_deletion_mask(dataset.clone(), fragments_covered) { - Some(fut) => Some(fut.await?), - None => None, - }; + // Anchor the deletion mask against either every dataset fragment or + // the caller-supplied restricted subset. + let dataset_fragments: RoaringBitmap = + dataset.fragments().iter().map(|f| f.id as u32).collect(); + let fragments_covered = match restrict_to_fragments { + Some(restrict) => dataset_fragments & restrict, + None => dataset_fragments, + }; - let combined = Self::combine_masks(fragments_allow, prefilter, deletion_mask); - let count = Self::count_from_mask(&combined, dataset.as_ref())?; + // Under stable row ids the prefilter and deletion masks are in stable-id + // space, so the universe must be too (see `stable_id_universe`); the + // default path builds it in row-address space. + let count = if dataset.manifest.uses_stable_row_ids() { + match prefilter { + // No prefilter: just the live row count, from metadata. + None => Self::count_live_rows(&dataset, &fragments_covered).await?, + Some(prefilter) => { + let universe = Self::stable_id_universe(&dataset, fragments_covered).await?; + let combined = Self::combine_masks(universe, Some(prefilter), None); + Self::count_from_mask(&combined, dataset.as_ref())? + } + } + } else { + let fragments_allow = Self::address_fragments_allow(&dataset, &fragments_covered)?; + // Load the deletion mask for the covered fragments. + let deletion_mask = + match DatasetPreFilter::create_deletion_mask(dataset.clone(), fragments_covered) { + Some(fut) => Some(fut.await?), + None => None, + }; + let combined = Self::combine_masks(fragments_allow, prefilter, deletion_mask); + Self::count_from_mask(&combined, dataset.as_ref())? + }; // Every aggregate is the same non-distinct COUNT shape — emit the // count once per output column. diff --git a/rust/lance/src/io/exec/count_pushdown.rs b/rust/lance/src/io/exec/count_pushdown.rs index 3a3f442aa3e..75a9725b593 100644 --- a/rust/lance/src/io/exec/count_pushdown.rs +++ b/rust/lance/src/io/exec/count_pushdown.rs @@ -146,22 +146,9 @@ fn try_rewrite(agg: &AggregateExec) -> DFResult>> return Ok(None); }; - // Stable-row-id mode: `DatasetPreFilter::create_deletion_mask` produces - // an AllowList in stable-id space, but `CountFromMaskExec` builds its - // fragments-allow list in row-address space. ANDing across the two - // yields a silently wrong count (rows in fragments > 0 are dropped - // because their stable ids and row addresses share a fragment-id bucket - // only by accident). Until the exec can reconcile the two id spaces, - // refuse to fire — but warn so we notice the lost optimization - // opportunity. - if filtered_read.dataset().manifest().uses_stable_row_ids() { - warn!( - "count_pushdown: skipped because the dataset uses stable row ids; \ - the count will be computed via a full scan. Reconciling the two id spaces \ - would let this query be answered from index metadata." - ); - return Ok(None); - } + // Stable-row-id mode is handled inside `CountFromMaskExec::do_execute`, + // which builds the count universe in stable-id space (matching the + // prefilter and deletion masks) rather than row-address space. let options = filtered_read.options(); // A refine filter is a residual the index couldn't fully evaluate — it @@ -668,7 +655,8 @@ mod tests { } #[tokio::test] - async fn rule_skips_with_stable_row_ids() { + async fn rule_fires_with_stable_row_ids() { + // Unfiltered count, stable row ids, with a deletion. use crate::dataset::WriteParams; let tmp = TempStrDir::default(); let mut dataset = gen_batch() @@ -692,8 +680,58 @@ mod tests { let (plan, count) = run_count(&mut scanner).await; assert_eq!(count, 19); assert!( - !plan_contains_pushdown(&plan), - "rule must not fire under stable row IDs, got plan: {}", + plan_contains_pushdown(&plan), + "rule should fire under stable row IDs, got plan: {}", + displayable(plan.as_ref()).indent(true) + ); + } + + #[tokio::test] + async fn rule_fires_with_stable_row_ids_and_filter() { + // Indexed filter, stable row ids, deletions spread across fragments -- + // the case the pre-fix code got wrong (dropped rows in fragments > 0). + use crate::dataset::WriteParams; + let tmp = TempStrDir::default(); + let mut dataset = gen_batch() + .col("ordered", lance_datagen::array::step::()) + .into_dataset_with_params( + tmp.as_str(), + FragmentCount::from(3), + FragmentRowCount::from(10), + Some(WriteParams { + max_rows_per_file: 10, + enable_stable_row_ids: true, + ..Default::default() + }), + ) + .await + .unwrap(); + dataset + .create_index( + &["ordered"], + IndexType::BTree, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + // Delete one row from fragment 1 and one from fragment 2. + dataset + .delete("ordered = 15 OR ordered = 25") + .await + .unwrap(); + let dataset = Arc::new(dataset); + + let mut scanner = dataset.scan(); + // Matches every row across all three fragments; with the two deletions + // the live count is 28. + scanner.filter("ordered >= 0").unwrap(); + let (plan, count) = run_count(&mut scanner).await; + assert_eq!(count, 28); + assert!( + plan_contains_pushdown(&plan), + "rule should fire under stable row IDs with a filter, got plan: {}", displayable(plan.as_ref()).indent(true) ); }