|
6 | 6 |
|
7 | 7 | from __future__ import annotations |
8 | 8 |
|
| 9 | +import itertools |
9 | 10 | import math |
| 11 | +import operator |
10 | 12 | import re |
11 | 13 | import typing |
12 | 14 | from collections.abc import Callable, Collection, Generator, Hashable |
@@ -422,10 +424,15 @@ def _diff_dataarrays( |
422 | 424 | lhs, rhs = xarray.align(lhs, rhs, join="inner") |
423 | 425 |
|
424 | 426 | is_dask = lhs.chunks is not None or rhs.chunks is not None |
425 | | - if is_dask and lhs.chunks is None: |
426 | | - lhs = lhs.chunk(dict(zip(rhs.dims, rhs.chunks))) # type: ignore[arg-type] |
427 | | - elif is_dask and rhs.chunks is None: |
428 | | - rhs = rhs.chunk(dict(zip(lhs.dims, lhs.chunks))) # type: ignore[arg-type] |
| 427 | + if is_dask: |
| 428 | + import dask.array as da |
| 429 | + |
| 430 | + # Ensure that both lhs and rhs are Dask arrays and that they |
| 431 | + # have aligned chunks |
| 432 | + lhs_data, rhs_data = da.broadcast_arrays(lhs.data, rhs.data) |
| 433 | + lhs = lhs.copy(deep=False, data=lhs_data) |
| 434 | + rhs = rhs.copy(deep=False, data=rhs_data) |
| 435 | + assert lhs.chunks == rhs.chunks |
429 | 436 |
|
430 | 437 | # Generate a bit-mask of the differences |
431 | 438 | # For Dask-backed arrays, this operation is delayed. |
@@ -477,33 +484,37 @@ def _diff_dataarrays( |
477 | 484 | # non-brief dim, with potentially repeated indices |
478 | 485 | # All of the arrays will have the same size, which is the number of differences. |
479 | 486 | # For Dask-backed arrays, this whole operation is delayed. |
| 487 | + diffs_idx: tuple[np.ndarray | Array, ...] |
480 | 488 |
|
481 | 489 | if brief_axes: |
482 | 490 | diffs_count = mask.astype(int).sum(axis=tuple(brief_axes)) |
483 | 491 | mask = diffs_count > 0 |
484 | | - if mask.ndim: |
485 | | - diffs_count = diffs_count[mask] |
| 492 | + if is_dask: |
| 493 | + assert isinstance(mask, Array) |
| 494 | + # a[mask] is very slow in Dask for 2+ dimensional arrays because it needs to |
| 495 | + # preserve the order of the returned elements, so it involves rechunking. Under |
| 496 | + # the assumption that the number of differences is << the number of total |
| 497 | + # elements, filter each chunk independently and then full-sort the results by |
| 498 | + # index. |
| 499 | + diffs_idx, sort_indices = _fast_dask_nonzero(mask) |
| 500 | + if brief_axes: |
| 501 | + if mask.ndim: |
| 502 | + diffs_count = _fast_dask_mask(diffs_count, mask, sort_indices) |
| 503 | + else: |
| 504 | + diffs_lhs = _fast_dask_mask(lhs.data, mask, sort_indices) |
| 505 | + diffs_rhs = _fast_dask_mask(rhs.data, mask, sort_indices) |
486 | 506 | else: |
487 | | - diffs_lhs = lhs.data[mask] |
488 | | - diffs_rhs = rhs.data[mask] |
489 | | - |
490 | | - diffs_idx = [] |
491 | | - for axis, size in enumerate(mask.shape): |
492 | | - idx_shape = (1,) * axis + (-1,) + (1,) * (mask.ndim - axis - 1) |
493 | | - if is_dask: |
494 | | - import dask.array as da |
495 | | - |
496 | | - assert isinstance(mask, da.Array) |
497 | | - idx = da.arange(size, chunks=mask.chunks[axis]) |
498 | | - idx = idx.reshape(idx_shape) |
499 | | - idx = da.broadcast_to(idx, mask.shape, chunks=mask.chunks) |
| 507 | + assert isinstance(mask, (np.ndarray, np.generic)) |
| 508 | + if brief_axes: |
| 509 | + if mask.ndim: |
| 510 | + diffs_idx = np.nonzero(mask) |
| 511 | + diffs_count = diffs_count[mask] |
| 512 | + else: |
| 513 | + diffs_idx = () |
500 | 514 | else: |
501 | | - idx = np.arange(size) |
502 | | - idx = idx.reshape(idx_shape) |
503 | | - idx = np.broadcast_to(idx, mask.shape) |
504 | | - |
505 | | - idx = idx[mask] |
506 | | - diffs_idx.append(idx) |
| 515 | + diffs_idx = np.nonzero(mask) |
| 516 | + diffs_lhs = lhs.data[mask] |
| 517 | + diffs_rhs = rhs.data[mask] |
507 | 518 |
|
508 | 519 | msg_prefix = "".join(f"[{elem}]" for elem in path) |
509 | 520 |
|
@@ -536,7 +547,7 @@ def _diff_dataarrays( |
536 | 547 |
|
537 | 548 | rel_delta = da.map_blocks(_rel_delta, diffs_lhs, diffs_rhs, dtype=float) |
538 | 549 | else: |
539 | | - rel_delta = _rel_delta(diffs_lhs, diffs_rhs) |
| 550 | + rel_delta = _rel_delta(diffs_lhs, diffs_rhs) # type: ignore[arg-type] |
540 | 551 | args = (diffs_lhs, diffs_rhs, abs_delta, rel_delta, *diffs_coords) |
541 | 552 | build_df = partial( |
542 | 553 | _build_dataframe, |
@@ -575,6 +586,96 @@ def _diff_dataarrays( |
575 | 586 | yield from pp_func(*args) |
576 | 587 |
|
577 | 588 |
|
| 589 | +def _fast_dask_nonzero(mask: Array) -> tuple[tuple[Array, ...], Array]: |
| 590 | + """Variant of da.nonzero(mask), which is much faster when the number of |
| 591 | + nonzero elements is much smaller than the total. |
| 592 | +
|
| 593 | + Returns: |
| 594 | +
|
| 595 | + - tuple of arrays of shape (nan, ), one array per axis, one point per nonzero |
| 596 | + element, just like da.nonzero(mask) |
| 597 | + - matching array of shape (nan, ) which is to be used by _fast_dask_mask to reorder |
| 598 | + the output. |
| 599 | + """ |
| 600 | + import dask |
| 601 | + import dask.array as da |
| 602 | + |
| 603 | + # 1. Apply np.nonzero() to each chunk independently and add the |
| 604 | + # coordinates of the top-left corner of the chunk to the output |
| 605 | + chunk_offsets: list[list[int]] = [ |
| 606 | + [0, *np.cumsum(c[:-1]).tolist()] for c in mask.chunks |
| 607 | + ] |
| 608 | + f = dask.delayed(_fast_dask_nonzero_chunk, pure=True) |
| 609 | + delayeds = [ |
| 610 | + f(chunk, chunk_offset) |
| 611 | + for chunk, chunk_offset in zip( |
| 612 | + mask.to_delayed().reshape(-1), |
| 613 | + itertools.product(*chunk_offsets), |
| 614 | + ) |
| 615 | + ] |
| 616 | + # 2. rechunk to a single chunk (needed for sorting) |
| 617 | + rechunked = dask.delayed(np.concatenate, pure=True)(delayeds, axis=1) |
| 618 | + nz = da.from_delayed( |
| 619 | + rechunked, |
| 620 | + shape=(mask.ndim, math.nan), |
| 621 | + dtype=int, |
| 622 | + meta=np.array([[]], dtype=int), |
| 623 | + ) |
| 624 | + # 3. Get the order in which np.nonzero() would have returned the output |
| 625 | + sort_indices = nz[::-1, :].map_blocks( |
| 626 | + np.lexsort, |
| 627 | + dtype=int, |
| 628 | + meta=np.array([], dtype=int), |
| 629 | + drop_axis=0, |
| 630 | + ) |
| 631 | + # 4. Reorder |
| 632 | + nz_sorted = nz.T.map_blocks( |
| 633 | + operator.getitem, |
| 634 | + sort_indices, |
| 635 | + dtype=int, |
| 636 | + meta=np.array([[]], dtype=int), |
| 637 | + ).T |
| 638 | + return tuple(nz_sorted), sort_indices |
| 639 | + |
| 640 | + |
| 641 | +def _fast_dask_nonzero_chunk( |
| 642 | + mask_chunk: np.ndarray, offset: tuple[int, ...] |
| 643 | +) -> np.ndarray: |
| 644 | + nz_indices = np.stack(np.nonzero(mask_chunk)) |
| 645 | + return nz_indices + np.array(offset)[:, None] |
| 646 | + |
| 647 | + |
| 648 | +def _fast_dask_mask(a: Array, mask: Array, sort_indices: Array) -> Array: |
| 649 | + """Variant of a[mask], which is much faster when the number of |
| 650 | + True points in the mask is much smaller than the total. |
| 651 | + Applying this function to multiple identically shaped **and chunked** |
| 652 | + arrays with the same mask will return objects in the same order. |
| 653 | + """ |
| 654 | + import dask |
| 655 | + import dask.array as da |
| 656 | + |
| 657 | + # 1. Apply a[mask] to each chunk independelty |
| 658 | + f = dask.delayed(operator.getitem, pure=True) |
| 659 | + delayeds = [ |
| 660 | + f(a_i, mask_i) |
| 661 | + for a_i, mask_i in zip( |
| 662 | + a.to_delayed().reshape(-1), |
| 663 | + mask.to_delayed().reshape(-1), |
| 664 | + ) |
| 665 | + ] |
| 666 | + # 2. rechunk to a single chunk (needed by a[b], where a has shape=(nan, ) |
| 667 | + # and b is an integer array |
| 668 | + rechunked = dask.delayed(np.concatenate, pure=True)(delayeds) |
| 669 | + # 3. Sort the results to match a[mask] |
| 670 | + sorted = dask.delayed(operator.getitem, pure=True)(rechunked, sort_indices) |
| 671 | + return da.from_delayed( |
| 672 | + sorted, |
| 673 | + shape=(math.nan,), |
| 674 | + dtype=a.dtype, |
| 675 | + meta=np.array([], dtype=a.dtype), |
| 676 | + ) |
| 677 | + |
| 678 | + |
578 | 679 | def _build_dataframe( |
579 | 680 | column_names: list[str], index_names: list[str], *args: np.ndarray |
580 | 681 | ) -> pd.DataFrame: |
|
0 commit comments