@@ -106,6 +106,8 @@ def __init__(
106106 self ._openmm_states = [None ] * len (lambdas )
107107 self ._gcmc_samplers = [None ] * len (lambdas )
108108 self ._gcmc_states = [None ] * len (lambdas )
109+ self ._gcmc_stats = [None ] * len (lambdas )
110+ self ._terminal_flip_stats = [[0 , 0 ]] * len (lambdas )
109111 self ._num_proposed = _np .matrix (_np .zeros ((len (lambdas ), len (lambdas ))))
110112 self ._num_accepted = _np .matrix (_np .zeros ((len (lambdas ), len (lambdas ))))
111113 self ._num_swaps = _np .matrix (_np .zeros ((len (lambdas ), len (lambdas ))))
@@ -130,6 +132,14 @@ def __setstate__(self, state):
130132 for key , value in state .items ():
131133 setattr (self , key , value )
132134
135+ # Provide defaults for attributes added after the initial release,
136+ # so that old checkpoint files can still be loaded.
137+ n = len (self ._lambdas )
138+ if not hasattr (self , "_gcmc_stats" ):
139+ self ._gcmc_stats = [None ] * n
140+ if not hasattr (self , "_terminal_flip_stats" ):
141+ self ._terminal_flip_stats = [[0 , 0 ]] * n
142+
133143 def __getstate__ (self ):
134144 """
135145 Get the state of the object.
@@ -145,6 +155,8 @@ def __getstate__(self):
145155 # Don't pickle the GCMC samplers since they need to be recreated.
146156 "_gcmc_samplers" : len (self ._gcmc_samplers ) * [None ],
147157 "_gcmc_states" : self ._gcmc_states ,
158+ "_gcmc_stats" : self ._gcmc_stats ,
159+ "_terminal_flip_stats" : self ._terminal_flip_stats ,
148160 "_num_proposed" : self ._num_proposed ,
149161 "_num_accepted" : self ._num_accepted ,
150162 "_num_swaps" : self ._num_swaps ,
@@ -823,7 +835,7 @@ def __init__(self, system, config):
823835 state = self ._dynamics_cache ._states [i ]
824836 dynamics .context ().setState (self ._dynamics_cache ._openmm_states [state ])
825837
826- # Reset the GCMC water state.
838+ # Reset the GCMC water state and restore statistics .
827839 if gcmc_sampler is not None :
828840 gcmc_sampler .push ()
829841 try :
@@ -834,6 +846,13 @@ def __init__(self, system, config):
834846 )
835847 finally :
836848 gcmc_sampler .pop ()
849+ if self ._dynamics_cache ._gcmc_stats [i ] is not None :
850+ gcmc_sampler .restore_stats (self ._dynamics_cache ._gcmc_stats [i ])
851+
852+ # Restore terminal flip sampler statistics.
853+ if self ._terminal_flip_samplers is not None :
854+ attempted , accepted = self ._dynamics_cache ._terminal_flip_stats [i ]
855+ self ._terminal_flip_samplers [i ].reset (attempted , accepted )
837856
838857 # Conversion factor for reduced potential.
839858 kT = (_sr .units .k_boltz * self ._config .temperature ).to (_sr .units .kcal_per_mol )
@@ -1190,6 +1209,7 @@ def run(self):
11901209
11911210 # Pickle the dynamics cache.
11921211 _logger .info ("Saving replica exchange state" )
1212+ self ._save_sampler_stats ()
11931213 with open (self ._repex_state , "wb" ) as f :
11941214 _pickle .dump (self ._dynamics_cache , f )
11951215
@@ -1211,6 +1231,11 @@ def run(self):
12111231
12121232 # Pickle final state of the dynamics cache.
12131233 _logger .info ("Saving final replica exchange state" )
1234+ if self ._terminal_flip_samplers is not None :
1235+ self ._dynamics_cache ._terminal_flip_stats = [
1236+ [s .num_attempted , s .num_accepted ]
1237+ for s in self ._terminal_flip_samplers
1238+ ]
12141239 with open (self ._repex_state , "wb" ) as f :
12151240 _pickle .dump (self ._dynamics_cache , f )
12161241
@@ -1842,6 +1867,21 @@ def _mix_replicas(num_replicas, energy_matrix, proposed, accepted):
18421867
18431868 return states
18441869
1870+ def _save_sampler_stats (self ):
1871+ """
1872+ Save GCMC and terminal flip sampler statistics to the dynamics cache
1873+ prior to pickling.
1874+ """
1875+ for i in range (len (self ._lambda_values )):
1876+ _ , gcmc_sampler = self ._dynamics_cache .get (i )
1877+ if gcmc_sampler is not None :
1878+ self ._dynamics_cache ._gcmc_stats [i ] = gcmc_sampler .get_stats ()
1879+
1880+ if self ._terminal_flip_samplers is not None :
1881+ self ._dynamics_cache ._terminal_flip_stats = [
1882+ [s .num_attempted , s .num_accepted ] for s in self ._terminal_flip_samplers
1883+ ]
1884+
18451885 def _save_transition_matrix (self ):
18461886 """
18471887 Internal method to save the replica exchange transition matrix.
0 commit comments