@@ -597,30 +597,57 @@ def _fast_dask_nonzero(mask: Array) -> tuple[tuple[Array, ...], Array]:
597597 - matching array of shape (nan, ) which is to be used by _fast_dask_mask to reorder
598598 the output.
599599 """
600- import dask
601- import dask .array as da
600+ from dask .base import tokenize
601+ from dask .core import flatten
602+ from dask .highlevelgraph import HighLevelGraph
603+
604+ try :
605+ from dask .base import List , Task , TaskRef
606+ except ImportError :
607+ List = lambda x : x # type: ignore[misc,assignment] # noqa: E731
608+ Task = lambda _ , f , * args : (f , * args ) # type: ignore[misc,assignment] # noqa: E731
609+ TaskRef = lambda x : x # type: ignore[misc,assignment] # noqa: E731
602610
603611 # 1. Apply np.nonzero() to each chunk independently and add the
604612 # coordinates of the top-left corner of the chunk to the output
605613 chunk_offsets : list [list [int ]] = [
606614 [0 , * np .cumsum (c [:- 1 ]).tolist ()] for c in mask .chunks
607615 ]
608- f = dask .delayed (_fast_dask_nonzero_chunk , pure = True )
609- delayeds = [
610- f (chunk , chunk_offset )
611- for chunk , chunk_offset in zip (
612- mask .to_delayed ().reshape (- 1 ),
613- itertools .product (* chunk_offsets ),
616+
617+ tok = tokenize (mask )
618+ name1 = f"fast_nonzero-{ tok } "
619+ layer1 = {}
620+ for i , (key , offset ) in enumerate (
621+ zip (flatten (mask .__dask_keys__ ()), itertools .product (* chunk_offsets ))
622+ ):
623+ layer1 [name1 , 0 , i ] = Task (
624+ (name1 , 0 , i ), _fast_dask_nonzero_chunk , TaskRef (key ), offset
614625 )
615- ]
616- # 2. rechunk to a single chunk (needed for sorting)
617- rechunked = dask .delayed (np .concatenate , pure = True )(delayeds , axis = 1 )
618- nz = da .from_delayed (
619- rechunked ,
626+ hlg1 = HighLevelGraph .from_collections (name1 , layer1 , dependencies = [mask ]) # type: ignore[list-item]
627+
628+ # 2. Rechunk to a single chunk
629+ name2 = f"fast_nonzero_rechunk-{ tok } "
630+ layer2 = {
631+ (name2 , 0 , 0 ): Task (
632+ (name2 , 0 , 0 ),
633+ partial (np .concatenate , axis = 1 ),
634+ List ([TaskRef (key ) for key in layer1 ]),
635+ )
636+ }
637+ hlg2 = HighLevelGraph (
638+ {** hlg1 .layers , name2 : layer2 },
639+ {** hlg1 .dependencies , name2 : {name1 }},
640+ )
641+
642+ nz = Array (
643+ dask = hlg2 ,
644+ name = name2 ,
620645 shape = (mask .ndim , math .nan ),
646+ chunks = ((mask .ndim ,), (math .nan ,)),
621647 dtype = int ,
622648 meta = np .array ([[]], dtype = int ),
623649 )
650+
624651 # 3. Get the order in which np.nonzero() would have returned the output
625652 sort_indices = nz [::- 1 , :].map_blocks (
626653 np .lexsort ,
@@ -651,31 +678,55 @@ def _fast_dask_mask(a: Array, mask: Array, sort_indices: Array) -> Array:
651678 Applying this function to multiple identically shaped **and chunked**
652679 arrays with the same mask will return objects in the same order.
653680 """
654- import dask
655- import dask .array as da
681+
682+ from dask .base import tokenize
683+ from dask .core import flatten
684+ from dask .highlevelgraph import HighLevelGraph
685+
686+ try :
687+ from dask .base import List , Task , TaskRef
688+ except ImportError :
689+ List = lambda x : x # type: ignore[misc,assignment] # noqa: E731
690+ Task = lambda _ , f , * args : (f , * args ) # type: ignore[misc,assignment] # noqa: E731
691+ TaskRef = lambda x : x # type: ignore[misc,assignment] # noqa: E731
656692
657693 # 1. Apply a[mask] to each chunk independelty
658- f = dask .delayed (operator .getitem , pure = True )
659- delayeds = [
660- f (a_i , mask_i )
661- for a_i , mask_i in zip (
662- a .to_delayed ().reshape (- 1 ),
663- mask .to_delayed ().reshape (- 1 ),
664- )
665- ]
666- # 2. rechunk to a single chunk (needed by a[b], where a has shape=(nan, )
667- # and b is an integer array
668- rechunked = dask .delayed (np .concatenate , pure = True )(delayeds )
694+ # The reported shape and chunks are bogus here!
695+ masked = a .map_blocks (operator .getitem , mask , dtype = a .dtype , meta = a ._meta ) # type: ignore[arg-type]
696+
697+ # 2. rechunk to a single chunk
669698 # 3. Sort the results to match a[mask]
670- sorted = dask .delayed (operator .getitem , pure = True )(rechunked , sort_indices )
671- return da .from_delayed (
672- sorted ,
699+ tok = tokenize (masked )
700+ name = f"fast_mask_merge-{ tok } "
701+ layer = {
702+ (name , 0 ): Task (
703+ (name , 0 ),
704+ _fast_dask_mask_chunk ,
705+ List ([TaskRef (key ) for key in flatten (masked .__dask_keys__ ())]),
706+ TaskRef (sort_indices .__dask_keys__ ()[0 ]),
707+ )
708+ }
709+ hlg = HighLevelGraph .from_collections (
710+ name ,
711+ layer ,
712+ dependencies = [masked , sort_indices ], # type: ignore[list-item]
713+ )
714+
715+ return Array (
716+ dask = hlg ,
717+ name = name ,
673718 shape = (math .nan ,),
719+ chunks = ((math .nan ,),),
674720 dtype = a .dtype ,
675721 meta = np .array ([], dtype = a .dtype ),
676722 )
677723
678724
725+ def _fast_dask_mask_chunk (chunks : list [np .ndarray ], order : np .ndarray ) -> np .ndarray :
726+ masked = np .concatenate (chunks )
727+ return masked [order ]
728+
729+
679730def _build_dataframe (
680731 column_names : list [str ], index_names : list [str ], * args : np .ndarray
681732) -> pd .DataFrame :
0 commit comments