Skip to content

Commit f5522e1

Browse files
waltsimsclaude
andcommitted
Simplify output to (n_sensor, Nt) everywhere, add reshape_to_grid helper
Keep the simplest possible output contract: all time-series are (n_sensor, Nt) with sensor points in C-flattened order. Aggregates are (n_sensor,). No automatic reshaping. Users who want spatial structure call reshape_to_grid(data, grid_shape) which converts (n_sensor, Nt) → (*grid_shape, Nt). Removes _reshape_sensor_to_grid and all the time-first reshaping logic. Integration tests use _c_to_f_reorder to compare C-flat Python output against F-flat MATLAB references. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a634930 commit f5522e1

6 files changed

Lines changed: 81 additions & 116 deletions

File tree

kwave/kspaceFirstOrder.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,9 @@ def kspaceFirstOrder(
157157
dict: Recorded sensor data keyed by field name (e.g.
158158
``"p"``, ``"p_final"``, ``"ux"``, ``"uy"``).
159159
160-
Sensor time-series are C-ordered. When the sensor mask covers the
161-
entire grid, time-series fields are returned as
162-
``(Nt, *grid_shape)`` (time-first). For partial masks the shape is
163-
``(n_sensor, Nt)`` with sensor points in C-flattened order.
160+
All time-series are ``(n_sensor, Nt)`` with sensor points in
161+
C-flattened order. Use :func:`reshape_to_grid` to recover spatial
162+
structure for full-grid masks.
164163
"""
165164
if device not in ("cpu", "gpu"):
166165
raise ValueError(f"device must be 'cpu' or 'gpu', got {device!r}")
@@ -178,8 +177,6 @@ def kspaceFirstOrder(
178177

179178
# --- Shared pre-processing (both backends) ---
180179

181-
user_grid_shape = tuple(int(n) for n in kgrid.N)
182-
183180
if not pml_inside:
184181
kgrid, medium, source, sensor = _expand_for_pml_outside(kgrid, medium, source, sensor, pml_size)
185182

@@ -250,35 +247,28 @@ def kspaceFirstOrder(
250247
if not pml_inside:
251248
result = _strip_pml(result, pml_size, kgrid.dim)
252249

253-
result = _reshape_sensor_to_grid(result, sensor, user_grid_shape)
254-
255250
return result
256251

257252

258-
def _reshape_sensor_to_grid(result, sensor, user_grid_shape):
259-
"""Reshape sensor time-series to (Nt, *grid_shape) when the sensor covers the full user grid.
253+
def reshape_to_grid(data, grid_shape):
254+
"""Reshape flat sensor data to grid shape.
260255
261-
After PML expansion the sensor mask is padded with zeros, so we check
262-
against the user's original grid shape (before expansion).
256+
Convenience helper for full-grid sensor masks where ``n_sensor``
257+
equals the total number of grid points.
258+
259+
Args:
260+
data: sensor array — ``(n_sensor, Nt)`` time-series or
261+
``(n_sensor,)`` aggregate.
262+
grid_shape: tuple of grid dimensions, e.g. ``(Nx, Ny)``.
263+
264+
Returns:
265+
For time-series: ``(*grid_shape, Nt)``
266+
For aggregates: ``(*grid_shape)``
263267
"""
264-
user_numel = int(np.prod(user_grid_shape))
265-
266-
mask = getattr(sensor, "mask", None) if sensor is not None else None
267-
if mask is None:
268-
n_sensor = user_numel
269-
elif _is_cartesian_mask(mask, len(user_grid_shape)):
270-
return result
271-
else:
272-
n_sensor = int(np.asarray(mask, dtype=bool).sum())
273-
274-
if n_sensor != user_numel:
275-
return result
276-
277-
for key, val in result.items():
278-
if not isinstance(val, np.ndarray):
279-
continue
280-
if val.ndim == 2 and val.shape[0] == n_sensor:
281-
result[key] = val.T.reshape(-1, *user_grid_shape)
282-
elif val.ndim == 1 and val.shape[0] == n_sensor:
283-
result[key] = val.reshape(user_grid_shape)
284-
return result
268+
data = np.asarray(data)
269+
if data.ndim == 2:
270+
n_sensor, Nt = data.shape
271+
return data.reshape(*grid_shape, Nt)
272+
elif data.ndim == 1:
273+
return data.reshape(grid_shape)
274+
return data

tests/integration/conftest.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,48 +33,44 @@ def _load(name):
3333
return _load
3434

3535

36-
def _to_matlab_shape(py_val, mat_val):
37-
"""Reshape C-order Python output to match MATLAB F-order reference shape.
36+
def _c_to_f_reorder(py_val, grid_shape):
37+
"""Reorder sensor rows from C-flat to F-flat ordering for MATLAB comparison.
3838
39-
The new API returns full-grid time-series as (Nt, *grid_shape) in C-order.
40-
MATLAB references store them as (n_sensor, Nt) in F-flat order.
41-
Similarly, aggregates are (*grid_shape) vs MATLAB (n_sensor,).
39+
Python (C-order) and MATLAB (F-order) flatten grid points in different
40+
orders. Both produce (n_sensor, Nt) but the row ordering differs.
41+
This builds a permutation to reorder Python rows to match MATLAB.
4242
"""
43-
if py_val.shape == mat_val.shape:
43+
if grid_shape is None or len(grid_shape) < 2:
4444
return py_val
45-
46-
# Time-series ≥3D: (Nt, *grid_shape) → (n_sensor, Nt) with F-order flatten
47-
if py_val.ndim >= 3 and mat_val.ndim == 2:
48-
Nt = py_val.shape[0]
49-
# Move time axis last, then F-order flatten the grid dims
50-
return np.moveaxis(py_val, 0, -1).reshape(-1, Nt, order="F")
51-
52-
# Time-series 1D: (Nt, N) → (N, Nt) — both 2D but transposed
53-
if py_val.ndim == 2 and mat_val.ndim == 2 and py_val.shape == mat_val.shape[::-1]:
54-
return py_val.T
55-
56-
# Aggregates: (*grid_shape) → (n_sensor,) with F-order flatten
57-
if py_val.ndim >= 2 and mat_val.ndim == 1:
58-
return py_val.ravel(order="F")
59-
45+
c_indices = np.arange(int(np.prod(grid_shape)))
46+
f_indices = np.ravel_multi_index(np.unravel_index(c_indices, grid_shape), grid_shape, order="F")
47+
# f_indices[i] = where C-flat point i lands in F-flat order
48+
# We need the inverse: for each F-flat position, which C-flat row?
49+
inv = np.argsort(f_indices)
50+
if py_val.ndim == 2 and py_val.shape[0] == len(c_indices):
51+
return py_val[inv]
52+
if py_val.ndim == 1 and py_val.shape[0] == len(c_indices):
53+
return py_val[inv]
6054
return py_val
6155

6256

63-
def assert_fields_close(result, ref, fields, *, rtol=1e-10, atol=1e-12):
57+
def assert_fields_close(result, ref, fields, *, rtol=1e-10, atol=1e-12, grid_shape=None):
6458
"""Compare Python result dict against MATLAB reference arrays.
6559
6660
Args:
6761
result: dict from kspaceFirstOrder()
6862
ref: dict from scipy.io.loadmat()
6963
fields: list of (python_key, matlab_key) tuples
64+
grid_shape: tuple of grid dims for C→F row reordering (full-grid masks only)
7065
rtol, atol: tolerances passed to np.testing.assert_allclose
7166
"""
7267
for py_key, mat_key in fields:
7368
assert py_key in result, f"Python result missing key '{py_key}'"
7469
assert mat_key in ref, f"MATLAB reference missing key '{mat_key}'"
7570
py_val = np.atleast_1d(np.squeeze(np.asarray(result[py_key])))
7671
mat_val = np.atleast_1d(np.squeeze(np.asarray(ref[mat_key])))
77-
py_val = _to_matlab_shape(py_val, mat_val)
72+
if grid_shape is not None:
73+
py_val = _c_to_f_reorder(py_val, grid_shape)
7874
assert py_val.shape == mat_val.shape, f"Shape mismatch for {py_key}: Python {py_val.shape} vs MATLAB {mat_val.shape}"
7975
np.testing.assert_allclose(
8076
py_val, mat_val, rtol=rtol, atol=atol, err_msg=f"Field '{py_key}' differs from MATLAB reference '{mat_key}'"

tests/integration/test_ivp_2D.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,5 @@ def test_ivp_2D_vs_matlab(load_matlab_ref):
4242
result,
4343
ref,
4444
[("p", "sensor_data_p")],
45+
grid_shape=(128, 128),
4546
)

tests/test_c_order.py

Lines changed: 26 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from kwave.kmedium import kWaveMedium
1010
from kwave.ksensor import kSensor
1111
from kwave.ksource import kSource
12-
from kwave.kspaceFirstOrder import _is_cartesian_mask, _reshape_sensor_to_grid
12+
from kwave.kspaceFirstOrder import reshape_to_grid
1313

1414
# ---------------------------------------------------------------------------
1515
# _fix_output_order (CppSimulation)
@@ -111,55 +111,34 @@ def test_1d_aggregate_reorder(self):
111111

112112

113113
# ---------------------------------------------------------------------------
114-
# _reshape_sensor_to_grid
114+
# reshape_to_grid helper
115115
# ---------------------------------------------------------------------------
116116

117117

118-
class TestReshapeSensorToGrid:
119-
def test_cartesian_mask_unchanged(self):
120-
"""Cartesian sensor masks should not be reshaped."""
121-
sensor = SimpleNamespace(mask=np.array([[0.0, 1e-3], [0.0, 0.0]])) # (2, 2) Cartesian
122-
result = {"p": np.arange(20).reshape(2, 10)}
123-
out = _reshape_sensor_to_grid(result, sensor, (64, 64))
124-
assert out["p"].shape == (2, 10) # unchanged
125-
126-
def test_partial_mask_unchanged(self):
127-
"""Partial binary masks should not be reshaped."""
128-
mask = np.zeros((8, 8), dtype=bool)
129-
mask[0, 0] = True
130-
mask[4, 4] = True
131-
sensor = SimpleNamespace(mask=mask)
132-
result = {"p": np.arange(20).reshape(2, 10)}
133-
out = _reshape_sensor_to_grid(result, sensor, (8, 8))
134-
assert out["p"].shape == (2, 10) # unchanged
135-
136-
def test_full_grid_reshaped(self):
137-
"""Full-grid binary mask is reshaped to (Nt, *grid_shape)."""
138-
sensor = SimpleNamespace(mask=np.ones((4, 6), dtype=bool))
139-
Nt = 5
140-
result = {"p": np.arange(120).reshape(24, Nt)}
141-
out = _reshape_sensor_to_grid(result, sensor, (4, 6))
142-
assert out["p"].shape == (Nt, 4, 6)
143-
144-
def test_aggregate_reshaped(self):
145-
"""1D aggregates reshaped to grid_shape for full-grid masks."""
146-
sensor = SimpleNamespace(mask=np.ones((4, 6), dtype=bool))
147-
result = {"p_max": np.arange(24)}
148-
out = _reshape_sensor_to_grid(result, sensor, (4, 6))
149-
assert out["p_max"].shape == (4, 6)
150-
151-
def test_sensor_none(self):
152-
"""sensor=None means full grid."""
153-
result = {"p": np.arange(80).reshape(16, 5)}
154-
out = _reshape_sensor_to_grid(result, None, (4, 4))
155-
assert out["p"].shape == (5, 4, 4)
156-
157-
def test_non_array_values_unchanged(self):
158-
"""Non-ndarray values in result are passed through."""
159-
sensor = SimpleNamespace(mask=np.ones((4, 4), dtype=bool))
160-
result = {"p": np.arange(80).reshape(16, 5), "metadata": "hello"}
161-
out = _reshape_sensor_to_grid(result, sensor, (4, 4))
162-
assert out["metadata"] == "hello"
118+
class TestReshapeToGrid:
119+
def test_time_series_2d(self):
120+
"""(n_sensor, Nt) → (*grid_shape, Nt)."""
121+
data = np.arange(120).reshape(24, 5)
122+
out = reshape_to_grid(data, (4, 6))
123+
assert out.shape == (4, 6, 5)
124+
125+
def test_aggregate_1d(self):
126+
"""(n_sensor,) → (*grid_shape)."""
127+
data = np.arange(24)
128+
out = reshape_to_grid(data, (4, 6))
129+
assert out.shape == (4, 6)
130+
131+
def test_3d_grid(self):
132+
"""Works with 3D grids."""
133+
data = np.arange(60).reshape(60, 1)
134+
out = reshape_to_grid(data, (3, 4, 5))
135+
assert out.shape == (3, 4, 5, 1)
136+
137+
def test_passthrough_higher_dim(self):
138+
"""Higher-dim arrays pass through unchanged."""
139+
data = np.arange(120).reshape(2, 3, 4, 5)
140+
out = reshape_to_grid(data, (4, 6))
141+
assert out.shape == (2, 3, 4, 5)
163142

164143

165144
# ---------------------------------------------------------------------------

tests/test_kspaceFirstOrder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ def test_python_backend_runs(self, sim_2d):
6161
kgrid, medium, source, sensor = sim_2d
6262
result = kspaceFirstOrder(kgrid, medium, source, sensor, backend="python")
6363
assert "p" in result
64-
# Full-grid sensor → (Nt, *grid_shape) in C-order
65-
assert result["p"].shape == (int(kgrid.Nt), 64, 64)
64+
assert result["p"].shape == (int(sensor.mask.sum()), int(kgrid.Nt))
6665

6766
def test_cpp_save_only(self, sim_2d):
6867
kgrid, medium, source, sensor = sim_2d

tests/test_native_solver.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_p0_source(self, grid_2d):
4343
result = kspaceFirstOrder(
4444
grid_2d, kWaveMedium(sound_speed=1500), _p0_source((64, 64)), kSensor(mask=np.ones((64, 64), dtype=bool)), backend="python"
4545
)
46-
assert result["p"].shape == (10, 64, 64)
46+
assert result["p"].shape == (64 * 64, 10)
4747
assert np.max(np.abs(result["p"])) > 0
4848

4949
def test_heterogeneous_medium(self, grid_2d):
@@ -56,7 +56,7 @@ def test_heterogeneous_medium(self, grid_2d):
5656
kSensor(mask=np.ones((64, 64), dtype=bool)),
5757
backend="python",
5858
)
59-
assert result["p"].shape == (10, 64, 64)
59+
assert result["p"].shape == (64 * 64, 10)
6060

6161
def test_absorption(self, grid_2d):
6262
result = kspaceFirstOrder(
@@ -66,7 +66,7 @@ def test_absorption(self, grid_2d):
6666
kSensor(mask=np.ones((64, 64), dtype=bool)),
6767
backend="python",
6868
)
69-
assert result["p"].shape == (10, 64, 64)
69+
assert result["p"].shape == (64 * 64, 10)
7070

7171
def test_pml_auto(self):
7272
kgrid = kWaveGrid(Vector([128, 128]), Vector([0.1e-3, 0.1e-3]))
@@ -84,7 +84,7 @@ def test_record_aggregates(self, grid_2d):
8484
sensor = kSensor(mask=np.ones((64, 64), dtype=bool))
8585
sensor.record = ["p", "p_max", "p_rms"]
8686
result = kspaceFirstOrder(grid_2d, kWaveMedium(sound_speed=1500), _p0_source((64, 64)), sensor, backend="python")
87-
assert result["p_max"].shape == (64, 64)
87+
assert result["p_max"].shape == (64 * 64,)
8888
assert "p_rms" in result
8989

9090

@@ -119,7 +119,7 @@ def test_nonlinearity_bona(self, grid_2d):
119119
kSensor(mask=np.ones((64, 64), dtype=bool)),
120120
backend="python",
121121
)
122-
assert result["p"].shape == (10, 64, 64)
122+
assert result["p"].shape == (64 * 64, 10)
123123

124124
def test_stokes_absorption(self, grid_2d):
125125
result = kspaceFirstOrder(
@@ -129,7 +129,7 @@ def test_stokes_absorption(self, grid_2d):
129129
kSensor(mask=np.ones((64, 64), dtype=bool)),
130130
backend="python",
131131
)
132-
assert result["p"].shape == (10, 64, 64)
132+
assert result["p"].shape == (64 * 64, 10)
133133

134134
def test_dirichlet_pressure_source(self, grid_1d):
135135
source = kSource()
@@ -146,16 +146,16 @@ def test_velocity_recording(self, grid_2d):
146146
sensor = kSensor(mask=np.ones((64, 64), dtype=bool))
147147
sensor.record = ["p", "ux", "uy", "ux_max", "uy_rms", "ux_final", "p_final"]
148148
result = kspaceFirstOrder(grid_2d, kWaveMedium(sound_speed=1500), _p0_source((64, 64)), sensor, backend="python")
149-
assert result["ux"].shape == (10, 64, 64)
150-
assert result["ux_max"].shape == (64, 64)
149+
assert result["ux"].shape == (64 * 64, 10)
150+
assert result["ux_max"].shape == (64 * 64,)
151151
assert "ux_final" in result and "p_final" in result
152152

153153
def test_intensity_recording(self, grid_2d):
154154
sensor = kSensor(mask=np.ones((64, 64), dtype=bool))
155155
sensor.record = ["p", "ux", "uy", "Ix", "Iy", "Ix_avg", "Iy_avg"]
156156
result = kspaceFirstOrder(grid_2d, kWaveMedium(sound_speed=1500), _p0_source((64, 64)), sensor, backend="python")
157-
assert result["Ix"].shape == (10, 64, 64)
158-
assert result["Ix_avg"].shape == (64, 64)
157+
assert result["Ix"].shape == (64 * 64, 10)
158+
assert result["Ix_avg"].shape == (64 * 64,)
159159

160160
def test_record_start_index(self, grid_1d):
161161
source = kSource()
@@ -172,7 +172,7 @@ def test_sensor_none_records_everywhere(self, grid_1d):
172172
source.p0 = np.zeros(64)
173173
source.p0[32] = 1.0
174174
result = kspaceFirstOrder(grid_1d, kWaveMedium(sound_speed=1500), source, None, backend="python", pml_inside=True)
175-
assert result["p"].shape == (20, 64)
175+
assert result["p"].shape == (64, 20)
176176

177177

178178
class TestCppSaveOnly:

0 commit comments

Comments
 (0)