|
6 | 6 |
|
7 | 7 | from __future__ import annotations |
8 | 8 |
|
| 9 | +import itertools |
9 | 10 | import math |
| 11 | +import operator |
10 | 12 | import re |
11 | 13 | import typing |
12 | 14 | from collections.abc import Callable, Collection, Generator, Hashable |
@@ -422,10 +424,15 @@ def _diff_dataarrays( |
422 | 424 | lhs, rhs = xarray.align(lhs, rhs, join="inner") |
423 | 425 |
|
424 | 426 | is_dask = lhs.chunks is not None or rhs.chunks is not None |
425 | | - if is_dask and lhs.chunks is None: |
426 | | - lhs = lhs.chunk(dict(zip(rhs.dims, rhs.chunks))) # type: ignore[arg-type] |
427 | | - elif is_dask and rhs.chunks is None: |
428 | | - rhs = rhs.chunk(dict(zip(lhs.dims, lhs.chunks))) # type: ignore[arg-type] |
| 427 | + if is_dask: |
| 428 | + import dask.array as da |
| 429 | + |
| 430 | + # Ensure that both lhs and rhs are Dask arrays and that they |
| 431 | + # have aligned chunks |
| 432 | + lhs_data, rhs_data = da.broadcast_arrays(lhs.data, rhs.data) |
| 433 | + lhs = lhs.copy(deep=False, data=lhs_data) |
| 434 | + rhs = rhs.copy(deep=False, data=rhs_data) |
| 435 | + assert lhs.chunks == rhs.chunks |
429 | 436 |
|
430 | 437 | # Generate a bit-mask of the differences |
431 | 438 | # For Dask-backed arrays, this operation is delayed. |
@@ -483,33 +490,37 @@ def _diff_dataarrays( |
483 | 490 | # non-brief dim, with potentially repeated indices |
484 | 491 | # All of the arrays will have the same size, which is the number of differences. |
485 | 492 | # For Dask-backed arrays, this whole operation is delayed. |
| 493 | + diffs_idx: tuple[np.ndarray | Array, ...] |
486 | 494 |
|
487 | 495 | if brief_axes: |
488 | 496 | diffs_count = mask.astype(int).sum(axis=tuple(brief_axes)) |
489 | 497 | mask = diffs_count > 0 |
490 | | - if mask.ndim: |
491 | | - diffs_count = diffs_count[mask] |
| 498 | + if is_dask: |
| 499 | + assert isinstance(mask, Array) |
| 500 | + # a[mask] is very slow in Dask for 2+ dimensional arrays because it needs to |
| 501 | + # preserve the order of the returned elements, so it involves rechunking. Under |
| 502 | + # the assumption that the number of differences is << the number of total |
| 503 | + # elements, filter each chunk independently and then full-sort the results by |
| 504 | + # index. |
| 505 | + diffs_idx, sort_indices = _fast_dask_nonzero(mask) |
| 506 | + if brief_axes: |
| 507 | + if mask.ndim: |
| 508 | + diffs_count = _fast_dask_mask(diffs_count, mask, sort_indices) |
| 509 | + else: |
| 510 | + diffs_lhs = _fast_dask_mask(lhs.data, mask, sort_indices) |
| 511 | + diffs_rhs = _fast_dask_mask(rhs.data, mask, sort_indices) |
492 | 512 | else: |
493 | | - diffs_lhs = lhs.data[mask] |
494 | | - diffs_rhs = rhs.data[mask] |
495 | | - |
496 | | - diffs_idx = [] |
497 | | - for axis, size in enumerate(mask.shape): |
498 | | - idx_shape = (1,) * axis + (-1,) + (1,) * (mask.ndim - axis - 1) |
499 | | - if is_dask: |
500 | | - import dask.array as da |
501 | | - |
502 | | - assert isinstance(mask, da.Array) |
503 | | - idx = da.arange(size, chunks=mask.chunks[axis]) |
504 | | - idx = idx.reshape(idx_shape) |
505 | | - idx = da.broadcast_to(idx, mask.shape, chunks=mask.chunks) |
| 513 | + assert isinstance(mask, (np.ndarray, np.generic)) |
| 514 | + if brief_axes: |
| 515 | + if mask.ndim: |
| 516 | + diffs_idx = np.nonzero(mask) |
| 517 | + diffs_count = diffs_count[mask] |
| 518 | + else: |
| 519 | + diffs_idx = np.array([], dtype=int) |
506 | 520 | else: |
507 | | - idx = np.arange(size) |
508 | | - idx = idx.reshape(idx_shape) |
509 | | - idx = np.broadcast_to(idx, mask.shape) |
510 | | - |
511 | | - idx = idx[mask] |
512 | | - diffs_idx.append(idx) |
| 521 | + diffs_idx = np.nonzero(mask) |
| 522 | + diffs_lhs = lhs.data[mask] |
| 523 | + diffs_rhs = rhs.data[mask] |
513 | 524 |
|
514 | 525 | msg_prefix = "".join(f"[{elem}]" for elem in path) |
515 | 526 |
|
@@ -542,7 +553,7 @@ def _diff_dataarrays( |
542 | 553 |
|
543 | 554 | rel_delta = da.map_blocks(_rel_delta, diffs_lhs, diffs_rhs, dtype=float) |
544 | 555 | else: |
545 | | - rel_delta = _rel_delta(diffs_lhs, diffs_rhs) |
| 556 | + rel_delta = _rel_delta(diffs_lhs, diffs_rhs) # type: ignore[arg-type] |
546 | 557 | args = (diffs_lhs, diffs_rhs, abs_delta, rel_delta, *diffs_coords) |
547 | 558 | build_df = partial( |
548 | 559 | _build_dataframe, |
@@ -581,6 +592,88 @@ def _diff_dataarrays( |
581 | 592 | yield from pp_func(*args) |
582 | 593 |
|
583 | 594 |
|
| 595 | +def _fast_dask_nonzero(mask: Array) -> tuple[tuple[Array, ...], Array]: |
| 596 | + """Variant of da.nonzero(mask), which is much faster when the number of |
| 597 | + nonzero elements is much smaller than the total. |
| 598 | +
|
| 599 | + Returns |
| 600 | +
|
| 601 | + - tuple of single-chunk arrays of shape (mask.ndim, number of differences), |
| 602 | + ordered as it would be returned by da.nonzero(mask) |
| 603 | + - single-chunk array of shape (number of differences, ) which is to be used |
| 604 | + by _fast_dask_mask to reorder the output. |
| 605 | + """ |
| 606 | + import dask |
| 607 | + import dask.array as da |
| 608 | + |
| 609 | + chunk_offsets: list[list[int]] = [ |
| 610 | + [0, *np.cumsum(c[:-1]).tolist()] for c in mask.chunks |
| 611 | + ] |
| 612 | + f = dask.delayed(_fast_dask_nonzero_chunk, pure=True) |
| 613 | + delayeds = [ |
| 614 | + f(chunk, chunk_offset) |
| 615 | + for chunk, chunk_offset in zip( |
| 616 | + mask.to_delayed().reshape(-1), |
| 617 | + itertools.product(*chunk_offsets), |
| 618 | + ) |
| 619 | + ] |
| 620 | + rechunked = dask.delayed(np.concatenate, pure=True)(delayeds, axis=1) |
| 621 | + nz = da.from_delayed( |
| 622 | + rechunked, |
| 623 | + shape=(mask.ndim, math.nan), |
| 624 | + dtype=int, |
| 625 | + meta=np.array([[]], dtype=int), |
| 626 | + ) |
| 627 | + sort_indices = nz[::-1, :].map_blocks( |
| 628 | + np.lexsort, |
| 629 | + dtype=int, |
| 630 | + meta=np.array([], dtype=int), |
| 631 | + drop_axis=0, |
| 632 | + ) |
| 633 | + |
| 634 | + nz_sorted = nz.T.map_blocks( |
| 635 | + operator.getitem, |
| 636 | + sort_indices, |
| 637 | + dtype=int, |
| 638 | + meta=np.array([[]], dtype=int), |
| 639 | + ).T |
| 640 | + return tuple(nz_sorted), sort_indices |
| 641 | + |
| 642 | + |
| 643 | +def _fast_dask_nonzero_chunk( |
| 644 | + mask_chunk: np.ndarray, offset: tuple[int, ...] |
| 645 | +) -> np.ndarray: |
| 646 | + nz_indices = np.stack(np.nonzero(mask_chunk)) |
| 647 | + return nz_indices + np.array(offset)[:, None] |
| 648 | + |
| 649 | + |
| 650 | +def _fast_dask_mask(a: Array, mask: Array, sort_indices: Array) -> Array: |
| 651 | + """Variant of a[mask], which does not preserve the order of the returned elements, |
| 652 | + which is much faster on Dask for 2+ dimensions arrays because it does not need |
| 653 | + rechunnking. Applying this function to multiple identically shaped **and chunked** |
| 654 | + arrays with the same mask will return objects in the same order. |
| 655 | + """ |
| 656 | + import dask |
| 657 | + import dask.array as da |
| 658 | + |
| 659 | + f = dask.delayed(operator.getitem, pure=True) |
| 660 | + delayeds = [ |
| 661 | + f(a_i, mask_i) |
| 662 | + for a_i, mask_i in zip( |
| 663 | + a.to_delayed().reshape(-1), |
| 664 | + mask.to_delayed().reshape(-1), |
| 665 | + ) |
| 666 | + ] |
| 667 | + rechunked = dask.delayed(np.concatenate, pure=True)(delayeds) |
| 668 | + sorted = dask.delayed(operator.getitem, pure=True)(rechunked, sort_indices) |
| 669 | + return da.from_delayed( |
| 670 | + sorted, |
| 671 | + shape=(math.nan,), |
| 672 | + dtype=a.dtype, |
| 673 | + meta=np.array([], dtype=a.dtype), |
| 674 | + ) |
| 675 | + |
| 676 | + |
584 | 677 | def _build_dataframe( |
585 | 678 | column_names: list[str], index_names: list[str], *args: np.ndarray |
586 | 679 | ) -> pd.DataFrame: |
|
0 commit comments