|
32 | 32 | normalize_axis_tuple, |
33 | 33 | ) |
34 | 34 |
|
35 | | -from legate.core import get_machine |
| 35 | +from cupynumeric.config import CuPyNumericOpCode |
| 36 | +from legate.core import ( |
| 37 | + get_machine, |
| 38 | + get_legate_runtime, |
| 39 | + ReductionOpKind, |
| 40 | + align, |
| 41 | + broadcast, |
| 42 | +) |
36 | 43 |
|
37 | 44 | from .._array.util import add_boilerplate, convert_to_cupynumeric_ndarray |
38 | 45 | from .._module import dot, empty_like, eye, matmul, ndarray |
39 | | -from .._module.creation_shape import zeros, zeros_like |
| 46 | +from .._module.array_rearrange import flip |
| 47 | +from .._module.creation_shape import empty, zeros, zeros_like |
| 48 | +from .._module.creation_matrices import diag |
| 49 | +from .._module.ssc_sorting import argsort |
40 | 50 | from .._ufunc.math import add, sqrt as _sqrt |
41 | 51 | from ._exception import LinAlgError |
42 | 52 |
|
@@ -1571,3 +1581,134 @@ def expm(a: ndarray, method: str = "pade") -> ndarray: |
1571 | 1581 | mdeg, s = expm_func(a[idx], output[idx]) |
1572 | 1582 |
|
1573 | 1583 | return output |
| 1584 | + |
| 1585 | + |
| 1586 | +@add_boilerplate("a") |
| 1587 | +def tssvd(a: ndarray) -> tuple[ndarray, ...]: |
| 1588 | + """ |
| 1589 | + Tall-skinny (TS) Singular Value Decomposition. |
| 1590 | +
|
| 1591 | + Parameters |
| 1592 | + ---------- |
| 1593 | + a : (M, N) array_like |
| 1594 | + Array like, dimension 2. |
| 1595 | +
|
| 1596 | + Returns |
| 1597 | + ------- |
| 1598 | + u : (M, N) array_like |
| 1599 | + Unitary array(s). |
| 1600 | + s : (N) array_like |
| 1601 | + The singular values, sorted in descending order |
| 1602 | + vh : (N, N) array_like |
| 1603 | + Unitary array(s). |
| 1604 | +
|
| 1605 | + Raises |
| 1606 | + ------ |
| 1607 | + LinAlgError |
| 1608 | + If TS-SVD computation does not converge. |
| 1609 | +
|
| 1610 | + Notes |
| 1611 | + ----- |
| 1612 | + This routine is only efficient if ``M >> N``. In particular, it assumes that |
| 1613 | + an ``(N, N)`` matrix can fit within a single processor memory. |
| 1614 | + |
| 1615 | + Implements the algorithm described in [1]_. |
| 1616 | + |
| 1617 | + Requires ``a.T @ a`` to not be singular. |
| 1618 | + Input matrix must be non-singular. |
| 1619 | +
|
| 1620 | + See Also |
| 1621 | + -------- |
| 1622 | + numpy.linalg.svd |
| 1623 | +
|
| 1624 | + Availability |
| 1625 | + -------- |
| 1626 | + Multiple GPUs, Multiple CPUs |
| 1627 | + |
| 1628 | + |
| 1629 | + References |
| 1630 | + ---------- |
| 1631 | + .. [1] https://stanford.edu/~rezab/classes/cme323/S22/notes/L17/cme323_lec17.pdf |
| 1632 | + """ |
| 1633 | + if a.ndim != 2 or a.size <= 1: |
| 1634 | + raise ValueError(f"Invalid input shape for tssvd: {a.shape}") |
| 1635 | + |
| 1636 | + m_info = get_machine() |
| 1637 | + num_PEs = m_info.count() |
| 1638 | + |
| 1639 | + # A.T*A: |
| 1640 | + # |
| 1641 | + # unbatched way (there's a bug resulting in 0-matrix, it seems): |
| 1642 | + #{ |
| 1643 | + m = a.shape[0] |
| 1644 | + n = a.shape[1] |
| 1645 | + |
| 1646 | + # TODO: Grammian API: |
| 1647 | + # |
| 1648 | + a2 = empty(shape=(n, n), dtype=a.dtype) |
| 1649 | + ah = a.transpose().conj() |
| 1650 | + a2._thunk.ts_matmul(ah._thunk, a._thunk) |
| 1651 | + #} |
| 1652 | + # |
| 1653 | + # batched way (slower, but passes): |
| 1654 | + # |
| 1655 | + # a2 = matmul(a.transpose().conj(), a) |
| 1656 | + |
| 1657 | + # eigen-vals, eigen-vecs of A.T*A: |
| 1658 | + # |
| 1659 | + ew, ev = eigh(a2) |
| 1660 | + |
| 1661 | + if any(abs(ew) <= np.finfo(a.dtype).eps): |
| 1662 | + raise LinAlgError("Singular matrix. Method cannot be applied.") |
| 1663 | + |
| 1664 | + # svals = map sqrt ew |
| 1665 | + # |
| 1666 | + svals = _sqrt(ew) |
| 1667 | + |
| 1668 | + # bring to standard form; |
| 1669 | + # i.e., decreasing singular values |
| 1670 | + # |
| 1671 | + # generate index permutation, pi |
| 1672 | + # via sort-by-key decreasingly: |
| 1673 | + # |
| 1674 | + d_indices = zeros(shape=(n, ), dtype=np.int64) |
| 1675 | + with m_info[0]: # ! |
| 1676 | + d_indices = argsort(svals) |
| 1677 | + # |
| 1678 | + # reverse: |
| 1679 | + # |
| 1680 | + # d_indices = d_indices[::-1] # Error: not implemented |
| 1681 | + d_indices = flip(d_indices) |
| 1682 | + |
| 1683 | + # V.T: |
| 1684 | + # |
| 1685 | + vt = ev.transpose().conj() |
| 1686 | + |
| 1687 | + reciprocal_svals = 1.0/svals |
| 1688 | + Sinv = diag(reciprocal_svals) |
| 1689 | + |
| 1690 | + # U = A*V*inv(S): |
| 1691 | + # |
| 1692 | + # B = matmul(ev, Sinv) |
| 1693 | + # u = matmul(a, B) |
| 1694 | + |
| 1695 | + B = empty(shape=(n, n), dtype=a.dtype) |
| 1696 | + B._thunk.ts_matmul(ev._thunk, Sinv._thunk) |
| 1697 | + |
| 1698 | + u = empty(shape=(m, n), dtype=a.dtype) |
| 1699 | + u._thunk.ts_matmul(a._thunk, B._thunk) |
| 1700 | + |
| 1701 | + # re-arrange svals decreasingly: |
| 1702 | + # |
| 1703 | + svals = svals[d_indices] |
| 1704 | + |
| 1705 | + # permute columns of U with pi: |
| 1706 | + # |
| 1707 | + # u = u[:, d_indices] |
| 1708 | + u = matmul(u, eye(u.shape[1])[d_indices].T) |
| 1709 | + |
| 1710 | + # permute rows of V.T with pi: |
| 1711 | + # |
| 1712 | + vt = vt[d_indices] |
| 1713 | + |
| 1714 | + return u, svals, vt |
0 commit comments