Skip to content

Commit 785c0bd

Browse files
authored
Merge pull request #676 from waltsims/c-order-migration
Migrate solver internals from F-order to C-order
2 parents 8ab6822 + ada220a commit 785c0bd

12 files changed

Lines changed: 652 additions & 123 deletions

File tree

docs/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
This project is a Python implementation of v1.4.0 of the [MATLAB toolbox k-Wave](http://www.k-wave.org/) as well as an
99
interface to the pre-compiled v1.3 of k-Wave simulation binaries, which support NVIDIA sm 5.0 (Maxwell) to sm 9.0a (Hopper) GPUs.
1010

11-
**New in v0.6.0:** Unified `kspaceFirstOrder()` API with a pure NumPy/CuPy solver. See the [API guide](https://k-wave-python.readthedocs.io/en/latest/get_started/new_api.html).
11+
**New in v0.6.0:** Unified `kspaceFirstOrder()` API with a pure NumPy/CuPy solver. See the [API guide](https://k-wave-python.readthedocs.io/en/latest/get_started/new_api.html). The `kspaceFirstOrder()` API is experimental and may change before v1.0.0.
1212

1313
## Mission
1414

kwave/kspaceFirstOrder.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _expand_for_pml_outside(kgrid, medium, source, sensor, pml_size):
6868
return expanded_kgrid, expanded_medium, expanded_source, expanded_sensor
6969

7070

71-
_FULL_GRID_SUFFIXES = ("_final", "_max", "_min", "_rms")
71+
_FULL_GRID_SUFFIXES = ("_final", "_max", "_min", "_rms", "_max_all", "_min_all", "_rms_all")
7272

7373

7474
def _strip_pml(result, pml_size, ndim):
@@ -156,6 +156,10 @@ def kspaceFirstOrder(
156156
Returns:
157157
dict: Recorded sensor data keyed by field name (e.g.
158158
``"p"``, ``"p_final"``, ``"ux"``, ``"uy"``).
159+
160+
All time-series are ``(n_sensor, Nt)`` with sensor points in
161+
C-flattened order. Use :func:`reshape_to_grid` to recover spatial
162+
structure for full-grid masks.
159163
"""
160164
if device not in ("cpu", "gpu"):
161165
raise ValueError(f"device must be 'cpu' or 'gpu', got {device!r}")
@@ -181,7 +185,7 @@ def kspaceFirstOrder(
181185
from kwave.utils.filters import smooth
182186

183187
source = copy.copy(source)
184-
source.p0 = smooth(np.asarray(source.p0, dtype=float).reshape(tuple(int(n) for n in kgrid.N), order="F"), restore_max=True)
188+
source.p0 = smooth(np.asarray(source.p0, dtype=float).reshape(tuple(int(n) for n in kgrid.N)), restore_max=True)
185189

186190
# --- Backend dispatch ---
187191

@@ -225,7 +229,7 @@ def kspaceFirstOrder(
225229
from kwave.utils.conversion import cart2grid
226230

227231
sensor = copy.copy(sensor)
228-
sensor.mask, _, _ = cart2grid(kgrid, np.asarray(sensor.mask))
232+
sensor.mask, _, _ = cart2grid(kgrid, np.asarray(sensor.mask), order="C")
229233

230234
cpp_sim = CppSimulation(kgrid, medium, source, sensor, pml_size=pml_size, pml_alpha=pml_alpha, use_sg=use_sg)
231235
if save_only:
@@ -244,3 +248,27 @@ def kspaceFirstOrder(
244248
result = _strip_pml(result, pml_size, kgrid.dim)
245249

246250
return result
251+
252+
253+
def reshape_to_grid(data, grid_shape):
254+
"""Reshape flat sensor data to grid shape.
255+
256+
Convenience helper for full-grid sensor masks where ``n_sensor``
257+
equals the total number of grid points.
258+
259+
Args:
260+
data: sensor array — ``(n_sensor, Nt)`` time-series or
261+
``(n_sensor,)`` aggregate.
262+
grid_shape: tuple of grid dimensions, e.g. ``(Nx, Ny)``.
263+
264+
Returns:
265+
For time-series: ``(*grid_shape, Nt)``
266+
For aggregates: ``(*grid_shape)``
267+
"""
268+
data = np.asarray(data)
269+
if data.ndim == 2:
270+
n_sensor, Nt = data.shape
271+
return data.reshape(*grid_shape, Nt)
272+
elif data.ndim == 1:
273+
return data.reshape(grid_shape)
274+
return data

kwave/solvers/cpp_simulation.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,62 @@ def run(self, *, device="cpu", num_threads=None, device_num=None, quiet=False, d
5656
data_dir = os.path.dirname(input_file)
5757
try:
5858
self._execute(input_file, output_file, device=device, num_threads=num_threads, device_num=device_num, quiet=quiet, debug=debug)
59-
return self._parse_output(output_file)
59+
result = self._parse_output(output_file)
60+
result = self._fix_output_order(result)
61+
return result
6062
finally:
6163
if cleanup:
6264
try:
6365
shutil.rmtree(data_dir)
6466
except OSError as exc:
6567
warnings.warn(f"Could not clean up temp directory {data_dir!r}: {exc}", RuntimeWarning, stacklevel=2)
6668

69+
_FULL_GRID_SUFFIXES = ("_final", "_max", "_min", "_rms", "_max_all", "_min_all", "_rms_all")
70+
71+
def _fix_output_order(self, result):
72+
"""Convert C++ output from F-order to C-order.
73+
74+
The C++ binary writes arrays in Fortran order. HDF5/h5py reads them
75+
with reversed dimensions. We fix full-grid fields via transpose and
76+
reorder sensor time-series rows from F-indexed to C-indexed.
77+
"""
78+
ndim = self.ndim
79+
grid_shape = tuple(int(n) for n in self.kgrid.N)
80+
81+
# 1. Transpose full-grid fields from reversed F-order to C-order
82+
for key, val in result.items():
83+
if not isinstance(val, np.ndarray):
84+
continue
85+
is_grid = any(key.endswith(s) for s in self._FULL_GRID_SUFFIXES)
86+
if is_grid and val.ndim == ndim:
87+
result[key] = val.transpose(tuple(range(ndim - 1, -1, -1)))
88+
89+
# 2. Reorder sensor time-series from F-indexed to C-indexed rows
90+
if self.sensor is None or self.sensor.mask is None:
91+
mask = np.ones(grid_shape, dtype=bool)
92+
else:
93+
mask = np.asarray(self.sensor.mask, dtype=bool).reshape(grid_shape)
94+
95+
n_sensor = int(mask.sum())
96+
if n_sensor > 0 and ndim >= 2:
97+
f_nz = np.where(mask.ravel(order="F"))[0]
98+
c_nz = np.where(mask.ravel())[0]
99+
f_equiv = np.ravel_multi_index(np.unravel_index(c_nz, grid_shape), grid_shape, order="F")
100+
perm = np.searchsorted(f_nz, f_equiv)
101+
102+
for key, val in result.items():
103+
if not isinstance(val, np.ndarray):
104+
continue
105+
is_grid = any(key.endswith(s) for s in self._FULL_GRID_SUFFIXES)
106+
if is_grid:
107+
continue
108+
if val.ndim == 2 and val.shape[0] == n_sensor:
109+
result[key] = val[perm]
110+
elif val.ndim == 1 and val.shape[0] == n_sensor:
111+
result[key] = val[perm]
112+
113+
return result
114+
67115
# -- HDF5 serialization --
68116

69117
def _write_hdf5(self, filepath):

kwave/solvers/kspace_solver.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ def _to_cpu(x):
3434
def _expand_to_grid(val, grid_shape, xp, name="parameter"):
3535
if val is None:
3636
raise ValueError(f"Missing required parameter: {name}")
37-
arr = xp.array(val, dtype=float).flatten(order="F")
37+
arr = xp.array(val, dtype=float).ravel()
3838
grid_size = int(np.prod(grid_shape))
3939
if arr.size == 1:
4040
return xp.full(grid_shape, float(arr[0]), dtype=float)
4141
if arr.size == grid_size:
42-
return arr.reshape(grid_shape, order="F")
42+
return arr.reshape(grid_shape)
4343
raise ValueError(f"{name} size {arr.size} incompatible with grid size {grid_size}")
4444

4545

@@ -48,16 +48,16 @@ def _build_source_op(mask_raw, signal_raw, mode, scale, *, xp, grid_shape, grid_
4848
4949
Returns a callable (t, field) → field that injects scaled source values.
5050
"""
51-
mask = xp.array(mask_raw, dtype=bool).flatten(order="F")
51+
mask = xp.array(mask_raw, dtype=bool).ravel()
5252
if mask.size == 1:
53-
mask = xp.full(grid_shape, bool(mask[0]), dtype=bool).flatten(order="F")
53+
mask = xp.full(grid_shape, bool(mask[0]), dtype=bool).ravel()
5454
n_src = int(xp.sum(mask))
5555

56-
signal_arr = xp.array(signal_raw, dtype=float, order="F")
56+
signal_arr = xp.array(signal_raw, dtype=float)
5757
if signal_arr.ndim == 1:
5858
signal = signal_arr.reshape(1, -1)
5959
else:
60-
signal = signal_arr.reshape(-1, signal_arr.shape[-1], order="F") if signal_arr.ndim > 2 else signal_arr
60+
signal = signal_arr.reshape(-1, signal_arr.shape[-1]) if signal_arr.ndim > 2 else signal_arr
6161

6262
scaled = signal * xp.atleast_1d(xp.asarray(scale))[:, None]
6363
signal_len = scaled.shape[1]
@@ -70,9 +70,9 @@ def get_val(t):
7070
def dirichlet(t, field):
7171
if t >= signal_len:
7272
return field
73-
flat = field.flatten(order="F") # copy — mutation is intentional
73+
flat = field.flatten() # copy — mutation is intentional
7474
flat[mask] = get_val(t)
75-
return flat.reshape(grid_shape, order="F")
75+
return flat.reshape(grid_shape)
7676

7777
# Pre-allocate buffer to avoid per-step allocation
7878
_src_buf = xp.zeros(grid_size, dtype=float)
@@ -82,15 +82,15 @@ def additive_kspace(t, field):
8282
return field
8383
_src_buf[:] = 0
8484
_src_buf[mask] = get_val(t)
85-
src = _src_buf.reshape(grid_shape, order="F")
85+
src = _src_buf.reshape(grid_shape)
8686
return field + diff_fn(src, source_kappa)
8787

8888
def additive_no_correction(t, field):
8989
if t >= signal_len:
9090
return field
9191
_src_buf[:] = 0
9292
_src_buf[mask] = get_val(t)
93-
return field + _src_buf.reshape(grid_shape, order="F")
93+
return field + _src_buf.reshape(grid_shape)
9494

9595
ops = {"dirichlet": dirichlet, "additive": additive_kspace, "additive-no-correction": additive_no_correction}
9696
if mode not in ops:
@@ -210,19 +210,19 @@ def _is_cartesian(arr):
210210

211211
if mask_raw is None:
212212
self.n_sensor_points = grid_numel
213-
self._extract = lambda f: f.flatten(order="F")
213+
self._extract = lambda f: f.ravel()
214214
else:
215215
mask_arr = np.asarray(mask_raw, dtype=float)
216216
# Check Cartesian first to avoid ambiguity when size == grid_numel
217217
if _is_cartesian(mask_arr):
218218
self._setup_cartesian_extract(mask_arr)
219219
elif _is_binary(mask_arr):
220-
bmask = xp.array(mask_arr, dtype=bool).flatten(order="F")
220+
bmask = xp.array(mask_arr, dtype=bool).ravel()
221221
if bmask.size == 1:
222222
bmask = xp.full(grid_numel, bool(bmask[0]), dtype=bool)
223223
self.n_sensor_points = int(xp.sum(bmask))
224224
idx = xp.where(bmask)[0]
225-
self._extract = lambda f, _i=idx: f.flatten(order="F")[_i]
225+
self._extract = lambda f, _i=idx: f.ravel()[_i]
226226
else:
227227
raise ValueError(
228228
f"Sensor mask shape {mask_arr.shape} is neither binary " f"(numel={grid_numel}) nor Cartesian ({self.ndim}, N_points)"
@@ -289,7 +289,7 @@ def _setup_cartesian_extract(self, cart_pos):
289289
x_vec, cart_x = axis_coords[0], cart.flatten()
290290

291291
def _extract_1d_interp(f):
292-
return xp.asarray(np.interp(cart_x, x_vec, _to_cpu(f).flatten(order="F")))
292+
return xp.asarray(np.interp(cart_x, x_vec, _to_cpu(f).ravel()))
293293

294294
self._extract = _extract_1d_interp
295295
else:
@@ -298,8 +298,8 @@ def _extract_1d_interp(f):
298298
int_idx = np.clip(np.floor(frac_idx).astype(int), 0, np.array(self.grid_shape)[:, None] - 2)
299299
local = frac_idx - int_idx
300300

301-
# F-order strides and 2^ndim corner enumeration
302-
strides = np.cumprod([1] + list(self.grid_shape[:-1]))
301+
# C-order strides and 2^ndim corner enumeration
302+
strides = np.cumprod([1] + list(self.grid_shape[:0:-1]))[::-1]
303303
n_corners = 2**self.ndim
304304
corner_indices = np.zeros((self.n_sensor_points, n_corners), dtype=int)
305305
corner_weights = np.ones((self.n_sensor_points, n_corners))
@@ -313,7 +313,7 @@ def _extract_1d_interp(f):
313313
corner_weights = xp.array(corner_weights)
314314

315315
def _extract_bilinear(f):
316-
return (f.flatten(order="F")[corner_indices] * corner_weights).sum(axis=1)
316+
return (f.ravel()[corner_indices] * corner_weights).sum(axis=1)
317317

318318
self._extract = _extract_bilinear
319319

@@ -460,15 +460,15 @@ def _setup_source_operators(self):
460460
grid_size = int(np.prod(self.grid_shape))
461461

462462
def _expand_mask(mask_raw):
463-
mask = xp.array(mask_raw, dtype=bool).flatten(order="F")
463+
mask = xp.array(mask_raw, dtype=bool).ravel()
464464
if mask.size == 1:
465-
mask = xp.full(self.grid_shape, bool(mask[0]), dtype=bool).flatten(order="F")
465+
mask = xp.full(self.grid_shape, bool(mask[0]), dtype=bool).ravel()
466466
return mask
467467

468468
def source_scale(mask_raw, c0):
469469
"""Get per-source-point sound speed values."""
470470
mask = _expand_mask(mask_raw)
471-
c0_flat = c0.flatten(order="F")
471+
c0_flat = c0.ravel()
472472
n_src = int(xp.sum(mask))
473473
return c0_flat[mask] if c0_flat.size > 1 else xp.full(n_src, float(c0_flat))
474474

@@ -566,7 +566,7 @@ def _setup_fields(self):
566566
if self.smooth_p0 and self.ndim >= 2:
567567
from kwave.utils.filters import smooth
568568

569-
# p0 is F-order from _expand_to_grid; smooth() is order-agnostic (uses FFT on shape)
569+
# smooth() is order-agnostic (uses FFT on shape)
570570
p0 = xp.asarray(smooth(_to_cpu(p0), restore_max=True))
571571
self._p0_initial = p0
572572
else:
@@ -779,6 +779,53 @@ def create_simulation(kgrid, medium, source, sensor, device="auto", smooth_p0=Fa
779779
)
780780

781781

782+
def _f_to_c_source_reorder(source, grid_shape):
783+
"""Reorder multi-row source signals from MATLAB F-flat to C-flat mask order.
784+
785+
MATLAB sends source signal rows ordered by F-flattened mask indices.
786+
The solver uses C-flat ordering internally. For single-row (uniform)
787+
sources, no reordering is needed.
788+
"""
789+
ndim = len(grid_shape)
790+
if ndim < 2:
791+
return source
792+
source = dict(source) # shallow copy — don't mutate caller's dict
793+
794+
for mask_key, signal_keys in [("p_mask", ["p"]), ("u_mask", ["ux", "uy", "uz"])]:
795+
mask_raw = source.get(mask_key)
796+
if mask_raw is None:
797+
continue
798+
mask = np.asarray(mask_raw, dtype=bool)
799+
if mask.size <= 1:
800+
continue
801+
mask_grid = mask.reshape(grid_shape)
802+
n_src = int(mask_grid.sum())
803+
if n_src < 2:
804+
continue
805+
806+
# Build F→C permutation for mask points
807+
f_nz = np.where(mask_grid.ravel(order="F"))[0]
808+
c_nz = np.where(mask_grid.ravel())[0]
809+
f_equiv = np.ravel_multi_index(np.unravel_index(c_nz, grid_shape), grid_shape, order="F")
810+
perm = np.searchsorted(f_nz, f_equiv)
811+
812+
for sig_key in signal_keys:
813+
sig = source.get(sig_key)
814+
if sig is None:
815+
continue
816+
sig = np.asarray(sig)
817+
if sig.ndim >= 2 and sig.shape[0] == n_src:
818+
source[sig_key] = sig[perm]
819+
820+
return source
821+
822+
782823
def simulate_from_dicts(kgrid, medium, source, sensor, device="auto", smooth_p0=False):
783-
"""MATLAB interop entry point."""
824+
"""MATLAB interop entry point.
825+
826+
Reorders multi-row source signals from MATLAB's F-flat mask ordering
827+
to the solver's C-flat ordering before running the simulation.
828+
"""
829+
grid_shape = tuple(kgrid[k] for k in ["Nx", "Ny", "Nz"] if k in kgrid)
830+
source = _f_to_c_source_reorder(source, grid_shape)
784831
return create_simulation(kgrid, medium, source, sensor, device, smooth_p0=smooth_p0).run()

0 commit comments

Comments
 (0)