Skip to content

Commit 0c26dcc

Browse files
committed
dask.distributed: do not rechunk
1 parent 51a47f6 commit 0c26dcc

1 file changed

Lines changed: 127 additions & 26 deletions

File tree

recursive_diff/core.py

Lines changed: 127 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
from __future__ import annotations
88

9+
import itertools
910
import math
11+
import operator
1012
import re
1113
import typing
1214
from collections.abc import Callable, Collection, Generator, Hashable
@@ -422,10 +424,15 @@ def _diff_dataarrays(
422424
lhs, rhs = xarray.align(lhs, rhs, join="inner")
423425

424426
is_dask = lhs.chunks is not None or rhs.chunks is not None
425-
if is_dask and lhs.chunks is None:
426-
lhs = lhs.chunk(dict(zip(rhs.dims, rhs.chunks))) # type: ignore[arg-type]
427-
elif is_dask and rhs.chunks is None:
428-
rhs = rhs.chunk(dict(zip(lhs.dims, lhs.chunks))) # type: ignore[arg-type]
427+
if is_dask:
428+
import dask.array as da
429+
430+
# Ensure that both lhs and rhs are Dask arrays and that they
431+
# have aligned chunks
432+
lhs_data, rhs_data = da.broadcast_arrays(lhs.data, rhs.data)
433+
lhs = lhs.copy(deep=False, data=lhs_data)
434+
rhs = rhs.copy(deep=False, data=rhs_data)
435+
assert lhs.chunks == rhs.chunks
429436

430437
# Generate a bit-mask of the differences
431438
# For Dask-backed arrays, this operation is delayed.
@@ -477,33 +484,37 @@ def _diff_dataarrays(
477484
# non-brief dim, with potentially repeated indices
478485
# All of the arrays will have the same size, which is the number of differences.
479486
# For Dask-backed arrays, this whole operation is delayed.
487+
diffs_idx: tuple[np.ndarray | Array, ...]
480488

481489
if brief_axes:
482490
diffs_count = mask.astype(int).sum(axis=tuple(brief_axes))
483491
mask = diffs_count > 0
484-
if mask.ndim:
485-
diffs_count = diffs_count[mask]
492+
if is_dask:
493+
assert isinstance(mask, Array)
494+
# a[mask] is very slow in Dask for 2+ dimensional arrays because it needs to
495+
# preserve the order of the returned elements, so it involves rechunking. Under
496+
# the assumption that the number of differences is << the number of total
497+
# elements, filter each chunk independently and then full-sort the results by
498+
# index.
499+
diffs_idx, sort_indices = _fast_dask_nonzero(mask)
500+
if brief_axes:
501+
if mask.ndim:
502+
diffs_count = _fast_dask_mask(diffs_count, mask, sort_indices)
503+
else:
504+
diffs_lhs = _fast_dask_mask(lhs.data, mask, sort_indices)
505+
diffs_rhs = _fast_dask_mask(rhs.data, mask, sort_indices)
486506
else:
487-
diffs_lhs = lhs.data[mask]
488-
diffs_rhs = rhs.data[mask]
489-
490-
diffs_idx = []
491-
for axis, size in enumerate(mask.shape):
492-
idx_shape = (1,) * axis + (-1,) + (1,) * (mask.ndim - axis - 1)
493-
if is_dask:
494-
import dask.array as da
495-
496-
assert isinstance(mask, da.Array)
497-
idx = da.arange(size, chunks=mask.chunks[axis])
498-
idx = idx.reshape(idx_shape)
499-
idx = da.broadcast_to(idx, mask.shape, chunks=mask.chunks)
507+
assert isinstance(mask, (np.ndarray, np.generic))
508+
if brief_axes:
509+
if mask.ndim:
510+
diffs_idx = np.nonzero(mask)
511+
diffs_count = diffs_count[mask]
512+
else:
513+
diffs_idx = ()
500514
else:
501-
idx = np.arange(size)
502-
idx = idx.reshape(idx_shape)
503-
idx = np.broadcast_to(idx, mask.shape)
504-
505-
idx = idx[mask]
506-
diffs_idx.append(idx)
515+
diffs_idx = np.nonzero(mask)
516+
diffs_lhs = lhs.data[mask]
517+
diffs_rhs = rhs.data[mask]
507518

508519
msg_prefix = "".join(f"[{elem}]" for elem in path)
509520

@@ -536,7 +547,7 @@ def _diff_dataarrays(
536547

537548
rel_delta = da.map_blocks(_rel_delta, diffs_lhs, diffs_rhs, dtype=float)
538549
else:
539-
rel_delta = _rel_delta(diffs_lhs, diffs_rhs)
550+
rel_delta = _rel_delta(diffs_lhs, diffs_rhs) # type: ignore[arg-type]
540551
args = (diffs_lhs, diffs_rhs, abs_delta, rel_delta, *diffs_coords)
541552
build_df = partial(
542553
_build_dataframe,
@@ -575,6 +586,96 @@ def _diff_dataarrays(
575586
yield from pp_func(*args)
576587

577588

589+
def _fast_dask_nonzero(mask: Array) -> tuple[tuple[Array, ...], Array]:
590+
"""Variant of da.nonzero(mask), which is much faster when the number of
591+
nonzero elements is much smaller than the total.
592+
593+
Returns:
594+
595+
- tuple of arrays of shape (nan, ), one array per axis, one point per nonzero
596+
element, just like da.nonzero(mask)
597+
- matching array of shape (nan, ) which is to be used by _fast_dask_mask to reorder
598+
the output.
599+
"""
600+
import dask
601+
import dask.array as da
602+
603+
# 1. Apply np.nonzero() to each chunk independently and add the
604+
# coordinates of the top-left corner of the chunk to the output
605+
chunk_offsets: list[list[int]] = [
606+
[0, *np.cumsum(c[:-1]).tolist()] for c in mask.chunks
607+
]
608+
f = dask.delayed(_fast_dask_nonzero_chunk, pure=True)
609+
delayeds = [
610+
f(chunk, chunk_offset)
611+
for chunk, chunk_offset in zip(
612+
mask.to_delayed().reshape(-1),
613+
itertools.product(*chunk_offsets),
614+
)
615+
]
616+
# 2. rechunk to a single chunk (needed for sorting)
617+
rechunked = dask.delayed(np.concatenate, pure=True)(delayeds, axis=1)
618+
nz = da.from_delayed(
619+
rechunked,
620+
shape=(mask.ndim, math.nan),
621+
dtype=int,
622+
meta=np.array([[]], dtype=int),
623+
)
624+
# 3. Get the order in which np.nonzero() would have returned the output
625+
sort_indices = nz[::-1, :].map_blocks(
626+
np.lexsort,
627+
dtype=int,
628+
meta=np.array([], dtype=int),
629+
drop_axis=0,
630+
)
631+
# 4. Reorder
632+
nz_sorted = nz.T.map_blocks(
633+
operator.getitem,
634+
sort_indices,
635+
dtype=int,
636+
meta=np.array([[]], dtype=int),
637+
).T
638+
return tuple(nz_sorted), sort_indices
639+
640+
641+
def _fast_dask_nonzero_chunk(
642+
mask_chunk: np.ndarray, offset: tuple[int, ...]
643+
) -> np.ndarray:
644+
nz_indices = np.stack(np.nonzero(mask_chunk))
645+
return nz_indices + np.array(offset)[:, None]
646+
647+
648+
def _fast_dask_mask(a: Array, mask: Array, sort_indices: Array) -> Array:
649+
"""Variant of a[mask], which is much faster when the number of
650+
True points in the mask is much smaller than the total.
651+
Applying this function to multiple identically shaped **and chunked**
652+
arrays with the same mask will return objects in the same order.
653+
"""
654+
import dask
655+
import dask.array as da
656+
657+
# 1. Apply a[mask] to each chunk independelty
658+
f = dask.delayed(operator.getitem, pure=True)
659+
delayeds = [
660+
f(a_i, mask_i)
661+
for a_i, mask_i in zip(
662+
a.to_delayed().reshape(-1),
663+
mask.to_delayed().reshape(-1),
664+
)
665+
]
666+
# 2. rechunk to a single chunk (needed by a[b], where a has shape=(nan, )
667+
# and b is an integer array
668+
rechunked = dask.delayed(np.concatenate, pure=True)(delayeds)
669+
# 3. Sort the results to match a[mask]
670+
sorted = dask.delayed(operator.getitem, pure=True)(rechunked, sort_indices)
671+
return da.from_delayed(
672+
sorted,
673+
shape=(math.nan,),
674+
dtype=a.dtype,
675+
meta=np.array([], dtype=a.dtype),
676+
)
677+
678+
578679
def _build_dataframe(
579680
column_names: list[str], index_names: list[str], *args: np.ndarray
580681
) -> pd.DataFrame:

0 commit comments

Comments
 (0)