@@ -54,40 +54,53 @@ JOBTYPE_TO_CONST = {
5454 ' YtW' : SHARP_YtW
5555}
5656
57-
58- def sht (jobtype , geom_info ginfo , alm_info ainfo , double[::1] input ,
57+ def sht (jobtype , geom_info ginfo , alm_info ainfo , double[:, :, ::1] input ,
5958 int spin = 0 , comm = None , add = False ):
6059 cdef void * comm_ptr
6160 cdef int flags = SHARP_DP | (SHARP_ADD if add else 0 )
62- cdef double * palm
63- cdef double * pmap
6461 cdef int r
6562 cdef sharp_jobtype jobtype_i
66- cdef double [::1 ] output_buf
63+ cdef double [:, :, ::1 ] output_buf
64+ cdef int ntrans = input .shape[0 ] * input .shape[1 ]
65+ cdef int i, j
66+
67+ if spin == 0 and input .shape[1 ] != 1 :
68+ raise ValueError (' For spin == 0, we need input.shape[1] == 1' )
69+ elif spin != 0 and input .shape[1 ] != 2 :
70+ raise ValueError (' For spin != 0, we need input.shape[1] == 2' )
71+
72+
73+ cdef size_t[::1 ] ptrbuf = np.empty(2 * ntrans, dtype = np.uintp)
74+ cdef double ** alm_ptrs = < double ** > & ptrbuf[0 ]
75+ cdef double ** map_ptrs = < double ** > & ptrbuf[ntrans]
6776
6877 try :
6978 jobtype_i = JOBTYPE_TO_CONST[jobtype]
7079 except KeyError :
7180 raise ValueError (' jobtype must be one of: %s ' % ' , ' .join(sorted (JOBTYPE_TO_CONST.keys())))
7281
7382 if jobtype_i == SHARP_Y or jobtype_i == SHARP_WY:
74- output = np.empty(ginfo.local_size(), dtype = np.float64)
83+ output = np.empty(( input .shape[ 0 ], input .shape[ 1 ], ginfo.local_size() ), dtype = np.float64)
7584 output_buf = output
76- pmap = & output_buf[0 ]
77- palm = & input [0 ]
85+ for i in range (input .shape[0 ]):
86+ for j in range (input .shape[1 ]):
87+ alm_ptrs[i * input .shape[1 ] + j] = & input [i, j, 0 ]
88+ map_ptrs[i * input .shape[1 ] + j] = & output_buf[i, j, 0 ]
7889 else :
79- output = np.empty(ainfo.local_size(), dtype = np.float64)
90+ output = np.empty(( input .shape[ 0 ], input .shape[ 1 ], ainfo.local_size() ), dtype = np.float64)
8091 output_buf = output
81- pmap = & input [0 ]
82- palm = & output_buf[0 ]
92+ for i in range (input .shape[0 ]):
93+ for j in range (input .shape[1 ]):
94+ alm_ptrs[i * input .shape[1 ] + j] = & output_buf[i, j, 0 ]
95+ map_ptrs[i * input .shape[1 ] + j] = & input [i, j, 0 ]
8396
8497 if comm is None :
8598 with nogil:
8699 sharp_execute (
87100 jobtype_i,
88101 geom_info = ginfo.ginfo, alm_info = ainfo.ainfo,
89- spin = spin, alm = & palm , map = & pmap ,
90- ntrans = 1 , flags = flags, time = NULL , opcnt = NULL )
102+ spin = spin, alm = alm_ptrs , map = map_ptrs ,
103+ ntrans = ntrans , flags = flags, time = NULL , opcnt = NULL )
91104 else :
92105 from mpi4py import MPI
93106 if not isinstance (comm, MPI.Comm):
@@ -97,8 +110,8 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[::1] input,
97110 r = sharp_execute_mpi_maybe (
98111 comm_ptr, jobtype_i,
99112 geom_info = ginfo.ginfo, alm_info = ainfo.ainfo,
100- spin = spin, alm = & palm , map = & pmap ,
101- ntrans = 1 , flags = flags, time = NULL , opcnt = NULL )
113+ spin = spin, alm = alm_ptrs , map = map_ptrs ,
114+ ntrans = ntrans , flags = flags, time = NULL , opcnt = NULL )
102115 if r == SHARP_ERROR_NO_MPI:
103116 raise Exception (' MPI requested, but not available' )
104117
0 commit comments