@@ -385,59 +385,72 @@ def make_fake_data() -> FakeData:
385385def test_rmsynth () -> None :
386386 """Test the NUFFT RM-synthesis routine agaist DFT."""
387387 fake_data = make_fake_data ()
388- tick = time ()
389- FDFcube , lam0Sq_m2 = do_rmsynth_planes (
390- dataQ = fake_data .stokes_q ,
391- dataU = fake_data .stokes_u ,
392- lambdaSqArr_m2 = fake_data .lsq ,
393- phiArr_radm2 = fake_data .phis ,
394- weightArr = fake_data .weights ,
395- lam0Sq_m2 = fake_data .lsq_0 ,
396- )
397- tock = time ()
398- logger .info (f"Time taken for NUFFT: { (tock - tick )* 1000 :0.2f} ms" )
399-
400- tick = time ()
401- FDFcube_old , lam0Sq_m2_old = do_rmsynth_planes_old (
402- dataQ = fake_data .stokes_q ,
403- dataU = fake_data .stokes_u ,
404- lambdaSqArr_m2 = fake_data .lsq ,
405- phiArr_radm2 = fake_data .phis ,
406- weightArr = fake_data .weights ,
407- lam0Sq_m2 = fake_data .lsq_0 ,
408- )
409- tock = time ()
410- logger .info (f"Time taken for DFT: { (tock - tick )* 1000 :0.2f} ms" )
388+ for eps in [1e-4 , 1e-5 , 1e-6 , 1e-8 ]:
389+ tick = time ()
390+ FDFcube , lam0Sq_m2 = do_rmsynth_planes (
391+ dataQ = fake_data .stokes_q ,
392+ dataU = fake_data .stokes_u ,
393+ lambdaSqArr_m2 = fake_data .lsq ,
394+ phiArr_radm2 = fake_data .phis ,
395+ weightArr = fake_data .weights ,
396+ lam0Sq_m2 = fake_data .lsq_0 ,
397+ eps = eps ,
398+ )
399+ tock = time ()
400+ logger .info (f"Time taken for NUFFT: { (tock - tick )* 1000 :0.2f} ms" )
401+
402+ tick = time ()
403+ FDFcube_old , lam0Sq_m2_old = do_rmsynth_planes_old (
404+ dataQ = fake_data .stokes_q ,
405+ dataU = fake_data .stokes_u ,
406+ lambdaSqArr_m2 = fake_data .lsq ,
407+ phiArr_radm2 = fake_data .phis ,
408+ weightArr = fake_data .weights ,
409+ lam0Sq_m2 = fake_data .lsq_0 ,
410+ )
411+ tock = time ()
412+ logger .info (f"Time taken for DFT: { (tock - tick )* 1000 :0.2f} ms" )
411413
412- assert np .allclose (FDFcube , FDFcube_old , rtol = 1e-6 , atol = 1e-6 )
413- assert np .allclose (lam0Sq_m2 , lam0Sq_m2_old , rtol = 1e-6 , atol = 1e-6 )
414+ # fiNUFFT can't go below 1e-8 in precision!
415+ if eps == 1e-8 :
416+ assert np .allclose (FDFcube , FDFcube_old , rtol = eps * 10 , atol = eps * 10 )
417+ assert np .allclose (lam0Sq_m2 , lam0Sq_m2_old , rtol = eps * 10 , atol = eps * 10 )
418+ else :
419+ assert np .allclose (FDFcube , FDFcube_old , rtol = eps , atol = eps )
420+ assert np .allclose (lam0Sq_m2 , lam0Sq_m2_old , rtol = eps , atol = eps )
414421
415422
416423def test_rmsf ():
417424 """Test the NUFFT RMSF routine agaist DFT."""
418425 fake_data = make_fake_data ()
419- tick = time ()
420- RMSFcube , phi2Arr , fwhmRMSFArr , statArr , _ = get_rmsf_planes (
421- lambdaSqArr_m2 = fake_data .lsq ,
422- phiArr_radm2 = fake_data .phis ,
423- weightArr = fake_data .weights ,
424- lam0Sq_m2 = fake_data .lsq_0 ,
425- )
426- tock = time ()
427- logger .info (f"Time taken for NUFFT: { (tock - tick )* 1000 :0.2f} ms" )
428-
429- tick = time ()
430- RMSFcube_old , phi2Arr_old , fwhmRMSFArr_old , statArr_old = get_rmsf_planes_old (
431- lambdaSqArr_m2 = fake_data .lsq ,
432- phiArr_radm2 = fake_data .phis ,
433- weightArr = fake_data .weights ,
434- lam0Sq_m2 = fake_data .lsq_0 ,
435- )
436- tock = time ()
437- logger .info (f"Time taken for DFT: { (tock - tick )* 1000 :0.2f} ms" )
438-
439- for new , old in zip (
440- [RMSFcube , phi2Arr , fwhmRMSFArr , statArr ],
441- [RMSFcube_old , phi2Arr_old , fwhmRMSFArr_old , statArr_old ],
442- ):
443- assert np .allclose (new , old , rtol = 1e-6 , atol = 1e-6 )
426+ for eps in [1e-4 , 1e-5 , 1e-6 , 1e-8 ]:
427+ tick = time ()
428+ RMSFcube , phi2Arr , fwhmRMSFArr , statArr , _ = get_rmsf_planes (
429+ lambdaSqArr_m2 = fake_data .lsq ,
430+ phiArr_radm2 = fake_data .phis ,
431+ weightArr = fake_data .weights ,
432+ lam0Sq_m2 = fake_data .lsq_0 ,
433+ eps = eps ,
434+ )
435+ tock = time ()
436+ logger .info (f"Time taken for NUFFT: { (tock - tick )* 1000 :0.2f} ms" )
437+
438+ tick = time ()
439+ RMSFcube_old , phi2Arr_old , fwhmRMSFArr_old , statArr_old = get_rmsf_planes_old (
440+ lambdaSqArr_m2 = fake_data .lsq ,
441+ phiArr_radm2 = fake_data .phis ,
442+ weightArr = fake_data .weights ,
443+ lam0Sq_m2 = fake_data .lsq_0 ,
444+ )
445+ tock = time ()
446+ logger .info (f"Time taken for DFT: { (tock - tick )* 1000 :0.2f} ms" )
447+
448+ # fiNUFFT can't go below 1e-8 in precision!
449+ for new , old in zip (
450+ [RMSFcube , phi2Arr , fwhmRMSFArr , statArr ],
451+ [RMSFcube_old , phi2Arr_old , fwhmRMSFArr_old , statArr_old ],
452+ ):
453+ if eps == 1e-8 :
454+ assert np .allclose (new , old , rtol = eps * 10 , atol = eps * 10 )
455+ else :
456+ assert np .allclose (new , old , rtol = eps , atol = eps )
0 commit comments