@@ -34,12 +34,12 @@ def _to_cpu(x):
3434def _expand_to_grid (val , grid_shape , xp , name = "parameter" ):
3535 if val is None :
3636 raise ValueError (f"Missing required parameter: { name } " )
37- arr = xp .array (val , dtype = float ).flatten ( order = "F" )
37+ arr = xp .array (val , dtype = float ).ravel ( )
3838 grid_size = int (np .prod (grid_shape ))
3939 if arr .size == 1 :
4040 return xp .full (grid_shape , float (arr [0 ]), dtype = float )
4141 if arr .size == grid_size :
42- return arr .reshape (grid_shape , order = "F" )
42+ return arr .reshape (grid_shape )
4343 raise ValueError (f"{ name } size { arr .size } incompatible with grid size { grid_size } " )
4444
4545
@@ -48,16 +48,16 @@ def _build_source_op(mask_raw, signal_raw, mode, scale, *, xp, grid_shape, grid_
4848
4949 Returns a callable (t, field) → field that injects scaled source values.
5050 """
51- mask = xp .array (mask_raw , dtype = bool ).flatten ( order = "F" )
51+ mask = xp .array (mask_raw , dtype = bool ).ravel ( )
5252 if mask .size == 1 :
53- mask = xp .full (grid_shape , bool (mask [0 ]), dtype = bool ).flatten ( order = "F" )
53+ mask = xp .full (grid_shape , bool (mask [0 ]), dtype = bool ).ravel ( )
5454 n_src = int (xp .sum (mask ))
5555
56- signal_arr = xp .array (signal_raw , dtype = float , order = "F" )
56+ signal_arr = xp .array (signal_raw , dtype = float )
5757 if signal_arr .ndim == 1 :
5858 signal = signal_arr .reshape (1 , - 1 )
5959 else :
60- signal = signal_arr .reshape (- 1 , signal_arr .shape [- 1 ], order = "F" ) if signal_arr .ndim > 2 else signal_arr
60+ signal = signal_arr .reshape (- 1 , signal_arr .shape [- 1 ]) if signal_arr .ndim > 2 else signal_arr
6161
6262 scaled = signal * xp .atleast_1d (xp .asarray (scale ))[:, None ]
6363 signal_len = scaled .shape [1 ]
@@ -70,9 +70,9 @@ def get_val(t):
7070 def dirichlet (t , field ):
7171 if t >= signal_len :
7272 return field
73- flat = field .flatten (order = "F" ) # copy — mutation is intentional
73+ flat = field .flatten () # copy — mutation is intentional
7474 flat [mask ] = get_val (t )
75- return flat .reshape (grid_shape , order = "F" )
75+ return flat .reshape (grid_shape )
7676
7777 # Pre-allocate buffer to avoid per-step allocation
7878 _src_buf = xp .zeros (grid_size , dtype = float )
@@ -82,15 +82,15 @@ def additive_kspace(t, field):
8282 return field
8383 _src_buf [:] = 0
8484 _src_buf [mask ] = get_val (t )
85- src = _src_buf .reshape (grid_shape , order = "F" )
85+ src = _src_buf .reshape (grid_shape )
8686 return field + diff_fn (src , source_kappa )
8787
8888 def additive_no_correction (t , field ):
8989 if t >= signal_len :
9090 return field
9191 _src_buf [:] = 0
9292 _src_buf [mask ] = get_val (t )
93- return field + _src_buf .reshape (grid_shape , order = "F" )
93+ return field + _src_buf .reshape (grid_shape )
9494
9595 ops = {"dirichlet" : dirichlet , "additive" : additive_kspace , "additive-no-correction" : additive_no_correction }
9696 if mode not in ops :
@@ -210,19 +210,19 @@ def _is_cartesian(arr):
210210
211211 if mask_raw is None :
212212 self .n_sensor_points = grid_numel
213- self ._extract = lambda f : f .flatten ( order = "F" )
213+ self ._extract = lambda f : f .ravel ( )
214214 else :
215215 mask_arr = np .asarray (mask_raw , dtype = float )
216216 # Check Cartesian first to avoid ambiguity when size == grid_numel
217217 if _is_cartesian (mask_arr ):
218218 self ._setup_cartesian_extract (mask_arr )
219219 elif _is_binary (mask_arr ):
220- bmask = xp .array (mask_arr , dtype = bool ).flatten ( order = "F" )
220+ bmask = xp .array (mask_arr , dtype = bool ).ravel ( )
221221 if bmask .size == 1 :
222222 bmask = xp .full (grid_numel , bool (bmask [0 ]), dtype = bool )
223223 self .n_sensor_points = int (xp .sum (bmask ))
224224 idx = xp .where (bmask )[0 ]
225- self ._extract = lambda f , _i = idx : f .flatten ( order = "F" )[_i ]
225+ self ._extract = lambda f , _i = idx : f .ravel ( )[_i ]
226226 else :
227227 raise ValueError (
228228 f"Sensor mask shape { mask_arr .shape } is neither binary " f"(numel={ grid_numel } ) nor Cartesian ({ self .ndim } , N_points)"
@@ -289,7 +289,7 @@ def _setup_cartesian_extract(self, cart_pos):
289289 x_vec , cart_x = axis_coords [0 ], cart .flatten ()
290290
291291 def _extract_1d_interp (f ):
292- return xp .asarray (np .interp (cart_x , x_vec , _to_cpu (f ).flatten ( order = "F" )))
292+ return xp .asarray (np .interp (cart_x , x_vec , _to_cpu (f ).ravel ( )))
293293
294294 self ._extract = _extract_1d_interp
295295 else :
@@ -298,8 +298,8 @@ def _extract_1d_interp(f):
298298 int_idx = np .clip (np .floor (frac_idx ).astype (int ), 0 , np .array (self .grid_shape )[:, None ] - 2 )
299299 local = frac_idx - int_idx
300300
301- # F -order strides and 2^ndim corner enumeration
302- strides = np .cumprod ([1 ] + list (self .grid_shape [:- 1 ]))
301+ # C -order strides and 2^ndim corner enumeration
302+ strides = np .cumprod ([1 ] + list (self .grid_shape [:0 : - 1 ]))[:: - 1 ]
303303 n_corners = 2 ** self .ndim
304304 corner_indices = np .zeros ((self .n_sensor_points , n_corners ), dtype = int )
305305 corner_weights = np .ones ((self .n_sensor_points , n_corners ))
@@ -313,7 +313,7 @@ def _extract_1d_interp(f):
313313 corner_weights = xp .array (corner_weights )
314314
315315 def _extract_bilinear (f ):
316- return (f .flatten ( order = "F" )[corner_indices ] * corner_weights ).sum (axis = 1 )
316+ return (f .ravel ( )[corner_indices ] * corner_weights ).sum (axis = 1 )
317317
318318 self ._extract = _extract_bilinear
319319
@@ -460,15 +460,15 @@ def _setup_source_operators(self):
460460 grid_size = int (np .prod (self .grid_shape ))
461461
462462 def _expand_mask (mask_raw ):
463- mask = xp .array (mask_raw , dtype = bool ).flatten ( order = "F" )
463+ mask = xp .array (mask_raw , dtype = bool ).ravel ( )
464464 if mask .size == 1 :
465- mask = xp .full (self .grid_shape , bool (mask [0 ]), dtype = bool ).flatten ( order = "F" )
465+ mask = xp .full (self .grid_shape , bool (mask [0 ]), dtype = bool ).ravel ( )
466466 return mask
467467
468468 def source_scale (mask_raw , c0 ):
469469 """Get per-source-point sound speed values."""
470470 mask = _expand_mask (mask_raw )
471- c0_flat = c0 .flatten ( order = "F" )
471+ c0_flat = c0 .ravel ( )
472472 n_src = int (xp .sum (mask ))
473473 return c0_flat [mask ] if c0_flat .size > 1 else xp .full (n_src , float (c0_flat ))
474474
@@ -566,7 +566,7 @@ def _setup_fields(self):
566566 if self .smooth_p0 and self .ndim >= 2 :
567567 from kwave .utils .filters import smooth
568568
569- # p0 is F-order from _expand_to_grid; smooth() is order-agnostic (uses FFT on shape)
569+ # smooth() is order-agnostic (uses FFT on shape)
570570 p0 = xp .asarray (smooth (_to_cpu (p0 ), restore_max = True ))
571571 self ._p0_initial = p0
572572 else :
@@ -779,6 +779,53 @@ def create_simulation(kgrid, medium, source, sensor, device="auto", smooth_p0=Fa
779779 )
780780
781781
782+ def _f_to_c_source_reorder (source , grid_shape ):
783+ """Reorder multi-row source signals from MATLAB F-flat to C-flat mask order.
784+
785+ MATLAB sends source signal rows ordered by F-flattened mask indices.
786+ The solver uses C-flat ordering internally. For single-row (uniform)
787+ sources, no reordering is needed.
788+ """
789+ ndim = len (grid_shape )
790+ if ndim < 2 :
791+ return source
792+ source = dict (source ) # shallow copy — don't mutate caller's dict
793+
794+ for mask_key , signal_keys in [("p_mask" , ["p" ]), ("u_mask" , ["ux" , "uy" , "uz" ])]:
795+ mask_raw = source .get (mask_key )
796+ if mask_raw is None :
797+ continue
798+ mask = np .asarray (mask_raw , dtype = bool )
799+ if mask .size <= 1 :
800+ continue
801+ mask_grid = mask .reshape (grid_shape )
802+ n_src = int (mask_grid .sum ())
803+ if n_src < 2 :
804+ continue
805+
806+ # Build F→C permutation for mask points
807+ f_nz = np .where (mask_grid .ravel (order = "F" ))[0 ]
808+ c_nz = np .where (mask_grid .ravel ())[0 ]
809+ f_equiv = np .ravel_multi_index (np .unravel_index (c_nz , grid_shape ), grid_shape , order = "F" )
810+ perm = np .searchsorted (f_nz , f_equiv )
811+
812+ for sig_key in signal_keys :
813+ sig = source .get (sig_key )
814+ if sig is None :
815+ continue
816+ sig = np .asarray (sig )
817+ if sig .ndim >= 2 and sig .shape [0 ] == n_src :
818+ source [sig_key ] = sig [perm ]
819+
820+ return source
821+
822+
782823def simulate_from_dicts (kgrid , medium , source , sensor , device = "auto" , smooth_p0 = False ):
783- """MATLAB interop entry point."""
824+ """MATLAB interop entry point.
825+
826+ Reorders multi-row source signals from MATLAB's F-flat mask ordering
827+ to the solver's C-flat ordering before running the simulation.
828+ """
829+ grid_shape = tuple (kgrid [k ] for k in ["Nx" , "Ny" , "Nz" ] if k in kgrid )
830+ source = _f_to_c_source_reorder (source , grid_shape )
784831 return create_simulation (kgrid , medium , source , sensor , device , smooth_p0 = smooth_p0 ).run ()
0 commit comments