Skip to content

Commit 4ff544e

Browse files
waltsimsclaude
andcommitted
Add agent-first CLI for k-Wave simulations
Structured JSON I/O at every step, deterministic exit codes, file-backed session model. Commands: session init/status/reset, phantom generate, sensor define, plan, run. Acceptance test: new_api_ivp_2D.py runs entirely via CLI with exact numerical parity to the Python API. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 785c0bd commit 4ff544e

16 files changed

Lines changed: 960 additions & 3 deletions

File tree

kwave/cli/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
def __getattr__(name):
2+
if name == "Session":
3+
from kwave.cli.session import Session
4+
5+
return Session
6+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
7+
8+
9+
__all__ = ["Session"]

kwave/cli/commands/__init__.py

Whitespace-only changes.

kwave/cli/commands/phantom.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Phantom generation and loading commands."""
2+
3+
import click
4+
import numpy as np
5+
6+
from kwave.cli.main import pass_session
7+
from kwave.cli.schema import CLIError, CLIResponse, ValidationError, json_command
8+
9+
10+
@click.group("phantom")
11+
def phantom():
12+
"""Define the simulation phantom (medium + initial pressure)."""
13+
pass
14+
15+
16+
@phantom.command()
17+
@click.option("--type", "phantom_type", required=True, type=click.Choice(["disc", "spherical", "layered"]))
18+
@click.option("--grid-size", required=True, help="Grid dimensions, e.g. 128,128")
19+
@click.option("--spacing", required=True, type=float, help="Grid spacing in meters, e.g. 0.1e-3")
20+
@click.option("--sound-speed", type=float, default=1500, help="Medium sound speed (m/s)")
21+
@click.option("--density", type=float, default=1000, help="Medium density (kg/m^3)")
22+
@click.option("--disc-center", default=None, help="Disc center, e.g. 64,64")
23+
@click.option("--disc-radius", type=int, default=5, help="Disc radius in grid points")
24+
@pass_session
25+
@json_command("phantom.generate")
26+
def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_center, disc_radius):
27+
"""Generate an analytical phantom."""
28+
sess.load()
29+
30+
grid_n = tuple(int(x) for x in grid_size.split(","))
31+
ndim = len(grid_n)
32+
grid_spacing = (spacing,) * ndim
33+
34+
if phantom_type == "disc":
35+
if ndim != 2:
36+
raise ValidationError(
37+
CLIError(
38+
code="DISC_REQUIRES_2D",
39+
field="grid_size",
40+
value=grid_size,
41+
constraint="disc phantom requires 2D grid",
42+
suggestion="Use --grid-size Nx,Ny (two dimensions)",
43+
)
44+
)
45+
from kwave.data import Vector
46+
from kwave.utils.mapgen import make_disc
47+
48+
if disc_center is None:
49+
center = Vector([n // 2 for n in grid_n])
50+
else:
51+
center = Vector([int(x) for x in disc_center.split(",")])
52+
53+
p0 = make_disc(Vector(list(grid_n)), center, disc_radius).astype(float)
54+
55+
elif phantom_type == "spherical":
56+
# Simple spherical inclusion centered in grid
57+
center = np.array([n // 2 for n in grid_n])
58+
coords = np.mgrid[tuple(slice(0, n) for n in grid_n)]
59+
dist = np.sqrt(sum((c - cn) ** 2 for c, cn in zip(coords, center)))
60+
p0 = (dist <= disc_radius).astype(float)
61+
62+
elif phantom_type == "layered":
63+
p0 = np.zeros(grid_n)
64+
layer_pos = grid_n[0] // 4
65+
p0[layer_pos, ...] = 1.0
66+
67+
# Save arrays
68+
p0_path = sess.save_array("p0", p0)
69+
70+
# Update session
71+
sess.update(
72+
"grid",
73+
{
74+
"N": list(grid_n),
75+
"spacing": list(grid_spacing),
76+
"sound_speed_for_time": sound_speed,
77+
},
78+
)
79+
sess.update(
80+
"medium",
81+
{
82+
"sound_speed": sound_speed,
83+
"density": density,
84+
},
85+
)
86+
sess.update(
87+
"source",
88+
{
89+
"type": "initial-pressure",
90+
"p0_path": p0_path,
91+
},
92+
)
93+
94+
return CLIResponse(
95+
result={
96+
"phantom_type": phantom_type,
97+
"grid_size": list(grid_n),
98+
"spacing": list(grid_spacing),
99+
"p0_shape": list(p0.shape),
100+
"p0_max": float(p0.max()),
101+
"sound_speed": sound_speed,
102+
"density": density,
103+
},
104+
derived={
105+
"ndim": ndim,
106+
"grid_points": int(np.prod(grid_n)),
107+
},
108+
)

kwave/cli/commands/plan.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Plan command: derive full simulation config, validate, estimate cost."""
2+
3+
import click
4+
import numpy as np
5+
6+
from kwave.cli.main import pass_session
7+
from kwave.cli.schema import CLIResponse, SessionError, json_command
8+
9+
10+
@click.command("plan")
11+
@pass_session
12+
@json_command("plan")
13+
def plan(sess):
14+
"""Derive full simulation config and validate before running."""
15+
sess.load()
16+
17+
# Check completeness
18+
comp = sess._completeness()
19+
missing = [k for k, v in comp.items() if not v]
20+
if missing:
21+
raise SessionError(f"Cannot plan: missing {', '.join(missing)}. Complete setup first.")
22+
23+
kgrid = sess.make_grid()
24+
medium = sess.make_medium()
25+
26+
g = sess.state["grid"]
27+
grid_n = tuple(g["N"])
28+
spacing = tuple(g["spacing"])
29+
ndim = len(grid_n)
30+
grid_points = int(np.prod(grid_n))
31+
32+
# Time stepping
33+
dt = float(kgrid.dt)
34+
Nt = int(kgrid.Nt)
35+
36+
# PPW check
37+
c_min = float(np.min(medium.sound_speed)) if hasattr(medium.sound_speed, "__len__") else float(medium.sound_speed)
38+
max_spacing = max(spacing)
39+
# For IVP problems, use grid-based wavelength estimate
40+
min_wavelength = c_min * dt * Nt / 2 # rough estimate
41+
ppw = c_min / (max_spacing * (1 / (2 * dt))) # Nyquist-based PPW
42+
43+
# CFL
44+
c_max = float(np.max(medium.sound_speed)) if hasattr(medium.sound_speed, "__len__") else float(medium.sound_speed)
45+
cfl = c_max * dt / min(spacing)
46+
47+
# Memory estimate: ~(3 + 2*ndim) fields of float64
48+
n_fields = 3 + 2 * ndim
49+
memory_bytes = grid_points * n_fields * 8
50+
memory_mb = memory_bytes / (1024 * 1024)
51+
52+
# Runtime estimate
53+
cost_per_point_step_ns = 50 # ~50ns per grid point per step on CPU
54+
estimated_runtime_s = grid_points * Nt * cost_per_point_step_ns / 1e9
55+
56+
# PML
57+
pml_size = 20
58+
59+
# Warnings
60+
warnings = []
61+
if ppw < 4:
62+
warnings.append(
63+
{
64+
"code": "LOW_PPW",
65+
"detail": f"PPW={ppw:.1f} is below recommended minimum of 4",
66+
"suggestion": "Increase grid resolution or reduce maximum frequency",
67+
}
68+
)
69+
if cfl > 0.5:
70+
warnings.append(
71+
{
72+
"code": "HIGH_CFL",
73+
"detail": f"CFL={cfl:.3f} exceeds 0.5, simulation may be unstable",
74+
"suggestion": "Reduce time step or increase grid spacing",
75+
}
76+
)
77+
78+
result = {
79+
"grid": {
80+
"N": list(grid_n),
81+
"spacing": list(spacing),
82+
"ndim": ndim,
83+
"dt": dt,
84+
"Nt": Nt,
85+
},
86+
"pml": {"size": pml_size},
87+
"medium": {
88+
"sound_speed": c_min if c_min == c_max else f"{c_min}-{c_max}",
89+
},
90+
"source": sess.state["source"],
91+
"sensor": sess.state["sensor"],
92+
"backend": "python",
93+
"device": "cpu",
94+
}
95+
96+
derived = {
97+
"ppw": round(ppw, 2),
98+
"cfl": round(cfl, 4),
99+
"grid_points": grid_points,
100+
"estimated_memory_mb": round(memory_mb, 1),
101+
"estimated_runtime_s": round(estimated_runtime_s, 1),
102+
}
103+
104+
return CLIResponse(
105+
result=result,
106+
derived=derived,
107+
warnings=warnings,
108+
)

kwave/cli/commands/run.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""Run command: execute simulation with structured JSON progress."""
2+
3+
import json
4+
import sys
5+
import time
6+
7+
import click
8+
import numpy as np
9+
10+
from kwave.cli.main import pass_session
11+
from kwave.cli.schema import CLIResponse, SessionError, json_command
12+
13+
14+
def _emit_event(event: dict):
15+
"""Write a JSON event to stdout and flush."""
16+
click.echo(json.dumps(event, default=str))
17+
18+
19+
@click.command("run")
20+
@click.option("--backend", default="python", type=click.Choice(["python", "cpp"]))
21+
@click.option("--device", default="cpu", type=click.Choice(["cpu", "gpu"]))
22+
@pass_session
23+
@json_command("run")
24+
def run(sess, backend, device):
25+
"""Execute the simulation."""
26+
sess.load()
27+
28+
# Check completeness
29+
comp = sess._completeness()
30+
missing = [k for k, v in comp.items() if not v]
31+
if missing:
32+
raise SessionError(f"Cannot run: missing {', '.join(missing)}. Complete setup first.")
33+
34+
kgrid = sess.make_grid()
35+
medium = sess.make_medium()
36+
source = sess.make_source()
37+
sensor = sess.make_sensor()
38+
39+
Nt = int(kgrid.Nt)
40+
41+
_emit_event({"event": "started", "backend": backend, "device": device, "Nt": Nt})
42+
43+
t_start = time.time()
44+
last_pct = -5 # emit at most every 5%
45+
46+
def progress_callback(step, total):
47+
nonlocal last_pct
48+
pct = round(100 * step / total, 1)
49+
if pct - last_pct >= 5 or step == total:
50+
last_pct = pct
51+
_emit_event(
52+
{
53+
"event": "progress",
54+
"step": step,
55+
"total": total,
56+
"pct": pct,
57+
"elapsed_s": round(time.time() - t_start, 2),
58+
}
59+
)
60+
61+
from kwave.kspaceFirstOrder import kspaceFirstOrder
62+
63+
result = kspaceFirstOrder(
64+
kgrid,
65+
medium,
66+
source,
67+
sensor,
68+
backend=backend,
69+
device=device,
70+
quiet=True,
71+
progress_callback=progress_callback,
72+
)
73+
74+
elapsed = round(time.time() - t_start, 2)
75+
76+
# Save results
77+
result_info = {}
78+
for key, val in result.items():
79+
if isinstance(val, np.ndarray):
80+
path = sess.save_array(f"result_{key}", val)
81+
result_info[key] = {"shape": list(val.shape), "path": path}
82+
else:
83+
result_info[key] = val
84+
85+
sess.update("result_path", str(sess.data_dir))
86+
87+
_emit_event({"event": "completed", "elapsed_s": elapsed, "output_keys": list(result.keys())})
88+
89+
return CLIResponse(
90+
result={
91+
"elapsed_s": elapsed,
92+
"outputs": result_info,
93+
},
94+
)

kwave/cli/commands/sensor.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Sensor definition command."""
2+
3+
import click
4+
5+
from kwave.cli.main import pass_session
6+
from kwave.cli.schema import CLIResponse, json_command
7+
8+
9+
@click.group("sensor")
10+
def sensor():
11+
"""Define sensor configuration."""
12+
pass
13+
14+
15+
@sensor.command()
16+
@click.option("--mask", required=True, help="Sensor mask: 'full-grid' or path to .npy file")
17+
@click.option("--record", default="p,p_final", help="Comma-separated fields to record, e.g. p,p_final,ux")
18+
@pass_session
19+
@json_command("sensor.define")
20+
def define(sess, mask, record):
21+
"""Define what and where to record."""
22+
sess.load()
23+
24+
record_fields = [r.strip() for r in record.split(",")]
25+
26+
sensor_config = {"record": record_fields}
27+
if mask == "full-grid":
28+
sensor_config["mask_type"] = "full-grid"
29+
else:
30+
sensor_config["mask_type"] = "file"
31+
sensor_config["mask_path"] = mask
32+
33+
sess.update("sensor", sensor_config)
34+
35+
return CLIResponse(
36+
result={
37+
"mask_type": sensor_config["mask_type"],
38+
"record": record_fields,
39+
}
40+
)

0 commit comments

Comments
 (0)