Skip to content

Commit 7eb13a0

Browse files
committed
fix tests to handle title, run_command as instance attributes
1 parent 84773e4 commit 7eb13a0

1 file changed

Lines changed: 7 additions & 12 deletions

File tree

tests/dataclass_test.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ def test_dataclass_valid_values(tmp_path):
6262
valid_parameters = get_params_dict(tmp_path / f"{__name__}.geoh5")
6363
model = BaseData(**valid_parameters)
6464
output_params = model.model_dump()
65-
assert all(k not in output_params for k in ["title", "run_command"])
66-
assert len(output_params) == len(valid_parameters) - 2
65+
assert len(output_params) == len(valid_parameters)
6766

6867
for k, v in output_params.items():
6968
assert valid_parameters[k] == v
@@ -84,12 +83,13 @@ def test_dataclass_invalid_values(tmp_path):
8483
try:
8584
BaseData(**invalid_params)
8685
except ValidationError as e:
87-
assert len(e.errors()) == 3 # type: ignore
86+
assert len(e.errors()) == 4 # type: ignore
8887
error_params = [error["loc"][0] for error in e.errors()] # type: ignore
8988
error_types = [error["type"] for error in e.errors()] # type: ignore
9089
for error_param in [
9190
"monitoring_directory",
9291
"geoh5",
92+
"title",
9393
]:
9494
assert error_param in error_params
9595
for error_type in ["string_type", "path_type", "is_instance_of"]:
@@ -102,9 +102,7 @@ def test_dataclass_input_file(tmp_path):
102102
model = BaseData.build(ifile)
103103

104104
assert model.geoh5.h5file == tmp_path / f"{__name__}.geoh5"
105-
assert model.flatten() == {
106-
k: v for k, v in valid_parameters.items() if k not in ["title", "run_command"]
107-
}
105+
assert model.flatten() == valid_parameters
108106
assert model._input_file == ifile # pylint: disable=protected-access
109107

110108

@@ -215,18 +213,15 @@ class NestedModel(BaseData):
215213

216214
assert isinstance(model.group, GroupParams)
217215
assert model.group.value == "test"
218-
assert model.flatten() == {
219-
k: v for k, v in valid_params.items() if k not in ["title", "run_command"]
220-
}
221-
216+
assert model.flatten() == valid_params
222217
assert model.group.options.group_type == "multi"
223218

224219

225220
def test_params_construction(tmp_path):
226221
params = BaseData(geoh5=Workspace(tmp_path / "test.geoh5"))
227222
assert BaseData.default_ui_json is None
228-
assert BaseData.title == "Base Data"
229-
assert BaseData.run_command == "geoapps_utils.driver.driver"
223+
assert params.title == "Base Data"
224+
assert params.run_command == "geoapps_utils.driver.driver"
230225
assert str(params.geoh5.h5file) == str(tmp_path / "test.geoh5")
231226

232227

0 commit comments

Comments
 (0)