Skip to content

Commit 946b83b

Browse files
committed
Refactor grid initialization in at_focused_bowl_3D example
- Updated the grid creation to use the new `from_domain` method for improved clarity and efficiency. - Adjusted sensor mask and data reshaping to utilize properties from the kWaveGrid instance. - Enhanced consistency in grid property usage throughout the example for better readability and maintainability.
1 parent 2dd13e8 commit 946b83b

2 files changed

Lines changed: 28 additions & 25 deletions

File tree

examples/at_focused_bowl_3D/at_focused_bowl_3D.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,9 @@
5757
# --------------------
5858

5959
# calculate the grid spacing based on the PPW and F0
60-
dx: float = c0 / (ppw * source_f0) # [m]
61-
62-
# compute the size of the grid
63-
Nx: int = round_even(axial_size / dx) + source_x_offset
64-
Ny: int = round_even(lateral_size / dx)
65-
Nz: int = Ny
66-
67-
grid_size_points = Vector([Nx, Ny, Nz])
68-
grid_spacing_meters = Vector([dx, dx, dx])
69-
70-
# create the k-space grid
71-
kgrid = kWaveGrid(grid_size_points, grid_spacing_meters)
60+
kgrid = kWaveGrid.from_domain(
61+
dimensions=[axial_size, lateral_size, lateral_size], frequency=source_f0, sound_speed_min=c0, points_per_wavelength=ppw
62+
)
7263

7364
# compute points per temporal period
7465
ppp: int = round(ppw / cfl)
@@ -124,8 +115,8 @@
124115
sensor = kSensor()
125116

126117
# set sensor mask to record central plane, not including the source point
127-
sensor.mask = np.zeros((Nx, Ny, Nz), dtype=bool)
128-
sensor.mask[(source_x_offset + 1) : -1, :, Nz // 2] = True
118+
sensor.mask = np.zeros(kgrid.N, dtype=bool)
119+
sensor.mask[(source_x_offset + 1) : -1, :, kgrid.Nz // 2] = True
129120

130121
# record the pressure
131122
sensor.record = ["p"]
@@ -155,10 +146,10 @@
155146
amp, _, _ = extract_amp_phase(sensor_data["p"].T, 1.0 / kgrid.dt, source_f0, dim=1, fft_padding=1, window="Rectangular")
156147

157148
# reshape data
158-
amp = np.reshape(amp, (Nx - (source_x_offset + 2), Ny), order="F")
149+
amp = np.reshape(amp, (kgrid.Nx - (source_x_offset + 2), kgrid.Ny), order="F")
159150

160151
# extract pressure on axis
161-
amp_on_axis = amp[:, Ny // 2]
152+
amp_on_axis = amp[:, kgrid.Ny // 2]
162153

163154
# define axis vectors for plotting
164155
x_vec = kgrid.x_vec[(source_x_offset + 1) : -1, :] - kgrid.x_vec[source_x_offset]
@@ -172,7 +163,7 @@
172163
knumber = 2.0 * np.pi * source_f0 / c0
173164

174165
# define axis
175-
x_max = Nx * dx
166+
x_max = kgrid.x_max
176167
delta_x = x_max / 10000.0
177168
x_ref = np.arange(0.0, x_max + delta_x, delta_x)
178169

@@ -210,14 +201,14 @@
210201
ax2a.pcolormesh(
211202
1e3 * np.squeeze(kgrid.y_vec),
212203
1e3 * np.squeeze(kgrid.x_vec),
213-
np.flip(source.p_mask[:, :, Nz // 2], axis=0),
204+
np.flip(source.p_mask[:, :, kgrid.Nz // 2], axis=0),
214205
shading="nearest",
215206
)
216207
ax2a.set(xlabel="y [mm]", ylabel="x [mm]", title="Source Mask")
217208
ax2b.pcolormesh(
218209
1e3 * np.squeeze(kgrid.y_vec),
219210
1e3 * np.squeeze(kgrid.x_vec),
220-
np.flip(grid_weights[:, :, Nz // 2], axis=0),
211+
np.flip(grid_weights[:, :, kgrid.Nz // 2], axis=0),
221212
shading="nearest",
222213
)
223214
ax2b.set(xlabel="y [mm]", ylabel="x [mm]", title="Off-Grid Source Weights")

kwave/kgrid.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -703,12 +703,12 @@ def k_dtt(self, dtt_type): # Not tested for correctness!
703703
return k, M
704704

705705
@classmethod
706-
def from_geometry(cls, domain_size, min_element_width, points_per_wavelength=10, cfl=None):
706+
def from_geometry(cls, dimensions, min_element_width, points_per_wavelength=10):
707707
"""
708708
Create a kWaveGrid based on domain dimensions and the smallest resolvable geometry element.
709709
710710
Args:
711-
domain_size: List or array of physical domain sizes [m]
711+
dimensions: List or array of physical domain sizes [m]
712712
min_element_width: Width of the smallest resolvable geometry element [m]
713713
points_per_wavelength: Number of points per wavelength (default=10)
714714
cfl: CFL number (default=cls.CFL_DEFAULT)
@@ -717,8 +717,8 @@ def from_geometry(cls, domain_size, min_element_width, points_per_wavelength=10,
717717
kWaveGrid instance with appropriate grid size and spacing
718718
"""
719719
# Validate input parameters
720-
domain_size = np.atleast_1d(domain_size)
721-
if not np.all(domain_size > 0):
720+
dimensions = np.atleast_1d(dimensions)
721+
if not np.all(dimensions > 0):
722722
raise ValueError("Domain dimensions must be positive")
723723
if not min_element_width > 0:
724724
raise ValueError("Minimum element width must be positive")
@@ -728,10 +728,10 @@ def from_geometry(cls, domain_size, min_element_width, points_per_wavelength=10,
728728
grid_spacing = min_element_width / points_per_wavelength
729729

730730
# Create a list of grid spacings with the same length as domain_size
731-
grid_spacing_list = [grid_spacing] * domain_size.size
731+
grid_spacing_list = [grid_spacing] * dimensions.size
732732

733733
# Calculate grid size
734-
N = np.ceil(domain_size / grid_spacing).astype(int)
734+
N = np.ceil(dimensions / grid_spacing).astype(int)
735735

736736
# Create grid instance
737737
grid = cls(N=N, spacing=grid_spacing_list)
@@ -785,3 +785,15 @@ def from_domain(cls, dimensions, frequency, sound_speed_min, sound_speed_max=Non
785785
grid = cls(N=N, spacing=grid_spacing)
786786

787787
return grid
788+
789+
@property
790+
def x_max(self):
791+
return self.x_vec[-1] - self.x_vec[0]
792+
793+
@property
794+
def y_max(self):
795+
return self.y_vec[-1] - self.y_vec[0]
796+
797+
@property
798+
def z_max(self):
799+
return self.z_vec[-1] - self.z_vec[0]

0 commit comments

Comments
 (0)