Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions python/python/lance/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,12 +452,23 @@ def _shard_scan(
filter=filter,
scan_in_order=True,
).to_batches():
batch = batch.slice(rows_to_skip, batch.num_rows - rows_to_skip)
# Take every Nth row
indices = list(range(0, batch.num_rows, self._world_size))
rows_to_skip = (
self._world_size - (batch.num_rows % self._world_size)
) % self._world_size
# Take this rank's rows by their global round-robin position. The
# scan is in order, so concatenating every batch reproduces the full
# filtered row sequence; this rank owns positions rank, rank + N,
# rank + 2N, ... (N = world size). `rows_to_skip` is the offset of
# this rank's next row within the current batch.
indices = list(range(rows_to_skip, batch.num_rows, self._world_size))
# Carry the offset into the next batch from the *full* batch length,
# not a sliced view. A filtered scan can split matches across many
# small batches/fragments, and computing the carry from a truncated
# batch would desync the round-robin and mis-assign rows.
rows_to_skip = (rows_to_skip - batch.num_rows) % self._world_size
# This rank may own no row in this batch (sparse matches with
# world_size > 1). Skip it: batch.take([]) infers a null-typed index
# array and PyArrow has no array_take(int64, null) kernel, so an empty
# take raises instead of yielding nothing.
if not indices:
continue
batch = batch.take(indices)

# Add to our collection
Expand All @@ -481,10 +492,16 @@ def _shard_scan(
accumulated_batches.append(big_batch)
# deliver any batches left over, they will be <= batch
# size but that is ok because we are done
last_batch = (
pa.Table.from_batches(accumulated_batches).combine_chunks().to_batches()[0]
)
yield last_batch
# A filter matching zero rows (or a rank that receives no rows after
# round-robin) must yield an empty stream, not crash. The unfiltered
# _sample_all path already handles emptiness; this mirrors that contract.
if accumulated_batches:
last_batch = (
pa.Table.from_batches(accumulated_batches)
.combine_chunks()
.to_batches()[0]
)
yield last_batch

def _sample_filtered(
self,
Expand Down
78 changes: 78 additions & 0 deletions python/python/tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,84 @@ def test_filtered_data_handling(sample_dataset):
assert all(id_val % 4 == 0 for id_val in all_ids), "Should keep rank 0 shard"


def _filtered_ids(sampler, ds, filter):
batches = list(sampler(ds, batch_size=1, columns=["id"], filter=filter))
return [row for batch in batches for row in batch.column("id").to_pylist()]


@pytest.mark.parametrize("randomize", [False, True])
def test_sharded_batch_sampler_empty_filtered_shard(tmp_path, randomize):
"""Empty filtered shards must yield an empty stream, not crash.

ShardedBatchSampler._shard_scan has two empty cases on the filtered path:

1. Global zero-match: the filter matches no rows, so nothing accumulates and
the final pa.Table.from_batches([]) raises ValueError.
2. Empty per-rank shard: a sparse filter with world_size > 1 leaves some
ranks with zero rows after round-robin, so batch.take([]) builds a
null-typed index array and raises ArrowNotImplementedError.

A rank with no matching rows must produce an empty iterator, matching the
unfiltered (_sample_all) path's contract.
"""
uri = str(tmp_path / "ds")
lance.write_dataset(
pa.table({"id": list(range(6)), "seg": [f"s{i % 2}" for i in range(6)]}),
uri,
)
ds = lance.dataset(uri)

# 3 rows match seg='s1' (ids 1,3,5); rank 3 of 4 gets none.
sampler = ShardedBatchSampler(rank=3, world_size=4, randomize=randomize)
assert _filtered_ids(sampler, ds, "seg IN ('s1')") == []

# Filter matching nothing -> empty stream, no crash.
sampler = ShardedBatchSampler(rank=0, world_size=1, randomize=randomize)
assert _filtered_ids(sampler, ds, "seg IN ('nope')") == []

# Non-empty shard unchanged: rank 0 of 4 from [1,3,5] -> [1].
sampler = ShardedBatchSampler(rank=0, world_size=4, randomize=randomize)
assert _filtered_ids(sampler, ds, "seg IN ('s1')") == [1]


def test_sharded_batch_sampler_filtered_carryover_across_fragments(tmp_path):
"""Round-robin offset must carry correctly across small filtered batches.

When a filtered scan splits matches across several small batches/fragments,
a rank's offset has to carry from one batch to the next. If the carry is
computed from a sliced (truncated) batch instead of the full batch length,
ranks desync and rows are dropped or double-assigned. Here every row matches
and each fragment holds a single row, so the per-rank shards must partition
the global sequence exactly.
"""
uri = str(tmp_path / "ds")
num_rows = 8
for i in range(num_rows):
lance.write_dataset(
pa.table({"id": [i], "seg": ["s1"]}),
uri,
mode="overwrite" if i == 0 else "append",
)
ds = lance.dataset(uri)

world_size = 4
shards = {
rank: _filtered_ids(
ShardedBatchSampler(rank=rank, world_size=world_size),
ds,
"seg IN ('s1')",
)
for rank in range(world_size)
}

# Each rank owns global positions rank, rank + world_size, ...; ids equal
# their global position here, so the shards are a disjoint, complete cover.
for rank in range(world_size):
assert shards[rank] == list(range(rank, num_rows, world_size))
combined = sorted(i for ids in shards.values() for i in ids)
assert combined == list(range(num_rows))


def test_randomization_effect():
"""Verify epoch-based randomization behavior."""
# Initialize randomized sampler
Expand Down
Loading