Skip to content

Commit 93c3b4f

Browse files
committed
Manually build hlg
1 parent 0c26dcc commit 93c3b4f

1 file changed

Lines changed: 80 additions & 29 deletions

File tree

recursive_diff/core.py

Lines changed: 80 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -597,30 +597,57 @@ def _fast_dask_nonzero(mask: Array) -> tuple[tuple[Array, ...], Array]:
597597
- matching array of shape (nan, ) which is to be used by _fast_dask_mask to reorder
598598
the output.
599599
"""
600-
import dask
601-
import dask.array as da
600+
from dask.base import tokenize
601+
from dask.core import flatten
602+
from dask.highlevelgraph import HighLevelGraph
603+
604+
try:
605+
from dask.base import List, Task, TaskRef
606+
except ImportError:
607+
List = lambda x: x # type: ignore[misc,assignment] # noqa: E731
608+
Task = lambda _, f, *args: (f, *args) # type: ignore[misc,assignment] # noqa: E731
609+
TaskRef = lambda x: x # type: ignore[misc,assignment] # noqa: E731
602610

603611
# 1. Apply np.nonzero() to each chunk independently and add the
604612
# coordinates of the top-left corner of the chunk to the output
605613
chunk_offsets: list[list[int]] = [
606614
[0, *np.cumsum(c[:-1]).tolist()] for c in mask.chunks
607615
]
608-
f = dask.delayed(_fast_dask_nonzero_chunk, pure=True)
609-
delayeds = [
610-
f(chunk, chunk_offset)
611-
for chunk, chunk_offset in zip(
612-
mask.to_delayed().reshape(-1),
613-
itertools.product(*chunk_offsets),
616+
617+
tok = tokenize(mask)
618+
name1 = f"fast_nonzero-{tok}"
619+
layer1 = {}
620+
for i, (key, offset) in enumerate(
621+
zip(flatten(mask.__dask_keys__()), itertools.product(*chunk_offsets))
622+
):
623+
layer1[name1, 0, i] = Task(
624+
(name1, 0, i), _fast_dask_nonzero_chunk, TaskRef(key), offset
614625
)
615-
]
616-
# 2. rechunk to a single chunk (needed for sorting)
617-
rechunked = dask.delayed(np.concatenate, pure=True)(delayeds, axis=1)
618-
nz = da.from_delayed(
619-
rechunked,
626+
hlg1 = HighLevelGraph.from_collections(name1, layer1, dependencies=[mask]) # type: ignore[list-item]
627+
628+
# 2. Rechunk to a single chunk
629+
name2 = f"fast_nonzero_rechunk-{tok}"
630+
layer2 = {
631+
(name2, 0, 0): Task(
632+
(name2, 0, 0),
633+
partial(np.concatenate, axis=1),
634+
List([TaskRef(key) for key in layer1]),
635+
)
636+
}
637+
hlg2 = HighLevelGraph(
638+
{**hlg1.layers, name2: layer2},
639+
{**hlg1.dependencies, name2: {name1}},
640+
)
641+
642+
nz = Array(
643+
dask=hlg2,
644+
name=name2,
620645
shape=(mask.ndim, math.nan),
646+
chunks=((mask.ndim,), (math.nan,)),
621647
dtype=int,
622648
meta=np.array([[]], dtype=int),
623649
)
650+
624651
# 3. Get the order in which np.nonzero() would have returned the output
625652
sort_indices = nz[::-1, :].map_blocks(
626653
np.lexsort,
@@ -651,31 +678,55 @@ def _fast_dask_mask(a: Array, mask: Array, sort_indices: Array) -> Array:
651678
Applying this function to multiple identically shaped **and chunked**
652679
arrays with the same mask will return objects in the same order.
653680
"""
654-
import dask
655-
import dask.array as da
681+
682+
from dask.base import tokenize
683+
from dask.core import flatten
684+
from dask.highlevelgraph import HighLevelGraph
685+
686+
try:
687+
from dask.base import List, Task, TaskRef
688+
except ImportError:
689+
List = lambda x: x # type: ignore[misc,assignment] # noqa: E731
690+
Task = lambda _, f, *args: (f, *args) # type: ignore[misc,assignment] # noqa: E731
691+
TaskRef = lambda x: x # type: ignore[misc,assignment] # noqa: E731
656692

657693
# 1. Apply a[mask] to each chunk independelty
658-
f = dask.delayed(operator.getitem, pure=True)
659-
delayeds = [
660-
f(a_i, mask_i)
661-
for a_i, mask_i in zip(
662-
a.to_delayed().reshape(-1),
663-
mask.to_delayed().reshape(-1),
664-
)
665-
]
666-
# 2. rechunk to a single chunk (needed by a[b], where a has shape=(nan, )
667-
# and b is an integer array
668-
rechunked = dask.delayed(np.concatenate, pure=True)(delayeds)
694+
# The reported shape and chunks are bogus here!
695+
masked = a.map_blocks(operator.getitem, mask, dtype=a.dtype, meta=a._meta) # type: ignore[arg-type]
696+
697+
# 2. rechunk to a single chunk
669698
# 3. Sort the results to match a[mask]
670-
sorted = dask.delayed(operator.getitem, pure=True)(rechunked, sort_indices)
671-
return da.from_delayed(
672-
sorted,
699+
tok = tokenize(masked)
700+
name = f"fast_mask_merge-{tok}"
701+
layer = {
702+
(name, 0): Task(
703+
(name, 0),
704+
_fast_dask_mask_chunk,
705+
List([TaskRef(key) for key in flatten(masked.__dask_keys__())]),
706+
TaskRef(sort_indices.__dask_keys__()[0]),
707+
)
708+
}
709+
hlg = HighLevelGraph.from_collections(
710+
name,
711+
layer,
712+
dependencies=[masked, sort_indices], # type: ignore[list-item]
713+
)
714+
715+
return Array(
716+
dask=hlg,
717+
name=name,
673718
shape=(math.nan,),
719+
chunks=((math.nan,),),
674720
dtype=a.dtype,
675721
meta=np.array([], dtype=a.dtype),
676722
)
677723

678724

725+
def _fast_dask_mask_chunk(chunks: list[np.ndarray], order: np.ndarray) -> np.ndarray:
726+
masked = np.concatenate(chunks)
727+
return masked[order]
728+
729+
679730
def _build_dataframe(
680731
column_names: list[str], index_names: list[str], *args: np.ndarray
681732
) -> pd.DataFrame:

0 commit comments

Comments
 (0)