1111from simpeg .maps import IdentityMap
1212
1313from geoh5py .data import NumericData
14+ from geoh5py .data .data_type import ReferencedValueMapType
1415from geoh5py .groups .property_group import GroupTypeEnum
1516from geoh5py .groups import UIJsonGroup
1617from geoh5py .objects import ObjectBase
@@ -518,12 +519,12 @@ def __init__(
518519 pgi_reg : PGIsmallness ,
519520 unit_map : dict ,
520521 physical_properties : list [str ],
521- value_map : dict [ int , str ] | None = None ,
522+ reference_type : ReferencedValueMapType | None = None ,
522523 ** kwargs ,
523524 ):
524525 self .pgi_reg = pgi_reg
525526 self .unit_map : dict = unit_map
526- self .value_map = value_map
527+ self .reference_type = reference_type
527528 self .physical_properties = physical_properties
528529 super ().__init__ (h5_object , ** kwargs )
529530
@@ -534,14 +535,14 @@ def get_values(self, values: list[np.ndarray] | None):
534535
535536 modellist = self .pgi_reg .wiresmap * values
536537 model = np .c_ [[a * b for a , b in zip (self .pgi_reg .maplist , modellist )]].T
537- membership = self .pgi_reg .gmm .predict (model )
538+ membership = self .pgi_reg .gmm ._estimate_log_prob (model ). argmax ( axis = 1 )
538539 return membership
539540
540541 def write (self , iteration : int , values : list [np .ndarray ] = None ):
541542 """
542543 Method to write the reference model with data map.
543544 """
544- petro_model = self .get_values (values ) + 1
545+ petro_model = self .get_values (values )
545546 petro_model = self .apply_transformations (petro_model ).flatten ()
546547 channel_name , base_name = self .get_names ("petrophysics" , "" , iteration )
547548 with fetch_active_workspace (self ._geoh5 , mode = "r+" ) as w_s :
@@ -552,17 +553,21 @@ def write(self, iteration: int, values: list[np.ndarray] = None):
552553 "association" : self .association ,
553554 "values" : petro_model ,
554555 "type" : "REFERENCED" ,
555- "value_map" : self .value_map ,
556556 }
557557 }
558558 )
559559
560- means = self .pgi_reg .gmm .means_
561- for ii , phys_prop in enumerate (self .physical_properties ):
562- data .add_data_map (
563- f"Mean { phys_prop } " ,
564- {
565- ind : f"{ mean :.3e} "
566- for ind , mean in zip (self .unit_map , means [:, ii ])
567- },
568- )
560+ if self .reference_type is not None :
561+ data .entity_type .value_map = self .reference_type .value_map
562+ data .entity_type .color_map = self .reference_type .color_map
563+
564+ # TODO: Add the means of the transformed models
565+ # means = self.pgi_reg.gmm.means_
566+ # for ii, phys_prop in enumerate(self.physical_properties):
567+ # data.add_data_map(
568+ # f"Mean {phys_prop}",
569+ # {
570+ # ind: f"{mean:.3e}"
571+ # for ind, mean in zip(self.unit_map, means[:, ii])
572+ # },
573+ # )
0 commit comments