@@ -979,6 +979,39 @@ def run(self):
979979 _logger .error ("Equilibration cancelled. Exiting." )
980980 _sys .exit (1 )
981981
982+ # Write a checkpoint immediately after equilibration so that a restart
983+ # after an early production crash doesn't need to re-equilibrate.
984+ if self ._is_equilibration and not self ._is_restart :
985+ lock = _FileLock (self ._lock_file )
986+ with lock .acquire (timeout = self ._config .timeout .to ("seconds" )):
987+ for j in range (num_checkpoint_batches ):
988+ replicas = replica_list [
989+ j * num_checkpoint_workers : (j + 1 ) * num_checkpoint_workers
990+ ]
991+ with ThreadPoolExecutor (
992+ max_workers = num_checkpoint_workers
993+ ) as executor :
994+ try :
995+ for index , error in executor .map (
996+ self ._checkpoint ,
997+ replicas ,
998+ repeat (self ._lambda_values ),
999+ repeat (- 1 ),
1000+ repeat (cycles ),
1001+ ):
1002+ if error is not None :
1003+ msg = (
1004+ f"Post-equilibration checkpoint failed for { _lam_sym } = "
1005+ f"{ self ._lambda_values [index ]:.5f} :\n { error } "
1006+ )
1007+ _logger .error (msg )
1008+ raise error
1009+ except KeyboardInterrupt :
1010+ _logger .error (
1011+ "Post-equilibration checkpoint cancelled. Exiting."
1012+ )
1013+ _sys .exit (1 )
1014+
9821015 # Current block number.
9831016 block = self ._start_block
9841017
@@ -1149,6 +1182,13 @@ def run(self):
11491182 )
11501183 self ._dynamics_cache .mix_states ()
11511184
1185+ # Snapshot the pre-run state for crash recovery.
1186+ if self ._config .auto_fix_minimise :
1187+ for i , state in enumerate (self ._dynamics_cache .get_states ()):
1188+ self ._dynamics_cache ._dynamics [
1189+ i
1190+ ]._d ._pre_run_state = self ._dynamics_cache ._openmm_states [state ]
1191+
11521192 # This is a checkpoint cycle.
11531193 if is_checkpoint :
11541194 # Update the block number.
@@ -1278,6 +1318,12 @@ def _run_block(
12781318 # Get the dynamics object (and GCMC sampler).
12791319 dynamics , gcmc_sampler = self ._dynamics_cache .get (index )
12801320
1321+ # Track whether any MC move changed the context positions so we
1322+ # can update _pre_run_state once at the end. Only needed when
1323+ # crash recovery is enabled.
1324+ needs_pre_run_snapshot = False
1325+ auto_fix_minimise = self ._config .auto_fix_minimise
1326+
12811327 # Perform the GCMC move before dynamics so that the energies
12821328 # computed during dynamics are consistent with the state used
12831329 # for replica exchange mixing.
@@ -1289,6 +1335,9 @@ def _run_block(
12891335 finally :
12901336 gcmc_sampler .pop ()
12911337
1338+ if auto_fix_minimise :
1339+ needs_pre_run_snapshot = True
1340+
12921341 # Write ghost residues immediately after the GCMC move so the
12931342 # ghost state and frame (saved during dynamics) are consistent.
12941343 if write_gcmc_ghosts :
@@ -1297,7 +1346,16 @@ def _run_block(
12971346 # Perform a terminal flip move before dynamics if requested.
12981347 if self ._terminal_flip_samplers is not None and is_terminal_flip :
12991348 _logger .info (f"Performing terminal flip move at { _lam_sym } = { lam :.5f} " )
1300- self ._terminal_flip_samplers [index ].move (dynamics .context ())
1349+ if self ._terminal_flip_samplers [index ].move (dynamics .context ()):
1350+ if auto_fix_minimise :
1351+ needs_pre_run_snapshot = True
1352+
1353+ # Snapshot the context state for crash recovery if any MC move
1354+ # changed positions.
1355+ if needs_pre_run_snapshot :
1356+ dynamics ._d ._pre_run_state = dynamics .context ().getState (
1357+ getPositions = True , getVelocities = True
1358+ )
13011359
13021360 _logger .info (f"Running dynamics at { _lam_sym } = { lam :.5f} " )
13031361
@@ -1313,7 +1371,7 @@ def _run_block(
13131371 lambda_windows = lambdas ,
13141372 rest2_scale_factors = self ._rest2_scale_factors ,
13151373 save_velocities = self ._config .save_velocities ,
1316- auto_fix_minimise = True ,
1374+ auto_fix_minimise = self . _config . auto_fix_minimise ,
13171375 num_energy_neighbours = self ._config .num_energy_neighbours ,
13181376 null_energy = self ._config .null_energy ,
13191377 save_crash_report = self ._config .save_crash_report ,
@@ -1544,7 +1602,7 @@ def _equilibrate(self, index):
15441602 energy_frequency = 0 ,
15451603 frame_frequency = 0 ,
15461604 save_velocities = False ,
1547- auto_fix_minimise = True ,
1605+ auto_fix_minimise = self . _config . auto_fix_minimise ,
15481606 save_crash_report = self ._config .save_crash_report ,
15491607 )
15501608
@@ -1728,10 +1786,15 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False):
17281786 # dynamics object.
17291787 dynamics ._d ._sire_mols .delete_all_frames ()
17301788
1731- _logger .info (
1732- f"Finished block { block + 1 } of { self ._start_block + num_blocks } "
1733- f"for { _lam_sym } = { lam :.5f} "
1734- )
1789+ if block == - 1 :
1790+ _logger .info (
1791+ f"Writing post-equilibration checkpoint for { _lam_sym } = { lam :.5f} "
1792+ )
1793+ else :
1794+ _logger .info (
1795+ f"Finished block { block + 1 } of { self ._start_block + num_blocks } "
1796+ f"for { _lam_sym } = { lam :.5f} "
1797+ )
17351798
17361799 # Log the number of waters within the GCMC sampling volume.
17371800 if gcmc_sampler is not None :
0 commit comments