|
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | from scipy.sparse import csc_matrix, csr_matrix |
| 8 | +from simpeg.regularization import PGIsmallness |
| 9 | + |
8 | 10 | from .directives import InversionDirective |
9 | 11 | from simpeg.maps import IdentityMap |
10 | 12 |
|
11 | 13 | from geoh5py.data import NumericData |
| 14 | +from geoh5py.data.data_type import ReferencedValueMapType |
12 | 15 | from geoh5py.groups.property_group import GroupTypeEnum |
13 | 16 | from geoh5py.groups import UIJsonGroup |
14 | 17 | from geoh5py.objects import ObjectBase |
@@ -503,3 +506,72 @@ def get_names( |
503 | 506 | base_name = "LP models" |
504 | 507 |
|
505 | 508 | return channel_name, base_name |
| 509 | + |
| 510 | + |
| 511 | +class SavePGIModel(SaveArrayGeoH5): |
| 512 | + """ |
| 513 | + Save the model as a property group in the geoh5 file |
| 514 | + """ |
| 515 | + |
| 516 | + def __init__( |
| 517 | + self, |
| 518 | + h5_object: ObjectBase, |
| 519 | + pgi_regularization: PGIsmallness, |
| 520 | + unit_map: dict, |
| 521 | + physical_properties: list[str], |
| 522 | + reference_type: ReferencedValueMapType | None = None, |
| 523 | + **kwargs, |
| 524 | + ): |
| 525 | + self.pgi_regularization = pgi_regularization |
| 526 | + self.unit_map: dict = unit_map |
| 527 | + self.reference_type = reference_type |
| 528 | + self.physical_properties = physical_properties |
| 529 | + super().__init__(h5_object, **kwargs) |
| 530 | + |
| 531 | + def get_values(self, values: list[np.ndarray] | None): |
| 532 | + |
| 533 | + if values is None: |
| 534 | + values = self.invProb.model |
| 535 | + |
| 536 | + modellist = self.pgi_regularization.wiresmap * values |
| 537 | + model = np.c_[ |
| 538 | + [a * b for a, b in zip(self.pgi_regularization.maplist, modellist)] |
| 539 | + ].T |
| 540 | + membership = self.pgi_regularization.gmm._estimate_log_prob(model).argmax( |
| 541 | + axis=1 |
| 542 | + ) |
| 543 | + return membership |
| 544 | + |
| 545 | + def write(self, iteration: int, values: list[np.ndarray] | None = None): |
| 546 | + """ |
| 547 | + Method to write the reference model with data map. |
| 548 | + """ |
| 549 | + petro_model = self.get_values(values) |
| 550 | + petro_model = self.apply_transformations(petro_model).flatten() |
| 551 | + channel_name, base_name = self.get_names("petrophysics", "", iteration) |
| 552 | + with fetch_active_workspace(self._geoh5, mode="r+") as w_s: |
| 553 | + h5_object = w_s.get_entity(self.h5_object)[0] |
| 554 | + data = h5_object.add_data( |
| 555 | + { |
| 556 | + channel_name: { |
| 557 | + "association": self.association, |
| 558 | + "values": petro_model, |
| 559 | + "type": "referenced", |
| 560 | + } |
| 561 | + } |
| 562 | + ) |
| 563 | + |
| 564 | + if self.reference_type is not None: |
| 565 | + data.entity_type.value_map = self.reference_type.value_map |
| 566 | + data.entity_type.color_map = self.reference_type.color_map |
| 567 | + |
| 568 | + # TODO: Add the means of the transformed models |
| 569 | + # means = self.pgi_regularization.gmm.means_ |
| 570 | + # for ii, phys_prop in enumerate(self.physical_properties): |
| 571 | + # data.add_data_map( |
| 572 | + # f"Mean {phys_prop}", |
| 573 | + # { |
| 574 | + # ind: f"{mean:.3e}" |
| 575 | + # for ind, mean in zip(self.unit_map, means[:, ii]) |
| 576 | + # }, |
| 577 | + # ) |
0 commit comments