@@ -515,14 +515,14 @@ class SavePGIModel(SaveArrayGeoH5):
515515
516516 def __init__ (
517517 self ,
518- h5_object ,
519- pgi_reg : PGIsmallness ,
518+ h5_object : ObjectBase ,
519+ pgi_regularization : PGIsmallness ,
520520 unit_map : dict ,
521521 physical_properties : list [str ],
522522 reference_type : ReferencedValueMapType | None = None ,
523523 ** kwargs ,
524524 ):
525- self .pgi_reg = pgi_reg
525+ self .pgi_regularization = pgi_regularization
526526 self .unit_map : dict = unit_map
527527 self .reference_type = reference_type
528528 self .physical_properties = physical_properties
@@ -533,9 +533,13 @@ def get_values(self, values: list[np.ndarray] | None):
533533 if values is None :
534534 values = self .invProb .model
535535
536- modellist = self .pgi_reg .wiresmap * values
537- model = np .c_ [[a * b for a , b in zip (self .pgi_reg .maplist , modellist )]].T
538- membership = self .pgi_reg .gmm ._estimate_log_prob (model ).argmax (axis = 1 )
536+ modellist = self .pgi_regularization .wiresmap * values
537+ model = np .c_ [
538+ [a * b for a , b in zip (self .pgi_regularization .maplist , modellist )]
539+ ].T
540+ membership = self .pgi_regularization .gmm ._estimate_log_prob (model ).argmax (
541+ axis = 1
542+ )
539543 return membership
540544
541545 def write (self , iteration : int , values : list [np .ndarray ] = None ):
@@ -562,7 +566,7 @@ def write(self, iteration: int, values: list[np.ndarray] = None):
562566 data .entity_type .color_map = self .reference_type .color_map
563567
564568 # TODO: Add the means of the transformed models
565- # means = self.pgi_reg .gmm.means_
569+ # means = self.pgi_regularization .gmm.means_
566570 # for ii, phys_prop in enumerate(self.physical_properties):
567571 # data.add_data_map(
568572 # f"Mean {phys_prop}",
0 commit comments