Skip to content

Commit 6a10930

Browse files
committed
Minor cleanup
1 parent 1e8b0b0 commit 6a10930

3 files changed

Lines changed: 27 additions & 15 deletions

File tree

pyrecest/distributions/circle/wrapped_normal_distribution.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,21 @@ def pdf(self, xs):
7878

7979
for i in range(n_inputs):
8080
old_result = 0.0
81-
result[i] = exp(x[i] * x[i] * tmp)
81+
xi = x[i]
82+
if hasattr(xi, "item"):
83+
xi = xi.item()
84+
result[i] = exp(xi * xi * tmp)
8285

8386
for k in range(1, max_iterations + 1):
84-
xp = x[i] + 2 * pi * k
85-
xm = x[i] - 2 * pi * k
87+
xp = xi + 2 * pi * k
88+
xm = xi - 2 * pi * k
8689
tp = xp * xp * tmp
8790
tm = xm * xm * tmp
8891
old_result = result[i]
89-
result[i] += (exp(tp) + exp(tm)).squeeze()
92+
addendum = exp(tp) + exp(tm)
93+
if hasattr(addendum, "item"):
94+
addendum = addendum.item()
95+
result[i] += addendum
9096

9197
if result[i] == old_result:
9298
break

pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_distribution.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ def integrand(*phis):
256256
# Applying the multiplicative factors for each additional dimension
257257
for i in range(2, dim + 1):
258258
result *= sin(phis[i - 1]) ** (i - 1)
259-
return result
259+
if hasattr(result, "item"):
260+
return result.item()
261+
return float(result)
260262

261263
int_result, _ = nquad(integrand, integration_boundaries)
262264

pyrecest/distributions/hypersphere_subset/spherical_harmonics_distribution_complex.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,19 +160,23 @@ def from_function_via_integral_sph(fun, degree, transformation="identity"):
160160

161161
coeff_mat = full((degree + 1, 2 * degree + 1), float("NaN"), dtype=complex128)
162162

163+
def _to_scalar(val):
164+
if hasattr(val, "item"):
165+
return val.item()
166+
return float(val)
167+
168+
def _fun_scalar(phi, theta):
169+
return _to_scalar(fun_with_trans(array(phi), array(theta)))
170+
163171
def real_part(phi, theta, n, m):
164-
return real(
165-
fun_with_trans(array(phi), array(theta))
166-
* conj(array(sph_harm_y(n, m, theta, phi)))
167-
* sin(theta)
168-
)
172+
val = _fun_scalar(phi, theta)
173+
val = val * conj(sph_harm_y(n, m, theta, phi)) * sin(theta)
174+
return float(real(val))
169175

170176
def imag_part(phi, theta, n, m):
171-
return imag(
172-
fun_with_trans(array(phi), array(theta))
173-
* conj(array(sph_harm_y(n, m, theta, phi)))
174-
* sin(theta)
175-
)
177+
val = _fun_scalar(phi, theta)
178+
val = val * conj(sph_harm_y(n, m, theta, phi)) * sin(theta)
179+
return float(imag(val))
176180

177181
for n in range(degree + 1): # Use n instead of l to comply with PEP 8
178182
for m in range(-n, n + 1):

0 commit comments

Comments
 (0)