From de133306f80246c00566350e7165be05dd853810 Mon Sep 17 00:00:00 2001 From: xushiyan <2701446+xushiyan@users.noreply.github.com> Date: Tue, 16 Jun 2026 15:39:15 -0700 Subject: [PATCH] fix(python): handle empty filtered shards in ShardedBatchSampler ShardedBatchSampler._shard_scan crashed on the filtered path instead of yielding an empty stream in two cases: - Global zero-match: nothing accumulates, so the final pa.Table.from_batches([]) raised ValueError. - Empty per-rank shard: a sparse filter with world_size > 1 leaves some ranks with no rows after round-robin, so batch.take([]) built a null-typed index array and raised ArrowNotImplementedError. Compute take indices against the full batch (range(rows_to_skip, n, N)) and carry rows_to_skip from the full batch length so the round-robin offset stays correct when a filtered scan splits matches across many small batches/fragments; skip batches where this rank owns no row; and make the final flush conditional. Empty shards now yield an empty iterator, matching the unfiltered _sample_all contract. Adds regression coverage for the empty-shard, zero-match, and cross-fragment carryover cases (parametrized over randomize). --- python/python/lance/sampler.py | 37 ++++++++++---- python/python/tests/test_sampler.py | 78 +++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 10 deletions(-) diff --git a/python/python/lance/sampler.py b/python/python/lance/sampler.py index b7e7230dfc6..40ef16a1f02 100644 --- a/python/python/lance/sampler.py +++ b/python/python/lance/sampler.py @@ -452,12 +452,23 @@ def _shard_scan( filter=filter, scan_in_order=True, ).to_batches(): - batch = batch.slice(rows_to_skip, batch.num_rows - rows_to_skip) - # Take every Nth row - indices = list(range(0, batch.num_rows, self._world_size)) - rows_to_skip = ( - self._world_size - (batch.num_rows % self._world_size) - ) % self._world_size + # Take this rank's rows by their global round-robin position. The + # scan is in order, so concatenating every batch reproduces the full + # filtered row sequence; this rank owns positions rank, rank + N, + # rank + 2N, ... (N = world size). `rows_to_skip` is the offset of + # this rank's next row within the current batch. + indices = list(range(rows_to_skip, batch.num_rows, self._world_size)) + # Carry the offset into the next batch from the *full* batch length, + # not a sliced view. A filtered scan can split matches across many + # small batches/fragments, and computing the carry from a truncated + # batch would desync the round-robin and mis-assign rows. + rows_to_skip = (rows_to_skip - batch.num_rows) % self._world_size + # This rank may own no row in this batch (sparse matches with + # world_size > 1). Skip it: batch.take([]) infers a null-typed index + # array and PyArrow has no array_take(int64, null) kernel, so an empty + # take raises instead of yielding nothing. + if not indices: + continue batch = batch.take(indices) # Add to our collection @@ -481,10 +492,16 @@ def _shard_scan( accumulated_batches.append(big_batch) # deliver any batches left over, they will be <= batch # size but that is ok because we are done - last_batch = ( - pa.Table.from_batches(accumulated_batches).combine_chunks().to_batches()[0] - ) - yield last_batch + # A filter matching zero rows (or a rank that receives no rows after + # round-robin) must yield an empty stream, not crash. The unfiltered + # _sample_all path already handles emptiness; this mirrors that contract. + if accumulated_batches: + last_batch = ( + pa.Table.from_batches(accumulated_batches) + .combine_chunks() + .to_batches()[0] + ) + yield last_batch def _sample_filtered( self, diff --git a/python/python/tests/test_sampler.py b/python/python/tests/test_sampler.py index 5136587f154..ca786c29c56 100644 --- a/python/python/tests/test_sampler.py +++ b/python/python/tests/test_sampler.py @@ -180,6 +180,84 @@ def test_filtered_data_handling(sample_dataset): assert all(id_val % 4 == 0 for id_val in all_ids), "Should keep rank 0 shard" +def _filtered_ids(sampler, ds, filter): + batches = list(sampler(ds, batch_size=1, columns=["id"], filter=filter)) + return [row for batch in batches for row in batch.column("id").to_pylist()] + + +@pytest.mark.parametrize("randomize", [False, True]) +def test_sharded_batch_sampler_empty_filtered_shard(tmp_path, randomize): + """Empty filtered shards must yield an empty stream, not crash. + + ShardedBatchSampler._shard_scan has two empty cases on the filtered path: + + 1. Global zero-match: the filter matches no rows, so nothing accumulates and + the final pa.Table.from_batches([]) raises ValueError. + 2. Empty per-rank shard: a sparse filter with world_size > 1 leaves some + ranks with zero rows after round-robin, so batch.take([]) builds a + null-typed index array and raises ArrowNotImplementedError. + + A rank with no matching rows must produce an empty iterator, matching the + unfiltered (_sample_all) path's contract. + """ + uri = str(tmp_path / "ds") + lance.write_dataset( + pa.table({"id": list(range(6)), "seg": [f"s{i % 2}" for i in range(6)]}), + uri, + ) + ds = lance.dataset(uri) + + # 3 rows match seg='s1' (ids 1,3,5); rank 3 of 4 gets none. + sampler = ShardedBatchSampler(rank=3, world_size=4, randomize=randomize) + assert _filtered_ids(sampler, ds, "seg IN ('s1')") == [] + + # Filter matching nothing -> empty stream, no crash. + sampler = ShardedBatchSampler(rank=0, world_size=1, randomize=randomize) + assert _filtered_ids(sampler, ds, "seg IN ('nope')") == [] + + # Non-empty shard unchanged: rank 0 of 4 from [1,3,5] -> [1]. + sampler = ShardedBatchSampler(rank=0, world_size=4, randomize=randomize) + assert _filtered_ids(sampler, ds, "seg IN ('s1')") == [1] + + +def test_sharded_batch_sampler_filtered_carryover_across_fragments(tmp_path): + """Round-robin offset must carry correctly across small filtered batches. + + When a filtered scan splits matches across several small batches/fragments, + a rank's offset has to carry from one batch to the next. If the carry is + computed from a sliced (truncated) batch instead of the full batch length, + ranks desync and rows are dropped or double-assigned. Here every row matches + and each fragment holds a single row, so the per-rank shards must partition + the global sequence exactly. + """ + uri = str(tmp_path / "ds") + num_rows = 8 + for i in range(num_rows): + lance.write_dataset( + pa.table({"id": [i], "seg": ["s1"]}), + uri, + mode="overwrite" if i == 0 else "append", + ) + ds = lance.dataset(uri) + + world_size = 4 + shards = { + rank: _filtered_ids( + ShardedBatchSampler(rank=rank, world_size=world_size), + ds, + "seg IN ('s1')", + ) + for rank in range(world_size) + } + + # Each rank owns global positions rank, rank + world_size, ...; ids equal + # their global position here, so the shards are a disjoint, complete cover. + for rank in range(world_size): + assert shards[rank] == list(range(rank, num_rows, world_size)) + combined = sorted(i for ids in shards.values() for i in ids) + assert combined == list(range(num_rows)) + + def test_randomization_effect(): """Verify epoch-based randomization behavior.""" # Initialize randomized sampler