Skip to content

Commit 945a042

Browse files
waltsimsclaude
andcommitted
Add phantom load, source define, and 1D IVP support
- phantom load: accepts .npy files for heterogeneous medium properties - source define: loads custom p0 from .npy file - sensor define: supports file-based masks (sparse sensors) - Grid materializer passes full sound speed array to makeTime (not just max) - Custom CFL support via --cfl flag - Test: ivp_1D_simulation.py replicated via CLI with exact parity Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4ff544e commit 945a042

5 files changed

Lines changed: 318 additions & 6 deletions

File tree

kwave/cli/commands/phantom.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,75 @@ def phantom():
1313
pass
1414

1515

16+
@phantom.command("load")
17+
@click.option("--grid-size", required=True, help="Grid dimensions, e.g. 512 or 128,128")
18+
@click.option("--spacing", required=True, type=float, help="Grid spacing in meters")
19+
@click.option("--sound-speed", required=True, help="Scalar value (m/s) or path to .npy file")
20+
@click.option("--density", default=None, help="Scalar value (kg/m^3) or path to .npy file")
21+
@click.option("--cfl", type=float, default=None, help="CFL number for time step calculation")
22+
@pass_session
23+
@json_command("phantom.load")
24+
def load(sess, grid_size, spacing, sound_speed, density, cfl):
25+
"""Load medium properties from scalar values or .npy files."""
26+
sess.load()
27+
28+
grid_n = tuple(int(x) for x in grid_size.split(","))
29+
ndim = len(grid_n)
30+
grid_spacing = (spacing,) * ndim
31+
32+
medium_state = {}
33+
34+
# Sound speed: scalar or .npy path
35+
if sound_speed.endswith(".npy"):
36+
arr = np.load(sound_speed)
37+
path = sess.save_array("sound_speed", arr)
38+
medium_state["sound_speed_path"] = path
39+
medium_state["sound_speed_scalar"] = None
40+
# Store path so makeTime gets the full array (not just max)
41+
sound_speed_for_time_path = path
42+
sound_speed_for_time_scalar = None
43+
else:
44+
medium_state["sound_speed_scalar"] = float(sound_speed)
45+
medium_state["sound_speed_path"] = None
46+
sound_speed_for_time_path = None
47+
sound_speed_for_time_scalar = float(sound_speed)
48+
49+
# Density: scalar, .npy path, or None
50+
if density is not None:
51+
if density.endswith(".npy"):
52+
arr = np.load(density)
53+
path = sess.save_array("density", arr)
54+
medium_state["density_path"] = path
55+
medium_state["density_scalar"] = None
56+
else:
57+
medium_state["density_scalar"] = float(density)
58+
medium_state["density_path"] = None
59+
60+
grid_state = {
61+
"N": list(grid_n),
62+
"spacing": list(grid_spacing),
63+
"sound_speed_for_time_path": sound_speed_for_time_path,
64+
"sound_speed_for_time_scalar": sound_speed_for_time_scalar,
65+
}
66+
if cfl is not None:
67+
grid_state["cfl"] = cfl
68+
69+
sess.update("grid", grid_state)
70+
sess.update("medium", medium_state)
71+
72+
return CLIResponse(
73+
result={
74+
"grid_size": list(grid_n),
75+
"spacing": list(grid_spacing),
76+
"medium": medium_state,
77+
},
78+
derived={
79+
"ndim": ndim,
80+
"grid_points": int(np.prod(grid_n)),
81+
},
82+
)
83+
84+
1685
@phantom.command()
1786
@click.option("--type", "phantom_type", required=True, type=click.Choice(["disc", "spherical", "layered"]))
1887
@click.option("--grid-size", required=True, help="Grid dimensions, e.g. 128,128")
@@ -73,7 +142,8 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_
73142
{
74143
"N": list(grid_n),
75144
"spacing": list(grid_spacing),
76-
"sound_speed_for_time": sound_speed,
145+
"sound_speed_for_time_scalar": sound_speed,
146+
"sound_speed_for_time_path": None,
77147
},
78148
)
79149
sess.update(

kwave/cli/commands/source.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Source definition command."""
2+
3+
import click
4+
import numpy as np
5+
6+
from kwave.cli.main import pass_session
7+
from kwave.cli.schema import CLIResponse, json_command
8+
9+
10+
@click.group("source")
11+
def source():
12+
"""Define simulation source."""
13+
pass
14+
15+
16+
@source.command()
17+
@click.option("--type", "source_type", required=True, type=click.Choice(["initial-pressure"]))
18+
@click.option("--p0-file", required=True, type=click.Path(exists=True), help="Path to .npy file with initial pressure distribution")
19+
@pass_session
20+
@json_command("source.define")
21+
def define(sess, source_type, p0_file):
22+
"""Define source from file."""
23+
sess.load()
24+
25+
p0 = np.load(p0_file)
26+
p0_path = sess.save_array("p0", p0)
27+
28+
sess.update(
29+
"source",
30+
{
31+
"type": source_type,
32+
"p0_path": p0_path,
33+
},
34+
)
35+
36+
return CLIResponse(
37+
result={
38+
"type": source_type,
39+
"p0_shape": list(p0.shape),
40+
"p0_max": float(p0.max()),
41+
"p0_min": float(p0.min()),
42+
}
43+
)

kwave/cli/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@ def cli(ctx, session_dir):
2525
from kwave.cli.commands.run import run # noqa: E402
2626
from kwave.cli.commands.sensor import sensor # noqa: E402
2727
from kwave.cli.commands.session_cmd import session # noqa: E402
28+
from kwave.cli.commands.source import source # noqa: E402
2829

2930
cli.add_command(session)
3031
cli.add_command(phantom)
3132
cli.add_command(sensor)
33+
cli.add_command(source)
3234
cli.add_command(plan)
3335
cli.add_command(run)

kwave/cli/session.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,24 @@ def make_grid(self):
123123

124124
g = self.state["grid"]
125125
if g is None:
126-
raise SessionError("Grid not defined. Run 'kwave phantom generate' first.")
126+
raise SessionError("Grid not defined. Run 'kwave phantom generate' or 'kwave phantom load' first.")
127127
grid_size = tuple(g["N"])
128128
grid_spacing = tuple(g["spacing"])
129129
kgrid = kWaveGrid(grid_size, grid_spacing)
130-
if g.get("sound_speed_for_time") is not None:
131-
kgrid.makeTime(g["sound_speed_for_time"])
130+
131+
# Resolve sound speed for time stepping (array or scalar)
132+
sound_speed = None
133+
if g.get("sound_speed_for_time_path") is not None:
134+
sound_speed = np.load(g["sound_speed_for_time_path"])
135+
elif g.get("sound_speed_for_time_scalar") is not None:
136+
sound_speed = g["sound_speed_for_time_scalar"]
137+
138+
if sound_speed is not None:
139+
cfl = g.get("cfl")
140+
if cfl is not None:
141+
kgrid.makeTime(sound_speed, cfl=cfl)
142+
else:
143+
kgrid.makeTime(sound_speed)
132144
return kgrid
133145

134146
def make_medium(self):
@@ -137,10 +149,19 @@ def make_medium(self):
137149

138150
m = self.state["medium"]
139151
if m is None:
140-
raise SessionError("Medium not defined. Run 'kwave phantom generate' first.")
152+
raise SessionError("Medium not defined. Run 'kwave phantom generate' or 'kwave phantom load' first.")
141153
kwargs = {}
154+
# Handle fields that can be scalars or .npy paths
142155
for field in ("sound_speed", "density", "alpha_coeff", "alpha_power", "BonA"):
143-
if field in m and m[field] is not None:
156+
# Check for path-based storage (from phantom load)
157+
path_key = f"{field}_path"
158+
scalar_key = f"{field}_scalar"
159+
if path_key in m and m[path_key] is not None:
160+
kwargs[field] = np.load(m[path_key])
161+
elif scalar_key in m and m[scalar_key] is not None:
162+
kwargs[field] = m[scalar_key]
163+
# Fallback: direct scalar value (from phantom generate)
164+
elif field in m and m[field] is not None:
144165
kwargs[field] = m[field]
145166
return kWaveMedium(**kwargs)
146167

tests/test_cli/test_1d_ivp.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""Test: replicate ivp_1D_simulation.py via CLI commands.
2+
3+
Exercises: 1D grid, heterogeneous medium (array .npy files),
4+
custom p0, sparse sensor mask, custom CFL.
5+
"""
6+
7+
import json
8+
9+
import numpy as np
10+
import pytest
11+
from click.testing import CliRunner
12+
13+
from kwave.cli.main import cli
14+
from kwave.data import Vector
15+
from kwave.kgrid import kWaveGrid
16+
from kwave.kmedium import kWaveMedium
17+
from kwave.ksensor import kSensor
18+
from kwave.ksource import kSource
19+
from kwave.kspaceFirstOrder import kspaceFirstOrder
20+
21+
22+
def _invoke(runner, args, session_dir):
23+
result = runner.invoke(cli, ["--session-dir", str(session_dir)] + args, catch_exceptions=False)
24+
assert result.exit_code == 0, f"Command failed: {args}\n{result.output}"
25+
output = result.output.strip()
26+
try:
27+
return json.loads(output)
28+
except json.JSONDecodeError:
29+
pass
30+
# For run command: find the last top-level JSON object
31+
depth = 0
32+
last_start = None
33+
for i, ch in enumerate(output):
34+
if ch == "{" and depth == 0:
35+
last_start = i
36+
if ch == "{":
37+
depth += 1
38+
elif ch == "}":
39+
depth -= 1
40+
if last_start is not None:
41+
return json.loads(output[last_start:])
42+
raise ValueError(f"Could not parse JSON from output: {output[:200]}")
43+
44+
45+
# --- Build the 1D IVP arrays (same as ivp_1D_simulation.py) ---
46+
47+
Nx = 512
48+
dx = 0.05e-3
49+
50+
51+
def _make_sound_speed():
52+
c = 1500 * np.ones(Nx)
53+
c[: Nx // 3] = 2000
54+
return c
55+
56+
57+
def _make_density():
58+
rho = 1000 * np.ones(Nx)
59+
rho[4 * Nx // 5 :] = 1500
60+
return rho
61+
62+
63+
def _make_p0():
64+
p0 = np.zeros(Nx)
65+
x0, width = 280, 100
66+
pulse = 0.5 * (np.sin(np.arange(width + 1) * np.pi / width - np.pi / 2) + 1)
67+
p0[x0 : x0 + width + 1] = pulse
68+
return p0
69+
70+
71+
def _make_sensor_mask():
72+
mask = np.zeros(Nx)
73+
mask[Nx // 4] = 1
74+
mask[3 * Nx // 4] = 1
75+
return mask
76+
77+
78+
@pytest.fixture
79+
def session_dir(tmp_path):
80+
return tmp_path / "kwave_test_session"
81+
82+
83+
@pytest.fixture
84+
def data_dir(tmp_path):
85+
"""Directory for pre-built .npy files (simulating what an agent would prepare)."""
86+
d = tmp_path / "arrays"
87+
d.mkdir()
88+
np.save(d / "sound_speed.npy", _make_sound_speed())
89+
np.save(d / "density.npy", _make_density())
90+
np.save(d / "p0.npy", _make_p0())
91+
np.save(d / "sensor_mask.npy", _make_sensor_mask())
92+
return d
93+
94+
95+
@pytest.fixture
96+
def runner():
97+
return CliRunner()
98+
99+
100+
class TestCLI1DIVP:
101+
"""Replicate ivp_1D_simulation.py end-to-end via CLI."""
102+
103+
def test_cli_matches_python_api(self, runner, session_dir, data_dir):
104+
# -- CLI flow --
105+
_invoke(runner, ["session", "init"], session_dir)
106+
107+
_invoke(
108+
runner,
109+
[
110+
"phantom",
111+
"load",
112+
"--grid-size",
113+
"512",
114+
"--spacing",
115+
"0.05e-3",
116+
"--sound-speed",
117+
str(data_dir / "sound_speed.npy"),
118+
"--density",
119+
str(data_dir / "density.npy"),
120+
"--cfl",
121+
"0.3",
122+
],
123+
session_dir,
124+
)
125+
126+
_invoke(
127+
runner,
128+
[
129+
"source",
130+
"define",
131+
"--type",
132+
"initial-pressure",
133+
"--p0-file",
134+
str(data_dir / "p0.npy"),
135+
],
136+
session_dir,
137+
)
138+
139+
_invoke(
140+
runner,
141+
[
142+
"sensor",
143+
"define",
144+
"--mask",
145+
str(data_dir / "sensor_mask.npy"),
146+
"--record",
147+
"p",
148+
],
149+
session_dir,
150+
)
151+
152+
plan_resp = _invoke(runner, ["plan"], session_dir)
153+
assert plan_resp["status"] == "ok"
154+
assert plan_resp["result"]["grid"]["N"] == [512]
155+
assert plan_resp["result"]["grid"]["Nt"] > 0
156+
157+
run_resp = _invoke(runner, ["run"], session_dir)
158+
assert run_resp["status"] == "ok"
159+
160+
# Load CLI results
161+
cli_p = np.load(run_resp["result"]["outputs"]["p"]["path"])
162+
163+
# -- Direct Python API (the example) --
164+
sound_speed = _make_sound_speed()
165+
density = _make_density()
166+
kgrid = kWaveGrid(Vector([Nx]), Vector([dx]))
167+
kgrid.makeTime(sound_speed, cfl=0.3)
168+
medium = kWaveMedium(sound_speed=sound_speed, density=density)
169+
source = kSource()
170+
source.p0 = _make_p0()
171+
sensor = kSensor(mask=_make_sensor_mask())
172+
result = kspaceFirstOrder(kgrid, medium, source, sensor, backend="python", quiet=True)
173+
174+
# -- Compare --
175+
assert cli_p.shape == result["p"].shape, f"Shape mismatch: {cli_p.shape} vs {result['p'].shape}"
176+
np.testing.assert_allclose(cli_p, result["p"], rtol=0, atol=0)

0 commit comments

Comments
 (0)