@@ -904,12 +904,12 @@ def run(self):
904904 frac = 1.0
905905 checkpoint_frequency = self ._config .energy_frequency
906906
907- # Store the number of repex cycles per block.
908- cycles_per_checkpoint = int ( frac )
907+ # Store the number of repex cycles per block (may be fractional) .
908+ cycles_per_checkpoint = frac
909909
910910 # Otherwise, we don't checkpoint.
911911 else :
912- cycles_per_checkpoint = cycles
912+ cycles_per_checkpoint = float ( cycles )
913913 num_blocks = 1
914914 rem = 0
915915
@@ -991,9 +991,13 @@ def run(self):
991991 else :
992992 cycles_per_gcmc = cycles + 1
993993
994+ # Initialise the threshold for the next checkpoint cycle. This is a float
995+ # to handle non-integer ratios between the checkpoint and energy frequencies.
996+ next_checkpoint = cycles_per_checkpoint
997+
994998 # Perform the replica exchange simulation.
995999 for i in range (cycles ):
996- _logger .info (f"Running dynamics for cycle { i + 1 } of { cycles } " )
1000+ _logger .info (f"Running dynamics for cycle { i + 1 } of { cycles } " )
9971001
9981002 # Log the states. This is the replica index for the state (positions
9991003 # and velocities) used to seed each replica for the current cycle.
@@ -1007,14 +1011,15 @@ def run(self):
10071011 # Clear the results list.
10081012 results = []
10091013
1010- # Whether to checkpoint.
1011- is_checkpoint = i > 0 and i % cycles_per_checkpoint == 0
1014+ # Whether to checkpoint. Use a float threshold to correctly handle
1015+ # non-integer ratios between the checkpoint and energy frequencies.
1016+ is_checkpoint = (i + 1 ) >= next_checkpoint - 1e-10
10121017
10131018 # Whether to perform a GCMC move before the dynamics block.
1014- is_gcmc = i % cycles_per_gcmc == 0
1019+ is_gcmc = ( i + 1 ) % cycles_per_gcmc == 0
10151020
10161021 # Whether a frame is saved at the end of the cycle.
1017- write_gcmc_ghosts = i > 0 and i % cycles_per_frame == 0
1022+ write_gcmc_ghosts = ( i + 1 ) % cycles_per_frame == 0
10181023
10191024 # Run a dynamics block for each replica, making sure only each GPU is only
10201025 # oversubscribed by a factor of self._config.oversubscription_factor.
@@ -1119,6 +1124,9 @@ def run(self):
11191124 # Update the block number.
11201125 block += 1
11211126
1127+ # Advance the checkpoint threshold.
1128+ next_checkpoint += cycles_per_checkpoint
1129+
11221130 # Guard the repex state and transition matrix saving with a file lock.
11231131 lock = _FileLock (self ._lock_file )
11241132 with lock .acquire (timeout = self ._config .timeout .to ("seconds" )):
@@ -1248,6 +1256,14 @@ def _run_block(
12481256 # Remove the PyCUDA context from the stack.
12491257 gcmc_sampler .pop ()
12501258
1259+ # A frame was saved at the end of the last cycle, so write
1260+ # the current ghost water residue indices to file. This is
1261+ # done here, immediately after the GCMC move, since the
1262+ # sampler state is only updated during GCMC moves and waters
1263+ # may have moved in/out of the GCMC sphere during dynamics.
1264+ if write_gcmc_ghosts :
1265+ gcmc_sampler .write_ghost_residues ()
1266+
12511267 # Run the dynamics.
12521268 dynamics .run (
12531269 self ._config .energy_frequency ,
@@ -1277,18 +1293,9 @@ def _run_block(
12771293 # Save the GCMC state.
12781294 if gcmc_sampler is not None :
12791295 self ._dynamics_cache .save_gcmc_state (index )
1280- # The frame frequency was hit, so write the indices of the
1281- # current ghost water residues to file.
1282- if write_gcmc_ghosts :
1283- gcmc_sampler .write_ghost_residues ()
12841296
12851297 # Get the energy at each lambda value.
1286- energies = (
1287- dynamics ._d .energy_trajectory ()
1288- .to_pandas (to_alchemlyb = True , energy_unit = "kcal/mol" )
1289- .iloc [- 1 , :]
1290- .to_numpy ()
1291- )
1298+ energies = dynamics ._current_energy_array ()
12921299
12931300 except Exception as e :
12941301 try :
@@ -1681,7 +1688,7 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False):
16811688 dynamics ._d ._sire_mols .delete_all_frames ()
16821689
16831690 _logger .info (
1684- f"Finished block { block + 1 } of { self ._start_block + num_blocks } "
1691+ f"Finished block { block + 1 } of { self ._start_block + num_blocks } "
16851692 f"for { _lam_sym } = { lam :.5f} "
16861693 )
16871694
0 commit comments