Skip to content

Commit c25249a

Browse files
authored
Add precision arg for NUFFTs (#122)
* Add eps as param * Add tests * Fix api calls
1 parent de42dd3 commit c25249a

4 files changed

Lines changed: 70 additions & 57 deletions

File tree

RMtools_1D/do_RMsynth_1D.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,6 @@ def run_rmsynth(
318318
phiArr_radm2=phiArr_radm2,
319319
weightArr=weightArr,
320320
nBits=nBits,
321-
verbose=verbose,
322321
log=log,
323322
lam0Sq_m2=0 if super_resolution else None,
324323
)

RMtools_3D/do_RMsynth_3D.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ def run_rmsynth(
203203
phiArr_radm2=phiArr_radm2,
204204
weightArr=weightArr,
205205
nBits=32,
206-
verbose=verbose,
207206
lam0Sq_m2=0 if super_resolution else None,
208207
)
209208
else:

RMutils/util_RM.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def do_rmsynth_planes(
9494
weightArr=None,
9595
lam0Sq_m2=None,
9696
nBits=32,
97-
verbose=False,
97+
eps=1e-6,
9898
log=print,
9999
):
100100
"""Perform RM-synthesis on Stokes Q and U cubes (1,2 or 3D). This version
@@ -109,15 +109,15 @@ def do_rmsynth_planes(
109109
lambdaSqArr_m2 ... vector of wavelength^2 values (assending freq order)
110110
phiArr_radm2 ... vector of trial Faraday depth values
111111
weightArr ... vector of weights, default [None] is Uniform (all 1s)
112+
lam0Sq_m2 ... force a reference lambda^2 value [None, calculated]
112113
nBits ... precision of data arrays [32]
113-
verbose ... print feedback during calculation [False]
114+
eps ... NUFFT tolerance [1e-6] (see https://finufft.readthedocs.io/en/latest/python.html#module-finufft)
114115
log ... function to be used to output messages [print]
115116
116117
"""
117118

118119
# Default data types
119120
dtFloat = "float" + str(nBits)
120-
dtComplex = "complex" + str(2 * nBits)
121121

122122
# Set the weight array
123123
if weightArr is None:
@@ -195,7 +195,7 @@ def do_rmsynth_planes(
195195
x=a,
196196
c=np.ascontiguousarray(pCube.T),
197197
s=(phiArr_radm2[::-1] * 2).astype(a.dtype),
198-
eps=1e-8,
198+
eps=eps,
199199
)
200200
* KArr[..., None]
201201
).T
@@ -228,6 +228,7 @@ def get_rmsf_planes(
228228
fitRMSF=False,
229229
fitRMSFreal=False,
230230
nBits=32,
231+
eps=1e-6,
231232
verbose=False,
232233
log=print,
233234
):
@@ -250,6 +251,7 @@ def get_rmsf_planes(
250251
fitRMSF ... fit the main lobe of the RMSF with a Gaussian [False]
251252
fitRMSFreal ... fit RMSF.real, rather than abs(RMSF) [False]
252253
nBits ... precision of data arrays [32]
254+
eps ... NUFFT tolerance [1e-6] (see https://finufft.readthedocs.io/en/latest/python.html#module-finufft)
253255
verbose ... print feedback during calculation [False]
254256
log ... function to be used to output messages [print]
255257
@@ -379,7 +381,7 @@ def get_rmsf_planes(
379381
x=a,
380382
c=np.ascontiguousarray(weightCube.T),
381383
s=(phiArr_radm2[::-1] * 2).astype(a.dtype),
382-
eps=1e-8,
384+
eps=eps,
383385
)
384386
* KArr[..., None]
385387
).T

tests/nufft_test.py

Lines changed: 63 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -385,59 +385,72 @@ def make_fake_data() -> FakeData:
385385
def test_rmsynth() -> None:
386386
"""Test the NUFFT RM-synthesis routine agaist DFT."""
387387
fake_data = make_fake_data()
388-
tick = time()
389-
FDFcube, lam0Sq_m2 = do_rmsynth_planes(
390-
dataQ=fake_data.stokes_q,
391-
dataU=fake_data.stokes_u,
392-
lambdaSqArr_m2=fake_data.lsq,
393-
phiArr_radm2=fake_data.phis,
394-
weightArr=fake_data.weights,
395-
lam0Sq_m2=fake_data.lsq_0,
396-
)
397-
tock = time()
398-
logger.info(f"Time taken for NUFFT: {(tock - tick)*1000:0.2f} ms")
399-
400-
tick = time()
401-
FDFcube_old, lam0Sq_m2_old = do_rmsynth_planes_old(
402-
dataQ=fake_data.stokes_q,
403-
dataU=fake_data.stokes_u,
404-
lambdaSqArr_m2=fake_data.lsq,
405-
phiArr_radm2=fake_data.phis,
406-
weightArr=fake_data.weights,
407-
lam0Sq_m2=fake_data.lsq_0,
408-
)
409-
tock = time()
410-
logger.info(f"Time taken for DFT: {(tock - tick)*1000:0.2f} ms")
388+
for eps in [1e-4, 1e-5, 1e-6, 1e-8]:
389+
tick = time()
390+
FDFcube, lam0Sq_m2 = do_rmsynth_planes(
391+
dataQ=fake_data.stokes_q,
392+
dataU=fake_data.stokes_u,
393+
lambdaSqArr_m2=fake_data.lsq,
394+
phiArr_radm2=fake_data.phis,
395+
weightArr=fake_data.weights,
396+
lam0Sq_m2=fake_data.lsq_0,
397+
eps=eps,
398+
)
399+
tock = time()
400+
logger.info(f"Time taken for NUFFT: {(tock - tick)*1000:0.2f} ms")
401+
402+
tick = time()
403+
FDFcube_old, lam0Sq_m2_old = do_rmsynth_planes_old(
404+
dataQ=fake_data.stokes_q,
405+
dataU=fake_data.stokes_u,
406+
lambdaSqArr_m2=fake_data.lsq,
407+
phiArr_radm2=fake_data.phis,
408+
weightArr=fake_data.weights,
409+
lam0Sq_m2=fake_data.lsq_0,
410+
)
411+
tock = time()
412+
logger.info(f"Time taken for DFT: {(tock - tick)*1000:0.2f} ms")
411413

412-
assert np.allclose(FDFcube, FDFcube_old, rtol=1e-6, atol=1e-6)
413-
assert np.allclose(lam0Sq_m2, lam0Sq_m2_old, rtol=1e-6, atol=1e-6)
414+
# fiNUFFT can't go below 1e-8 in precision!
415+
if eps == 1e-8:
416+
assert np.allclose(FDFcube, FDFcube_old, rtol=eps * 10, atol=eps * 10)
417+
assert np.allclose(lam0Sq_m2, lam0Sq_m2_old, rtol=eps * 10, atol=eps * 10)
418+
else:
419+
assert np.allclose(FDFcube, FDFcube_old, rtol=eps, atol=eps)
420+
assert np.allclose(lam0Sq_m2, lam0Sq_m2_old, rtol=eps, atol=eps)
414421

415422

416423
def test_rmsf():
417424
"""Test the NUFFT RMSF routine agaist DFT."""
418425
fake_data = make_fake_data()
419-
tick = time()
420-
RMSFcube, phi2Arr, fwhmRMSFArr, statArr, _ = get_rmsf_planes(
421-
lambdaSqArr_m2=fake_data.lsq,
422-
phiArr_radm2=fake_data.phis,
423-
weightArr=fake_data.weights,
424-
lam0Sq_m2=fake_data.lsq_0,
425-
)
426-
tock = time()
427-
logger.info(f"Time taken for NUFFT: {(tock - tick)*1000:0.2f} ms")
428-
429-
tick = time()
430-
RMSFcube_old, phi2Arr_old, fwhmRMSFArr_old, statArr_old = get_rmsf_planes_old(
431-
lambdaSqArr_m2=fake_data.lsq,
432-
phiArr_radm2=fake_data.phis,
433-
weightArr=fake_data.weights,
434-
lam0Sq_m2=fake_data.lsq_0,
435-
)
436-
tock = time()
437-
logger.info(f"Time taken for DFT: {(tock - tick)*1000:0.2f} ms")
438-
439-
for new, old in zip(
440-
[RMSFcube, phi2Arr, fwhmRMSFArr, statArr],
441-
[RMSFcube_old, phi2Arr_old, fwhmRMSFArr_old, statArr_old],
442-
):
443-
assert np.allclose(new, old, rtol=1e-6, atol=1e-6)
426+
for eps in [1e-4, 1e-5, 1e-6, 1e-8]:
427+
tick = time()
428+
RMSFcube, phi2Arr, fwhmRMSFArr, statArr, _ = get_rmsf_planes(
429+
lambdaSqArr_m2=fake_data.lsq,
430+
phiArr_radm2=fake_data.phis,
431+
weightArr=fake_data.weights,
432+
lam0Sq_m2=fake_data.lsq_0,
433+
eps=eps,
434+
)
435+
tock = time()
436+
logger.info(f"Time taken for NUFFT: {(tock - tick)*1000:0.2f} ms")
437+
438+
tick = time()
439+
RMSFcube_old, phi2Arr_old, fwhmRMSFArr_old, statArr_old = get_rmsf_planes_old(
440+
lambdaSqArr_m2=fake_data.lsq,
441+
phiArr_radm2=fake_data.phis,
442+
weightArr=fake_data.weights,
443+
lam0Sq_m2=fake_data.lsq_0,
444+
)
445+
tock = time()
446+
logger.info(f"Time taken for DFT: {(tock - tick)*1000:0.2f} ms")
447+
448+
# fiNUFFT can't go below 1e-8 in precision!
449+
for new, old in zip(
450+
[RMSFcube, phi2Arr, fwhmRMSFArr, statArr],
451+
[RMSFcube_old, phi2Arr_old, fwhmRMSFArr_old, statArr_old],
452+
):
453+
if eps == 1e-8:
454+
assert np.allclose(new, old, rtol=eps * 10, atol=eps * 10)
455+
else:
456+
assert np.allclose(new, old, rtol=eps, atol=eps)

0 commit comments

Comments
 (0)