@@ -596,16 +596,18 @@ def _fast_dask_nonzero(mask: Array) -> tuple[tuple[Array, ...], Array]:
596596 """Variant of da.nonzero(mask), which is much faster when the number of
597597 nonzero elements is much smaller than the total.
598598
599- Returns
599+ Returns:
600600
601- - tuple of single-chunk arrays of shape (mask.ndim, number of differences),
602- ordered as it would be returned by da.nonzero(mask)
603- - single-chunk array of shape (number of differences , ) which is to be used
604- by _fast_dask_mask to reorder the output.
601+ - tuple of arrays of shape (nan, ), one array per axis, one point per nonzero
602+ element, just like da.nonzero(mask)
603+ - matching array of shape (nan , ) which is to be used by _fast_dask_mask to reorder
604+ the output.
605605 """
606606 import dask
607607 import dask .array as da
608608
609+ # 1. Apply np.nonzero() to each chunk independently and add the
610+ # coordinates of the top-left corner of the chunk to the output
609611 chunk_offsets : list [list [int ]] = [
610612 [0 , * np .cumsum (c [:- 1 ]).tolist ()] for c in mask .chunks
611613 ]
@@ -617,20 +619,22 @@ def _fast_dask_nonzero(mask: Array) -> tuple[tuple[Array, ...], Array]:
617619 itertools .product (* chunk_offsets ),
618620 )
619621 ]
622+ # 2. rechunk to a single chunk (needed for sorting)
620623 rechunked = dask .delayed (np .concatenate , pure = True )(delayeds , axis = 1 )
621624 nz = da .from_delayed (
622625 rechunked ,
623626 shape = (mask .ndim , math .nan ),
624627 dtype = int ,
625628 meta = np .array ([[]], dtype = int ),
626629 )
630+ # 3. Get the order in which np.nonzero() would have returned the output
627631 sort_indices = nz [::- 1 , :].map_blocks (
628632 np .lexsort ,
629633 dtype = int ,
630634 meta = np .array ([], dtype = int ),
631635 drop_axis = 0 ,
632636 )
633-
637+ # 4. Reorder
634638 nz_sorted = nz .T .map_blocks (
635639 operator .getitem ,
636640 sort_indices ,
@@ -648,14 +652,15 @@ def _fast_dask_nonzero_chunk(
648652
649653
650654def _fast_dask_mask (a : Array , mask : Array , sort_indices : Array ) -> Array :
651- """Variant of a[mask], which does not preserve the order of the returned elements,
652- which is much faster on Dask for 2+ dimensions arrays because it does not need
653- rechunnking. Applying this function to multiple identically shaped **and chunked**
655+ """Variant of a[mask], which is much faster when the number of
656+ True points in the mask is much smaller than the total.
657+ Applying this function to multiple identically shaped **and chunked**
654658 arrays with the same mask will return objects in the same order.
655659 """
656660 import dask
657661 import dask .array as da
658662
663+ # 1. Apply a[mask] to each chunk independelty
659664 f = dask .delayed (operator .getitem , pure = True )
660665 delayeds = [
661666 f (a_i , mask_i )
@@ -664,7 +669,10 @@ def _fast_dask_mask(a: Array, mask: Array, sort_indices: Array) -> Array:
664669 mask .to_delayed ().reshape (- 1 ),
665670 )
666671 ]
672+ # 2. rechunk to a single chunk (needed by a[b], where a has shape=(nan, )
673+ # and b is an integer array
667674 rechunked = dask .delayed (np .concatenate , pure = True )(delayeds )
675+ # 3. Sort the results to match a[mask]
668676 sorted = dask .delayed (operator .getitem , pure = True )(rechunked , sort_indices )
669677 return da .from_delayed (
670678 sorted ,
0 commit comments