Skip to content

Commit 68480f6

Browse files
authored
Bugfix/ensemble libe specs attrs passthrough (#1264)
* it turns out that values set by validators are still considered "unset". So for updating purposes for libE_specs, we want to exclude fields that are still set to their defaults * starting to create unit test * finish up unit test * platform_specs sometimes seems to be at risk of disappearing when we convert LibeSpecs to dict, so lets save it and reinsert
1 parent 10de371 commit 68480f6

2 files changed

Lines changed: 33 additions & 2 deletions

File tree

libensemble/ensemble.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,14 @@ def libE_specs(self, new_specs):
326326
return
327327

328328
# Cast new libE_specs temporarily to dict
329-
if not isinstance(new_specs, dict):
330-
new_specs = specs_dump(new_specs, by_alias=True, exclude_none=True, exclude_unset=True)
329+
if not isinstance(new_specs, dict): # exclude_defaults should only be enabled with Pydantic v2
330+
platform_specs_set = False
331+
if new_specs.platform_specs != {}: # bugginess across Pydantic versions for recursively casting to dict
332+
platform_specs_set = True
333+
platform_specs = new_specs.platform_specs
334+
new_specs = specs_dump(new_specs, exclude_none=True, exclude_defaults=True)
335+
if platform_specs_set:
336+
new_specs["platform_specs"] = specs_dump(platform_specs, exclude_none=True)
331337

332338
# Unset "comms" if we already have a libE_specs that contains that field, that came from parse_args
333339
if new_specs.get("comms") and hasattr(self._libE_specs, "comms") and self.parsed:

libensemble/tests/unit_tests/test_ensemble.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,35 @@ def test_flakey_workflow():
166166
assert not flag, "should've caught input errors"
167167

168168

169+
def test_ensemble_specs_update_libE_specs():
170+
"""Test that libE_specs is updated as expected with .attribute setting"""
171+
from libensemble.ensemble import Ensemble
172+
from libensemble.resources.platforms import PerlmutterGPU
173+
from libensemble.specs import LibeSpecs
174+
175+
platform_specs = PerlmutterGPU()
176+
177+
ensemble = Ensemble(
178+
libE_specs=LibeSpecs(comms="local", nworkers=4),
179+
)
180+
181+
ensemble.libE_specs = LibeSpecs(
182+
num_resource_sets=ensemble.nworkers - 1,
183+
resource_info={"gpus_on_node": 4},
184+
use_workflow_dir=True,
185+
platform_specs=platform_specs,
186+
)
187+
188+
assert ensemble.libE_specs.num_resource_sets == ensemble.nworkers - 1
189+
assert len(str(ensemble.libE_specs.workflow_dir_path)) > 1
190+
assert ensemble.libE_specs.platform_specs == specs_dump(platform_specs, exclude_none=True)
191+
192+
169193
if __name__ == "__main__":
170194
test_ensemble_init()
171195
test_ensemble_parse_args_false()
172196
test_from_files()
173197
test_bad_func_loads()
174198
test_full_workflow()
175199
test_flakey_workflow()
200+
test_ensemble_specs_update_libE_specs()

0 commit comments

Comments
 (0)