Skip to content

Commit 5058e72

Browse files
committed
docs
1 parent 5bda8eb commit 5058e72

1 file changed

Lines changed: 17 additions & 9 deletions

File tree

recursive_diff/core.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -596,16 +596,18 @@ def _fast_dask_nonzero(mask: Array) -> tuple[tuple[Array, ...], Array]:
596596
"""Variant of da.nonzero(mask), which is much faster when the number of
597597
nonzero elements is much smaller than the total.
598598
599-
Returns
599+
Returns:
600600
601-
- tuple of single-chunk arrays of shape (mask.ndim, number of differences),
602-
ordered as it would be returned by da.nonzero(mask)
603-
- single-chunk array of shape (number of differences, ) which is to be used
604-
by _fast_dask_mask to reorder the output.
601+
- tuple of arrays of shape (nan, ), one array per axis, one point per nonzero
602+
element, just like da.nonzero(mask)
603+
- matching array of shape (nan, ) which is to be used by _fast_dask_mask to reorder
604+
the output.
605605
"""
606606
import dask
607607
import dask.array as da
608608

609+
# 1. Apply np.nonzero() to each chunk independently and add the
610+
# coordinates of the top-left corner of the chunk to the output
609611
chunk_offsets: list[list[int]] = [
610612
[0, *np.cumsum(c[:-1]).tolist()] for c in mask.chunks
611613
]
@@ -617,20 +619,22 @@ def _fast_dask_nonzero(mask: Array) -> tuple[tuple[Array, ...], Array]:
617619
itertools.product(*chunk_offsets),
618620
)
619621
]
622+
# 2. rechunk to a single chunk (needed for sorting)
620623
rechunked = dask.delayed(np.concatenate, pure=True)(delayeds, axis=1)
621624
nz = da.from_delayed(
622625
rechunked,
623626
shape=(mask.ndim, math.nan),
624627
dtype=int,
625628
meta=np.array([[]], dtype=int),
626629
)
630+
# 3. Get the order in which np.nonzero() would have returned the output
627631
sort_indices = nz[::-1, :].map_blocks(
628632
np.lexsort,
629633
dtype=int,
630634
meta=np.array([], dtype=int),
631635
drop_axis=0,
632636
)
633-
637+
# 4. Reorder
634638
nz_sorted = nz.T.map_blocks(
635639
operator.getitem,
636640
sort_indices,
@@ -648,14 +652,15 @@ def _fast_dask_nonzero_chunk(
648652

649653

650654
def _fast_dask_mask(a: Array, mask: Array, sort_indices: Array) -> Array:
651-
"""Variant of a[mask], which does not preserve the order of the returned elements,
652-
which is much faster on Dask for 2+ dimensions arrays because it does not need
653-
rechunnking. Applying this function to multiple identically shaped **and chunked**
655+
"""Variant of a[mask], which is much faster when the number of
656+
True points in the mask is much smaller than the total.
657+
Applying this function to multiple identically shaped **and chunked**
654658
arrays with the same mask will return objects in the same order.
655659
"""
656660
import dask
657661
import dask.array as da
658662

663+
# 1. Apply a[mask] to each chunk independelty
659664
f = dask.delayed(operator.getitem, pure=True)
660665
delayeds = [
661666
f(a_i, mask_i)
@@ -664,7 +669,10 @@ def _fast_dask_mask(a: Array, mask: Array, sort_indices: Array) -> Array:
664669
mask.to_delayed().reshape(-1),
665670
)
666671
]
672+
# 2. rechunk to a single chunk (needed by a[b], where a has shape=(nan, )
673+
# and b is an integer array
667674
rechunked = dask.delayed(np.concatenate, pure=True)(delayeds)
675+
# 3. Sort the results to match a[mask]
668676
sorted = dask.delayed(operator.getitem, pure=True)(rechunked, sort_indices)
669677
return da.from_delayed(
670678
sorted,

0 commit comments

Comments
 (0)