Skip to content

Commit d26ab30

Browse files
waltsimsclaude
andcommitted
Simplify CLI: normalize schema, deduplicate, remove dead code
- Normalize medium state to unified {field}_scalar/{field}_path schema (phantom generate now uses same format as phantom load) - Remove bogus PPW check from plan (Nyquist-based PPW was a false positive for all valid CFL values) - Remove dead code: unused min_wavelength, unused sys import, unused Optional - Add Session.assert_ready() and Session.update_many() - Rename _completeness() to completeness() (public API) - Extract _invoke helper to tests/test_cli/conftest.py - Shrink e2e test grid from 128x128 to 48x48 - Extract _parse_int_tuple and _resolve_scalar_or_path in phantom.py Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 945a042 commit d26ab30

8 files changed

Lines changed: 166 additions & 287 deletions

File tree

kwave/cli/commands/phantom.py

Lines changed: 29 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@
77
from kwave.cli.schema import CLIError, CLIResponse, ValidationError, json_command
88

99

10+
def _parse_int_tuple(s: str) -> tuple[int, ...]:
11+
return tuple(int(x) for x in s.split(","))
12+
13+
14+
def _resolve_scalar_or_path(value: str, name: str, sess) -> dict:
15+
"""Parse a CLI value as scalar float or .npy path. Returns {name_scalar, name_path} dict."""
16+
if value.endswith(".npy"):
17+
arr = np.load(value)
18+
path = sess.save_array(name, arr)
19+
return {f"{name}_scalar": None, f"{name}_path": path}
20+
return {f"{name}_scalar": float(value), f"{name}_path": None}
21+
22+
1023
@click.group("phantom")
1124
def phantom():
1225
"""Define the simulation phantom (medium + initial pressure)."""
@@ -25,60 +38,23 @@ def load(sess, grid_size, spacing, sound_speed, density, cfl):
2538
"""Load medium properties from scalar values or .npy files."""
2639
sess.load()
2740

28-
grid_n = tuple(int(x) for x in grid_size.split(","))
41+
grid_n = _parse_int_tuple(grid_size)
2942
ndim = len(grid_n)
3043
grid_spacing = (spacing,) * ndim
3144

32-
medium_state = {}
33-
34-
# Sound speed: scalar or .npy path
35-
if sound_speed.endswith(".npy"):
36-
arr = np.load(sound_speed)
37-
path = sess.save_array("sound_speed", arr)
38-
medium_state["sound_speed_path"] = path
39-
medium_state["sound_speed_scalar"] = None
40-
# Store path so makeTime gets the full array (not just max)
41-
sound_speed_for_time_path = path
42-
sound_speed_for_time_scalar = None
43-
else:
44-
medium_state["sound_speed_scalar"] = float(sound_speed)
45-
medium_state["sound_speed_path"] = None
46-
sound_speed_for_time_path = None
47-
sound_speed_for_time_scalar = float(sound_speed)
48-
49-
# Density: scalar, .npy path, or None
45+
medium_state = _resolve_scalar_or_path(sound_speed, "sound_speed", sess)
5046
if density is not None:
51-
if density.endswith(".npy"):
52-
arr = np.load(density)
53-
path = sess.save_array("density", arr)
54-
medium_state["density_path"] = path
55-
medium_state["density_scalar"] = None
56-
else:
57-
medium_state["density_scalar"] = float(density)
58-
medium_state["density_path"] = None
59-
60-
grid_state = {
61-
"N": list(grid_n),
62-
"spacing": list(grid_spacing),
63-
"sound_speed_for_time_path": sound_speed_for_time_path,
64-
"sound_speed_for_time_scalar": sound_speed_for_time_scalar,
65-
}
47+
medium_state.update(_resolve_scalar_or_path(density, "density", sess))
48+
49+
grid_state = {"N": list(grid_n), "spacing": list(grid_spacing)}
6650
if cfl is not None:
6751
grid_state["cfl"] = cfl
6852

69-
sess.update("grid", grid_state)
70-
sess.update("medium", medium_state)
53+
sess.update_many({"grid": grid_state, "medium": medium_state})
7154

7255
return CLIResponse(
73-
result={
74-
"grid_size": list(grid_n),
75-
"spacing": list(grid_spacing),
76-
"medium": medium_state,
77-
},
78-
derived={
79-
"ndim": ndim,
80-
"grid_points": int(np.prod(grid_n)),
81-
},
56+
result={"grid_size": list(grid_n), "spacing": list(grid_spacing), "medium": medium_state},
57+
derived={"ndim": ndim, "grid_points": int(np.prod(grid_n))},
8258
)
8359

8460

@@ -96,7 +72,7 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_
9672
"""Generate an analytical phantom."""
9773
sess.load()
9874

99-
grid_n = tuple(int(x) for x in grid_size.split(","))
75+
grid_n = _parse_int_tuple(grid_size)
10076
ndim = len(grid_n)
10177
grid_spacing = (spacing,) * ndim
10278

@@ -117,12 +93,11 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_
11793
if disc_center is None:
11894
center = Vector([n // 2 for n in grid_n])
11995
else:
120-
center = Vector([int(x) for x in disc_center.split(",")])
96+
center = Vector(_parse_int_tuple(disc_center))
12197

12298
p0 = make_disc(Vector(list(grid_n)), center, disc_radius).astype(float)
12399

124100
elif phantom_type == "spherical":
125-
# Simple spherical inclusion centered in grid
126101
center = np.array([n // 2 for n in grid_n])
127102
coords = np.mgrid[tuple(slice(0, n) for n in grid_n)]
128103
dist = np.sqrt(sum((c - cn) ** 2 for c, cn in zip(coords, center)))
@@ -133,32 +108,14 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_
133108
layer_pos = grid_n[0] // 4
134109
p0[layer_pos, ...] = 1.0
135110

136-
# Save arrays
137111
p0_path = sess.save_array("p0", p0)
138112

139-
# Update session
140-
sess.update(
141-
"grid",
113+
sess.update_many(
142114
{
143-
"N": list(grid_n),
144-
"spacing": list(grid_spacing),
145-
"sound_speed_for_time_scalar": sound_speed,
146-
"sound_speed_for_time_path": None,
147-
},
148-
)
149-
sess.update(
150-
"medium",
151-
{
152-
"sound_speed": sound_speed,
153-
"density": density,
154-
},
155-
)
156-
sess.update(
157-
"source",
158-
{
159-
"type": "initial-pressure",
160-
"p0_path": p0_path,
161-
},
115+
"grid": {"N": list(grid_n), "spacing": list(grid_spacing)},
116+
"medium": {"sound_speed_scalar": sound_speed, "sound_speed_path": None, "density_scalar": density, "density_path": None},
117+
"source": {"type": "initial-pressure", "p0_path": p0_path},
118+
}
162119
)
163120

164121
return CLIResponse(
@@ -171,8 +128,5 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_
171128
"sound_speed": sound_speed,
172129
"density": density,
173130
},
174-
derived={
175-
"ndim": ndim,
176-
"grid_points": int(np.prod(grid_n)),
177-
},
131+
derived={"ndim": ndim, "grid_points": int(np.prod(grid_n))},
178132
)

kwave/cli/commands/plan.py

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from kwave.cli.main import pass_session
7-
from kwave.cli.schema import CLIResponse, SessionError, json_command
7+
from kwave.cli.schema import CLIResponse, json_command
88

99

1010
@click.command("plan")
@@ -13,59 +13,29 @@
1313
def plan(sess):
1414
"""Derive full simulation config and validate before running."""
1515
sess.load()
16-
17-
# Check completeness
18-
comp = sess._completeness()
19-
missing = [k for k, v in comp.items() if not v]
20-
if missing:
21-
raise SessionError(f"Cannot plan: missing {', '.join(missing)}. Complete setup first.")
16+
sess.assert_ready("plan")
2217

2318
kgrid = sess.make_grid()
2419
medium = sess.make_medium()
2520

26-
g = sess.state["grid"]
27-
grid_n = tuple(g["N"])
28-
spacing = tuple(g["spacing"])
21+
grid_n = tuple(int(n) for n in kgrid.N)
22+
spacing = tuple(float(d) for d in kgrid.spacing)
2923
ndim = len(grid_n)
3024
grid_points = int(np.prod(grid_n))
31-
32-
# Time stepping
3325
dt = float(kgrid.dt)
3426
Nt = int(kgrid.Nt)
3527

36-
# PPW check
37-
c_min = float(np.min(medium.sound_speed)) if hasattr(medium.sound_speed, "__len__") else float(medium.sound_speed)
38-
max_spacing = max(spacing)
39-
# For IVP problems, use grid-based wavelength estimate
40-
min_wavelength = c_min * dt * Nt / 2 # rough estimate
41-
ppw = c_min / (max_spacing * (1 / (2 * dt))) # Nyquist-based PPW
42-
43-
# CFL
4428
c_max = float(np.max(medium.sound_speed)) if hasattr(medium.sound_speed, "__len__") else float(medium.sound_speed)
29+
c_min = float(np.min(medium.sound_speed)) if hasattr(medium.sound_speed, "__len__") else float(medium.sound_speed)
4530
cfl = c_max * dt / min(spacing)
4631

47-
# Memory estimate: ~(3 + 2*ndim) fields of float64
4832
n_fields = 3 + 2 * ndim
49-
memory_bytes = grid_points * n_fields * 8
50-
memory_mb = memory_bytes / (1024 * 1024)
33+
memory_mb = grid_points * n_fields * 8 / (1024 * 1024)
34+
estimated_runtime_s = grid_points * Nt * 50e-9 # ~50ns per grid point per step on CPU
5135

52-
# Runtime estimate
53-
cost_per_point_step_ns = 50 # ~50ns per grid point per step on CPU
54-
estimated_runtime_s = grid_points * Nt * cost_per_point_step_ns / 1e9
55-
56-
# PML
5736
pml_size = 20
5837

59-
# Warnings
6038
warnings = []
61-
if ppw < 4:
62-
warnings.append(
63-
{
64-
"code": "LOW_PPW",
65-
"detail": f"PPW={ppw:.1f} is below recommended minimum of 4",
66-
"suggestion": "Increase grid resolution or reduce maximum frequency",
67-
}
68-
)
6939
if cfl > 0.5:
7040
warnings.append(
7141
{
@@ -94,7 +64,6 @@ def plan(sess):
9464
}
9565

9666
derived = {
97-
"ppw": round(ppw, 2),
9867
"cfl": round(cfl, 4),
9968
"grid_points": grid_points,
10069
"estimated_memory_mb": round(memory_mb, 1),

kwave/cli/commands/run.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
"""Run command: execute simulation with structured JSON progress."""
22

33
import json
4-
import sys
54
import time
65

76
import click
87
import numpy as np
98

109
from kwave.cli.main import pass_session
11-
from kwave.cli.schema import CLIResponse, SessionError, json_command
10+
from kwave.cli.schema import CLIResponse, json_command
1211

1312

1413
def _emit_event(event: dict):
@@ -24,12 +23,7 @@ def _emit_event(event: dict):
2423
def run(sess, backend, device):
2524
"""Execute the simulation."""
2625
sess.load()
27-
28-
# Check completeness
29-
comp = sess._completeness()
30-
missing = [k for k, v in comp.items() if not v]
31-
if missing:
32-
raise SessionError(f"Cannot run: missing {', '.join(missing)}. Complete setup first.")
26+
sess.assert_ready("run")
3327

3428
kgrid = sess.make_grid()
3529
medium = sess.make_medium()

kwave/cli/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Agent-first CLI for k-Wave simulations. All commands return structured JSON."""
22

33
from pathlib import Path
4-
from typing import Optional
54

65
import click
76

0 commit comments

Comments
 (0)