Skip to content

Commit d913fa9

Browse files
committed
dask.distributed: do not rechunk
1 parent ccba3dd commit d913fa9

3 files changed

Lines changed: 186 additions & 31 deletions

File tree

doc/whats-new.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ v2.1.0 (unreleased)
99
of text messages for all differences in numpy, pandas, and xarray objects
1010
- New function :func:`display_diffs` that displays differences
1111
in Jupyter notebooks
12-
- Fixed issue that would cause excessive RAM usage when comparing Dask arrays with
13-
2+ dimensions using a distributed scheduler
12+
- Fixed issues that would cause slowdowns and excessive RAM usage when comparing Dask
13+
arrays with 2+ dimensions using a distributed scheduler
1414
- Added support for P2P rechunk in Dask distributed
1515

1616

recursive_diff/core.py

Lines changed: 119 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
from __future__ import annotations
88

9+
import itertools
910
import math
11+
import operator
1012
import re
1113
import typing
1214
from collections.abc import Callable, Collection, Generator, Hashable
@@ -422,10 +424,15 @@ def _diff_dataarrays(
422424
lhs, rhs = xarray.align(lhs, rhs, join="inner")
423425

424426
is_dask = lhs.chunks is not None or rhs.chunks is not None
425-
if is_dask and lhs.chunks is None:
426-
lhs = lhs.chunk(dict(zip(rhs.dims, rhs.chunks))) # type: ignore[arg-type]
427-
elif is_dask and rhs.chunks is None:
428-
rhs = rhs.chunk(dict(zip(lhs.dims, lhs.chunks))) # type: ignore[arg-type]
427+
if is_dask:
428+
import dask.array as da
429+
430+
# Ensure that both lhs and rhs are Dask arrays and that they
431+
# have aligned chunks
432+
lhs_data, rhs_data = da.broadcast_arrays(lhs.data, rhs.data)
433+
lhs = lhs.copy(deep=False, data=lhs_data)
434+
rhs = rhs.copy(deep=False, data=rhs_data)
435+
assert lhs.chunks == rhs.chunks
429436

430437
# Generate a bit-mask of the differences
431438
# For Dask-backed arrays, this operation is delayed.
@@ -483,33 +490,37 @@ def _diff_dataarrays(
483490
# non-brief dim, with potentially repeated indices
484491
# All of the arrays will have the same size, which is the number of differences.
485492
# For Dask-backed arrays, this whole operation is delayed.
493+
diffs_idx: tuple[np.ndarray | Array, ...]
486494

487495
if brief_axes:
488496
diffs_count = mask.astype(int).sum(axis=tuple(brief_axes))
489497
mask = diffs_count > 0
490-
if mask.ndim:
491-
diffs_count = diffs_count[mask]
498+
if is_dask:
499+
assert isinstance(mask, Array)
500+
# a[mask] is very slow in Dask for 2+ dimensional arrays because it needs to
501+
# preserve the order of the returned elements, so it involves rechunking. Under
502+
# the assumption that the number of differences is << the number of total
503+
# elements, filter each chunk independently and then full-sort the results by
504+
# index.
505+
diffs_idx, sort_indices = _fast_dask_nonzero(mask)
506+
if brief_axes:
507+
if mask.ndim:
508+
diffs_count = _fast_dask_mask(diffs_count, mask, sort_indices)
509+
else:
510+
diffs_lhs = _fast_dask_mask(lhs.data, mask, sort_indices)
511+
diffs_rhs = _fast_dask_mask(rhs.data, mask, sort_indices)
492512
else:
493-
diffs_lhs = lhs.data[mask]
494-
diffs_rhs = rhs.data[mask]
495-
496-
diffs_idx = []
497-
for axis, size in enumerate(mask.shape):
498-
idx_shape = (1,) * axis + (-1,) + (1,) * (mask.ndim - axis - 1)
499-
if is_dask:
500-
import dask.array as da
501-
502-
assert isinstance(mask, da.Array)
503-
idx = da.arange(size, chunks=mask.chunks[axis])
504-
idx = idx.reshape(idx_shape)
505-
idx = da.broadcast_to(idx, mask.shape, chunks=mask.chunks)
513+
assert isinstance(mask, (np.ndarray, np.generic))
514+
if brief_axes:
515+
if mask.ndim:
516+
diffs_idx = np.nonzero(mask)
517+
diffs_count = diffs_count[mask]
518+
else:
519+
diffs_idx = np.array([], dtype=int)
506520
else:
507-
idx = np.arange(size)
508-
idx = idx.reshape(idx_shape)
509-
idx = np.broadcast_to(idx, mask.shape)
510-
511-
idx = idx[mask]
512-
diffs_idx.append(idx)
521+
diffs_idx = np.nonzero(mask)
522+
diffs_lhs = lhs.data[mask]
523+
diffs_rhs = rhs.data[mask]
513524

514525
msg_prefix = "".join(f"[{elem}]" for elem in path)
515526

@@ -542,7 +553,7 @@ def _diff_dataarrays(
542553

543554
rel_delta = da.map_blocks(_rel_delta, diffs_lhs, diffs_rhs, dtype=float)
544555
else:
545-
rel_delta = _rel_delta(diffs_lhs, diffs_rhs)
556+
rel_delta = _rel_delta(diffs_lhs, diffs_rhs) # type: ignore[arg-type]
546557
args = (diffs_lhs, diffs_rhs, abs_delta, rel_delta, *diffs_coords)
547558
build_df = partial(
548559
_build_dataframe,
@@ -581,6 +592,88 @@ def _diff_dataarrays(
581592
yield from pp_func(*args)
582593

583594

595+
def _fast_dask_nonzero(mask: Array) -> tuple[tuple[Array, ...], Array]:
596+
"""Variant of da.nonzero(mask), which is much faster when the number of
597+
nonzero elements is much smaller than the total.
598+
599+
Returns
600+
601+
- tuple of single-chunk arrays of shape (mask.ndim, number of differences),
602+
ordered as it would be returned by da.nonzero(mask)
603+
- single-chunk array of shape (number of differences, ) which is to be used
604+
by _fast_dask_mask to reorder the output.
605+
"""
606+
import dask
607+
import dask.array as da
608+
609+
chunk_offsets: list[list[int]] = [
610+
[0, *np.cumsum(c[:-1]).tolist()] for c in mask.chunks
611+
]
612+
f = dask.delayed(_fast_dask_nonzero_chunk, pure=True)
613+
delayeds = [
614+
f(chunk, chunk_offset)
615+
for chunk, chunk_offset in zip(
616+
mask.to_delayed().reshape(-1),
617+
itertools.product(*chunk_offsets),
618+
)
619+
]
620+
rechunked = dask.delayed(np.concatenate, pure=True)(delayeds, axis=1)
621+
nz = da.from_delayed(
622+
rechunked,
623+
shape=(mask.ndim, math.nan),
624+
dtype=int,
625+
meta=np.array([[]], dtype=int),
626+
)
627+
sort_indices = nz[::-1, :].map_blocks(
628+
np.lexsort,
629+
dtype=int,
630+
meta=np.array([], dtype=int),
631+
drop_axis=0,
632+
)
633+
634+
nz_sorted = nz.T.map_blocks(
635+
operator.getitem,
636+
sort_indices,
637+
dtype=int,
638+
meta=np.array([[]], dtype=int),
639+
).T
640+
return tuple(nz_sorted), sort_indices
641+
642+
643+
def _fast_dask_nonzero_chunk(
644+
mask_chunk: np.ndarray, offset: tuple[int, ...]
645+
) -> np.ndarray:
646+
nz_indices = np.stack(np.nonzero(mask_chunk))
647+
return nz_indices + np.array(offset)[:, None]
648+
649+
650+
def _fast_dask_mask(a: Array, mask: Array, sort_indices: Array) -> Array:
651+
"""Variant of a[mask], which does not preserve the order of the returned elements,
652+
which is much faster on Dask for 2+ dimensions arrays because it does not need
653+
rechunnking. Applying this function to multiple identically shaped **and chunked**
654+
arrays with the same mask will return objects in the same order.
655+
"""
656+
import dask
657+
import dask.array as da
658+
659+
f = dask.delayed(operator.getitem, pure=True)
660+
delayeds = [
661+
f(a_i, mask_i)
662+
for a_i, mask_i in zip(
663+
a.to_delayed().reshape(-1),
664+
mask.to_delayed().reshape(-1),
665+
)
666+
]
667+
rechunked = dask.delayed(np.concatenate, pure=True)(delayeds)
668+
sorted = dask.delayed(operator.getitem, pure=True)(rechunked, sort_indices)
669+
return da.from_delayed(
670+
sorted,
671+
shape=(math.nan,),
672+
dtype=a.dtype,
673+
meta=np.array([], dtype=a.dtype),
674+
)
675+
676+
584677
def _build_dataframe(
585678
column_names: list[str], index_names: list[str], *args: np.ndarray
586679
) -> pd.DataFrame:

recursive_diff/tests/test_recursive_diff.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@ def __repr__(self):
8383
return f"Square({self.side})"
8484

8585

86-
def check(lhs, rhs, *expect, **kwargs):
87-
expect = sorted(expect)
88-
actual = sorted(recursive_diff(lhs, rhs, **kwargs))
86+
def check(lhs, rhs, *expect, order=False, **kwargs):
87+
f = list if order else sorted
88+
expect = f(expect)
89+
actual = f(recursive_diff(lhs, rhs, **kwargs))
8990
assert actual == expect
9091

9192

@@ -1165,6 +1166,67 @@ def test_dask_dataarray(chunk_lhs, chunk_rhs):
11651166
check(lhs, rhs, "[data][x=2]: c != d")
11661167

11671168

1169+
@requires_dask
1170+
@pytest.mark.parametrize(
1171+
"chunk_lhs,chunk_rhs",
1172+
[
1173+
(None, None),
1174+
(None, -1),
1175+
(None, 2),
1176+
({"x": 3, "y": 1}, {"x": 2, "y": 2}),
1177+
],
1178+
)
1179+
def test_dask_dataarray_2d(chunk_lhs, chunk_rhs):
1180+
lhs = xarray.DataArray([[0, 1, 2], [3, 4, 5]], dims=["x", "y"])
1181+
rhs = xarray.DataArray([[0, 1, 2], [3, 4, 6]], dims=["x", "y"])
1182+
if chunk_lhs:
1183+
lhs = lhs.chunk(chunk_lhs)
1184+
if chunk_rhs:
1185+
rhs = rhs.chunk(chunk_rhs)
1186+
1187+
check(lhs, rhs, "[data][x=1, y=2]: 5 != 6 (abs: 1.0e+00, rel: 2.0e-01)")
1188+
1189+
1190+
def test_dask_dataarray_ordered(chunk):
1191+
"""Test that difference order goes in C order and is not influenced
1192+
by Dask chunks.
1193+
"""
1194+
lhs = xarray.DataArray(np.arange(2 * 3 * 4).reshape(2, 3, 4), dims=["x", "y", "z"])
1195+
rhs = lhs + 1
1196+
if chunk:
1197+
lhs = lhs.chunk({"x": 2, "y": 2, "z": 3})
1198+
rhs = rhs.chunk({"x": 2, "y": 2, "z": 3})
1199+
check(
1200+
lhs,
1201+
rhs,
1202+
"[data][x=0, y=0, z=0]: 0 != 1 (abs: 1.0e+00, rel: nan)",
1203+
"[data][x=0, y=0, z=1]: 1 != 2 (abs: 1.0e+00, rel: 1.0e+00)",
1204+
"[data][x=0, y=0, z=2]: 2 != 3 (abs: 1.0e+00, rel: 5.0e-01)",
1205+
"[data][x=0, y=0, z=3]: 3 != 4 (abs: 1.0e+00, rel: 3.3e-01)",
1206+
"[data][x=0, y=1, z=0]: 4 != 5 (abs: 1.0e+00, rel: 2.5e-01)",
1207+
"[data][x=0, y=1, z=1]: 5 != 6 (abs: 1.0e+00, rel: 2.0e-01)",
1208+
"[data][x=0, y=1, z=2]: 6 != 7 (abs: 1.0e+00, rel: 1.7e-01)",
1209+
"[data][x=0, y=1, z=3]: 7 != 8 (abs: 1.0e+00, rel: 1.4e-01)",
1210+
"[data][x=0, y=2, z=0]: 8 != 9 (abs: 1.0e+00, rel: 1.2e-01)",
1211+
"[data][x=0, y=2, z=1]: 9 != 10 (abs: 1.0e+00, rel: 1.1e-01)",
1212+
"[data][x=0, y=2, z=2]: 10 != 11 (abs: 1.0e+00, rel: 1.0e-01)",
1213+
"[data][x=0, y=2, z=3]: 11 != 12 (abs: 1.0e+00, rel: 9.1e-02)",
1214+
"[data][x=1, y=0, z=0]: 12 != 13 (abs: 1.0e+00, rel: 8.3e-02)",
1215+
"[data][x=1, y=0, z=1]: 13 != 14 (abs: 1.0e+00, rel: 7.7e-02)",
1216+
"[data][x=1, y=0, z=2]: 14 != 15 (abs: 1.0e+00, rel: 7.1e-02)",
1217+
"[data][x=1, y=0, z=3]: 15 != 16 (abs: 1.0e+00, rel: 6.7e-02)",
1218+
"[data][x=1, y=1, z=0]: 16 != 17 (abs: 1.0e+00, rel: 6.2e-02)",
1219+
"[data][x=1, y=1, z=1]: 17 != 18 (abs: 1.0e+00, rel: 5.9e-02)",
1220+
"[data][x=1, y=1, z=2]: 18 != 19 (abs: 1.0e+00, rel: 5.6e-02)",
1221+
"[data][x=1, y=1, z=3]: 19 != 20 (abs: 1.0e+00, rel: 5.3e-02)",
1222+
"[data][x=1, y=2, z=0]: 20 != 21 (abs: 1.0e+00, rel: 5.0e-02)",
1223+
"[data][x=1, y=2, z=1]: 21 != 22 (abs: 1.0e+00, rel: 4.8e-02)",
1224+
"[data][x=1, y=2, z=2]: 22 != 23 (abs: 1.0e+00, rel: 4.5e-02)",
1225+
"[data][x=1, y=2, z=3]: 23 != 24 (abs: 1.0e+00, rel: 4.3e-02)",
1226+
order=False,
1227+
)
1228+
1229+
11681230
@requires_dask
11691231
def test_dask_dataarray_discards_data():
11701232
"""Test that chunked Dask datasets are loaded into memory and then

0 commit comments

Comments
 (0)