Skip to content

Commit 7ca197b

Browse files
authored
Merge pull request #262 from FlorianPfaff/feature/hypertoroidal_grid
Added HypertoroidalGridDistribution
2 parents 54dbb46 + 53fb3fd commit 7ca197b

8 files changed

Lines changed: 368 additions & 6 deletions

File tree

pyrecest/_backend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def get_backend_name():
234234
"sqrtm",
235235
"svd",
236236
"matrix_rank",
237+
"block_diag", # For PyRecEst
237238
],
238239
"random": [
239240
"choice",

pyrecest/_backend/jax/linalg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
expm,
1919
sqrtm,
2020
polar,
21+
block_diag, # For PyRecEst
2122
)
2223

2324
unsupported_functions = [

pyrecest/_backend/numpy/linalg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
norm,
1515
svd,
1616
)
17-
from scipy.linalg import expm
17+
from scipy.linalg import (
18+
expm,
19+
block_diag, # For PyRecEst
20+
)
1821

1922
from .._shared_numpy.linalg import (
2023
fractional_matrix_power,

pyrecest/_backend/pytorch/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@
6363
isinf,
6464
deg2rad,
6565
argsort,
66-
max,
67-
min,
6866
roll,
6967
dstack,
7068
vmap,
@@ -357,11 +355,12 @@ def shape(val):
357355
return val.shape
358356

359357

360-
def amax(a, axis=None):
358+
def max(a, axis=None):
361359
if axis is None:
362360
return _torch.max(array(a))
363361
return _torch.max(array(a), dim=axis).values
364362

363+
amax=max
365364

366365
def maximum(a, b):
367366
return _torch.max(array(a), array(b))
@@ -811,10 +810,11 @@ def sort(a, axis=-1):
811810
return sorted_a
812811

813812

814-
def amin(a, axis=-1):
813+
def min(a, axis=-1):
815814
(values, _) = _torch.min(a, dim=axis)
816815
return values
817816

817+
amin = min
818818

819819
def take(a, indices, axis=0):
820820
if not _torch.is_tensor(indices):

pyrecest/_backend/pytorch/linalg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
solve,
1616
)
1717
from torch.linalg import matrix_exp as expm
18+
from torch import block_diag # For PyRecEst
1819

1920
from .._backend_config import np_atol as atol
2021
from ..numpy import linalg as _gsnplinalg

pyrecest/distributions/abstract_grid_distribution.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pylint: disable=redefined-builtin,no-name-in-module,no-member
88
from pyrecest.backend import abs, any, mean
9+
from math import prod
910

1011
from .abstract_distribution_type import AbstractDistributionType
1112

@@ -24,7 +25,10 @@ def __init__(
2425
assert (
2526
not grid_type == "custom" or grid is not None
2627
) # if grid_type is custom, grid needs to be given
27-
assert grid is None or grid.shape == () or grid.shape[0] == grid_values.shape[0]
28+
assert (
29+
# Use builtin prod because .shape is a tuple of ints
30+
grid is None or grid.shape == () or grid.shape[0] == prod(grid_values.shape)
31+
)
2832
assert (
2933
grid is None or grid.shape == () or grid.ndim == 1 or grid.shape[1] == dim
3034
)

0 commit comments

Comments
 (0)