Skip to content

Commit a5f140d

Browse files
Address code review: use .copy() for sort output, use ones() in tests
Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/31c921c4-c991-44f5-89c7-ee929e5d4d07 Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com>
1 parent 5d985ef commit a5f140d

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def _perturb_eigenvalues(eigenvalues):
257257
258258
Mirrors MATLAB's makeSureEigenvaluesAreNotTooClose.
259259
"""
260-
lam = sort(eigenvalues)[::-1]
260+
lam = sort(eigenvalues)[::-1].copy()
261261
diffs = diff(lam) # non-positive for sorted-descending
262262
diffs = minimum(diffs, -0.01) # enforce gap >= 0.01
263263
lam[1:] = lam[0] + cumsum(diffs)
@@ -314,7 +314,7 @@ def grad_log_c(lam):
314314
log_c0 = ComplexBinghamDistribution.log_norm(B_diag)
315315
grad = empty(d)
316316
for i in range(d):
317-
lam_p = array(lam)
317+
lam_p = lam.copy()
318318
lam_p[i] += eps
319319
log_cp = ComplexBinghamDistribution.log_norm(
320320
diag(array(lam_p, dtype=complex128))

pyrecest/tests/distributions/test_complex_bingham_distribution.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
log,
1313
mean,
1414
pi,
15+
ones,
1516
random,
1617
real,
1718
sort,
@@ -115,7 +116,7 @@ def test_sample_unit_norm(self):
115116
random.seed(42)
116117
S = self.cB2.sample(100)
117118
norms = linalg.norm(S, axis=0)
118-
npt.assert_allclose(norms, [1.0] * 100, atol=1e-12)
119+
npt.assert_allclose(norms, ones(100), atol=1e-12)
119120

120121
@unittest.skipIf(
121122
pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member
@@ -126,7 +127,7 @@ def test_sample_3d_unit_norm(self):
126127
random.seed(7)
127128
S = self.cB3.sample(50)
128129
norms = linalg.norm(S, axis=0)
129-
npt.assert_allclose(norms, [1.0] * 50, atol=1e-12)
130+
npt.assert_allclose(norms, ones(50), atol=1e-12)
130131

131132
def test_log_norm_2d_analytic(self):
132133
a = 3.0

0 commit comments

Comments
 (0)