Skip to content

Commit 4c96f3d

Browse files
aschaffermanopapad
andauthored
Tall-Skinny SVD (TSSVD) (nv-legate#702)
* Preliminary TSSVD impl. * Replaced diagflat w/ diag. * Fixes, handling permutations due to svals order, prelim tests. * More testing, removed fake complexification trickling down from eig(). * Made testing more robust and comprehensive. * Error reporting on a.shape. * Error reporting on singular matrix. * Test with non-singular matrices, only. * Documented the non-singular matrix constraint. * Added profiling TSSVD example. * Update cupynumeric/linalg/linalg.py Dox update on input constraints. Co-authored-by: Manolis Papadakis <manopapad@gmail.com> * Update cupynumeric/linalg/linalg.py Dox update on reference. Co-authored-by: Manolis Papadakis <manopapad@gmail.com> * Fixed some dox mishaps, following GitHub 'Commit Suggestion'. * Addressed review on moving argsort. * Addressed review on to-real truncation. * Addressed review on .conj(). * Addressed review on comment typo. * Addressed review on raise LinAlgError on singular matrix. * Addressed review on Multiple GPU/CPU dox. * Addressed review on using eigh() rather than eig(). * Reverted eigh() to eig(), due to failures in the former. * Branches off batched vs. unbatched matmul, to benefit tall-skinny. * Fix for surpassing memory bounds on the unbatched case. * Clean-up print() in linalg.py. * Added batched vs. unbatched matmul test. * Enlarged shape for TSSVD example. * More ts_matmul() usage. * Fix for eager path of ts_matmul(). * More efficient permutation and enforced single-GPU argsort(). * Revert TSSVD example to avoid OOM in CI. * Addressed review on removing 'full matrices' note in dox. * Addressed review on using cupynumeric primitives in examples. * Addressed review on adding comment and assertion in mapper.cc. * Addressed review on tailoring example for benchmarking. * Fixed output allocation. --------- Co-authored-by: Manolis Papadakis <manopapad@gmail.com>
1 parent 038ffe2 commit 4c96f3d

10 files changed

Lines changed: 674 additions & 22 deletions

File tree

cupynumeric/_thunk/deferred.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3916,3 +3916,52 @@ def stencil_hint(
39163916
legate_runtime.prefetch_bloated_instances(
39173917
self.base, low_offsets, high_offsets, False
39183918
)
3919+
3920+
@auto_convert("rhs1_thunk", "rhs2_thunk")
3921+
def ts_matmul(
3922+
self,
3923+
rhs1_thunk: Any,
3924+
rhs2_thunk: Any) -> Any:
3925+
3926+
lhs_thunk: NumPyThunk = self
3927+
3928+
# Clear output array
3929+
lhs_thunk.fill(np.array(0, dtype=lhs_thunk.dtype))
3930+
lhs = lhs_thunk.base # type: ignore
3931+
3932+
rhs1 = rhs1_thunk.base
3933+
rhs2 = rhs2_thunk.base
3934+
3935+
m = lhs.shape[0]
3936+
n = lhs.shape[1]
3937+
k = rhs1.shape[1]
3938+
unbatched = 1
3939+
3940+
assert m == rhs1.shape[0]
3941+
assert n == rhs2.shape[1]
3942+
assert k == rhs2.shape[0]
3943+
lhs = lhs.promote(1, k)
3944+
rhs1 = rhs1.promote(2, n)
3945+
rhs2 = rhs2.promote(0, m)
3946+
3947+
task = legate_runtime.create_auto_task(
3948+
self.library, CuPyNumericOpCode.MATMUL
3949+
)
3950+
p_lhs = task.add_reduction(lhs, ReductionOpKind.ADD)
3951+
p_rhs1 = task.add_input(rhs1)
3952+
p_rhs2 = task.add_input(rhs2)
3953+
#
3954+
# specify unbatched matrix multiplication:
3955+
#
3956+
task.add_scalar_arg(unbatched, ty.uint32)
3957+
3958+
task.add_constraint(align(p_lhs, p_rhs1))
3959+
task.add_constraint(align(p_lhs, p_rhs2))
3960+
#
3961+
# additional constraints:
3962+
#
3963+
# task.add_constraint(broadcast(p_rhs1, (0,)))
3964+
# task.add_constraint(broadcast(p_rhs2, (1,)))
3965+
task.add_constraint(broadcast(p_lhs))
3966+
#
3967+
task.execute()

cupynumeric/_thunk/eager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,3 +2128,9 @@ def stencil_hint(
21282128
) -> None:
21292129
if self.deferred is not None:
21302130
self.deferred.stencil_hint(low_offsets, high_offsets)
2131+
2132+
def ts_matmul(
2133+
self,
2134+
rhs1_thunk: Any,
2135+
rhs2_thunk: Any) -> Any:
2136+
np.matmul(rhs1_thunk.array, rhs2_thunk.array, out=self.array)

cupynumeric/_thunk/thunk.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,3 +1651,11 @@ def stencil_hint(
16511651
high_offsets: tuple[int, ...],
16521652
) -> None:
16531653
...
1654+
1655+
@abstractmethod
1656+
def ts_matmul(
1657+
self,
1658+
rhs1_thunk: Any,
1659+
rhs2_thunk: Any
1660+
) -> Any:
1661+
...

cupynumeric/linalg/linalg.py

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,21 @@
3232
normalize_axis_tuple,
3333
)
3434

35-
from legate.core import get_machine
35+
from cupynumeric.config import CuPyNumericOpCode
36+
from legate.core import (
37+
get_machine,
38+
get_legate_runtime,
39+
ReductionOpKind,
40+
align,
41+
broadcast,
42+
)
3643

3744
from .._array.util import add_boilerplate, convert_to_cupynumeric_ndarray
3845
from .._module import dot, empty_like, eye, matmul, ndarray
39-
from .._module.creation_shape import zeros, zeros_like
46+
from .._module.array_rearrange import flip
47+
from .._module.creation_shape import empty, zeros, zeros_like
48+
from .._module.creation_matrices import diag
49+
from .._module.ssc_sorting import argsort
4050
from .._ufunc.math import add, sqrt as _sqrt
4151
from ._exception import LinAlgError
4252

@@ -1571,3 +1581,134 @@ def expm(a: ndarray, method: str = "pade") -> ndarray:
15711581
mdeg, s = expm_func(a[idx], output[idx])
15721582

15731583
return output
1584+
1585+
1586+
@add_boilerplate("a")
1587+
def tssvd(a: ndarray) -> tuple[ndarray, ...]:
1588+
"""
1589+
Tall-skinny (TS) Singular Value Decomposition.
1590+
1591+
Parameters
1592+
----------
1593+
a : (M, N) array_like
1594+
Array like, dimension 2.
1595+
1596+
Returns
1597+
-------
1598+
u : (M, N) array_like
1599+
Unitary array(s).
1600+
s : (N) array_like
1601+
The singular values, sorted in descending order
1602+
vh : (N, N) array_like
1603+
Unitary array(s).
1604+
1605+
Raises
1606+
------
1607+
LinAlgError
1608+
If TS-SVD computation does not converge.
1609+
1610+
Notes
1611+
-----
1612+
This routine is only efficient if ``M >> N``. In particular, it assumes that
1613+
an ``(N, N)`` matrix can fit within a single processor memory.
1614+
1615+
Implements the algorithm described in [1]_.
1616+
1617+
Requires ``a.T @ a`` to not be singular.
1618+
Input matrix must be non-singular.
1619+
1620+
See Also
1621+
--------
1622+
numpy.linalg.svd
1623+
1624+
Availability
1625+
--------
1626+
Multiple GPUs, Multiple CPUs
1627+
1628+
1629+
References
1630+
----------
1631+
.. [1] https://stanford.edu/~rezab/classes/cme323/S22/notes/L17/cme323_lec17.pdf
1632+
"""
1633+
if a.ndim != 2 or a.size <= 1:
1634+
raise ValueError(f"Invalid input shape for tssvd: {a.shape}")
1635+
1636+
m_info = get_machine()
1637+
num_PEs = m_info.count()
1638+
1639+
# A.T*A:
1640+
#
1641+
# unbatched way (there's a bug resulting in 0-matrix, it seems):
1642+
#{
1643+
m = a.shape[0]
1644+
n = a.shape[1]
1645+
1646+
# TODO: Grammian API:
1647+
#
1648+
a2 = empty(shape=(n, n), dtype=a.dtype)
1649+
ah = a.transpose().conj()
1650+
a2._thunk.ts_matmul(ah._thunk, a._thunk)
1651+
#}
1652+
#
1653+
# batched way (slower, but passes):
1654+
#
1655+
# a2 = matmul(a.transpose().conj(), a)
1656+
1657+
# eigen-vals, eigen-vecs of A.T*A:
1658+
#
1659+
ew, ev = eigh(a2)
1660+
1661+
if any(abs(ew) <= np.finfo(a.dtype).eps):
1662+
raise LinAlgError("Singular matrix. Method cannot be applied.")
1663+
1664+
# svals = map sqrt ew
1665+
#
1666+
svals = _sqrt(ew)
1667+
1668+
# bring to standard form;
1669+
# i.e., decreasing singular values
1670+
#
1671+
# generate index permutation, pi
1672+
# via sort-by-key decreasingly:
1673+
#
1674+
d_indices = zeros(shape=(n, ), dtype=np.int64)
1675+
with m_info[0]: # !
1676+
d_indices = argsort(svals)
1677+
#
1678+
# reverse:
1679+
#
1680+
# d_indices = d_indices[::-1] # Error: not implemented
1681+
d_indices = flip(d_indices)
1682+
1683+
# V.T:
1684+
#
1685+
vt = ev.transpose().conj()
1686+
1687+
reciprocal_svals = 1.0/svals
1688+
Sinv = diag(reciprocal_svals)
1689+
1690+
# U = A*V*inv(S):
1691+
#
1692+
# B = matmul(ev, Sinv)
1693+
# u = matmul(a, B)
1694+
1695+
B = empty(shape=(n, n), dtype=a.dtype)
1696+
B._thunk.ts_matmul(ev._thunk, Sinv._thunk)
1697+
1698+
u = empty(shape=(m, n), dtype=a.dtype)
1699+
u._thunk.ts_matmul(a._thunk, B._thunk)
1700+
1701+
# re-arrange svals decreasingly:
1702+
#
1703+
svals = svals[d_indices]
1704+
1705+
# permute columns of U with pi:
1706+
#
1707+
# u = u[:, d_indices]
1708+
u = matmul(u, eye(u.shape[1])[d_indices].T)
1709+
1710+
# permute rows of V.T with pi:
1711+
#
1712+
vt = vt[d_indices]
1713+
1714+
return u, svals, vt

examples/tssvd.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2025 NVIDIA Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import argparse
19+
import re
20+
21+
from benchmark import *
22+
23+
import cupynumeric as num
24+
import numpy as np
25+
26+
27+
def check_result(a, u, s, vh):
28+
print("Checking result...")
29+
30+
# (u * s) @ vh
31+
a2 = num.matmul(u * s, vh)
32+
print("PASS!" if num.allclose(a, a2) else "FAIL!")
33+
34+
35+
# make random real full column rank mxn, m>n matrix:
36+
#
37+
def make_random_matrix(
38+
m: int, n: int, scale: float = 10.0,
39+
dtype_=np.dtype("float64") ) -> np.ndarray:
40+
num.random.seed(6174)
41+
42+
mat = scale * num.random.rand(m, n)
43+
44+
mat = mat.astype(dtype_)
45+
46+
# strictly diagonally dominant:
47+
#
48+
for i in range(n):
49+
mat[i, i] = 1.0 + num.sum(num.abs(mat[i,:]))
50+
51+
return mat
52+
53+
54+
def run_tssvd(m, n, perform_check, timing):
55+
A = make_random_matrix(m, n)
56+
57+
timer.start()
58+
u, s, vh = num.linalg.tssvd(A)
59+
total = timer.stop()
60+
61+
if perform_check:
62+
check_result(A, u, s, vh)
63+
64+
if timing:
65+
print(f"TSSVD elapsed Time: {total:.3f} ms")
66+
67+
68+
if __name__ == "__main__":
69+
parser = argparse.ArgumentParser()
70+
parser.add_argument(
71+
"-t",
72+
"--time",
73+
dest="timing",
74+
action="store_true",
75+
help="perform timing",
76+
)
77+
parser.add_argument(
78+
"-m",
79+
"--rows",
80+
type=int,
81+
default=10,
82+
dest="m",
83+
help="number of rows in the matrix",
84+
)
85+
parser.add_argument(
86+
"-n",
87+
"--cols",
88+
type=int,
89+
default=10,
90+
dest="n",
91+
help="number of cols in the matrix",
92+
)
93+
parser.add_argument(
94+
"--check",
95+
dest="check",
96+
action="store_true",
97+
help="compare result to numpy",
98+
)
99+
args, num, timer = parse_args(parser)
100+
101+
run_benchmark(
102+
run_tssvd,
103+
args.benchmark,
104+
"TSSVD",
105+
(
106+
args.m,
107+
args.n,
108+
args.check,
109+
args.timing,
110+
),
111+
)

0 commit comments

Comments
 (0)