2929 from simpeg_drivers .driver import InversionDriver
3030
3131
32+ MODEL_TYPES = [
33+ "starting" ,
34+ "reference" ,
35+ "lower_bound" ,
36+ "upper_bound" ,
37+ "conductivity" ,
38+ "alpha_s" ,
39+ "length_scale_x" ,
40+ "length_scale_y" ,
41+ "length_scale_z" ,
42+ "gradient_dip" ,
43+ "gradient_direction" ,
44+ "s_norm" ,
45+ "x_norm" ,
46+ "y_norm" ,
47+ "z_norm" ,
48+ ]
49+
50+
3251class InversionModelCollection :
3352 """
3453 Collection of inversion models.
@@ -39,50 +58,41 @@ class InversionModelCollection:
3958
4059 """
4160
42- model_types = [
43- "starting" ,
44- "reference" ,
45- "lower_bound" ,
46- "upper_bound" ,
47- "conductivity" ,
48- "alpha_s" ,
49- "length_scale_x" ,
50- "length_scale_y" ,
51- "length_scale_z" ,
52- "s_norm" ,
53- "x_norm" ,
54- "y_norm" ,
55- "z_norm" ,
56- ]
57-
5861 def __init__ (self , driver : InversionDriver ):
5962 """
6063 :param driver: Parental InversionDriver class.
6164 """
6265 self ._active_cells : np .ndarray | None = None
6366 self ._driver = driver
6467 self .is_sigma = self .driver .params .physical_property == "conductivity"
65- self . is_vector = (
68+ is_vector = (
6669 True if self .driver .params .inversion_type == "magnetic vector" else False
6770 )
68- self .n_blocks = (
69- 3 if self .driver .params .inversion_type == "magnetic vector" else 1
71+ self ._starting = InversionModel (driver , "starting" , is_vector = is_vector )
72+ self ._reference = InversionModel (driver , "reference" , is_vector = is_vector )
73+ self ._lower_bound = InversionModel (driver , "lower_bound" , is_vector = is_vector )
74+ self ._upper_bound = InversionModel (driver , "upper_bound" , is_vector = is_vector )
75+ self ._conductivity = InversionModel (driver , "conductivity" , is_vector = is_vector )
76+ self ._alpha_s = InversionModel (driver , "alpha_s" , is_vector = is_vector )
77+ self ._length_scale_x = InversionModel (
78+ driver , "length_scale_x" , is_vector = is_vector
79+ )
80+ self ._length_scale_y = InversionModel (
81+ driver , "length_scale_y" , is_vector = is_vector
82+ )
83+ self ._length_scale_z = InversionModel (
84+ driver , "length_scale_z" , is_vector = is_vector
85+ )
86+ self ._gradient_dip = InversionModel (
87+ driver , "gradient_dip" , trim_active_cells = False
7088 )
71- self ._starting = InversionModel (driver , "starting" )
72- self ._reference = InversionModel (driver , "reference" )
73- self ._lower_bound = InversionModel (driver , "lower_bound" )
74- self ._upper_bound = InversionModel (driver , "upper_bound" )
75- self ._conductivity = InversionModel (driver , "conductivity" )
76- self ._alpha_s = InversionModel (driver , "alpha_s" )
77- self ._length_scale_x = InversionModel (driver , "length_scale_x" )
78- self ._length_scale_y = InversionModel (driver , "length_scale_y" )
79- self ._length_scale_z = InversionModel (driver , "length_scale_z" )
80- self ._gradient_dip = InversionModel (driver , "gradient_dip" )
81- self ._gradient_direction = InversionModel (driver , "gradient_direction" )
82- self ._s_norm = InversionModel (driver , "s_norm" )
83- self ._x_norm = InversionModel (driver , "x_norm" )
84- self ._y_norm = InversionModel (driver , "y_norm" )
85- self ._z_norm = InversionModel (driver , "z_norm" )
89+ self ._gradient_direction = InversionModel (
90+ driver , "gradient_direction" , trim_active_cells = False
91+ )
92+ self ._s_norm = InversionModel (driver , "s_norm" , is_vector = is_vector )
93+ self ._x_norm = InversionModel (driver , "x_norm" , is_vector = is_vector )
94+ self ._y_norm = InversionModel (driver , "y_norm" , is_vector = is_vector )
95+ self ._z_norm = InversionModel (driver , "z_norm" , is_vector = is_vector )
8696
8797 @property
8898 def n_active (self ) -> int :
@@ -307,7 +317,7 @@ def z_norm(self) -> np.ndarray | None:
307317 def _model_method_wrapper (self , method , name = None , ** kwargs ):
308318 """wraps individual model's specific method and applies in loop over model types."""
309319 returned_items = {}
310- for mtype in self . model_types :
320+ for mtype in MODEL_TYPES :
311321 model = getattr (self , f"_{ mtype } " )
312322 if model .model is not None :
313323 f = getattr (model , method )
@@ -364,43 +374,24 @@ class InversionModel:
364374 remove_air: Use active cells vector to remove air cells from model.
365375 """
366376
367- model_types = [
368- "starting" ,
369- "reference" ,
370- "lower_bound" ,
371- "upper_bound" ,
372- "conductivity" ,
373- "alpha_s" ,
374- "length_scale_x" ,
375- "length_scale_y" ,
376- "length_scale_z" ,
377- "gradient_dip" ,
378- "gradient_direction" ,
379- "s_norm" ,
380- "x_norm" ,
381- "y_norm" ,
382- "z_norm" ,
383- ]
384-
385377 def __init__ (
386378 self ,
387379 driver : InversionDriver ,
388380 model_type : str ,
381+ is_vector : bool = False ,
382+ trim_active_cells : bool = True ,
389383 ):
390384 """
391385 :param driver: InversionDriver object.
392- :param model_type: Type of inversion model, can be any of "starting", "reference",
393- "lower_bound", "upper_bound".
386+ :param model_type: Type of inversion model, can be any of MODEL_TYPES.
387+ :param is_vector: If True, model is a vector.
388+ :param trim_active_cells: If True, remove air cells from model.
394389 """
395390 self .driver = driver
396391 self .model_type = model_type
397392 self .model : np .ndarray | None = None
398- self .is_vector = (
399- True if self .driver .params .inversion_type == "magnetic vector" else False
400- )
401- self .n_blocks = (
402- 3 if self .driver .params .inversion_type == "magnetic vector" else 1
403- )
393+ self .is_vector = is_vector
394+ self .trim_active_cells = trim_active_cells
404395 self ._initialize ()
405396
406397 def _initialize (self ):
@@ -452,7 +443,7 @@ def _initialize(self):
452443 and self .is_vector
453444 and model .shape [0 ] == self .driver .inversion_mesh .n_cells
454445 ):
455- model = np .tile (model , self .n_blocks )
446+ model = np .tile (model , 3 if self .is_vector else 1 )
456447
457448 if model is not None :
458449 self .model = mkvc (model )
@@ -461,8 +452,8 @@ def _initialize(self):
461452 def remove_air (self , active_cells ):
462453 """Use active cells vector to remove air cells from model"""
463454
464- if self .model is not None :
465- self .model = self .model [np .tile (active_cells , self .n_blocks )]
455+ if self .model is not None and self . trim_active_cells :
456+ self .model = self .model [np .tile (active_cells , 3 if self .is_vector else 1 )]
466457
467458 def permute_2_octree (self ) -> np .ndarray | None :
468459 """
@@ -604,7 +595,7 @@ def model_type(self):
604595
605596 @model_type .setter
606597 def model_type (self , v ):
607- if v not in self . model_types :
608- msg = f"Invalid model_type: { v } . Must be one of { (* self . model_types ,)} ."
598+ if v not in MODEL_TYPES :
599+ msg = f"Invalid model_type: { v } . Must be one of { (* MODEL_TYPES ,)} ."
609600 raise ValueError (msg )
610601 self ._model_type = v
0 commit comments