Skip to content

Commit 4a64798

Browse files
committed
finish gravity uijson and round trip tests to file and back
1 parent 4ad0623 commit 4a64798

4 files changed

Lines changed: 107 additions & 109 deletions

File tree

simpeg_drivers-assets/uijson/gravity_inversion.ui.json

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
"{6a057fdc-b355-11e3-95be-fd84a7ffcb88}",
2020
"{f26feba3-aded-494b-b9e9-b2bbcbe298e1}",
2121
"{48f5054a-1c5c-4ca4-9048-80f36dc60a06}",
22-
"{b020a277-90e2-4cd7-84d6-612ee3f25051}",
23-
"{b54f6be6-0eb5-4a4e-887a-ba9d276f9a83}",
24-
"{5ffa3816-358d-4cdd-9b7d-e1f7f5543e05}"
22+
"{b020a277-90e2-4cd7-84d6-612ee3f25051}"
2523
],
2624
"value": ""
2725
},
@@ -770,6 +768,11 @@
770768
"parent": "data_object",
771769
"isValue": true,
772770
"property": "",
771+
"dataType": "Float",
772+
"association": [
773+
"Vertex",
774+
"Cell"
775+
],
773776
"value": 1,
774777
"min": 1,
775778
"max": 1000,

simpeg_drivers/potential_fields/gravity/uijson.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,55 @@
2525
from geoh5py.ui_json.ui_json import BaseUIJson
2626

2727
from simpeg_drivers import assets_path
28-
from simpeg_drivers.uijson import BaseInversionUIJson, SimPEGDriversUIJson
28+
from simpeg_drivers.uijson import SimPEGDriversUIJson
2929

3030

3131
class GravityForwardUIJson(SimPEGDriversUIJson):
3232
"""Gravity Forward UIJson."""
3333

34-
default_ui_json: ClassVar[Path] = assets_path / "gravity_forward.ui.json"
34+
default_ui_json: ClassVar[Path] = assets_path() / "gravity_forward.ui.json"
3535

3636
inversion_type: str
37-
forward_only: str
37+
forward_only: bool
38+
data_object: ObjectForm
39+
z_from_topo: BoolForm
40+
receivers_radar_drape: DataForm
41+
receivers_offset_z: FloatForm
42+
gps_receivers_offset: str
43+
gz_channel_bool: BoolForm
44+
gx_channel_bool: BoolForm
45+
gy_channel_bool: BoolForm
46+
guv_channel_bool: BoolForm
47+
gxy_channel_bool: BoolForm
48+
gxx_channel_bool: BoolForm
49+
gyy_channel_bool: BoolForm
50+
gzz_channel_bool: BoolForm
51+
gxz_channel_bool: BoolForm
52+
gyz_channel_bool: BoolForm
53+
mesh: ObjectForm
54+
starting_model: DataForm
55+
topography_object: ObjectForm
56+
topography: DataForm
57+
active_model: DataForm
58+
output_tile_files: bool
59+
parallelized: BoolForm
60+
n_cpu: IntegerForm
61+
tile_spatial: DataForm
62+
max_chunk_size: IntegerForm
63+
chunk_by_rows: BoolForm
64+
out_group: GroupForm
65+
ga_group: str
66+
generate_sweep: BoolForm
67+
distributed_workers: str
68+
69+
70+
class GravityInversionUIJson(SimPEGDriversUIJson):
71+
"""Gravity Inversion UIJson."""
72+
73+
default_ui_json: ClassVar[Path] = assets_path() / "gravity_inversionforward.ui.json"
74+
75+
inversion_type: str
76+
forward_only: bool
3877
data_object: ObjectForm
3978
gz_channel: DataForm
4079
gz_uncertainty: DataForm
@@ -85,34 +124,22 @@ class GravityForwardUIJson(SimPEGDriversUIJson):
85124
initial_beta: FloatForm
86125
coolingFactor: FloatForm
87126
coolingRate: IntegerForm
88-
max_global_iteration: IntegerForm
127+
max_global_iterations: IntegerForm
89128
max_line_search_iterations: IntegerForm
90129
max_cg_iterations: IntegerForm
91130
tol_cg: FloatForm
92131
f_min_change: FloatForm
93132
sens_wts_threshold: FloatForm
94133
every_iteration_bool: BoolForm
95-
save_sensitivity: BoolForm
134+
save_sensitivities: BoolForm
96135
parallelized: BoolForm
97136
n_cpu: IntegerForm
98-
tile_spatial: IntegerForm
137+
tile_spatial: DataForm
99138
store_sensitivities: ChoiceForm
100139
max_ram: str
101140
max_chunk_size: IntegerForm
141+
chunk_by_rows: BoolForm
102142
out_group: GroupForm
103-
104-
105-
class GravityInversionUIJson(SimPEGDriversUIJson):
106-
default_ui_json: ClassVar[Path] = assets_path / "gravity_inversion.ui.json"
107-
title: str = "Gravity Inversion"
108-
inversion_type: str = "gravity"
109-
110-
gx_channel_bool: bool = False
111-
gy_channel_bool: bool = False
112-
gz_channel_bool: bool = True
113-
gxx_channel_bool: bool = False
114-
gxy_channel_bool: bool = False
115-
gxz_channel_bool: bool = False
116-
gyy_channel_bool: bool = False
117-
gyz_channel_bool: bool = False
118-
gzz_channel_bool: bool = False
143+
ga_group: str
144+
generate_sweep: BoolForm
145+
distributed_workers: str

simpeg_drivers/uijson.py

Lines changed: 12 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,27 @@
1111
import json
1212
import logging
1313
from pathlib import Path
14-
from uuid import UUID
14+
from typing import Annotated
1515

1616
from geoh5py.groups import SimPEGGroup, UIJsonGroup
17+
from geoh5py.shared.validators import empty_string_to_none, none_to_empty_string
1718
from geoh5py.ui_json.enforcers import UUIDEnforcer
1819
from geoh5py.ui_json.ui_json import BaseUIJson
19-
from pydantic import field_validator
20+
from pydantic import BeforeValidator, PlainSerializer, field_validator
2021

2122
import simpeg_drivers
2223

2324

2425
logger = logging.getLogger(__name__)
2526

2627

28+
OptionalPath = Annotated[
29+
Path | None,
30+
BeforeValidator(empty_string_to_none),
31+
PlainSerializer(none_to_empty_string),
32+
]
33+
34+
2735
class SimPEGDriversUIJson(BaseUIJson):
2836
version: str = simpeg_drivers.__version__
2937
title: str
@@ -32,8 +40,8 @@ class SimPEGDriversUIJson(BaseUIJson):
3240
conda_environment: str
3341
run_command: str
3442
geoh5: Path | None
35-
monitoring_directory: Path
36-
workspace_geoh5: Path
43+
monitoring_directory: OptionalPath
44+
workspace_geoh5: OptionalPath
3745

3846
@field_validator("version", mode="before")
3947
@classmethod
@@ -61,80 +69,3 @@ def write_default(cls):
6169
data = uijson.model_dump_json(indent=4)
6270
with open(cls.default_ui_json, "w", encoding="utf-8") as file:
6371
file.write(data)
64-
65-
66-
class CoreUIJson(SimPEGDriversUIJson):
67-
"""
68-
Core class for ui.json data.
69-
"""
70-
71-
run_command: str
72-
conda_environment: str
73-
forward_only: bool
74-
data_object: UUID
75-
mesh: UUID | None
76-
starting_model: float | UUID
77-
topography_object: UUID | None
78-
topography: float | UUID | None
79-
active_model: UUID | None
80-
tile_spatial: int
81-
parallelized: bool
82-
n_cpu: int | None
83-
max_chunk_size: int
84-
save_sensitivities: bool
85-
out_group: SimPEGGroup | UIJsonGroup | None
86-
generate_sweep: bool
87-
88-
89-
class BaseInversionUIJson(BaseUIJson):
90-
"""
91-
Base class for inversion ui.json data.
92-
"""
93-
94-
reference_model: float | UUID | None
95-
lower_bound: float | UUID | None
96-
upper_bound: float | UUID | None
97-
98-
alpha_s: float | UUID | None
99-
length_scale_x: float | UUID
100-
length_scale_y: float | UUID
101-
length_scale_z: float | UUID
102-
103-
s_norm: float | UUID | None
104-
x_norm: float | UUID
105-
y_norm: float | UUID | None
106-
z_norm: float | UUID
107-
gradient_type: str
108-
max_irls_iterations: int
109-
starting_chi_factor: float
110-
111-
chi_factor: float
112-
auto_scale_misfits: bool
113-
initial_beta_ratio: float | None
114-
initial_beta: float | None
115-
coolingFactor: float
116-
117-
coolingRate: float
118-
max_global_iterations: int
119-
max_line_search_iterations: int
120-
max_cg_iterations: int
121-
tol_cg: float
122-
f_min_change: float
123-
124-
sens_wts_threshold: float
125-
every_iteration_bool: bool
126-
save_sensitivities: str
127-
128-
beta_tol: float
129-
prctile: float
130-
coolEps_q: bool
131-
coolEpsFact: float
132-
beta_search: bool
133-
134-
chunk_by_rows: bool
135-
output_tile_files: bool
136-
inversion_style: str
137-
max_ram: float | None
138-
ga_group: SimPEGGroup | None
139-
distributed_workers: int | None
140-
no_data_value: float | None

tests/uijson_test.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,14 @@
1313
from pathlib import Path
1414
from typing import ClassVar
1515

16+
import numpy as np
1617
from geoh5py import Workspace
1718

18-
from simpeg_drivers.uijson import CoreUIJson
19+
from simpeg_drivers.params import ActiveCellsOptions
20+
from simpeg_drivers.potential_fields.gravity.params import GravityInversionOptions
21+
from simpeg_drivers.potential_fields.gravity.uijson import GravityInversionUIJson
22+
from simpeg_drivers.uijson import SimPEGDriversUIJson
23+
from simpeg_drivers.utils.testing import setup_inversion_workspace
1924

2025

2126
logger = logging.getLogger(__name__)
@@ -25,7 +30,7 @@ def test_version_warning(tmp_path, caplog):
2530
workspace = Workspace(tmp_path / "test.geoh5")
2631

2732
with caplog.at_level(logging.WARNING):
28-
_ = CoreUIJson(
33+
_ = SimPEGDriversUIJson(
2934
version="0.2.0",
3035
title="My app",
3136
geoh5=str(workspace.h5file),
@@ -50,7 +55,7 @@ def test_write_default(tmp_path):
5055
with open(default_path, "w", encoding="utf-8") as f:
5156
json.dump(data, f, ensure_ascii=False, indent=4)
5257

53-
class MyUIJson(CoreUIJson):
58+
class MyUIJson(SimPEGDriversUIJson):
5459
default_ui_json: ClassVar[Path] = default_path
5560
version: str = "0.2.0"
5661

@@ -60,3 +65,35 @@ class MyUIJson(CoreUIJson):
6065
data = json.load(f)
6166

6267
assert data["version"] == "0.3.0-alpha.1"
68+
69+
70+
def test_gravity_uijson(tmp_path):
71+
geoh5, _, starting_model, survey, topography = setup_inversion_workspace(
72+
tmp_path, background=0.0, anomaly=0.75, inversion_type="gravity"
73+
)
74+
75+
opts = GravityInversionOptions(
76+
geoh5=geoh5,
77+
data_object=survey,
78+
gz_channel=survey.add_data({"gz": {"values": np.ones(survey.n_vertices)}}),
79+
gz_uncertainty=survey.add_data(
80+
{"gz_unc": {"values": np.ones(survey.n_vertices)}}
81+
),
82+
mesh=starting_model.parent,
83+
starting_model=starting_model,
84+
active_cells=ActiveCellsOptions(
85+
topography_object=topography,
86+
),
87+
)
88+
params_uijson_path = tmp_path / "from_params.ui.json"
89+
opts.write_ui_json(params_uijson_path)
90+
91+
uijson = GravityInversionUIJson.read(params_uijson_path)
92+
uijson_path = tmp_path / "from_uijson.ui.json"
93+
uijson.write(uijson_path)
94+
with open(params_uijson_path, encoding="utf-8") as f:
95+
params_data = json.load(f)
96+
with open(uijson_path, encoding="utf-8") as f:
97+
uijson_data = json.load(f)
98+
99+
assert uijson_data == params_data

0 commit comments

Comments
 (0)