Skip to content

Commit 29a033d

Browse files
authored
Merge pull request #141 from OpenBioSim/fix_sire_419
Set _pre_run_state attribribute in Sire dynamics
2 parents 5481736 + 727269c commit 29a033d

4 files changed

Lines changed: 193 additions & 51 deletions

File tree

src/somd2/config/_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def __init__(
162162
overwrite=False,
163163
somd1_compatibility=False,
164164
pert_file=None,
165+
auto_fix_minimise=True,
165166
save_crash_report=False,
166167
save_energy_components=False,
167168
page_size=None,
@@ -496,6 +497,10 @@ def __init__(
496497
The path to a SOMD1 perturbation file to apply to the reference system.
497498
When set, this will automatically set 'somd1_compatibility' to True.
498499
500+
auto_fix_minimise: bool
501+
Whether to attempt to automatically recover from simulation instabilities
502+
by minimising and restarting. Defaults to True.
503+
499504
save_crash_report: bool
500505
Whether to save a crash report if the simulation crashes.
501506
@@ -599,6 +604,7 @@ def __init__(
599604
self.taylor_power = taylor_power
600605
self.somd1_compatibility = somd1_compatibility
601606
self.pert_file = pert_file
607+
self.auto_fix_minimise = auto_fix_minimise
602608
self.save_crash_report = save_crash_report
603609
self.save_energy_components = save_energy_components
604610
self.timeout = timeout
@@ -2383,6 +2389,16 @@ def pert_file(self, pert_file):
23832389

23842390
self._pert_file = pert_file
23852391

2392+
@property
2393+
def auto_fix_minimise(self):
2394+
return self._auto_fix_minimise
2395+
2396+
@auto_fix_minimise.setter
2397+
def auto_fix_minimise(self, auto_fix_minimise):
2398+
if not isinstance(auto_fix_minimise, bool):
2399+
raise ValueError("'auto_fix_minimise' must be of type 'bool'")
2400+
self._auto_fix_minimise = auto_fix_minimise
2401+
23862402
@property
23872403
def save_crash_report(self):
23882404
return self._save_crash_report

src/somd2/runner/_base.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,26 +1820,33 @@ def _checkpoint(
18201820
# Get the lambda value.
18211821
lam = self._lambda_values[index]
18221822

1823+
# -1 is the sentinel for a post-equilibration checkpoint. No
1824+
# energies are collected during equilibration, so skip all
1825+
# parquet-related work in this case.
1826+
is_post_equilibration = block == -1
1827+
18231828
# Get the energy trajectory.
1824-
df = system.energy_trajectory(to_alchemlyb=True, energy_unit="kT")
1829+
if not is_post_equilibration:
1830+
df = system.energy_trajectory(to_alchemlyb=True, energy_unit="kT")
18251831

18261832
# Set the lambda values at which energies were sampled.
18271833
if lambda_energy is None:
18281834
lambda_energy = self._lambda_values
18291835

18301836
# Create the metadata.
1831-
metadata = {
1832-
"attrs": df.attrs,
1833-
"somd2 version": __version__,
1834-
"sire version": f"{_sire_version}+{_sire_revisionid}",
1835-
"lambda": f"{lam:.5f}",
1836-
"speed": speed,
1837-
"temperature": str(self._config.temperature.value()),
1838-
}
1839-
1840-
# Add the lambda gradient if available.
1841-
if lambda_grad is not None:
1842-
metadata["lambda_grad"] = [f"{v:.5f}" for v in lambda_grad]
1837+
if not is_post_equilibration:
1838+
metadata = {
1839+
"attrs": df.attrs,
1840+
"somd2 version": __version__,
1841+
"sire version": f"{_sire_version}+{_sire_revisionid}",
1842+
"lambda": f"{lam:.5f}",
1843+
"speed": speed,
1844+
"temperature": str(self._config.temperature.value()),
1845+
}
1846+
1847+
# Add the lambda gradient if available.
1848+
if lambda_grad is not None:
1849+
metadata["lambda_grad"] = [f"{v:.5f}" for v in lambda_grad]
18431850

18441851
if is_final_block:
18451852
# Save the end-state GCMC topologies for trajectory analysis and visualisation.
@@ -1930,7 +1937,7 @@ def _checkpoint(
19301937

19311938
else:
19321939
# Update the starting block if necessary.
1933-
if block == 0:
1940+
if block <= 0:
19341941
block = self._start_block
19351942

19361943
# Save the current trajectory chunk to file.
@@ -1958,18 +1965,20 @@ def _checkpoint(
19581965
# Stream the checkpoint to file.
19591966
_sr.stream.save(system, self._filenames[index]["checkpoint"])
19601967

1961-
# Create the parquet file name.
1962-
filename = self._filenames[index]["energy_traj"]
1963-
1964-
# Create the parquet file.
1965-
if block == self._start_block:
1966-
_dataframe_to_parquet(df, metadata=metadata, filename=filename)
1967-
# Append to the parquet file.
1968-
else:
1969-
_parquet_append(
1970-
filename,
1971-
df.iloc[-self._energy_per_block :],
1972-
)
1968+
# Skip parquet creation for post-equilibration checkpoints.
1969+
if not is_post_equilibration:
1970+
# Create the parquet file name.
1971+
filename = self._filenames[index]["energy_traj"]
1972+
1973+
# Create the parquet file.
1974+
if block == self._start_block:
1975+
_dataframe_to_parquet(df, metadata=metadata, filename=filename)
1976+
# Append to the parquet file.
1977+
else:
1978+
_parquet_append(
1979+
filename,
1980+
df.iloc[-self._energy_per_block :],
1981+
)
19731982

19741983
except Exception as e:
19751984
return index, e

src/somd2/runner/_repex.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,39 @@ def run(self):
979979
_logger.error("Equilibration cancelled. Exiting.")
980980
_sys.exit(1)
981981

982+
# Write a checkpoint immediately after equilibration so that a restart
983+
# after an early production crash doesn't need to re-equilibrate.
984+
if self._is_equilibration and not self._is_restart:
985+
lock = _FileLock(self._lock_file)
986+
with lock.acquire(timeout=self._config.timeout.to("seconds")):
987+
for j in range(num_checkpoint_batches):
988+
replicas = replica_list[
989+
j * num_checkpoint_workers : (j + 1) * num_checkpoint_workers
990+
]
991+
with ThreadPoolExecutor(
992+
max_workers=num_checkpoint_workers
993+
) as executor:
994+
try:
995+
for index, error in executor.map(
996+
self._checkpoint,
997+
replicas,
998+
repeat(self._lambda_values),
999+
repeat(-1),
1000+
repeat(cycles),
1001+
):
1002+
if error is not None:
1003+
msg = (
1004+
f"Post-equilibration checkpoint failed for {_lam_sym} = "
1005+
f"{self._lambda_values[index]:.5f}:\n{error}"
1006+
)
1007+
_logger.error(msg)
1008+
raise error
1009+
except KeyboardInterrupt:
1010+
_logger.error(
1011+
"Post-equilibration checkpoint cancelled. Exiting."
1012+
)
1013+
_sys.exit(1)
1014+
9821015
# Current block number.
9831016
block = self._start_block
9841017

@@ -1149,6 +1182,13 @@ def run(self):
11491182
)
11501183
self._dynamics_cache.mix_states()
11511184

1185+
# Snapshot the pre-run state for crash recovery.
1186+
if self._config.auto_fix_minimise:
1187+
for i, state in enumerate(self._dynamics_cache.get_states()):
1188+
self._dynamics_cache._dynamics[
1189+
i
1190+
]._d._pre_run_state = self._dynamics_cache._openmm_states[state]
1191+
11521192
# This is a checkpoint cycle.
11531193
if is_checkpoint:
11541194
# Update the block number.
@@ -1278,6 +1318,12 @@ def _run_block(
12781318
# Get the dynamics object (and GCMC sampler).
12791319
dynamics, gcmc_sampler = self._dynamics_cache.get(index)
12801320

1321+
# Track whether any MC move changed the context positions so we
1322+
# can update _pre_run_state once at the end. Only needed when
1323+
# crash recovery is enabled.
1324+
needs_pre_run_snapshot = False
1325+
auto_fix_minimise = self._config.auto_fix_minimise
1326+
12811327
# Perform the GCMC move before dynamics so that the energies
12821328
# computed during dynamics are consistent with the state used
12831329
# for replica exchange mixing.
@@ -1289,6 +1335,9 @@ def _run_block(
12891335
finally:
12901336
gcmc_sampler.pop()
12911337

1338+
if auto_fix_minimise:
1339+
needs_pre_run_snapshot = True
1340+
12921341
# Write ghost residues immediately after the GCMC move so the
12931342
# ghost state and frame (saved during dynamics) are consistent.
12941343
if write_gcmc_ghosts:
@@ -1297,7 +1346,16 @@ def _run_block(
12971346
# Perform a terminal flip move before dynamics if requested.
12981347
if self._terminal_flip_samplers is not None and is_terminal_flip:
12991348
_logger.info(f"Performing terminal flip move at {_lam_sym} = {lam:.5f}")
1300-
self._terminal_flip_samplers[index].move(dynamics.context())
1349+
if self._terminal_flip_samplers[index].move(dynamics.context()):
1350+
if auto_fix_minimise:
1351+
needs_pre_run_snapshot = True
1352+
1353+
# Snapshot the context state for crash recovery if any MC move
1354+
# changed positions.
1355+
if needs_pre_run_snapshot:
1356+
dynamics._d._pre_run_state = dynamics.context().getState(
1357+
getPositions=True, getVelocities=True
1358+
)
13011359

13021360
_logger.info(f"Running dynamics at {_lam_sym} = {lam:.5f}")
13031361

@@ -1313,7 +1371,7 @@ def _run_block(
13131371
lambda_windows=lambdas,
13141372
rest2_scale_factors=self._rest2_scale_factors,
13151373
save_velocities=self._config.save_velocities,
1316-
auto_fix_minimise=True,
1374+
auto_fix_minimise=self._config.auto_fix_minimise,
13171375
num_energy_neighbours=self._config.num_energy_neighbours,
13181376
null_energy=self._config.null_energy,
13191377
save_crash_report=self._config.save_crash_report,
@@ -1544,7 +1602,7 @@ def _equilibrate(self, index):
15441602
energy_frequency=0,
15451603
frame_frequency=0,
15461604
save_velocities=False,
1547-
auto_fix_minimise=True,
1605+
auto_fix_minimise=self._config.auto_fix_minimise,
15481606
save_crash_report=self._config.save_crash_report,
15491607
)
15501608

@@ -1728,10 +1786,15 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False):
17281786
# dynamics object.
17291787
dynamics._d._sire_mols.delete_all_frames()
17301788

1731-
_logger.info(
1732-
f"Finished block {block + 1} of {self._start_block + num_blocks} "
1733-
f"for {_lam_sym} = {lam:.5f}"
1734-
)
1789+
if block == -1:
1790+
_logger.info(
1791+
f"Writing post-equilibration checkpoint for {_lam_sym} = {lam:.5f}"
1792+
)
1793+
else:
1794+
_logger.info(
1795+
f"Finished block {block + 1} of {self._start_block + num_blocks} "
1796+
f"for {_lam_sym} = {lam:.5f}"
1797+
)
17351798

17361799
# Log the number of waters within the GCMC sampling volume.
17371800
if gcmc_sampler is not None:

0 commit comments

Comments
 (0)