Skip to content

Commit 8ac4c6d

Browse files
committed
Fix the save directive. Bring back re-sorting
(cherry picked from commit 79b4d99b00ecf2659636a560e6cd444b4c8ddb9b)
1 parent b85287b commit 8ac4c6d

3 files changed

Lines changed: 21 additions & 16 deletions

File tree

simpeg/directives/_save_geoh5.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from simpeg.maps import IdentityMap
1212

1313
from geoh5py.data import NumericData
14+
from geoh5py.data.data_type import ReferencedValueMapType
1415
from geoh5py.groups.property_group import GroupTypeEnum
1516
from geoh5py.groups import UIJsonGroup
1617
from geoh5py.objects import ObjectBase
@@ -518,12 +519,12 @@ def __init__(
518519
pgi_reg: PGIsmallness,
519520
unit_map: dict,
520521
physical_properties: list[str],
521-
value_map: dict[int, str] | None = None,
522+
reference_type: ReferencedValueMapType | None = None,
522523
**kwargs,
523524
):
524525
self.pgi_reg = pgi_reg
525526
self.unit_map: dict = unit_map
526-
self.value_map = value_map
527+
self.reference_type = reference_type
527528
self.physical_properties = physical_properties
528529
super().__init__(h5_object, **kwargs)
529530

@@ -534,14 +535,14 @@ def get_values(self, values: list[np.ndarray] | None):
534535

535536
modellist = self.pgi_reg.wiresmap * values
536537
model = np.c_[[a * b for a, b in zip(self.pgi_reg.maplist, modellist)]].T
537-
membership = self.pgi_reg.gmm.predict(model)
538+
membership = self.pgi_reg.gmm._estimate_log_prob(model).argmax(axis=1)
538539
return membership
539540

540541
def write(self, iteration: int, values: list[np.ndarray] = None):
541542
"""
542543
Method to write the reference model with data map.
543544
"""
544-
petro_model = self.get_values(values) + 1
545+
petro_model = self.get_values(values)
545546
petro_model = self.apply_transformations(petro_model).flatten()
546547
channel_name, base_name = self.get_names("petrophysics", "", iteration)
547548
with fetch_active_workspace(self._geoh5, mode="r+") as w_s:
@@ -552,17 +553,21 @@ def write(self, iteration: int, values: list[np.ndarray] = None):
552553
"association": self.association,
553554
"values": petro_model,
554555
"type": "REFERENCED",
555-
"value_map": self.value_map,
556556
}
557557
}
558558
)
559559

560-
means = self.pgi_reg.gmm.means_
561-
for ii, phys_prop in enumerate(self.physical_properties):
562-
data.add_data_map(
563-
f"Mean {phys_prop}",
564-
{
565-
ind: f"{mean:.3e}"
566-
for ind, mean in zip(self.unit_map, means[:, ii])
567-
},
568-
)
560+
if self.reference_type is not None:
561+
data.entity_type.value_map = self.reference_type.value_map
562+
data.entity_type.color_map = self.reference_type.color_map
563+
564+
# TODO: Add the means of the transformed models
565+
# means = self.pgi_reg.gmm.means_
566+
# for ii, phys_prop in enumerate(self.physical_properties):
567+
# data.add_data_map(
568+
# f"Mean {phys_prop}",
569+
# {
570+
# ind: f"{mean:.3e}"
571+
# for ind, mean in zip(self.unit_map, means[:, ii])
572+
# },
573+
# )

simpeg/regularization/pgi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def __init__(
200200
**kwargs,
201201
):
202202
self.gmmref = copy.deepcopy(gmmref)
203-
# self.gmmref.order_clusters_GM_weight()
203+
self.gmmref.order_clusters_GM_weight()
204204
self.approx_gradient = approx_gradient
205205
self.approx_eval = approx_eval
206206
self.approx_hessian = approx_hessian

simpeg/utils/pgi_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,7 @@ def update_gmm_with_priors(self, debug=False):
982982
"""
983983

984984
self.compute_clusters_precisions()
985-
# self.order_cluster()
985+
self.order_cluster()
986986

987987
if debug:
988988
print("before update means: ", self.means_)

0 commit comments

Comments
 (0)