Skip to content

Commit dd20732

Browse files
committed
Using backend more for spherical harmonics
1 parent c527e6d commit dd20732

1 file changed

Lines changed: 19 additions & 15 deletions

File tree

pyrecest/distributions/hypersphere_subset/spherical_harmonics_distribution_complex.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import warnings
22

3-
import numpy as np
43
import scipy
54

65
# pylint: disable=redefined-builtin,no-name-in-module,no-member
@@ -17,6 +16,7 @@
1716
full,
1817
imag,
1918
isnan,
19+
zeros_like,
2020
linalg,
2121
pi,
2222
real,
@@ -25,6 +25,10 @@
2525
sin,
2626
sqrt,
2727
zeros,
28+
cos,
29+
sin,
30+
meshgrid,
31+
deg2rad,
2832
)
2933

3034
# pylint: disable=E0611
@@ -235,12 +239,12 @@ def _get_dh_grid_cartesian(degree):
235239
)
236240
grid = dummy.expand(grid="DH", extend=False)
237241
lats, lons = grid.lats(), grid.lons()
238-
lon_mesh, lat_mesh = np.meshgrid(lons, lats)
239-
theta = np.radians(90.0 - lat_mesh) # colatitude in radians
240-
phi = np.radians(lon_mesh) # azimuth in radians
241-
x_c = np.sin(theta) * np.cos(phi)
242-
y_c = np.sin(theta) * np.sin(phi)
243-
z_c = np.cos(theta)
242+
lon_mesh, lat_mesh = meshgrid(lons, lats)
243+
theta = deg2rad(90.0 - lat_mesh) # colatitude in radians
244+
phi = deg2rad(lon_mesh) # azimuth in radians
245+
x_c = sin(theta) * cos(phi)
246+
y_c = sin(theta) * sin(phi)
247+
z_c = cos(theta)
244248
return x_c.ravel(), y_c.ravel(), z_c.ravel(), theta.shape
245249

246250
def _eval_on_grid(self, target_degree=None):
@@ -310,11 +314,11 @@ def from_distribution_numerical_fast(dist, degree, transformation="identity"):
310314
x_c, y_c, z_c, grid_shape = (
311315
SphericalHarmonicsDistributionComplex._get_dh_grid_cartesian(degree)
312316
)
313-
xs = np.column_stack([x_c, y_c, z_c])
314-
fvals = np.asarray(dist.pdf(xs), dtype=float).reshape(grid_shape)
317+
xs = column_stack([x_c, y_c, z_c])
318+
fvals = array(dist.pdf(xs), dtype=float).reshape(grid_shape)
315319

316320
if transformation == "sqrt":
317-
fvals = np.sqrt(np.maximum(fvals, 0.0))
321+
fvals = sqrt(max(fvals, 0.0))
318322

319323
return SphericalHarmonicsDistributionComplex._fit_from_grid(
320324
fvals, degree, transformation
@@ -336,10 +340,10 @@ def convolve(self, other):
336340

337341
if self.transformation == "identity" and other.transformation == "identity":
338342
# Direct frequency-domain formula: h_{l,m} = sqrt(4π/(2l+1)) * f_{l,m} * g_{l,0}
339-
h_lm = np.zeros_like(self.coeff_mat)
343+
h_lm = zeros_like(self.coeff_mat)
340344
for l in range(degree + 1):
341345
factor = (
342-
np.sqrt(4.0 * np.pi / (2 * l + 1))
346+
sqrt(4.0 * pi / (2 * l + 1))
343347
* other.coeff_mat[l, l]
344348
)
345349
for m in range(-l, l + 1):
@@ -376,16 +380,16 @@ def _grid_to_coeff(grid_vals):
376380
q_lm = _grid_to_coeff(q_grid)
377381

378382
# Convolution formula on the identity coefficients
379-
r_lm = np.zeros_like(p_lm)
383+
r_lm = zeros_like(p_lm)
380384
for l in range(degree + 1):
381-
factor = np.sqrt(4.0 * np.pi / (2 * l + 1)) * q_lm[l, l]
385+
factor = sqrt(4.0 * pi / (2 * l + 1)) * q_lm[l, l]
382386
for m in range(-l, l + 1):
383387
r_lm[l, l + m] = factor * p_lm[l, l + m]
384388

385389
# Evaluate r on the standard DH grid, take sqrt, refit
386390
r_shd_id = SphericalHarmonicsDistributionComplex(r_lm, "identity")
387391
r_grid = r_shd_id._eval_on_grid()
388-
sqrt_r_grid = np.sqrt(np.maximum(r_grid, 0.0))
392+
sqrt_r_grid = sqrt(max(r_grid, 0.0))
389393

390394
return SphericalHarmonicsDistributionComplex._fit_from_grid(
391395
sqrt_r_grid, degree, "sqrt"

0 commit comments

Comments
 (0)