-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbase.py
More file actions
382 lines (303 loc) · 12.7 KB
/
base.py
File metadata and controls
382 lines (303 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
# Copyright (c) 2022-2026 Mira Geoscience Ltd. '
# '
# This file is part of geoapps-utils package. '
# '
# geoapps-utils is distributed under the terms and conditions of the MIT License '
# (see LICENSE file at the root of this source code package). '
# '
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
from __future__ import annotations
import sys
import tempfile
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, ClassVar, GenericAlias, Self # type: ignore
from geoh5py import Workspace
from geoh5py.groups import UIJsonGroup
from geoh5py.objects import ObjectBase
from geoh5py.ui_json import InputFile, UIJson, monitored_directory_copy
from geoh5py.ui_json.utils import fetch_active_workspace
from pydantic import BaseModel, ConfigDict, ValidationError
from geoapps_utils import assets_path
from geoapps_utils.driver.params import BaseParams
from geoapps_utils.utils.formatters import recursive_flatten
from geoapps_utils.utils.importing import GeoAppsError
from geoapps_utils.utils.logger import get_logger
logger = get_logger(name=__name__, level_name=False, propagate=False, add_name=False)
def input_file_deprecation_warning(input_file: InputFile) -> Path:
"""
Warn the user of future deprecation and get a file path to an existing file.
"""
warnings.warn(
"The use of InputFile will be deprecated in future versions."
"Please start using UIJson class instead.",
DeprecationWarning,
stacklevel=2,
)
if input_file.path_name is None or not Path(input_file.path_name).is_file():
temp_path = Path(tempfile.mkdtemp()) / "temp.ui.json"
input_file.write_ui_json(path=temp_path.parent, name=temp_path.name)
return temp_path
return Path(input_file.path_name)
class Driver(ABC):
"""
# todo: Get rid of BaseParams to have a more robust DriverClass
Base driver class.
:param params: Application parameters.
"""
_params_class: type[Options]
def __init__(self, params: Options | BaseParams):
self._out_group: UIJsonGroup | None = None
self.params = params
@property
def params(self):
"""Application parameters."""
return self._params
@params.setter
def params(self, val: Options | BaseParams):
if not isinstance(val, self._params_class):
raise TypeError(
f"Parameters must be of type {self._params_class}.\n"
f"Got {type(val)} instead."
)
self._params = val
@property
def workspace(self):
"""Application workspace."""
return self._params.geoh5
@property
def out_group(self) -> UIJsonGroup | None:
if self._out_group is None:
if self.params.out_group is not None:
self._out_group = self.params.out_group
return self._out_group
@property
def params_class(self):
"""Default parameter class."""
return self._params_class
@abstractmethod
def run(self):
"""Run the application."""
@classmethod
def start(
cls, filepath: str | Path | InputFile | UIJson, mode="r+", **kwargs
) -> Self:
"""
Run application specified by 'filepath' ui.json file.
:param filepath: Path to valid ui.json file for the application driver.
:param mode: Mode to open the geoh5 file with.
:param kwargs: Additional keyword arguments for Options class.
:return: Self object.
"""
if isinstance(filepath, InputFile):
filepath = input_file_deprecation_warning(filepath)
ifile = UIJson.read(filepath) if isinstance(filepath, str | Path) else filepath
if not isinstance(ifile, UIJson):
raise TypeError("Input file must be a string path or an InputFile object.")
if ifile.geoh5 is None:
raise GeoAppsError("The application needs a valid 'geoh5' file.")
params = cls._params_class.build(ifile, **kwargs)
with params.geoh5.open(mode=mode):
try:
logger.info("Initializing application . . .")
driver = cls(params)
logger.info("Running application . . .")
driver.run()
logger.info("Results saved to %s", params.geoh5.h5file)
except GeoAppsError as error:
logger.warning("\n\nApplicationError: %s\n\n", error)
sys.exit(1)
return driver
def add_ui_json(self, entity: ObjectBase):
"""
Add ui.json as FileData to entity.
:param entity: Object to add ui.json file to.
"""
with tempfile.TemporaryDirectory() as tmpdirname:
path = Path(tmpdirname) / self.params.title
self.params.ui_json.write(path)
entity.add_file(path)
def update_monitoring_directory(
self, entity: ObjectBase, copy_children: bool = True
):
"""
If monitoring directory is active, copy entity to monitoring directory.
:param entity: Object being added to monitoring directory.
:param copy_children: If True, copy all children of the entity to the monitoring directory.
"""
self.add_ui_json(entity)
if (
self.params.monitoring_directory is not None
and Path(self.params.monitoring_directory).is_dir()
):
monitored_directory_copy(
str(Path(self.params.monitoring_directory).resolve()),
entity,
copy_children=copy_children,
)
@classmethod
def get_default_ui_json_path(cls) -> Path | None:
"""
Get the default ui.json file path for the application.
:return: Path to default ui.json file.
"""
if issubclass(cls._params_class, Options):
return cls._params_class.default_ui_json
return None
@classmethod
def get_default_ui_json(cls) -> UIJson:
"""
Load the driver's default ui.json template from disk
with no parameters filled in.
:return: The default ui.json configuration.
"""
if issubclass(cls._params_class, Options):
return cls._params_class.get_default_ui_json()
raise ValueError(f"Driver {cls} does not have a default ui.json.")
class Options(BaseModel):
"""
Core parameters expected by the ui.json file format.
:param conda_environment: Environment used to run run_command.
:param geoh5: Current workspace path.
:param monitoring_directory: Path to monitoring directory, where .geoh5 files
are automatically processed by GA.
:param run_command: Command to run the application through GA.
:param title: Application title.
"""
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
name: ClassVar[str] = "base"
default_ui_json: ClassVar[Path | None] = assets_path() / "uijson/base.ui.json"
title: str = "Base Data"
run_command: str = "geoapps_utils.base"
conda_environment: str | None = None
geoh5: Workspace
monitoring_directory: str | Path | None = None
out_group: UIJsonGroup | None = None
@staticmethod
def collect_input_from_dict(
model: type[BaseModel], data: dict[str, Any]
) -> dict[str, dict | Any]:
"""
Recursively replace BaseModel objects with nested dictionary of 'data' values.
:param base_model: BaseModel object to structure data for.
:param data: Flat dictionary of parameters and values without nesting structure.
"""
update = data.copy()
nested_fields: list[str] = []
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
for field, info in model.model_fields.items():
# Already a BaseModel, no need to nest
if isinstance(update.get(field, None), BaseModel):
continue
if (
isinstance(info.annotation, type)
and not isinstance(info.annotation, GenericAlias)
and issubclass(info.annotation, BaseModel)
):
# Nest and deal with aliases
update = Options.collect_input_from_dict(info.annotation, update)
nested = info.annotation.model_construct(**update).model_dump(
exclude_unset=True
)
if any(nested):
update[field] = nested
nested_fields += nested
for field in nested_fields:
if field in update:
del update[field]
return update
@classmethod
def build(
cls, input_data: InputFile | dict | None | UIJson = None, **kwargs
) -> Self:
"""
Build a dataclass from a dictionary or UIJson.
:param input_data: Dictionary of parameters and values.
:return: Dataclass of application parameters.
"""
data = input_data if isinstance(input_data, dict | UIJson) else {}
if isinstance(input_data, InputFile) and input_data.data is not None:
file_path = input_file_deprecation_warning(input_data)
data = UIJson.read(file_path)
if isinstance(data, UIJson):
data = data.to_params()
if not isinstance(data, dict):
raise TypeError("Input data must be a dictionary or UIJson.")
data.update(kwargs)
options = cls.collect_input_from_dict(cls, data) # type: ignore
try:
out = cls(**options)
except ValidationError as errors:
summary = "\n - ".join(
f"{'.'.join(str(loc) for loc in error['loc'])}: "
f"{error['msg']} for value -> {error['input']}"
for error in errors.errors()
)
raise GeoAppsError(
f"Invalid input data for {cls.__name__}:\n - {summary}"
) from errors
return out
def _recursive_flatten(self, data: dict[str, Any]) -> dict[str, Any]:
"""
Recursively flatten nested dictionary.
To be used on output of BaseModel.model_dump.
:param data: Dictionary of parameters and values.
"""
logger.warning(
"Deprecated method: Use geoapps_utils.utils.formatters._recursive_flatten"
)
return recursive_flatten(data)
def flatten(self) -> dict:
"""
Flatten the parameters to a dictionary.
:return: Dictionary of parameters.
"""
out = recursive_flatten(self.model_dump())
out.pop("input_file", None)
return out
@property
def input_file(self) -> UIJson:
"""Create an InputFile with data matching current parameter state."""
warnings.warn(
"InputFile property is deprecated and will be removed in future versions. "
"Use `ui_json` instead.",
DeprecationWarning,
stacklevel=2,
)
return self.ui_json
def serialize(self, mode="python"):
"""Return a demoted uijson dictionary representation the params data."""
serialized = self.ui_json.model_dump(
exclude_unset=True, by_alias=True, mode=mode
)
return serialized
def update_out_group_options(self):
"""
Serialize current state and save to the out_group options.
"""
if self.out_group is None:
raise ValueError("No output group defined to save options.")
with fetch_active_workspace(self.geoh5, mode="r+"):
self.out_group.options = self.serialize(mode="json")
self.out_group.metadata = None
@property
def ui_json(self) -> UIJson:
"""
The parent UIJson object.
"""
ui_json = self.get_default_ui_json()
ui_json.set_values(**self.flatten())
return ui_json
@classmethod
def get_default_ui_json(cls) -> UIJson:
"""
Load the driver's default ui.json template from disk
with no parameters filled in.
:return: The default ui.json configuration.
"""
if cls.default_ui_json is None or not cls.default_ui_json.exists():
raise ValueError(f"Driver {cls} does not have a default ui.json.")
return UIJson.read(cls.default_ui_json)