Skip to content

Commit 2dd13e8

Browse files
committed
Refactor kWaveGrid initialization in at_circular_piston_3D example
- Updated the grid creation process to utilize the new `from_domain` method for better clarity and efficiency. - Adjusted sensor mask and data reshaping to align with the new grid structure. - Ensured consistent usage of grid properties throughout the example for improved readability and maintainability.
1 parent d882889 commit 2dd13e8

2 files changed

Lines changed: 16 additions & 22 deletions

File tree

examples/at_circular_piston_3D/at_circular_piston_3D.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,9 @@
4848
# GRID
4949
# --------------------
5050

51-
# calculate the grid spacing based on the PPW and F0
52-
dx: float = c0 / (ppw * source_f0) # [m]
53-
54-
# compute the size of the grid
55-
# is round_even needed?
56-
Nx: int = round_even(axial_size / dx)
57-
Ny: int = round_even(lateral_size / dx)
58-
Nz: int = Ny
59-
60-
grid_size_points = Vector([Nx, Ny, Nz])
61-
grid_spacing_meters = Vector([dx, dx, dx])
62-
63-
# create the k-space grid
64-
kgrid = kWaveGrid(grid_size_points, grid_spacing_meters)
51+
kgrid = kWaveGrid.from_domain(
52+
dimensions=np.array([axial_size, lateral_size, lateral_size]), frequency=source_f0, sound_speed_min=c0, points_per_wavelength=ppw
53+
)
6554

6655
# compute points per temporal period
6756
ppp: int = round(ppw / cfl)
@@ -113,8 +102,8 @@
113102
sensor = kSensor()
114103

115104
# set sensor mask to record central plane, not including the source point
116-
sensor.mask = np.zeros((Nx, Ny, Nz), dtype=bool)
117-
sensor.mask[1:, :, Nz // 2] = True
105+
sensor.mask = np.zeros(kgrid.N, dtype=bool)
106+
sensor.mask[1:, :, kgrid.Nz // 2] = True
118107

119108
# record the pressure
120109
sensor.record = ["p"]
@@ -143,10 +132,10 @@
143132
amp, _, _ = extract_amp_phase(sensor_data["p"].T, 1.0 / kgrid.dt, source_f0, dim=1, fft_padding=1, window="Rectangular")
144133

145134
# reshape data
146-
amp = np.reshape(amp, (Nx - 1, Ny), order="F")
135+
amp = np.reshape(amp, (kgrid.Nx - 1, kgrid.Ny), order="F")
147136

148137
# extract pressure on axis
149-
amp_on_axis = amp[:, Ny // 2]
138+
amp_on_axis = amp[:, kgrid.Ny // 2]
150139

151140
# define axis vectors for plotting
152141
x_vec = kgrid.x_vec[1:, :] - kgrid.x_vec[0]
@@ -161,7 +150,7 @@
161150

162151
# define radius and axis
163152
a: float = source_diam / 2.0
164-
x_max: float = (Nx - 1) * dx
153+
x_max: float = (kgrid.Nx - 1) * kgrid.dx
165154
delta_x: float = x_max / 10000.0
166155
x_ref: float = np.arange(0.0, x_max + delta_x, delta_x, dtype=float)
167156

@@ -194,7 +183,7 @@
194183
# plot the source mask (pml is outside the grid in this example)
195184
fig2, ax2 = plt.subplots(1, 1)
196185
ax2.pcolormesh(
197-
1e3 * np.squeeze(kgrid.y_vec), 1e3 * np.squeeze(kgrid.x_vec), np.flip(source.p_mask[:, :, Nz // 2], axis=0), shading="nearest"
186+
1e3 * np.squeeze(kgrid.y_vec), 1e3 * np.squeeze(kgrid.x_vec), np.flip(source.p_mask[:, :, kgrid.Nz // 2], axis=0), shading="nearest"
198187
)
199188
ax2.set(xlabel="y [mm]", ylabel="x [mm]", title="Source Mask")
200189

kwave/kgrid.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def __init__(self, N, spacing):
3737
N, spacing = np.atleast_1d(N), np.atleast_1d(spacing) # if inputs are lists
3838
assert N.ndim == 1 and spacing.ndim == 1 # ensure no multidimensional lists
3939
assert (1 <= N.size <= 3) and (1 <= spacing.size <= 3) # ensure valid dimensionality
40+
if spacing.size == 1:
41+
spacing = spacing * np.ones(N.size)
4042
assert N.size == spacing.size, "Size list N and spacing list do not have the same size."
4143

4244
self.N = N.astype(int) #: grid size in each dimension [grid points]
@@ -725,11 +727,14 @@ def from_geometry(cls, domain_size, min_element_width, points_per_wavelength=10,
725727
# Ensure at least points_per_wavelength points across the smallest element
726728
grid_spacing = min_element_width / points_per_wavelength
727729

730+
# Create a list of grid spacings with the same length as domain_size
731+
grid_spacing_list = [grid_spacing] * domain_size.size
732+
728733
# Calculate grid size
729734
N = np.ceil(domain_size / grid_spacing).astype(int)
730735

731736
# Create grid instance
732-
grid = cls(N=N, spacing=grid_spacing)
737+
grid = cls(N=N, spacing=grid_spacing_list)
733738

734739
# Note: Time parameters are left as "auto"
735740
# The user can set them later using makeTime method
@@ -764,7 +769,7 @@ def from_domain(cls, dimensions, frequency, sound_speed_min, sound_speed_max=Non
764769
# Use sound_speed_min for sound_speed_max if not provided
765770
if sound_speed_max is None:
766771
sound_speed_max = sound_speed_min
767-
if sound_speed_max > 0:
772+
if not sound_speed_max > 0:
768773
raise ValueError("Sound speed must be positive")
769774

770775
# Calculate wavelength

0 commit comments

Comments
 (0)