Skip to content

Commit bf8a0aa

Browse files
authored
Differentiation test suite and tracer leak fix (#243)
* Add differentiation test suite for all solvers * Fix tracer leak in OnGrid/FourierSeries laplacian_with_pml
1 parent 7e16754 commit bf8a0aa

3 files changed

Lines changed: 351 additions & 10 deletions

File tree

CHANGELOG.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66

77
## [Unreleased]
88

9+
### Added
10+
11+
- Differentiation test suite covering all solvers with jit+grad, vmap+grad, and finite-difference checks
12+
13+
### Fixed
14+
15+
- Tracer leak in OnGrid/FourierSeries laplacian_with_pml when using helmholtz_solver with checkpoint=False
16+
917
### Changed
1018

1119
- Migrated from Poetry to uv for dependency management and builds
@@ -15,7 +23,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1523
## [0.2.1] - 2024-09-17
1624

1725
### Changed
18-
1926
- Upgraded `jaxdf` dependency
2027

2128
## [0.2.0] - 2023-12-18

jwave/acoustics/operators.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,19 @@ def laplacian_with_pml(u: Continuous,
5656
return sum_over_dims(mod_diag_jacobian)
5757

5858

59-
@operator
59+
def ongrid_laplacian_with_pml_init(u: OnGrid, medium: Medium, omega, *args,
60+
**kwargs):
61+
p = {
62+
"gradient": gradient.default_params(u),
63+
"diag_jacobian": diag_jacobian.default_params(u),
64+
}
65+
rho0 = medium.density
66+
if issubclass(type(rho0), Field):
67+
p["gradient_rho0"] = gradient.default_params(rho0)
68+
return p
69+
70+
71+
@operator(init_params=ongrid_laplacian_with_pml_init)
6072
def laplacian_with_pml(u: OnGrid,
6173
medium: Medium,
6274
*,
@@ -77,18 +89,18 @@ def laplacian_with_pml(u: OnGrid,
7789
pml = u.replace_params(pml_grid)
7890

7991
# Making laplacian
80-
grad_u = gradient(u)
92+
grad_u = gradient(u, params=params["gradient"])
8193
mod_grad_u = grad_u * pml
82-
mod_diag_jacobian = diag_jacobian(mod_grad_u) * pml
94+
mod_diag_jacobian = diag_jacobian(
95+
mod_grad_u, params=params["diag_jacobian"]) * pml
8396
nabla_u = sum_over_dims(mod_diag_jacobian)
8497

8598
# Density term
8699
rho0 = medium.density
87100
if not (issubclass(type(rho0), Field)):
88-
# Assume it is a number
89101
rho_u = 0.0
90102
else:
91-
grad_rho0 = gradient(rho0)
103+
grad_rho0 = gradient(rho0, params=params["gradient_rho0"])
92104
rho_u = sum_over_dims(mod_grad_u * grad_rho0) / rho0
93105

94106
# Put everything together
@@ -161,10 +173,14 @@ def laplacian_with_pml(u: FiniteDifferences,
161173

162174
def fourier_laplacian_with_pml_init(u: FourierSeries, medium: Medium, omega,
163175
*args, **kwargs):
164-
return {
176+
p = {
165177
"pml_on_grid": on_grid_pml_init(u, medium, omega),
166178
"fft_u": gradient.default_params(u),
167179
}
180+
rho0 = medium.density
181+
if issubclass(type(rho0), Field):
182+
p["fft_rho0"] = gradient.default_params(rho0)
183+
return p
168184

169185

170186
@operator(init_params=fourier_laplacian_with_pml_init)
@@ -209,9 +225,6 @@ def laplacian_with_pml(u: FourierSeries,
209225
rho0, FourierSeries
210226
), "rho0 must be a FourierSeries or a number when used with FourierSeries fields"
211227

212-
if not ("fft_rho0" in params.keys()):
213-
params["fft_rho0"] = gradient.default_params(rho0)
214-
215228
grad_rho0 = gradient(rho0, stagger=[0.5], params=params["fft_rho0"])
216229
dx = list(map(lambda x: -x / 2, u.domain.dx))
217230
_ru = shift_operator(mod_grad_u * grad_rho0, dx=dx)

0 commit comments

Comments
 (0)