Skip to content

Commit d1918f4

Browse files
authored
Upgrade jaxdf to 0.3.0, fix FD accuracy bug (#224) (#246)
* Upgrade jaxdf to 0.3.0, add regression test for FD accuracy bug (#224) * Move imports to top of test file * Fix JAX array in Domain static metadata for CBS normalization * Relax BLI sensor tolerance for float32 FFT platform differences
1 parent bf8a0aa commit d1918f4

6 files changed

Lines changed: 65 additions & 33 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1313
### Fixed
1414

1515
- Tracer leak in OnGrid/FourierSeries laplacian_with_pml when using helmholtz_solver with checkpoint=False
16+
- FiniteDifferences with non-default accuracy no longer causes pytree mismatch in time-domain simulation (#224)
1617

1718
### Changed
1819

1920
- Migrated from Poetry to uv for dependency management and builds
2021
- Minimum Python version bumped to 3.11
2122
- Upgraded plumkdocs to >=1.0.0 and mkdocstrings to >=1.0.0
23+
- Upgraded jaxdf dependency to >=0.3.0
2224

2325
## [0.2.1] - 2024-09-17
2426

2527
### Changed
28+
2629
- Upgraded `jaxdf` dependency
2730

2831
## [0.2.0] - 2023-12-18

jwave/acoustics/time_harmonic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Union
1818

1919
import jax
20+
import numpy as np
2021
from jax import numpy as jnp
2122
from jax.lax import while_loop
2223
from jax.scipy.sparse.linalg import bicgstab, gmres
@@ -164,12 +165,12 @@ def _cbs_norm_units(medium, omega, k0, src):
164165
# Store conversion variables
165166
domain = medium.domain
166167
_conversion = {
167-
"dx": jnp.mean(jnp.asarray(domain.dx)),
168+
"dx": float(np.mean(domain.dx)),
168169
"omega": omega,
169170
}
170171

171172
# Set discretization to 1
172-
dx = tuple(map(lambda x: x / _conversion["dx"], domain.dx))
173+
dx = tuple(float(x / _conversion["dx"]) for x in domain.dx)
173174
domain = Domain(domain.N, dx)
174175

175176
# set omega to 1
@@ -197,7 +198,7 @@ def _cbs_norm_units(medium, omega, k0, src):
197198

198199
def _cbs_unnorm_units(field, conversion):
199200
domain = field.domain
200-
dx = tuple(map(lambda x: x * conversion["dx"], domain.dx))
201+
dx = tuple(float(x * conversion["dx"]) for x in domain.dx)
201202
domain = Domain(domain.N, dx)
202203

203204
return FourierSeries(field.params, domain)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ license = "LGPL-3.0-only"
1313
keywords = ["jax", "acoustics", "simulation", "ultrasound", "differentiable-programming"]
1414
requires-python = ">=3.11"
1515
dependencies = [
16-
"jaxdf>=0.2.8",
16+
"jaxdf>=0.3.0",
1717
"matplotlib>=3.0.0",
1818
]
1919
classifiers = [

tests/acoustics/test_simulate_wave_propagation.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
from jax import numpy as jnp
55

6-
from jwave.acoustics import simulate_wave_propagation
7-
from jwave.geometry import Domain, FourierSeries, Medium, TimeAxis
6+
from jwave import FiniteDifferences
7+
from jwave.acoustics import TimeWavePropagationSettings, simulate_wave_propagation
8+
from jwave.geometry import Domain, FourierSeries, Medium, TimeAxis, circ_mask
89
from jwave.logger import logger, set_logging_level
910

1011

@@ -40,5 +41,23 @@ def test_correct_call():
4041
assert "Starting simulation using FourierSeries code" in log_contents
4142

4243

44+
def test_fd_nondefault_accuracy():
45+
"""Regression test for jwave#224: FD fields with accuracy != 8
46+
must not cause pytree mismatch in lax.scan."""
47+
domain = Domain((64, 64), (1e-3, 1e-3))
48+
p0_arr = 5.0 * circ_mask(domain.N, 3, (32, 32))
49+
p0 = FiniteDifferences(
50+
jnp.expand_dims(p0_arr, -1), domain, accuracy=4)
51+
sound_speed = FiniteDifferences(
52+
jnp.expand_dims(jnp.ones(domain.N) * 1500.0, -1), domain, accuracy=4)
53+
medium = Medium(domain, sound_speed=sound_speed, pml_size=0)
54+
time_axis = TimeAxis.from_medium(medium, cfl=0.1)
55+
time_axis.t_end = 2e-6
56+
settings = TimeWavePropagationSettings(smooth_initial=False)
57+
58+
p = simulate_wave_propagation(medium, time_axis, p0=p0, settings=settings)
59+
assert p is not None
60+
61+
4362
if __name__ == "__main__":
4463
test_correct_call()

tests/test_off_grid_sensors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,11 @@ def test_sensor(nx, ny, nz):
6767
s3d = BLISensors((x + 0.25, y + 0.3, z + 0.1), (nx, ny, nz))
6868
domain3d = Domain((nx, ny, nz), (1, 1, 1))
6969
# Check ones in ones out.
70+
# rtol=1e-4: BLI uses float32 3D FFT interpolation, which has limited
71+
# precision that varies across platforms (different BLAS/FFT backends).
7072
p3d = FourierSeries(np.ones((nx, ny, nz)), domain3d)
7173
y = s3d(p3d, None, None)
72-
assert (np.all(np.isclose(y, 1)))
74+
assert (np.all(np.isclose(y, 1, rtol=1e-4)))
7375

7476
# Check zeros in zeros out
7577
p3d = FourierSeries(np.zeros((nx, ny, nz)), domain3d)

uv.lock

Lines changed: 33 additions & 26 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)