Skip to content

Commit 5ede7ea

Browse files
authored
Merge pull request #139 from OpenBioSim/fix_config_comparison
Fix restraint config comparison
2 parents c9dd8f6 + 570429b commit 5ede7ea

2 files changed

Lines changed: 48 additions & 35 deletions

File tree

src/somd2/config/_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,11 @@ def as_dict(self, sire_compatible=False):
688688
if value is None and sire_compatible:
689689
d[attr_l] = False
690690

691+
# Don't include lambda_schedule_name or perturbed_system_file in the dictionary,
692+
# since these are just helper attributes.
693+
d.pop("_lambda_schedule_name", None)
694+
d.pop("_perturbed_system_file", None)
695+
691696
# Handle the lambda schedule separately so that we can use simplified
692697
# keyword options.
693698

@@ -716,7 +721,6 @@ def as_dict(self, sire_compatible=False):
716721
and self._perturbed_system_file is not None
717722
):
718723
d["perturbed_system"] = str(self._perturbed_system_file)
719-
d.pop("perturbed_system_file", None)
720724

721725
return d
722726

src/somd2/runner/_base.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,7 +1397,6 @@ def _compare_configs(config1, config2):
13971397
"frame_frequency",
13981398
"save_velocities",
13991399
"perturbed_system",
1400-
"perturbed_system_file",
14011400
"platform",
14021401
"max_threads",
14031402
"max_gpus",
@@ -1434,51 +1433,61 @@ def _compare_configs(config1, config2):
14341433
# Standard schedules are stored as strings, so we can compare these directly.
14351434
if v1 == v2:
14361435
continue
1437-
else:
1438-
try:
1439-
v1 = _Config._from_hex(v1)
1440-
except Exception as e:
1436+
try:
1437+
v1 = _Config._from_hex(v1)
1438+
except Exception as e:
1439+
raise ValueError(
1440+
f"Unable to deserialise lambda schedule from config1: {str(e)}"
1441+
)
1442+
try:
1443+
v2 = _Config._from_hex(v2)
1444+
except Exception as e:
1445+
raise ValueError(
1446+
f"Unable to deserialise lambda schedule from config2: {str(e)}"
1447+
)
1448+
if v1 != v2:
1449+
raise ValueError(
1450+
f"{key} has changed since the last run. This is not "
1451+
"allowed when using the restart option."
1452+
)
1453+
continue
1454+
1455+
# Restraints are stored as a list of hexadecimal strings of serialised objects.
1456+
# We need to deserialise them before comparison.
1457+
elif key == "restraints":
1458+
if v1 and v2:
1459+
if len(v1) != len(v2):
14411460
raise ValueError(
1442-
f"Unable to deserialise lambda schedule from config1: {str(e)}"
1461+
f"Number of restraints has changed since the last run "
1462+
f"({len(v1)} vs {len(v2)}). This is not allowed when "
1463+
"using the restart option."
14431464
)
1465+
# Deserialise all restraints from both configs.
14441466
try:
1445-
v2 = _Config._from_hex(v2)
1467+
deserialized_v1 = [_Config._from_hex(r) for r in v1]
14461468
except Exception as e:
14471469
raise ValueError(
1448-
f"Unable to deserialise lambda schedule from config2: {str(e)}"
1470+
f"Unable to deserialise restraint from config1: {str(e)}"
14491471
)
1450-
if v1 != v2:
1472+
try:
1473+
deserialized_v2 = [_Config._from_hex(r) for r in v2]
1474+
except Exception as e:
14511475
raise ValueError(
1452-
f"{key} has changed since the last run. This is not "
1453-
"allowed when using the restart option."
1476+
f"Unable to deserialise restraint from config2: {str(e)}"
14541477
)
1455-
else:
1456-
continue
1457-
1458-
# Restraints are stored as a list of hexadecimal strings of serialised objects.
1459-
# We need to deserialise them before comparison.
1460-
elif key == "restraints":
1461-
if v1 and v2:
1462-
for r1, r2 in zip(v1, v2):
1463-
try:
1464-
r1 = _Config._from_hex(r1)
1465-
except Exception as e:
1466-
raise ValueError(
1467-
f"Unable to deserialise restraint from config1: {str(e)}"
1468-
)
1469-
try:
1470-
r2 = _Config._from_hex(r2)
1471-
except Exception as e:
1472-
raise ValueError(
1473-
f"Unable to deserialise restraint from config2: {str(e)}"
1474-
)
1475-
if r1 != r2:
1478+
# Match each restraint in v1 against v2, regardless of order.
1479+
unmatched = list(deserialized_v2)
1480+
for r1 in deserialized_v1:
1481+
for i, r2 in enumerate(unmatched):
1482+
if r1 == r2:
1483+
unmatched.pop(i)
1484+
break
1485+
else:
14761486
raise ValueError(
14771487
f"{key} has changed since the last run. This is not "
14781488
"allowed when using the restart option."
14791489
)
1480-
else:
1481-
continue
1490+
continue
14821491

14831492
# Convert GeneralUnits to strings for comparison.
14841493
if isinstance(v1, _GeneralUnit):

0 commit comments

Comments
 (0)