77from kwave .cli .schema import CLIError , CLIResponse , ValidationError , json_command
88
99
10+ def _parse_int_tuple (s : str ) -> tuple [int , ...]:
11+ return tuple (int (x ) for x in s .split ("," ))
12+
13+
14+ def _resolve_scalar_or_path (value : str , name : str , sess ) -> dict :
15+ """Parse a CLI value as scalar float or .npy path. Returns {name_scalar, name_path} dict."""
16+ if value .endswith (".npy" ):
17+ arr = np .load (value )
18+ path = sess .save_array (name , arr )
19+ return {f"{ name } _scalar" : None , f"{ name } _path" : path }
20+ return {f"{ name } _scalar" : float (value ), f"{ name } _path" : None }
21+
22+
1023@click .group ("phantom" )
1124def phantom ():
1225 """Define the simulation phantom (medium + initial pressure)."""
@@ -25,60 +38,23 @@ def load(sess, grid_size, spacing, sound_speed, density, cfl):
2538 """Load medium properties from scalar values or .npy files."""
2639 sess .load ()
2740
28- grid_n = tuple ( int ( x ) for x in grid_size . split ( "," ) )
41+ grid_n = _parse_int_tuple ( grid_size )
2942 ndim = len (grid_n )
3043 grid_spacing = (spacing ,) * ndim
3144
32- medium_state = {}
33-
34- # Sound speed: scalar or .npy path
35- if sound_speed .endswith (".npy" ):
36- arr = np .load (sound_speed )
37- path = sess .save_array ("sound_speed" , arr )
38- medium_state ["sound_speed_path" ] = path
39- medium_state ["sound_speed_scalar" ] = None
40- # Store path so makeTime gets the full array (not just max)
41- sound_speed_for_time_path = path
42- sound_speed_for_time_scalar = None
43- else :
44- medium_state ["sound_speed_scalar" ] = float (sound_speed )
45- medium_state ["sound_speed_path" ] = None
46- sound_speed_for_time_path = None
47- sound_speed_for_time_scalar = float (sound_speed )
48-
49- # Density: scalar, .npy path, or None
45+ medium_state = _resolve_scalar_or_path (sound_speed , "sound_speed" , sess )
5046 if density is not None :
51- if density .endswith (".npy" ):
52- arr = np .load (density )
53- path = sess .save_array ("density" , arr )
54- medium_state ["density_path" ] = path
55- medium_state ["density_scalar" ] = None
56- else :
57- medium_state ["density_scalar" ] = float (density )
58- medium_state ["density_path" ] = None
59-
60- grid_state = {
61- "N" : list (grid_n ),
62- "spacing" : list (grid_spacing ),
63- "sound_speed_for_time_path" : sound_speed_for_time_path ,
64- "sound_speed_for_time_scalar" : sound_speed_for_time_scalar ,
65- }
47+ medium_state .update (_resolve_scalar_or_path (density , "density" , sess ))
48+
49+ grid_state = {"N" : list (grid_n ), "spacing" : list (grid_spacing )}
6650 if cfl is not None :
6751 grid_state ["cfl" ] = cfl
6852
69- sess .update ("grid" , grid_state )
70- sess .update ("medium" , medium_state )
53+ sess .update_many ({"grid" : grid_state , "medium" : medium_state })
7154
7255 return CLIResponse (
73- result = {
74- "grid_size" : list (grid_n ),
75- "spacing" : list (grid_spacing ),
76- "medium" : medium_state ,
77- },
78- derived = {
79- "ndim" : ndim ,
80- "grid_points" : int (np .prod (grid_n )),
81- },
56+ result = {"grid_size" : list (grid_n ), "spacing" : list (grid_spacing ), "medium" : medium_state },
57+ derived = {"ndim" : ndim , "grid_points" : int (np .prod (grid_n ))},
8258 )
8359
8460
@@ -96,7 +72,7 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_
9672 """Generate an analytical phantom."""
9773 sess .load ()
9874
99- grid_n = tuple ( int ( x ) for x in grid_size . split ( "," ) )
75+ grid_n = _parse_int_tuple ( grid_size )
10076 ndim = len (grid_n )
10177 grid_spacing = (spacing ,) * ndim
10278
@@ -117,12 +93,11 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_
11793 if disc_center is None :
11894 center = Vector ([n // 2 for n in grid_n ])
11995 else :
120- center = Vector ([ int ( x ) for x in disc_center . split ( "," )] )
96+ center = Vector (_parse_int_tuple ( disc_center ) )
12197
12298 p0 = make_disc (Vector (list (grid_n )), center , disc_radius ).astype (float )
12399
124100 elif phantom_type == "spherical" :
125- # Simple spherical inclusion centered in grid
126101 center = np .array ([n // 2 for n in grid_n ])
127102 coords = np .mgrid [tuple (slice (0 , n ) for n in grid_n )]
128103 dist = np .sqrt (sum ((c - cn ) ** 2 for c , cn in zip (coords , center )))
@@ -133,32 +108,14 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_
133108 layer_pos = grid_n [0 ] // 4
134109 p0 [layer_pos , ...] = 1.0
135110
136- # Save arrays
137111 p0_path = sess .save_array ("p0" , p0 )
138112
139- # Update session
140- sess .update (
141- "grid" ,
113+ sess .update_many (
142114 {
143- "N" : list (grid_n ),
144- "spacing" : list (grid_spacing ),
145- "sound_speed_for_time_scalar" : sound_speed ,
146- "sound_speed_for_time_path" : None ,
147- },
148- )
149- sess .update (
150- "medium" ,
151- {
152- "sound_speed" : sound_speed ,
153- "density" : density ,
154- },
155- )
156- sess .update (
157- "source" ,
158- {
159- "type" : "initial-pressure" ,
160- "p0_path" : p0_path ,
161- },
115+ "grid" : {"N" : list (grid_n ), "spacing" : list (grid_spacing )},
116+ "medium" : {"sound_speed_scalar" : sound_speed , "sound_speed_path" : None , "density_scalar" : density , "density_path" : None },
117+ "source" : {"type" : "initial-pressure" , "p0_path" : p0_path },
118+ }
162119 )
163120
164121 return CLIResponse (
@@ -171,8 +128,5 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_
171128 "sound_speed" : sound_speed ,
172129 "density" : density ,
173130 },
174- derived = {
175- "ndim" : ndim ,
176- "grid_points" : int (np .prod (grid_n )),
177- },
131+ derived = {"ndim" : ndim , "grid_points" : int (np .prod (grid_n ))},
178132 )
0 commit comments