Skip to content

Commit 795f377

Browse files
committed
Add serialisation for sampler statistics.
1 parent 88a76f1 commit 795f377

4 files changed

Lines changed: 134 additions & 1 deletion

File tree

src/somd2/runner/_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,6 +1197,7 @@ def increment_filename(base_filename, suffix):
11971197
output_directory / f"energy_components_{lam}.txt"
11981198
)
11991199
filenames["gcmc_ghosts"] = str(output_directory / f"gcmc_ghosts_{lam}.txt")
1200+
filenames["sampler_stats"] = str(output_directory / f"sampler_stats_{lam}.pkl")
12001201
if restart:
12011202
filenames["config"] = str(
12021203
output_directory / increment_filename("config", "yaml")

src/somd2/runner/_repex.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def __init__(
106106
self._openmm_states = [None] * len(lambdas)
107107
self._gcmc_samplers = [None] * len(lambdas)
108108
self._gcmc_states = [None] * len(lambdas)
109+
self._gcmc_stats = [None] * len(lambdas)
110+
self._terminal_flip_stats = [[0, 0]] * len(lambdas)
109111
self._num_proposed = _np.matrix(_np.zeros((len(lambdas), len(lambdas))))
110112
self._num_accepted = _np.matrix(_np.zeros((len(lambdas), len(lambdas))))
111113
self._num_swaps = _np.matrix(_np.zeros((len(lambdas), len(lambdas))))
@@ -130,6 +132,14 @@ def __setstate__(self, state):
130132
for key, value in state.items():
131133
setattr(self, key, value)
132134

135+
# Provide defaults for attributes added after the initial release,
136+
# so that old checkpoint files can still be loaded.
137+
n = len(self._lambdas)
138+
if not hasattr(self, "_gcmc_stats"):
139+
self._gcmc_stats = [None] * n
140+
if not hasattr(self, "_terminal_flip_stats"):
141+
self._terminal_flip_stats = [[0, 0]] * n
142+
133143
def __getstate__(self):
134144
"""
135145
Get the state of the object.
@@ -145,6 +155,8 @@ def __getstate__(self):
145155
# Don't pickle the GCMC samplers since they need to be recreated.
146156
"_gcmc_samplers": len(self._gcmc_samplers) * [None],
147157
"_gcmc_states": self._gcmc_states,
158+
"_gcmc_stats": self._gcmc_stats,
159+
"_terminal_flip_stats": self._terminal_flip_stats,
148160
"_num_proposed": self._num_proposed,
149161
"_num_accepted": self._num_accepted,
150162
"_num_swaps": self._num_swaps,
@@ -823,7 +835,7 @@ def __init__(self, system, config):
823835
state = self._dynamics_cache._states[i]
824836
dynamics.context().setState(self._dynamics_cache._openmm_states[state])
825837

826-
# Reset the GCMC water state.
838+
# Reset the GCMC water state and restore statistics.
827839
if gcmc_sampler is not None:
828840
gcmc_sampler.push()
829841
try:
@@ -834,6 +846,13 @@ def __init__(self, system, config):
834846
)
835847
finally:
836848
gcmc_sampler.pop()
849+
if self._dynamics_cache._gcmc_stats[i] is not None:
850+
gcmc_sampler.restore_stats(self._dynamics_cache._gcmc_stats[i])
851+
852+
# Restore terminal flip sampler statistics.
853+
if self._terminal_flip_samplers is not None:
854+
attempted, accepted = self._dynamics_cache._terminal_flip_stats[i]
855+
self._terminal_flip_samplers[i].reset(attempted, accepted)
837856

838857
# Conversion factor for reduced potential.
839858
kT = (_sr.units.k_boltz * self._config.temperature).to(_sr.units.kcal_per_mol)
@@ -1190,6 +1209,7 @@ def run(self):
11901209

11911210
# Pickle the dynamics cache.
11921211
_logger.info("Saving replica exchange state")
1212+
self._save_sampler_stats()
11931213
with open(self._repex_state, "wb") as f:
11941214
_pickle.dump(self._dynamics_cache, f)
11951215

@@ -1211,6 +1231,11 @@ def run(self):
12111231

12121232
# Pickle final state of the dynamics cache.
12131233
_logger.info("Saving final replica exchange state")
1234+
if self._terminal_flip_samplers is not None:
1235+
self._dynamics_cache._terminal_flip_stats = [
1236+
[s.num_attempted, s.num_accepted]
1237+
for s in self._terminal_flip_samplers
1238+
]
12141239
with open(self._repex_state, "wb") as f:
12151240
_pickle.dump(self._dynamics_cache, f)
12161241

@@ -1842,6 +1867,21 @@ def _mix_replicas(num_replicas, energy_matrix, proposed, accepted):
18421867

18431868
return states
18441869

1870+
def _save_sampler_stats(self):
1871+
"""
1872+
Save GCMC and terminal flip sampler statistics to the dynamics cache
1873+
prior to pickling.
1874+
"""
1875+
for i in range(len(self._lambda_values)):
1876+
_, gcmc_sampler = self._dynamics_cache.get(i)
1877+
if gcmc_sampler is not None:
1878+
self._dynamics_cache._gcmc_stats[i] = gcmc_sampler.get_stats()
1879+
1880+
if self._terminal_flip_samplers is not None:
1881+
self._dynamics_cache._terminal_flip_stats = [
1882+
[s.num_attempted, s.num_accepted] for s in self._terminal_flip_samplers
1883+
]
1884+
18451885
def _save_transition_matrix(self):
18461886
"""
18471887
Internal method to save the replica exchange transition matrix.

src/somd2/runner/_runner.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,16 @@ def generate_lam_vals(lambda_base, increment=0.001):
695695
finally:
696696
gcmc_sampler.pop()
697697

698+
# Restore sampler statistics from a previous run.
699+
if self._is_restart:
700+
stats = self._load_sampler_stats(index)
701+
if stats is not None:
702+
if gcmc_sampler is not None and "gcmc" in stats:
703+
gcmc_sampler.restore_stats(stats["gcmc"])
704+
if terminal_flip_sampler is not None and "terminal_flip" in stats:
705+
attempted, accepted = stats["terminal_flip"]
706+
terminal_flip_sampler.reset(attempted, accepted)
707+
698708
# Set the number of neighbours used for the energy calculation.
699709
# If not None, then we add one to account for the extra windows
700710
# used for finite-difference gradient analysis.
@@ -924,6 +934,11 @@ def generate_lam_vals(lambda_base, increment=0.001):
924934
if error is not None:
925935
raise error
926936

937+
# Save sampler statistics alongside the checkpoint.
938+
self._save_sampler_stats(
939+
index, gcmc_sampler, terminal_flip_sampler
940+
)
941+
927942
# Delete all trajectory frames from the Sire system within the
928943
# dynamics object.
929944
dynamics._d._sire_mols.delete_all_frames()
@@ -1213,12 +1228,73 @@ def generate_lam_vals(lambda_base, increment=0.001):
12131228
_logger.error(msg)
12141229
raise RuntimeError(msg)
12151230

1231+
# Save sampler statistics alongside the final checkpoint.
1232+
self._save_sampler_stats(index, gcmc_sampler, terminal_flip_sampler)
1233+
12161234
_logger.success(
12171235
f"{_lam_sym} = {lambda_value:.5f} complete, speed = {speed:.2f} ns day-1"
12181236
)
12191237

12201238
return time
12211239

1240+
def _save_sampler_stats(self, index, gcmc_sampler, terminal_flip_sampler):
1241+
"""
1242+
Save GCMC and terminal flip sampler statistics to a pickle file.
1243+
1244+
Parameters
1245+
----------
1246+
1247+
index : int
1248+
The index of the lambda value.
1249+
1250+
gcmc_sampler : GCMCSampler or None
1251+
The GCMC sampler for this replica.
1252+
1253+
terminal_flip_sampler : TerminalFlipSampler or None
1254+
The terminal flip sampler for this replica.
1255+
"""
1256+
import pickle as _pickle
1257+
1258+
stats = {}
1259+
if gcmc_sampler is not None:
1260+
stats["gcmc"] = gcmc_sampler.get_stats()
1261+
if terminal_flip_sampler is not None:
1262+
stats["terminal_flip"] = [
1263+
terminal_flip_sampler.num_attempted,
1264+
terminal_flip_sampler.num_accepted,
1265+
]
1266+
with open(self._filenames[index]["sampler_stats"], "wb") as f:
1267+
_pickle.dump(stats, f)
1268+
1269+
def _load_sampler_stats(self, index):
1270+
"""
1271+
Load sampler statistics from a pickle file.
1272+
1273+
Parameters
1274+
----------
1275+
1276+
index : int
1277+
The index of the lambda value.
1278+
1279+
Returns
1280+
-------
1281+
1282+
dict or None
1283+
The sampler statistics, or None if the file does not exist.
1284+
"""
1285+
import pickle as _pickle
1286+
from pathlib import Path as _Path
1287+
1288+
path = _Path(self._filenames[index]["sampler_stats"])
1289+
if not path.exists():
1290+
return None
1291+
try:
1292+
with open(path, "rb") as f:
1293+
return _pickle.load(f)
1294+
except Exception as e:
1295+
_logger.warning(f"Could not load sampler stats for index {index}: {e}")
1296+
return None
1297+
12221298
def _minimisation(
12231299
self,
12241300
system,

src/somd2/runner/_samplers/_terminal_flip.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,3 +539,19 @@ def acceptance_rate(self):
539539
if self._num_attempted == 0:
540540
return 0.0
541541
return self._num_accepted / self._num_attempted
542+
543+
def reset(self, num_attempted=0, num_accepted=0):
544+
"""
545+
Reset the move counters.
546+
547+
Parameters
548+
----------
549+
550+
num_attempted : int
551+
Value to restore ``num_attempted`` to. Defaults to 0.
552+
553+
num_accepted : int
554+
Value to restore ``num_accepted`` to. Defaults to 0.
555+
"""
556+
self._num_attempted = num_attempted
557+
self._num_accepted = num_accepted

0 commit comments

Comments
 (0)