Skip to content

Commit 4f90b05

Browse files
authored
Adding implementation for np.linalg.inv (#1673)
1 parent 781b0c0 commit 4f90b05

2 files changed

Lines changed: 176 additions & 0 deletions

File tree

cupynumeric/linalg/linalg.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .._array.util import add_boilerplate, convert_to_cupynumeric_ndarray
2828
from .._module import dot, empty_like, eye, matmul, ndarray
2929
from .._module.array_rearrange import flip
30+
from .._module.array_dimension import broadcast_to
3031
from .._module.creation_matrices import diag
3132
from .._module.creation_shape import zeros, zeros_like
3233
from .._module.ssc_searching import where
@@ -1631,6 +1632,61 @@ def make_uv(A: ndarray, b: Any, m: int) -> tuple[ndarray, ndarray]:
16311632
return (U, V)
16321633

16331634

1635+
@add_boilerplate("a")
1636+
def inv(a: ndarray) -> ndarray:
1637+
"""
1638+
Compute theinverse of a matrix.
1639+
1640+
Parameters
1641+
----------
1642+
a : (..., M, M) array_like
1643+
Square matrix or batch of square matrices to be inverted.
1644+
1645+
Returns
1646+
-------
1647+
inv_a : ndarray
1648+
Inverse of the input matrix `a`.
1649+
1650+
Raises
1651+
------
1652+
LinAlgError
1653+
If `a` is non-invertible or not square.
1654+
1655+
Notes
1656+
-----
1657+
Only supports inputs of at least two dimensions.
1658+
Batched inversion supported if input has >2 dimensions.
1659+
1660+
See Also
1661+
--------
1662+
numpy.linalg.inv
1663+
1664+
Availability
1665+
-----------
1666+
Single GPU, Single CPU
1667+
"""
1668+
if len(a.shape) < 2:
1669+
raise LinAlgError(
1670+
f"{len(a.shape)}-dimensional array given. "
1671+
"Array must be at least two-dimensional."
1672+
)
1673+
if a.shape[-2] != a.shape[-1]:
1674+
raise LinAlgError(
1675+
f"Array of shape {a.shape} given. "
1676+
"Last 2 dimensions of the array must be square."
1677+
)
1678+
1679+
n = a.shape[-1]
1680+
eye_shape = a.shape[:-2] + (n, n)
1681+
eye_arr = eye(n, dtype=a.dtype)
1682+
1683+
# If batched, we need to broadcast I to batch it as well
1684+
if a.ndim > 2:
1685+
eye_arr = broadcast_to(eye_arr, eye_shape)
1686+
1687+
return _thunk_solve(a, eye_arr)
1688+
1689+
16341690
@add_boilerplate("a")
16351691
def pinv(a: ndarray, rtol: float = 1e-5) -> ndarray:
16361692
"""

tests/integration/test_inv.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2026 NVIDIA Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
import numpy as np
17+
import pytest
18+
from utils.comparisons import allclose
19+
20+
from cupynumeric.linalg import LinAlgError
21+
22+
import cupynumeric as num
23+
24+
RTOL = {
25+
np.float32: 1e-1,
26+
np.complex64: 1e-1,
27+
np.float64: 1e-5,
28+
np.complex128: 1e-5,
29+
}
30+
31+
ATOL = {
32+
np.float32: 1e-3,
33+
np.complex64: 1e-3,
34+
np.float64: 1e-8,
35+
np.complex128: 1e-8,
36+
}
37+
38+
39+
def create_random_invertible_matrix(size, dtype, batch_shape=None):
40+
"""Constructs a random invertible matrix of shape (size, size)
41+
or (*batch_shape, size, size) if batch_shape is not None."""
42+
shape = (size, size) if batch_shape is None else batch_shape + (size, size)
43+
rand_matrix = np.random.randn(*shape).astype(dtype)
44+
diag_vals = np.sum(np.abs(rand_matrix), axis=-1)
45+
46+
if batch_shape is None:
47+
np.fill_diagonal(rand_matrix, diag_vals)
48+
else:
49+
total_batch_size = np.prod(batch_shape)
50+
matrices = [
51+
create_random_invertible_matrix(size, dtype)
52+
for _ in range(total_batch_size)
53+
]
54+
rand_matrix = np.stack(matrices).reshape(batch_shape + (size, size))
55+
return rand_matrix
56+
57+
58+
@pytest.mark.parametrize("size", (10, 50))
59+
@pytest.mark.parametrize(
60+
"dtype", (np.float32, np.float64, np.complex64, np.complex128)
61+
)
62+
def test_inv_basic(size, dtype):
63+
"""Test basic inv functionality"""
64+
a = create_random_invertible_matrix(size, dtype)
65+
a_inv_num = num.linalg.inv(a)
66+
a_inv_np = np.linalg.inv(a)
67+
68+
assert allclose(
69+
a_inv_num,
70+
a_inv_np,
71+
rtol=RTOL[dtype],
72+
atol=ATOL[dtype],
73+
check_dtype=False,
74+
)
75+
76+
77+
@pytest.mark.parametrize("batch_shape", ((2,), (5, 1), (4, 2)))
78+
@pytest.mark.parametrize("size", (5, 10))
79+
@pytest.mark.parametrize(
80+
"dtype", (np.float32, np.float64, np.complex64, np.complex128)
81+
)
82+
@pytest.mark.xfail(
83+
reason="Solver implementation for matrices on batch size > 2 returns incorrect results"
84+
)
85+
def test_inv_batch(batch_shape, size, dtype):
86+
"""Test inv functionality with batching"""
87+
a = create_random_invertible_matrix(size, dtype, batch_shape)
88+
a_inv_num = num.linalg.inv(a)
89+
a_inv_np = np.linalg.inv(a)
90+
91+
assert allclose(
92+
a_inv_num,
93+
a_inv_np,
94+
rtol=RTOL[dtype],
95+
atol=ATOL[dtype],
96+
check_dtype=False,
97+
)
98+
99+
100+
def test_require_two_dims():
101+
"""Test that inv raises an error when there are less than two dimensions"""
102+
a = num.random.randn(10).astype(np.float64)
103+
msg = "Array must be at least two-dimensional."
104+
with pytest.raises(LinAlgError, match=msg):
105+
num.linalg.inv(a)
106+
107+
108+
@pytest.mark.parametrize("shape", ((5, 10), (1, 5, 10, 20)))
109+
def test_non_square_matrix_error(shape):
110+
"""Test that inv raises an error for non-square matrices"""
111+
a = num.random.randn(*shape).astype(np.float64)
112+
msg = "Last 2 dimensions of the array must be square."
113+
with pytest.raises(LinAlgError, match=msg):
114+
num.linalg.inv(a)
115+
116+
117+
if __name__ == "__main__":
118+
import sys
119+
120+
sys.exit(pytest.main(sys.argv))

0 commit comments

Comments
 (0)