Skip to content

Commit de402f5

Browse files
committed
Perform GCMC before dynamics so energies correspond to correct state.
1 parent 222c0b7 commit de402f5

2 files changed

Lines changed: 63 additions & 52 deletions

File tree

src/somd2/runner/_repex.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,22 @@ def _run_block(
12421242
# Get the dynamics object (and GCMC sampler).
12431243
dynamics, gcmc_sampler = self._dynamics_cache.get(index)
12441244

1245+
# Perform the GCMC move before dynamics so that the energies
1246+
# computed during dynamics are consistent with the state used
1247+
# for replica exchange mixing.
1248+
if gcmc_sampler is not None and is_gcmc:
1249+
gcmc_sampler.push()
1250+
try:
1251+
_logger.info(f"Performing GCMC move at {_lam_sym} = {lam:.5f}")
1252+
gcmc_sampler.move(dynamics.context())
1253+
finally:
1254+
gcmc_sampler.pop()
1255+
1256+
# Write ghost residues immediately after the GCMC move so the
1257+
# ghost state and frame (saved during dynamics) are consistent.
1258+
if write_gcmc_ghosts:
1259+
gcmc_sampler.write_ghost_residues()
1260+
12451261
_logger.info(f"Running dynamics at {_lam_sym} = {lam:.5f}")
12461262

12471263
# Draw new velocities from the Maxwell-Boltzmann distribution.
@@ -1272,27 +1288,10 @@ def _run_block(
12721288
)
12731289

12741290
if gcmc_sampler is not None:
1275-
# Write ghost residues before the GCMC move so the ghost state
1276-
# is consistent with the saved frame (which is also captured
1277-
# before the GCMC move).
1278-
if write_gcmc_ghosts:
1279-
gcmc_sampler.write_ghost_residues()
1280-
1281-
if is_gcmc:
1282-
# Push the PyCUDA context on top of the stack.
1283-
gcmc_sampler.push()
1284-
try:
1285-
# Perform the GCMC move.
1286-
_logger.info(f"Performing GCMC move at {_lam_sym} = {lam:.5f}")
1287-
gcmc_sampler.move(dynamics.context())
1288-
finally:
1289-
# Remove the PyCUDA context from the stack.
1290-
gcmc_sampler.pop()
1291-
12921291
# Save the GCMC state.
12931292
self._dynamics_cache.save_gcmc_state(index)
12941293

1295-
# Save the OpenMM state after any GCMC move so the context is consistent.
1294+
# Save the OpenMM state.
12961295
self._dynamics_cache.save_openmm_state(index)
12971296

12981297
# Get the energy at each lambda value.

src/somd2/runner/_runner.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,29 @@ def generate_lam_vals(lambda_base, increment=0.001):
722722

723723
# Loop until we reach the runtime.
724724
while runtime < checkpoint_frequency:
725+
# Perform a GCMC move before dynamics so the ghost
726+
# state is consistent with the energies computed
727+
# during dynamics.
728+
_logger.info(
729+
f"Performing GCMC move at {_lam_sym} = {lambda_value:.5f}"
730+
)
731+
gcmc_sampler.push()
732+
try:
733+
gcmc_sampler.move(dynamics.context())
734+
finally:
735+
gcmc_sampler.pop()
736+
737+
# Write ghost residues immediately after the GCMC
738+
# move if a frame will be saved in the upcoming
739+
# dynamics block.
740+
if (
741+
save_frames
742+
and runtime + self._config.energy_frequency
743+
>= next_frame
744+
):
745+
gcmc_sampler.write_ghost_residues()
746+
next_frame += self._config.frame_frequency
747+
725748
# Run the dynamics in blocks of the GCMC frequency.
726749
dynamics.run(
727750
self._config.gcmc_frequency,
@@ -748,23 +771,6 @@ def generate_lam_vals(lambda_base, increment=0.001):
748771
# Update the runtime.
749772
runtime += self._config.energy_frequency
750773

751-
# If a frame is saved, write the ghost residue indices
752-
# before the GCMC move so the ghost state is consistent
753-
# with the saved frame.
754-
if save_frames and runtime >= next_frame:
755-
gcmc_sampler.write_ghost_residues()
756-
next_frame += self._config.frame_frequency
757-
758-
# Perform a GCMC move.
759-
_logger.info(
760-
f"Performing GCMC move at {_lam_sym} = {lambda_value:.5f}"
761-
)
762-
gcmc_sampler.push()
763-
try:
764-
gcmc_sampler.move(dynamics.context())
765-
finally:
766-
gcmc_sampler.pop()
767-
768774
else:
769775
dynamics.run(
770776
checkpoint_frequency,
@@ -948,7 +954,29 @@ def generate_lam_vals(lambda_base, increment=0.001):
948954
next_frame = self._config.frame_frequency
949955

950956
# Loop until we reach the runtime.
951-
while runtime <= time:
957+
while runtime < time:
958+
# Perform a GCMC move before dynamics so the ghost
959+
# state is consistent with the energies computed
960+
# during dynamics.
961+
_logger.info(
962+
f"Performing GCMC move at {_lam_sym} = {lambda_value:.5f}"
963+
)
964+
gcmc_sampler.push()
965+
try:
966+
gcmc_sampler.move(dynamics.context())
967+
finally:
968+
gcmc_sampler.pop()
969+
970+
# Write ghost residues immediately after the GCMC
971+
# move if a frame will be saved in the upcoming
972+
# dynamics block.
973+
if (
974+
save_frames
975+
and runtime + self._config.energy_frequency >= next_frame
976+
):
977+
gcmc_sampler.write_ghost_residues()
978+
next_frame += self._config.frame_frequency
979+
952980
# Run the dynamics in blocks of the GCMC frequency.
953981
dynamics.run(
954982
self._config.gcmc_frequency,
@@ -963,24 +991,8 @@ def generate_lam_vals(lambda_base, increment=0.001):
963991
save_crash_report=self._config.save_crash_report,
964992
)
965993

966-
# Perform a GCMC move.
967-
_logger.info(
968-
f"Performing GCMC move at {_lam_sym} = {lambda_value:.5f}"
969-
)
970-
gcmc_sampler.push()
971-
try:
972-
gcmc_sampler.move(dynamics.context())
973-
finally:
974-
gcmc_sampler.pop()
975-
976994
# Update the runtime.
977995
runtime += self._config.energy_frequency
978-
979-
# If a frame is saved, then we need to save current indices
980-
# of the ghost water residues.
981-
if save_frames and runtime >= next_frame:
982-
gcmc_sampler.write_ghost_residues()
983-
next_frame += self._config.frame_frequency
984996
else:
985997
dynamics.run(
986998
time,

0 commit comments

Comments
 (0)