Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rust/lance/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ harness = false
name = "scan"
harness = false

[[bench]]
name = "count_pushdown"
harness = false

[[bench]]
name = "vector_index"
harness = false
Expand Down
128 changes: 128 additions & 0 deletions rust/lance/benches/count_pushdown.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

//! Benchmarks for `COUNT(*)` via the scanner aggregate plan (the path the
//! `count_pushdown` rule rewrites into `CountFromMaskExec`).
//!
//! The dataset uses stable row ids, multiple fragments, and scattered
//! cross-fragment deletions, with a BTree scalar index on the filter column.
//! Run on two revisions to compare (e.g. before/after a change to the rule):
//!
//! ```text
//! cargo bench -p lance --bench count_pushdown
//! ```

use std::sync::Arc;

use arrow_array::types::UInt32Type;
use criterion::{Criterion, criterion_group, criterion_main};
use lance::Dataset;
use lance::dataset::WriteParams;
use lance::index::DatasetIndexExt;
use lance_core::utils::tempfile::TempStrDir;
use lance_datagen::{BatchCount, RowCount, array, gen_batch};
use lance_index::IndexType;
use lance_index::scalar::ScalarIndexParams;
#[cfg(target_os = "linux")]
use lance_testing::pprof::{Output, PProfProfiler};

const ROWS_PER_FRAGMENT: usize = 100_000;
const NUM_FRAGMENTS: usize = 50;
const TOTAL_ROWS: u32 = (ROWS_PER_FRAGMENT * NUM_FRAGMENTS) as u32; // 5,000,000

struct Fixture {
_datadir: TempStrDir,
dataset: Arc<Dataset>,
}

impl Fixture {
async fn open() -> Self {
let datadir = TempStrDir::default();
// `value` steps 0..TOTAL_ROWS, so `value < k` selects exactly k rows
// (before deletions) and gives precise control over selectivity.
let reader = gen_batch()
.col("value", array::step::<UInt32Type>())
.into_reader_rows(
RowCount::from(ROWS_PER_FRAGMENT as u64),
BatchCount::from(NUM_FRAGMENTS as u32),
);
let mut dataset = Dataset::write(
reader,
datadir.as_str(),
Some(WriteParams {
max_rows_per_file: ROWS_PER_FRAGMENT,
enable_stable_row_ids: true,
..Default::default()
}),
)
.await
.unwrap();

// Scatter deletions across every fragment (~1%) to exercise the
// deletion mask in stable-id space.
dataset.delete("value % 100 = 0").await.unwrap();

dataset
.create_index(
&["value"],
IndexType::BTree,
None,
&ScalarIndexParams::default(),
true,
)
.await
.unwrap();

Self {
_datadir: datadir,
dataset: Arc::new(dataset),
}
}
}

async fn count_unfiltered(dataset: &Dataset) -> u64 {
dataset.scan().count_rows().await.unwrap()
}

async fn count_filtered(dataset: &Dataset, filter: &str) -> u64 {
let mut scanner = dataset.scan();
scanner.filter(filter).unwrap();
scanner.count_rows().await.unwrap()
}

fn bench_count(c: &mut Criterion) {
let rt = tokio::runtime::Runtime::new().unwrap();
let fixture = rt.block_on(Fixture::open());
let ds = &fixture.dataset;

c.bench_function("count_unfiltered", |b| {
b.iter(|| rt.block_on(count_unfiltered(ds)))
});

// ~1% of rows match.
let filter_1pct = format!("value < {}", TOTAL_ROWS / 100);
c.bench_function("count_filtered_1pct", |b| {
b.iter(|| rt.block_on(count_filtered(ds, &filter_1pct)))
});

// ~50% of rows match.
let filter_50pct = format!("value < {}", TOTAL_ROWS / 2);
c.bench_function("count_filtered_50pct", |b| {
b.iter(|| rt.block_on(count_filtered(ds, &filter_50pct)))
});
}

#[cfg(target_os = "linux")]
criterion_group!(
name = benches;
config = Criterion::default().significance_level(0.1).sample_size(10)
.with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
targets = bench_count);

#[cfg(not(target_os = "linux"))]
criterion_group!(
name = benches;
config = Criterion::default().significance_level(0.1).sample_size(10);
targets = bench_count);

criterion_main!(benches);
60 changes: 60 additions & 0 deletions rust/lance/src/dataset/tests/dataset_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,66 @@ async fn test_scanner_count_rows_with_indexed_filter() {
);
}

#[tokio::test]
async fn test_scanner_count_rows_with_indexed_filter_stable_row_ids() {
// Indexed-filter count under stable row ids, with deletions in both
// fragments. The rule fires and the cross-fragment count stays correct.
let tmp = tempdir().unwrap();
let uri = tmp.path().to_str().unwrap();
let mut ds = gen_batch()
.col("x", array::step::<Int64Type>())
.col("y", array::step_custom::<Int64Type>(0, 2))
.col("category", array::cycle::<Int64Type>(vec![1, 2, 3]))
.into_dataset_with_params(
uri,
FragmentCount::from(2),
FragmentRowCount::from(50),
Some(crate::dataset::WriteParams {
max_rows_per_file: 50,
enable_stable_row_ids: true,
..Default::default()
}),
)
.await
.unwrap();
ds.create_index(
&["x"],
IndexType::BTree,
None,
&ScalarIndexParams::default(),
true,
)
.await
.unwrap();
// Delete one row from each fragment (x=10 in frag 0, x=70 in frag 1).
ds.delete("x = 10 OR x = 70").await.unwrap();

let mut scanner = ds.scan();
scanner.filter("x < 100").unwrap();
scanner
.aggregate(AggregateExpr::builder().count_star().build())
.unwrap();
let plan = scanner.create_plan().await.unwrap();

assert_plan_node_equals(
plan.clone(),
"AggregateExec: mode=Final, gby=[], aggr=[count(Int32(1))]
CountFromMask
ScalarIndexQuery: query=[x < 100]@x_idx(BTree)",
)
.await
.unwrap();

let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
assert_eq!(batches.len(), 1);
// 100 rows match `x < 100`, minus the two deletions.
assert_eq!(
batches[0].column(0).as_primitive::<Int64Type>().value(0),
98,
);
}

#[tokio::test]
async fn test_scanner_count_rows_with_partial_index_coverage() {
// Index covers the first two fragments, then a third fragment is
Expand Down
161 changes: 122 additions & 39 deletions rust/lance/src/io/exec/count_from_mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion::physical_plan::{
};
use datafusion_physical_expr::EquivalenceProperties;
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use futures::{StreamExt, TryStreamExt};
use futures::{StreamExt, TryStreamExt, stream};
use lance_core::{Error, Result};
use lance_select::{RowAddrMask, RowAddrSelection, RowAddrTreeMap};
use lance_table::format::Fragment;
Expand All @@ -40,6 +40,7 @@ use tracing::instrument;

use super::utils::InstrumentedRecordBatchStreamAdapter;
use crate::Dataset;
use crate::dataset::rowids::load_row_id_sequences;
use crate::index::prefilter::DatasetPreFilter;

/// An execution node that computes a `COUNT(*)`-style aggregate from an
Expand Down Expand Up @@ -234,36 +235,16 @@ impl CountFromMaskExec {
Ok(count)
}

#[instrument(name = "count_from_mask", skip_all, level = "debug")]
async fn do_execute(
dataset: Arc<Dataset>,
aggregate_funcs_len: usize,
prefilter_input: Option<Arc<dyn ExecutionPlan>>,
restrict_to_fragments: Option<RoaringBitmap>,
context: Arc<datafusion::execution::context::TaskContext>,
schema: SchemaRef,
) -> Result<RecordBatch> {
let prefilter = match prefilter_input {
None => None,
Some(input) => Some(Self::load_prefilter(input, context.clone()).await?),
};

// Anchor the deletion mask against either every dataset fragment or
// the caller-supplied restricted subset.
let dataset_fragments: RoaringBitmap =
dataset.fragments().iter().map(|f| f.id as u32).collect();
let fragments_covered = match restrict_to_fragments {
Some(restrict) => dataset_fragments & restrict,
None => dataset_fragments,
};

// Build the fragments allow list as concrete `[0..physical_rows)`
// ranges rather than `Full` markers. `Full` interacts poorly with
// `BlockList` subtraction — `RowAddrTreeMap::Sub` materializes a
// `RoaringBitmap::full()` (2^32 rows) per fragment when a `Full`
// entry gets a partial block subtracted from it, which inflates
// counts and is expensive. Concrete ranges avoid that path entirely
// and keep `len()` exact at every combine step.
/// Row-address-space fragments-allow list: concrete `[0..physical_rows)`
/// ranges per covered fragment.
///
/// Concrete ranges, not `Full` markers: subtracting a `BlockList` from a
/// `Full` entry materializes a `RoaringBitmap::full()` (2^32) per fragment,
/// which is slow and throws off `len()`.
fn address_fragments_allow(
dataset: &Dataset,
fragments_covered: &RoaringBitmap,
) -> Result<RowAddrTreeMap> {
let frag_map: HashMap<u32, &Fragment> = dataset
.fragments()
.iter()
Expand All @@ -287,16 +268,118 @@ impl CountFromMaskExec {
bitmap.insert_range(0u32..(physical as u32));
fragments_allow.insert_bitmap(frag_id, bitmap);
}
Ok(fragments_allow)
}

/// Live (non-deleted) row count of the covered fragments, from fragment
/// metadata. Used for an unfiltered count: no prefilter to intersect, so no
/// need to build the stable-id universe.
async fn count_live_rows(dataset: &Dataset, fragments_covered: &RoaringBitmap) -> Result<i64> {
let frags = dataset
.get_fragments()
.into_iter()
.filter(|f| fragments_covered.contains(f.id() as u32));
let counts = stream::iter(frags)
.map(|f| async move { f.count_rows(None).await })
.buffer_unordered(dataset.object_store.as_ref().io_parallelism())
.try_collect::<Vec<_>>()
.await?;
Ok(counts.iter().sum::<usize>() as i64)
}

/// Count universe in stable-id space: live stable row ids whose current home
/// is in `fragments_covered`. Staying in stable-id space lets it intersect
/// the index prefilter directly; deletions are already folded in, so the
/// caller passes no separate deletion mask.
async fn stable_id_universe(
dataset: &Arc<Dataset>,
fragments_covered: RoaringBitmap,
) -> Result<RowAddrTreeMap> {
// create_restricted_deletion_mask gives a live-id allow list restricted
// to `fragments_covered`. It returns None only with no deletions and full
// coverage — then the universe is every stable id, loaded below.
if let Some(fut) = DatasetPreFilter::create_restricted_deletion_mask(
dataset.clone(),
fragments_covered.clone(),
) {
let mask = fut.await?;
return mask.allow_list().cloned().ok_or_else(|| {
Error::internal(
"CountFromMaskExec: stable-row-id deletion mask must be an AllowList"
.to_string(),
)
});
}
Self::load_stable_id_universe(dataset, &fragments_covered).await
}

/// Every stable row id in the covered fragments, from their row-id sequences
/// (metadata, not column data). Only used with no deletions and full coverage.
async fn load_stable_id_universe(
dataset: &Dataset,
fragments_covered: &RoaringBitmap,
) -> Result<RowAddrTreeMap> {
let frags: Vec<Fragment> = dataset
.fragments()
.iter()
.filter(|f| fragments_covered.contains(f.id as u32))
.cloned()
.collect();
let mut sequences = load_row_id_sequences(dataset, &frags);
let mut universe = RowAddrTreeMap::new();
while let Some((_frag_id, sequence)) = sequences.try_next().await? {
universe |= RowAddrTreeMap::from(sequence.as_ref());
}
Ok(universe)
}

#[instrument(name = "count_from_mask", skip_all, level = "debug")]
async fn do_execute(
dataset: Arc<Dataset>,
aggregate_funcs_len: usize,
prefilter_input: Option<Arc<dyn ExecutionPlan>>,
restrict_to_fragments: Option<RoaringBitmap>,
context: Arc<datafusion::execution::context::TaskContext>,
schema: SchemaRef,
) -> Result<RecordBatch> {
let prefilter = match prefilter_input {
None => None,
Some(input) => Some(Self::load_prefilter(input, context.clone()).await?),
};

// Load the deletion mask for the covered fragments.
let deletion_mask =
match DatasetPreFilter::create_deletion_mask(dataset.clone(), fragments_covered) {
Some(fut) => Some(fut.await?),
None => None,
};
// Anchor the deletion mask against either every dataset fragment or
// the caller-supplied restricted subset.
let dataset_fragments: RoaringBitmap =
dataset.fragments().iter().map(|f| f.id as u32).collect();
let fragments_covered = match restrict_to_fragments {
Some(restrict) => dataset_fragments & restrict,
None => dataset_fragments,
};

let combined = Self::combine_masks(fragments_allow, prefilter, deletion_mask);
let count = Self::count_from_mask(&combined, dataset.as_ref())?;
// Under stable row ids the prefilter and deletion masks are in stable-id
// space, so the universe must be too (see `stable_id_universe`); the
// default path builds it in row-address space.
let count = if dataset.manifest.uses_stable_row_ids() {
match prefilter {
// No prefilter: just the live row count, from metadata.
None => Self::count_live_rows(&dataset, &fragments_covered).await?,
Some(prefilter) => {
let universe = Self::stable_id_universe(&dataset, fragments_covered).await?;
let combined = Self::combine_masks(universe, Some(prefilter), None);
Self::count_from_mask(&combined, dataset.as_ref())?
}
}
} else {
let fragments_allow = Self::address_fragments_allow(&dataset, &fragments_covered)?;
// Load the deletion mask for the covered fragments.
let deletion_mask =
match DatasetPreFilter::create_deletion_mask(dataset.clone(), fragments_covered) {
Some(fut) => Some(fut.await?),
None => None,
};
let combined = Self::combine_masks(fragments_allow, prefilter, deletion_mask);
Self::count_from_mask(&combined, dataset.as_ref())?
};

// Every aggregate is the same non-distinct COUNT shape — emit the
// count once per output column.
Expand Down
Loading
Loading