Skip to content

Commit 04f3df2

Browse files
Address code review: remove test coupling, fix log(0) warning, clean up mask evaluation
Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/3f0bd299-8203-4bff-91e8-b8e62be085a8 Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com>
1 parent 30426d2 commit 04f3df2

2 files changed

Lines changed: 7 additions & 25 deletions

File tree

pyrecest/distributions/hypersphere_subset/complex_watson_distribution.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ def log_norm(D, kappa):
7272

7373
# Asymptotic formula for high kappa
7474
# log C ~ log(2) + D*log(pi) + (1-D)*log(kappa) + kappa
75+
# log_c_high is evaluated for all kappa before masking; clip to avoid log(0) warning
7576
log_c_high = (
76-
math.log(2) + D * math.log(math.pi) + (1 - D) * np.log(kappa + 1e-300) + kappa
77+
math.log(2) + D * math.log(math.pi)
78+
+ (1 - D) * np.log(np.maximum(kappa, 1e-300)) + kappa
7779
)
7880

7981
# Intermediate formula (Mardia1999 Eq. 3):
@@ -322,13 +324,14 @@ def _sample_diagonal_complex_bingham_magnitudes(Lambda, D):
322324
Lambda_pos = Lambda[: D - 1] # first D-1 (positive) eigenvalues
323325

324326
# Precompute for the truncated exponential inverse CDF
325-
temp1 = np.where(Lambda_pos >= 0.03, -1.0 / np.where(Lambda_pos >= 0.03, Lambda_pos, 1.0), 0.0)
326-
temp2 = np.where(Lambda_pos >= 0.03, 1.0 - np.exp(-Lambda_pos), 0.0)
327+
large = Lambda_pos >= 0.03
328+
safe_lambda = np.where(large, Lambda_pos, 1.0)
329+
temp1 = np.where(large, -1.0 / safe_lambda, 0.0)
330+
temp2 = np.where(large, 1.0 - np.exp(-Lambda_pos), 0.0)
327331

328332
s = np.zeros(D)
329333
while True:
330334
U = np.random.rand(D - 1)
331-
large = Lambda_pos >= 0.03
332335
if np.any(large):
333336
s[: D - 1][large] = temp1[large] * np.log(1.0 - U[large] * temp2[large])
334337
if np.any(~large):

pyrecest/tests/distributions/test_complex_watson_distribution.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
from pyrecest.distributions.hypersphere_subset.complex_watson_distribution import (
77
ComplexWatsonDistribution,
88
)
9-
from pyrecest.distributions.hypersphere_subset.bayesian_complex_watson_mixture_model import (
10-
_simplex_integral,
11-
)
129

1310

1411
def _random_unit_vector(D, rng=None):
@@ -149,24 +146,6 @@ def test_fit_recovers_parameters(self):
149146
self.assertAlmostEqual(dist_hat.kappa, kappa, delta=2.0)
150147

151148

152-
class TestSimplexIntegral(unittest.TestCase):
153-
def test_D1(self):
154-
self.assertAlmostEqual(_simplex_integral(np.array([3.0])), np.exp(3.0))
155-
156-
def test_D2_known(self):
157-
# int_0^1 exp(a*t + b*(1-t)) dt = (exp(a) - exp(b)) / (a - b)
158-
a, b = 2.0, 1.0
159-
expected = (np.exp(a) - np.exp(b)) / (a - b)
160-
result = _simplex_integral(np.array([a, b]))
161-
self.assertAlmostEqual(result, expected, places=8)
162-
163-
def test_D3_nonnegative(self):
164-
result = _simplex_integral(np.array([2.0, 1.0, 0.0]))
165-
self.assertGreater(result, 0.0)
166-
# Known value from direct integration: exp(2)/2 - exp(1) + 1/2
167-
expected = np.exp(2) / 2 - np.exp(1) + 0.5
168-
self.assertAlmostEqual(result, expected, places=5)
169-
170149

171150
if __name__ == "__main__":
172151
unittest.main()

0 commit comments

Comments
 (0)