Skip to content

Commit 88a76f1

Browse files
committed
Use a unique TerminalFlipSampler per replica.
1 parent 3745398 commit 88a76f1

1 file changed

Lines changed: 22 additions & 19 deletions

File tree

src/somd2/runner/_repex.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -854,19 +854,22 @@ def __init__(self, system, config):
854854
else:
855855
self._start_block = 0
856856

857-
# Create the terminal flip sampler (if terminal groups were detected).
857+
# Create a terminal flip sampler per replica (if terminal groups were detected).
858858
if self._terminal_groups:
859859
from ._samplers import TerminalFlipSampler
860860

861-
self._terminal_flip_sampler = TerminalFlipSampler(
862-
self._terminal_groups,
863-
float(self._config.temperature.value()),
864-
)
861+
self._terminal_flip_samplers = [
862+
TerminalFlipSampler(
863+
self._terminal_groups,
864+
float(self._config.temperature.value()),
865+
)
866+
for _ in self._lambda_values
867+
]
865868
_logger.info(
866-
f"Terminal flip sampler ready ({len(self._terminal_groups)} group(s))"
869+
f"Terminal flip samplers ready ({len(self._terminal_groups)} group(s))"
867870
)
868871
else:
869-
self._terminal_flip_sampler = None
872+
self._terminal_flip_samplers = None
870873

871874
from threading import Lock
872875

@@ -1018,7 +1021,7 @@ def run(self):
10181021
# Work out the number of cycles per terminal flip move.
10191022
if (
10201023
self._config.terminal_flip_frequency is not None
1021-
and self._terminal_flip_sampler is not None
1024+
and self._terminal_flip_samplers is not None
10221025
):
10231026
cycles_per_flip = max(
10241027
1,
@@ -1163,15 +1166,6 @@ def run(self):
11631166
)
11641167
self._dynamics_cache.mix_states()
11651168

1166-
# Log terminal flip acceptance rate at each cycle.
1167-
if self._terminal_flip_sampler is not None:
1168-
_logger.info(
1169-
f"Terminal flip acceptance rate: "
1170-
f"{self._terminal_flip_sampler.acceptance_rate:.3f} "
1171-
f"({self._terminal_flip_sampler.num_accepted}/"
1172-
f"{self._terminal_flip_sampler.num_attempted})"
1173-
)
1174-
11751169
# This is a checkpoint cycle.
11761170
if is_checkpoint:
11771171
# Update the block number.
@@ -1312,9 +1306,9 @@ def _run_block(
13121306
gcmc_sampler.write_ghost_residues()
13131307

13141308
# Perform a terminal flip move before dynamics if requested.
1315-
if self._terminal_flip_sampler is not None and is_terminal_flip:
1309+
if self._terminal_flip_samplers is not None and is_terminal_flip:
13161310
_logger.info(f"Performing terminal flip move at {_lam_sym} = {lam:.5f}")
1317-
self._terminal_flip_sampler.move(dynamics.context())
1311+
self._terminal_flip_samplers[index].move(dynamics.context())
13181312

13191313
_logger.info(f"Running dynamics at {_lam_sym} = {lam:.5f}")
13201314

@@ -1770,6 +1764,15 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False):
17701764
# Remove the PyCUDA context from the stack.
17711765
gcmc_sampler.pop()
17721766

1767+
# Log terminal flip acceptance rate for this replica.
1768+
if self._terminal_flip_samplers is not None:
1769+
sampler = self._terminal_flip_samplers[index]
1770+
_logger.info(
1771+
f"Terminal flip acceptance rate at {_lam_sym} = {lam:.5f}: "
1772+
f"{sampler.acceptance_rate:.3f} "
1773+
f"({sampler.num_accepted}/{sampler.num_attempted})"
1774+
)
1775+
17731776
if is_final_block:
17741777
_logger.success(f"{_lam_sym} = {lam:.5f} complete")
17751778

0 commit comments

Comments
 (0)