fix(python): handle empty filtered shards in ShardedBatchSampler#7366
Draft
xushiyan wants to merge 1 commit into
Draft
fix(python): handle empty filtered shards in ShardedBatchSampler#7366xushiyan wants to merge 1 commit into
xushiyan wants to merge 1 commit into
Conversation
ShardedBatchSampler._shard_scan crashed on the filtered path instead of yielding an empty stream in two cases: - Global zero-match: nothing accumulates, so the final pa.Table.from_batches([]) raised ValueError. - Empty per-rank shard: a sparse filter with world_size > 1 leaves some ranks with no rows after round-robin, so batch.take([]) built a null-typed index array and raised ArrowNotImplementedError. Compute take indices against the full batch (range(rows_to_skip, n, N)) and carry rows_to_skip from the full batch length so the round-robin offset stays correct when a filtered scan splits matches across many small batches/fragments; skip batches where this rank owns no row; and make the final flush conditional. Empty shards now yield an empty iterator, matching the unfiltered _sample_all contract. Adds regression coverage for the empty-shard, zero-match, and cross-fragment carryover cases (parametrized over randomize).
624fcb7 to
de13330
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
ShardedBatchSampler._shard_scanso filtered scans with zero matches or empty per-rank shards yield an empty stream instead of raising.rows_to_skipfrom the full batch length, not a sliced view.batch.take([])when a rank owns no rows in a batch (PyArrow has noarray_take(int64, null)kernel).Problem
When
ShardedBatchSampleris called with afilter, the filtered path can crash in two cases:pa.Table.from_batches([])raisesValueError.world_size > 1leaves some ranks with zero rows ->batch.take([])raisesArrowNotImplementedError.The unfiltered path already handles emptiness; this aligns the filtered path with that contract.
Test plan
test_sharded_batch_sampler_empty_filtered_shard(randomize on/off)test_sharded_batch_sampler_filtered_carryover_across_fragmentsuv run pytest python/tests/test_sampler.py(full sampler suite)Found during distributed training work with sparse segment filters.