Skip to content

Commit 57558db

Browse files
committed
Based backend on latest geomstats + fixed errors arising
1 parent 616a605 commit 57558db

22 files changed

Lines changed: 391 additions & 182 deletions

pyrecest/_backend/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# _Backend Folder
1+
# Backend Interface
22

33
This folder contains code from the Geomstats project, adjusted for pyRecEst by Florian Pfaff. The original version of Geomstats is authored by Nina Miolane et al., and is a Python package geared towards Riemannian Geometry in Machine Learning.
44

pyrecest/_backend/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def get_backend_name():
4747
"argmin",
4848
"array",
4949
"array_from_sparse",
50+
"asarray",
5051
"as_dtype",
5152
"assignment",
5253
"assignment_by_sum",
@@ -87,6 +88,7 @@ def get_backend_name():
8788
"get_default_cdtype",
8889
"get_slice",
8990
"greater",
91+
"has_autodiff",
9092
"hsplit",
9193
"hstack",
9294
"imag",
@@ -205,6 +207,7 @@ def get_backend_name():
205207
"jacobian_vec",
206208
"jacobian_and_hessian",
207209
"value_and_grad",
210+
"value_and_jacobian",
208211
"value_jacobian_and_hessian",
209212
],
210213
"linalg": [
@@ -218,9 +221,11 @@ def get_backend_name():
218221
"inv",
219222
"is_single_matrix_pd",
220223
"logm",
224+
"matrix_power",
221225
"norm",
222226
"qr",
223227
"quadratic_assignment",
228+
"polar",
224229
"solve",
225230
"solve_sylvester",
226231
"sqrtm",

pyrecest/_backend/_backend_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
pytorch_atol = 1e-6
22
pytorch_rtol = 1e-5
33

4-
np_atol = 1e-12
5-
np_rtol = 1e-6
4+
np_atol = 1e-8
5+
np_rtol = 1e-5
66

77
jax_atol = 1e-6
88
jax_rtol = 1e-5

pyrecest/_backend/_dtype_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def _wrapped(x, *args, **kwargs):
321321
cmp_dtype = x.dtype
322322
else:
323323
float_name = dtype_as_str(x.dtype)
324-
cmp_dtype = as_dtype(f"complex{int((float_name[-2:]))*2}")
324+
cmp_dtype = as_dtype(f"complex{int((float_name[-2:])) * 2}")
325325

326326
if out.dtype != cmp_dtype:
327327
return cast(out, cmp_dtype)
@@ -383,7 +383,7 @@ def _np_box_unary_scalar(target=None):
383383
def _decorator(func):
384384
@functools.wraps(func)
385385
def _wrapped(x, *args, **kwargs):
386-
if type(x) is float:
386+
if isinstance(x, float):
387387
return func(x, *args, dtype=_config.DEFAULT_DTYPE, **kwargs)
388388

389389
return func(x, *args, **kwargs)
@@ -407,7 +407,7 @@ def _np_box_binary_scalar(target=None):
407407
def _decorator(func):
408408
@functools.wraps(func)
409409
def _wrapped(x1, x2, *args, **kwargs):
410-
if type(x1) is float:
410+
if isinstance(x1, float):
411411
return func(x1, x2, *args, dtype=_config.DEFAULT_DTYPE, **kwargs)
412412

413413
return func(x1, x2, *args, **kwargs)

pyrecest/_backend/_shared_numpy/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
def angle(z, deg=False):
4141
out = _np.angle(z, deg=deg)
42-
if type(z) is float:
42+
if isinstance(z, float):
4343
return cast(out, get_default_dtype())
4444

4545
return out
@@ -63,7 +63,9 @@ def real(x):
6363

6464
def arange(start_or_stop, /, stop=None, step=1, dtype=None, **kwargs):
6565
if dtype is None and (
66-
type(stop) is float or type(step) is float or type(start_or_stop) is float
66+
isinstance(stop, float)
67+
or isinstance(step, float)
68+
or isinstance(start_or_stop, float)
6769
):
6870
dtype = get_default_dtype()
6971

pyrecest/_backend/_shared_numpy/_common.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,15 @@ def is_array(x):
7070
return type(x) is _np.ndarray
7171

7272

73-
def to_ndarray(x, to_ndim, axis=0):
74-
x = _np.array(x)
75-
if x.ndim == to_ndim - 1:
73+
def to_ndarray(x, to_ndim, axis=0, dtype=None):
74+
x = _np.asarray(x, dtype=dtype)
75+
76+
if x.ndim > to_ndim:
77+
raise ValueError("The ndim cannot be adapted properly.")
78+
79+
while x.ndim < to_ndim:
7680
x = _np.expand_dims(x, axis=axis)
7781

78-
if x.ndim != 0 and x.ndim < to_ndim:
79-
raise ValueError("The ndim was not adapted properly.")
8082
return x
8183

8284

pyrecest/_backend/_shared_numpy/linalg.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,18 @@
88
atol = _common.atol
99

1010

11+
def _transpose(array):
12+
axes = list(range(0, array.ndim))
13+
axes[-2], axes[-1] = axes[-1], axes[-2]
14+
return _np.transpose(array, axes=axes)
15+
16+
1117
def _is_symmetric(x, tol=atol):
12-
new_x = _to_ndarray(x, to_ndim=3)
13-
return (_np.abs(new_x - _np.transpose(new_x, axes=(0, 2, 1))) < tol).all()
18+
return (_np.abs(x - _transpose(x)) < tol).all()
1419

1520

1621
def _is_hermitian(x, tol=atol):
17-
new_x = _to_ndarray(x, to_ndim=3)
18-
return (_np.abs(new_x - _np.conj(_np.transpose(new_x, axes=(0, 2, 1)))) < tol).all()
22+
return (_np.abs(x - _np.conj(_transpose(x))) < tol).all()
1923

2024

2125
_diag_vec = _np.vectorize(_np.diag, signature="(n)->(n,n)")
@@ -26,38 +30,30 @@ def _is_hermitian(x, tol=atol):
2630

2731

2832
def logm(x):
29-
ndim = x.ndim
30-
new_x = _to_ndarray(x, to_ndim=3)
31-
32-
if _is_symmetric(new_x) and new_x.dtype not in [_np.complex64, _np.complex128]:
33-
eigvals, eigvecs = _np.linalg.eigh(new_x)
33+
if _is_symmetric(x) and x.dtype not in [_np.complex64, _np.complex128]:
34+
eigvals, eigvecs = _np.linalg.eigh(x)
3435
if (eigvals > 0).all():
3536
eigvals = _np.log(eigvals)
3637
eigvals = _diag_vec(eigvals)
37-
transp_eigvecs = _np.transpose(eigvecs, axes=(0, 2, 1))
38+
transp_eigvecs = _transpose(eigvecs)
3839
result = _np.matmul(eigvecs, eigvals)
3940
result = _np.matmul(result, transp_eigvecs)
4041
else:
41-
result = _logm_vec(new_x)
42+
result = _logm_vec(x)
4243
else:
43-
result = _logm_vec(new_x)
44+
result = _logm_vec(x)
4445

45-
if ndim == 2:
46-
return result[0]
4746
return result
4847

4948

5049
def solve_sylvester(a, b, q, tol=atol):
5150
if a.shape == b.shape:
52-
axes = (0, 2, 1) if a.ndim == 3 else (1, 0)
53-
if _np.all(_np.isclose(a, b)) and _np.all(
54-
_np.abs(a - _np.transpose(a, axes)) < tol
55-
):
51+
if _np.all(_np.isclose(a, b)) and _np.all(_np.abs(a - _transpose(a)) < tol):
5652
eigvals, eigvecs = _np.linalg.eigh(a)
5753
if _np.all(eigvals >= tol):
58-
tilde_q = _np.transpose(eigvecs, axes) @ q @ eigvecs
54+
tilde_q = _transpose(eigvecs) @ q @ eigvecs
5955
tilde_x = tilde_q / (eigvals[..., :, None] + eigvals[..., None, :])
60-
return eigvecs @ tilde_x @ _np.transpose(eigvecs, axes)
56+
return eigvecs @ tilde_x @ _transpose(eigvecs)
6157

6258
return _np.vectorize(
6359
_scipy.linalg.solve_sylvester, signature="(m,m),(n,n),(m,n)->(m,n)"
@@ -102,4 +98,41 @@ def fractional_matrix_power(A, t):
10298
if A.ndim == 2:
10399
return _scipy.linalg.fractional_matrix_power(A, t)
104100

105-
return _np.stack([_scipy.linalg.fractional_matrix_power(A_, t) for A_ in A])
101+
return _np.stack([_scipy.linalg.fractional_matrix_power(A_, t) for A_ in A])
102+
103+
104+
def polar(*args, **kwargs):
105+
"""Polar decomposition of a matrix."""
106+
return _np.vectorize(
107+
_scipy.linalg.polar, signature="(n,n)->(n,n),(n,n)", excluded=["side"]
108+
)(*args, **kwargs)
109+
110+
111+
def solve(a, b):
112+
"""
113+
Solve a linear matrix equation, or system of linear scalar equations.
114+
115+
Computes the "exact" solution, `x`, of the well-determined, i.e., full
116+
rank, linear matrix equation `ax = b`.
117+
118+
Parameters
119+
----------
120+
a : array-like, shape=[..., M, M]
121+
Coefficient matrix.
122+
b : array-like, shape=[..., M]
123+
Ordinate or "dependent variable" values".
124+
125+
Returns
126+
-------
127+
x : array-like, shape=[..., M]
128+
Solution to the system a x = b.
129+
"""
130+
batch_shape = a.shape[:-2]
131+
if batch_shape:
132+
b = _np.expand_dims(b, axis=-1)
133+
134+
res = _np.linalg.solve(a, b)
135+
if batch_shape:
136+
return res[..., 0]
137+
138+
return res

pyrecest/_backend/autograd/__init__.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
any,
1010
argmax,
1111
argmin,
12+
asarray,
1213
broadcast_arrays,
1314
broadcast_to,
1415
clip,
@@ -61,7 +62,6 @@
6162
take,
6263
tile,
6364
transpose,
64-
trapz,
6565
tril,
6666
tril_indices,
6767
triu,
@@ -72,7 +72,12 @@
7272
where,
7373
zeros_like,
7474
)
75-
from autograd.numpy import trapz as trapezoid
75+
76+
try:
77+
from autograd.numpy import trapezoid
78+
except ImportError:
79+
from autograd.numpy import trapz as trapezoid
80+
7681
from autograd.scipy.special import erf, gamma, polygamma # NOQA
7782

7883
from .._shared_numpy import (
@@ -123,9 +128,11 @@
123128
vec_to_diag,
124129
vectorize,
125130
)
126-
from . import autodiff # NOQA
127-
from . import linalg # NOQA
128-
from . import random # NOQA
131+
from . import (
132+
autodiff, # NOQA
133+
linalg, # NOQA
134+
random, # NOQA
135+
)
129136
from ._common import (
130137
_box_binary_scalar,
131138
_box_unary_scalar,
@@ -153,6 +160,16 @@
153160
empty = _dyn_update_dtype(target=_np.empty)
154161

155162

163+
def has_autodiff():
164+
"""If allows for automatic differentiation.
165+
166+
Returns
167+
-------
168+
has_autodiff : bool
169+
"""
170+
return True
171+
172+
156173
def imag(x):
157174
out = _np.imag(x)
158175
if is_array(x):

pyrecest/_backend/autograd/autodiff.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def wrapped_grad_func(i, ans, *args, **kwargs):
115115
)
116116
else:
117117
raise NotImplementedError(
118-
"custom_gradient is not yet implemented " "for more than 3 gradients."
118+
"custom_gradient is not yet implemented for more than 3 gradients."
119119
)
120120

121121
return wrapped_function
@@ -200,7 +200,7 @@ def _value_and_jacobian_op(fun, x):
200200
return ans, _np.reshape(_np.stack(grads), jacobian_shape)
201201

202202

203-
def _value_and_jacobian(fun, point_ndim=1):
203+
def value_and_jacobian(fun, point_ndim=1):
204204
def _value_and_jacobian_vec(x):
205205
if x.ndim == point_ndim:
206206
return _value_and_jacobian_op(fun)(x)
@@ -335,7 +335,7 @@ def jacobian_and_hessian(func, func_out_ndim=None):
335335
Function that returns func's jacobian and
336336
func's hessian values at its inputs args.
337337
"""
338-
return _value_and_jacobian(jacobian_vec(func))
338+
return value_and_jacobian(jacobian_vec(func))
339339

340340

341341
def value_jacobian_and_hessian(func, func_out_ndim=None):
@@ -355,7 +355,7 @@ def value_jacobian_and_hessian(func, func_out_ndim=None):
355355

356356
def _cached_value_and_jacobian(fun, return_cached=False):
357357
def _jac(x):
358-
ans, jac = _value_and_jacobian(fun)(x)
358+
ans, jac = value_and_jacobian(fun)(x)
359359
if not return_cached:
360360
cache.append(ans)
361361
return jac

pyrecest/_backend/autograd/linalg.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,24 @@
1212
eigh,
1313
eigvalsh,
1414
inv,
15+
matrix_power,
1516
matrix_rank,
1617
norm,
17-
solve,
1818
svd,
1919
)
2020
from autograd.scipy.linalg import expm
2121
from scipy.optimize import quadratic_assignment as _quadratic_assignment
2222

23-
from .._shared_numpy.linalg import fractional_matrix_power, is_single_matrix_pd
23+
from .._shared_numpy.linalg import (
24+
fractional_matrix_power,
25+
is_single_matrix_pd,
26+
polar,
27+
qr,
28+
solve,
29+
solve_sylvester,
30+
sqrtm,
31+
)
2432
from .._shared_numpy.linalg import logm as _logm
25-
from .._shared_numpy.linalg import qr, solve_sylvester, sqrtm
2633

2734

2835
def _adjoint(_ans, x, fn):

0 commit comments

Comments
 (0)