@@ -200,21 +200,46 @@ def _create_dynamics(
200200 is not used.
201201 """
202202
203+ from math import floor
204+
203205 # Copy the dynamics keyword arguments.
204206 dynamics_kwargs = dynamics_kwargs .copy ()
205207
208+ # Store the number of replicas.
209+ num_replicas = len (lambdas )
210+
206211 # Copy the GCMC keyword arguments.
207212 if gcmc_kwargs is not None :
208213 gcmc_kwargs = gcmc_kwargs .copy ()
209214
210215 # Initialise the dynamics object list.
211216 self ._dynamics = []
212217
218+ # A set of visited device indices.
219+ devices = set ()
220+
221+ # Determine whether there is a remainder in the number of replicas.
222+ remainder = num_replicas % num_gpus
223+
224+ # Store the number of contexts for each device. The last device will
225+ # have remainder contexts, while all others have
226+ contexts_per_device = num_replicas * [floor (num_replicas / num_gpus )]
227+
228+ # Set the last device to have the remainder contexts.
229+ contexts_per_device [- 1 ] = remainder
230+
213231 # Create the dynamics objects in serial.
214232 for i , (lam , scale ) in enumerate (zip (lambdas , rest2_scale_factors )):
215233 # Work out the device index.
216234 device = i % num_gpus
217235
236+ # If we've not seen this device before then get the memory statistics
237+ # prior to creating the dynamics object and GCMC sampler.
238+ if device not in devices :
239+ used_mem_before , free_mem_before , total_mem = self ._check_device_memory (
240+ device
241+ )
242+
218243 # This is a restart, get the system for this replica.
219244 if isinstance (system , list ):
220245 mols = system [i ]
@@ -284,6 +309,39 @@ def _create_dynamics(
284309 # Append the dynamics object.
285310 self ._dynamics .append (dynamics )
286311
312+ # Check the memory footprint for this device.
313+ if not device in devices :
314+ # Add the device to the set of visited devices.
315+ devices .add (device )
316+
317+ # Get the current memory usage.
318+ used_mem , free_mem , total_mem = self ._check_device_memory (device )
319+
320+ # Work out the memory used by this dynamics object and GCMC sampler.
321+ mem_used = used_mem - used_mem_before
322+
323+ # Work out the estimate for all replicas on this device.
324+ est_total = mem_used * contexts_per_device [device ]
325+
326+ # If this exceeds the total memory, raise an error.
327+ if est_total > total_mem :
328+ msg = (
329+ f"Not enough memory on device { device } for all assigned replicas. "
330+ f"Estimated memory usage: { est_total / 1e9 :.2f} GB, "
331+ f"Available memory: { total_mem / 1e9 :.2f} GB."
332+ )
333+ _logger .error (msg )
334+ raise MemoryError (msg )
335+
336+ # If there's less than 20% free memory, raise a warning.
337+ elif ((total_mem - est_total ) / total_mem ) < 0.2 :
338+ _logger .warning (
339+ f"Device { device } will have less than 20% free memory "
340+ f"after creating all assigned replicas. "
341+ f"{ est_total / 1e9 :.2f} GB, "
342+ f"Available memory: { total_mem / 1e9 :.2f} GB."
343+ )
344+
287345 _logger .info (
288346 f"Created dynamics object for lambda { lam :.5f} on device { device } "
289347 )
@@ -447,6 +505,30 @@ def get_swaps(self):
447505 """
448506 return self ._num_swaps
449507
508+ def _check_device_memory (self , index ):
509+ """
510+ Check the memory usage of the specified CUDA device.
511+
512+ Parameters
513+ ----------
514+
515+ index: int
516+ The index of the CUDA device.
517+ """
518+ from pynvml import (
519+ nvmlInit ,
520+ nvmlShutdown ,
521+ nvmlDeviceGetHandleByIndex ,
522+ nvmlDeviceGetMemoryInfo ,
523+ )
524+
525+ nvmlInit ()
526+ handle = nvmlDeviceGetHandleByIndex (index )
527+ info = nvmlDeviceGetMemoryInfo (handle )
528+ result = (info .used , info .free , info .total )
529+ nvmlShutdown ()
530+ return result
531+
450532
451533class RepexRunner (_RunnerBase ):
452534 """
@@ -1143,6 +1225,50 @@ def _minimise(self, index):
11431225 # Minimise.
11441226 dynamics .minimise (timeout = self ._config .timeout )
11451227
1228+ # If we're not equilibrating and the production constraints will change,
1229+ # then we need to rebuild the context.
1230+ if not self ._is_equilibration :
1231+ constraints_changed = (
1232+ self ._initial_constraint != self ._config .constraint
1233+ ) or (
1234+ self ._initial_perturbable_constraint
1235+ != self ._config .perturbable_constraint
1236+ )
1237+
1238+ if constraints_changed :
1239+ # Commit the current system.
1240+ system = dynamics .commit ()
1241+
1242+ # Delete the dynamics object.
1243+ self ._dynamics_cache .delete (index )
1244+
1245+ # Work out the device index.
1246+ device = index % self ._num_gpus
1247+
1248+ # Copy the dynamics keyword arguments.
1249+ dynamics_kwargs = self ._dynamics_kwargs .copy ()
1250+
1251+ # Overload the device and lambda value.
1252+ dynamics_kwargs ["device" ] = device
1253+ dynamics_kwargs ["lambda_value" ] = self ._lambda_values [index ]
1254+ dynamics_kwargs ["rest2_scale" ] = self ._rest2_scale_factors [index ]
1255+
1256+ # Create the production dynamics object.
1257+ dynamics = system .dynamics (** dynamics_kwargs )
1258+
1259+ # Reset the GCMC water state. The dynamics object is created from
1260+ # the original Sire system, so the water state in the context does
1261+ # not match the current GCMC water state.
1262+ if gcmc_sampler is not None :
1263+ self ._reset_gcmc_sampler (gcmc_sampler , dynamics )
1264+
1265+ # Set the new dynamics object.
1266+ self ._dynamics_cache .set (index , dynamics )
1267+
1268+ _logger .info (
1269+ f"Created dynamics object for { _lam_sym } = { self ._lambda_values [index ]:.5f} "
1270+ )
1271+
11461272 except Exception as e :
11471273 return False , index , e
11481274
0 commit comments