Skip to content

Commit 1688896

Browse files
author
Dag Sverre Seljebotn
committed
Support for multiple transforms and spin-transforms
1 parent 48e2131 commit 1688896

2 files changed

Lines changed: 31 additions & 17 deletions

File tree

python/libsharp/libsharp.pyx

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,40 +54,53 @@ JOBTYPE_TO_CONST = {
5454
'YtW': SHARP_YtW
5555
}
5656

57-
58-
def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input,
57+
def sht(jobtype, geom_info ginfo, alm_info ainfo, double[:, :, ::1] input,
5958
int spin=0, comm=None, add=False):
6059
cdef void *comm_ptr
6160
cdef int flags = SHARP_DP | (SHARP_ADD if add else 0)
62-
cdef double *palm
63-
cdef double *pmap
6461
cdef int r
6562
cdef sharp_jobtype jobtype_i
66-
cdef double[::1] output_buf
63+
cdef double[:, :, ::1] output_buf
64+
cdef int ntrans = input.shape[0] * input.shape[1]
65+
cdef int i, j
66+
67+
if spin == 0 and input.shape[1] != 1:
68+
raise ValueError('For spin == 0, we need input.shape[1] == 1')
69+
elif spin != 0 and input.shape[1] != 2:
70+
raise ValueError('For spin != 0, we need input.shape[1] == 2')
71+
72+
73+
cdef size_t[::1] ptrbuf = np.empty(2 * ntrans, dtype=np.uintp)
74+
cdef double **alm_ptrs = <double**>&ptrbuf[0]
75+
cdef double **map_ptrs = <double**>&ptrbuf[ntrans]
6776

6877
try:
6978
jobtype_i = JOBTYPE_TO_CONST[jobtype]
7079
except KeyError:
7180
raise ValueError('jobtype must be one of: %s' % ', '.join(sorted(JOBTYPE_TO_CONST.keys())))
7281

7382
if jobtype_i == SHARP_Y or jobtype_i == SHARP_WY:
74-
output = np.empty(ginfo.local_size(), dtype=np.float64)
83+
output = np.empty((input.shape[0], input.shape[1], ginfo.local_size()), dtype=np.float64)
7584
output_buf = output
76-
pmap = &output_buf[0]
77-
palm = &input[0]
85+
for i in range(input.shape[0]):
86+
for j in range(input.shape[1]):
87+
alm_ptrs[i * input.shape[1] + j] = &input[i, j, 0]
88+
map_ptrs[i * input.shape[1] + j] = &output_buf[i, j, 0]
7889
else:
79-
output = np.empty(ainfo.local_size(), dtype=np.float64)
90+
output = np.empty((input.shape[0], input.shape[1], ainfo.local_size()), dtype=np.float64)
8091
output_buf = output
81-
pmap = &input[0]
82-
palm = &output_buf[0]
92+
for i in range(input.shape[0]):
93+
for j in range(input.shape[1]):
94+
alm_ptrs[i * input.shape[1] + j] = &output_buf[i, j, 0]
95+
map_ptrs[i * input.shape[1] + j] = &input[i, j, 0]
8396

8497
if comm is None:
8598
with nogil:
8699
sharp_execute (
87100
jobtype_i,
88101
geom_info=ginfo.ginfo, alm_info=ainfo.ainfo,
89-
spin=spin, alm=&palm, map=&pmap,
90-
ntrans=1, flags=flags, time=NULL, opcnt=NULL)
102+
spin=spin, alm=alm_ptrs, map=map_ptrs,
103+
ntrans=ntrans, flags=flags, time=NULL, opcnt=NULL)
91104
else:
92105
from mpi4py import MPI
93106
if not isinstance(comm, MPI.Comm):
@@ -97,8 +110,8 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input,
97110
r = sharp_execute_mpi_maybe (
98111
comm_ptr, jobtype_i,
99112
geom_info=ginfo.ginfo, alm_info=ainfo.ainfo,
100-
spin=spin, alm=&palm, map=&pmap,
101-
ntrans=1, flags=flags, time=NULL, opcnt=NULL)
113+
spin=spin, alm=alm_ptrs, map=map_ptrs,
114+
ntrans=ntrans, flags=flags, time=NULL, opcnt=NULL)
102115
if r == SHARP_ERROR_NO_MPI:
103116
raise Exception('MPI requested, but not available')
104117

python/libsharp/tests/test_sht.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ def test_basic():
2525
alm[0] = 1
2626

2727

28-
29-
map = libsharp.synthesis(grid, order, alm, comm=MPI.COMM_WORLD)
28+
map = libsharp.synthesis(grid, order, np.repeat(alm[None, None, :], 3, 0), comm=MPI.COMM_WORLD)
29+
assert np.all(map[2, :] == map[1, :]) and np.all(map[1, :] == map[0, :])
30+
map = map[0, 0, :]
3031
if rank == 0:
3132
healpy.mollzoom(map)
3233
from matplotlib.pyplot import show

0 commit comments

Comments
 (0)