Skip to content

Commit 175d645

Browse files
committed
[GEOPY-2049] augment tests
1 parent af97e14 commit 175d645

2 files changed

Lines changed: 120 additions & 59 deletions

File tree

simpeg_drivers/uijson.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,12 @@ def verify_and_update_version(cls, value: str) -> str:
3434
input_version = cls.comparable_version(value)
3535
if input_version != package_version:
3636
logger.warning(
37-
"Provided ui.json file version %s does not match the the current"
38-
"simpeg-drivers version %s. This may lead to unpredictable"
39-
"behavior.",
37+
"Provided ui.json file version '%s' does not match the current "
38+
"simpeg-drivers version '%s'. This may lead to unpredictable behavior.",
4039
value,
4140
simpeg_drivers.__version__,
4241
)
43-
return input_version
42+
return value
4443

4544
@staticmethod
4645
def comparable_version(value: str) -> str:
@@ -54,10 +53,17 @@ def comparable_version(value: str) -> str:
5453
For example, if the version is "0.2.0+local", it will return "0.2.0".
5554
"""
5655
version = Version(value)
57-
version.post = None
58-
if version.pre is not None and version.pre[0] == "rc": # pylint: disable=unsubscriptable-object
59-
version.pre = None
60-
return version.public
56+
57+
# Extract the base version (major.minor.patch)
58+
base_version = version.base_version
59+
60+
# If it's not an RC, keep any pre-release info (alpha/beta)
61+
if version.pre is not None and version.pre[0] != "rc": # pylint: disable=unsubscriptable-object
62+
# Recreate version with pre-release but no post or local
63+
return f"{base_version}{version.pre[0]}{version.pre[1]}"
64+
65+
# No pre-release info or it's an RC, return just the base version
66+
return base_version
6167

6268
@classmethod
6369
def write_default(cls):

tests/uijson_test.py

Lines changed: 106 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import ClassVar
1515

1616
import numpy as np
17+
import pytest
1718
from geoh5py import Workspace
1819
from geoh5py.ui_json.annotations import Deprecated
1920
from packaging.version import Version
@@ -35,12 +36,23 @@ def _current_version() -> Version:
3536
return Version(simpeg_drivers.__version__)
3637

3738

38-
def test_version_warning(tmp_path, caplog):
39-
workspace = Workspace.create(tmp_path / "test.geoh5")
39+
@pytest.fixture(name="workspace")
40+
def workspace_fixture(tmp_path):
41+
"""Create a workspace for testing."""
42+
return Workspace.create(tmp_path / "test.geoh5")
4043

41-
with caplog.at_level(logging.WARNING):
42-
_ = SimPEGDriversUIJson(
43-
version="0.2.0",
44+
45+
@pytest.fixture(name="simpeg_uijson_factory")
46+
def simpeg_uijson_factory_fixture(workspace):
47+
"""Create a SimPEGDriversUIJson object with configurable version."""
48+
49+
def _create_uijson(version: str | None = None, **kwargs):
50+
"""Create a SimPEGDriversUIJson with the given version and custom fields."""
51+
if version is None:
52+
version = _current_version().public
53+
54+
return SimPEGDriversUIJson(
55+
version=version,
4456
title="My app",
4557
icon="",
4658
documentation="",
@@ -49,8 +61,87 @@ def test_version_warning(tmp_path, caplog):
4961
monitoring_directory="",
5062
conda_environment="my-app",
5163
workspace_geoh5="",
64+
**kwargs,
65+
)
66+
67+
return _create_uijson
68+
69+
70+
@pytest.mark.parametrize(
71+
"version_input,expected",
72+
[
73+
# Normal version
74+
("1.2.3", "1.2.3"),
75+
# Post-release version
76+
("1.2.3.post1", "1.2.3"),
77+
# RC pre-release version
78+
("1.2.3rc1", "1.2.3"),
79+
# Alpha pre-release version (should not normalize)
80+
("1.2.3a1", "1.2.3a1"),
81+
# Beta pre-release version (should not normalize)
82+
("1.2.3b1", "1.2.3b1"),
83+
# Local version
84+
("1.2.3+local", "1.2.3"),
85+
# Combined cases
86+
("1.2.3rc1.post2+local", "1.2.3"),
87+
],
88+
)
89+
def test_comparable_version(version_input, expected):
90+
"""Test the comparable_version method of SimPEGDriversUIJson."""
91+
assert SimPEGDriversUIJson.comparable_version(version_input) == expected
92+
93+
94+
@pytest.mark.parametrize(
95+
"version_input,package_version,should_warn",
96+
[
97+
# Different version (should warn)
98+
("1.0.0", "2.0.0", True),
99+
# Same version (should not warn)
100+
("2.0.0", "2.0.0", False),
101+
# Post-release variant (should not warn)
102+
("2.0.0.post1", "2.0.0", False),
103+
("2.0.0", "2.0.0.post1", False),
104+
# RC variant (should not warn)
105+
("2.0.0rc1", "2.0.0", False),
106+
("2.0.0", "2.0.0rc1", False),
107+
("2.0.0rc1", "2.0.0rc2", False),
108+
# differ by the pre-release number, non RC (should warn)
109+
("2.0.0a1", "2.0.0a2", True),
110+
("2.0.0b1", "2.0.0b2", True),
111+
("2.0.0a1", "2.0.0", True),
112+
("2.0.0", "2.0.0a1", True),
113+
("2.0.0a1", "2.0.0b1", True),
114+
("2.0.0b1", "2.0.0a1", True),
115+
("2.0.0rc1", "2.0.0b1", True),
116+
("2.0.0b1", "2.0.0rc1", True),
117+
# same normalized versions (should not warn)
118+
("2.0.0-beta.1", "2.0.0b1", False),
119+
("2.0.0b1", "2.0.0-beta.1", False),
120+
],
121+
)
122+
def test_version_warning(
123+
monkeypatch,
124+
caplog,
125+
simpeg_uijson_factory,
126+
version_input,
127+
package_version,
128+
should_warn,
129+
):
130+
"""Test version warning behavior with mocked package version."""
131+
# Mock the package version
132+
monkeypatch.setattr(simpeg_drivers, "__version__", package_version)
133+
134+
with caplog.at_level(logging.WARNING):
135+
caplog.clear()
136+
_ = simpeg_uijson_factory(version=version_input)
137+
138+
warning_message = f"version '{version_input}' does not match the current simpeg-drivers version"
139+
warning_found = any(
140+
warning_message in record.message for record in caplog.records
52141
)
53142

143+
assert warning_found == should_warn
144+
54145

55146
def test_write_default(tmp_path):
56147
default_path = tmp_path / "default.ui.json"
@@ -75,70 +166,34 @@ class MyUIJson(SimPEGDriversUIJson):
75166
with open(default_path, encoding="utf-8") as f:
76167
data = json.load(f)
77168

78-
assert Version(data["version"]) == Version(_current_version().public)
79-
169+
# Use comparable_version for comparison to handle pre/post-release versions
170+
assert SimPEGDriversUIJson.comparable_version(
171+
data["version"]
172+
) == SimPEGDriversUIJson.comparable_version(simpeg_drivers.__version__)
80173

81-
def test_deprecations(tmp_path, caplog):
82-
workspace = Workspace.create(tmp_path / "test.geoh5")
83174

175+
def test_deprecations(caplog, simpeg_uijson_factory):
84176
class MyUIJson(SimPEGDriversUIJson):
85177
my_param: Deprecated
86178

87179
with caplog.at_level(logging.WARNING):
88-
_ = MyUIJson(
89-
version=_current_version().public,
90-
title="My app",
91-
icon="",
92-
documentation="",
93-
geoh5=str(workspace.h5file),
94-
run_command="myapp.driver",
95-
monitoring_directory="",
96-
conda_environment="my-app",
97-
workspace_geoh5="",
98-
my_param="whoopsie",
99-
)
180+
_ = MyUIJson(**simpeg_uijson_factory().model_dump(), my_param="whoopsie")
100181
assert "Skipping deprecated field: my_param." in caplog.text
101182

102183

103-
def test_pydantic_deprecation(tmp_path):
104-
workspace = Workspace.create(tmp_path / "test.geoh5")
105-
184+
def test_pydantic_deprecation(simpeg_uijson_factory):
106185
class MyUIJson(SimPEGDriversUIJson):
107186
my_param: str = Field(deprecated="Use my_param2 instead.", exclude=True)
108187

109-
uijson = MyUIJson(
110-
version=_current_version().public,
111-
title="My app",
112-
icon="",
113-
documentation="",
114-
geoh5=str(workspace.h5file),
115-
run_command="myapp.driver",
116-
monitoring_directory="",
117-
conda_environment="my-app",
118-
workspace_geoh5="",
119-
my_param="whoopsie",
120-
)
188+
uijson = MyUIJson(**simpeg_uijson_factory(my_param="whoopsie").model_dump())
121189
assert "my_param" not in uijson.model_dump()
122190

123191

124-
def test_alias(tmp_path):
125-
workspace = Workspace.create(tmp_path / "test.geoh5")
126-
192+
def test_alias(simpeg_uijson_factory):
127193
class MyUIJson(SimPEGDriversUIJson):
128194
my_param: str = Field(validation_alias=AliasChoices("my_param", "myParam"))
129195

130-
uijson = MyUIJson(
131-
version=_current_version().public,
132-
title="My app",
133-
icon="",
134-
documentation="",
135-
geoh5=str(workspace.h5file),
136-
run_command="myapp.driver",
137-
monitoring_directory="",
138-
conda_environment="my-app",
139-
workspace_geoh5="",
140-
myParam="hello",
141-
)
196+
uijson = MyUIJson(**simpeg_uijson_factory(myParam="hello").model_dump())
142197
assert uijson.my_param == "hello"
143198
assert "myParam" not in uijson.model_fields_set
144199
assert "my_param" in uijson.model_dump()

0 commit comments

Comments
 (0)