Skip to content

Commit ed143a0

Browse files
authored
Merge pull request #82 from OpenBioSim/fix_memory_footprint
Add per-device repex memory footprint error and warning
2 parents c04695e + 62e4b08 commit ed143a0

2 files changed

Lines changed: 127 additions & 0 deletions

File tree

environment.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ dependencies:
1111
- loch
1212
- loguru
1313
- numba
14+
- nvidia-ml-py
1415
- versioningit

src/somd2/runner/_repex.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,21 +200,46 @@ def _create_dynamics(
200200
is not used.
201201
"""
202202

203+
from math import floor
204+
203205
# Copy the dynamics keyword arguments.
204206
dynamics_kwargs = dynamics_kwargs.copy()
205207

208+
# Store the number of replicas.
209+
num_replicas = len(lambdas)
210+
206211
# Copy the GCMC keyword arguments.
207212
if gcmc_kwargs is not None:
208213
gcmc_kwargs = gcmc_kwargs.copy()
209214

210215
# Initialise the dynamics object list.
211216
self._dynamics = []
212217

218+
# A set of visited device indices.
219+
devices = set()
220+
221+
# Determine whether there is a remainder in the number of replicas.
222+
remainder = num_replicas % num_gpus
223+
224+
# Store the number of contexts for each device. The last device will
225+
# have remainder contexts, while all others have
226+
contexts_per_device = num_replicas * [floor(num_replicas / num_gpus)]
227+
228+
# Set the last device to have the remainder contexts.
229+
contexts_per_device[-1] = remainder
230+
213231
# Create the dynamics objects in serial.
214232
for i, (lam, scale) in enumerate(zip(lambdas, rest2_scale_factors)):
215233
# Work out the device index.
216234
device = i % num_gpus
217235

236+
# If we've not seen this device before then get the memory statistics
237+
# prior to creating the dynamics object and GCMC sampler.
238+
if device not in devices:
239+
used_mem_before, free_mem_before, total_mem = self._check_device_memory(
240+
device
241+
)
242+
218243
# This is a restart, get the system for this replica.
219244
if isinstance(system, list):
220245
mols = system[i]
@@ -284,6 +309,39 @@ def _create_dynamics(
284309
# Append the dynamics object.
285310
self._dynamics.append(dynamics)
286311

312+
# Check the memory footprint for this device.
313+
if not device in devices:
314+
# Add the device to the set of visited devices.
315+
devices.add(device)
316+
317+
# Get the current memory usage.
318+
used_mem, free_mem, total_mem = self._check_device_memory(device)
319+
320+
# Work out the memory used by this dynamics object and GCMC sampler.
321+
mem_used = used_mem - used_mem_before
322+
323+
# Work out the estimate for all replicas on this device.
324+
est_total = mem_used * contexts_per_device[device]
325+
326+
# If this exceeds the total memory, raise an error.
327+
if est_total > total_mem:
328+
msg = (
329+
f"Not enough memory on device {device} for all assigned replicas. "
330+
f"Estimated memory usage: {est_total / 1e9:.2f} GB, "
331+
f"Available memory: {total_mem / 1e9:.2f} GB."
332+
)
333+
_logger.error(msg)
334+
raise MemoryError(msg)
335+
336+
# If there's less than 20% free memory, raise a warning.
337+
elif ((total_mem - est_total) / total_mem) < 0.2:
338+
_logger.warning(
339+
f"Device {device} will have less than 20% free memory "
340+
f"after creating all assigned replicas. "
341+
f"{est_total / 1e9:.2f} GB, "
342+
f"Available memory: {total_mem / 1e9:.2f} GB."
343+
)
344+
287345
_logger.info(
288346
f"Created dynamics object for lambda {lam:.5f} on device {device}"
289347
)
@@ -447,6 +505,30 @@ def get_swaps(self):
447505
"""
448506
return self._num_swaps
449507

508+
def _check_device_memory(self, index):
509+
"""
510+
Check the memory usage of the specified CUDA device.
511+
512+
Parameters
513+
----------
514+
515+
index: int
516+
The index of the CUDA device.
517+
"""
518+
from pynvml import (
519+
nvmlInit,
520+
nvmlShutdown,
521+
nvmlDeviceGetHandleByIndex,
522+
nvmlDeviceGetMemoryInfo,
523+
)
524+
525+
nvmlInit()
526+
handle = nvmlDeviceGetHandleByIndex(index)
527+
info = nvmlDeviceGetMemoryInfo(handle)
528+
result = (info.used, info.free, info.total)
529+
nvmlShutdown()
530+
return result
531+
450532

451533
class RepexRunner(_RunnerBase):
452534
"""
@@ -1143,6 +1225,50 @@ def _minimise(self, index):
11431225
# Minimise.
11441226
dynamics.minimise(timeout=self._config.timeout)
11451227

1228+
# If we're not equilibrating and the production constraints will change,
1229+
# then we need to rebuild the context.
1230+
if not self._is_equilibration:
1231+
constraints_changed = (
1232+
self._initial_constraint != self._config.constraint
1233+
) or (
1234+
self._initial_perturbable_constraint
1235+
!= self._config.perturbable_constraint
1236+
)
1237+
1238+
if constraints_changed:
1239+
# Commit the current system.
1240+
system = dynamics.commit()
1241+
1242+
# Delete the dynamics object.
1243+
self._dynamics_cache.delete(index)
1244+
1245+
# Work out the device index.
1246+
device = index % self._num_gpus
1247+
1248+
# Copy the dynamics keyword arguments.
1249+
dynamics_kwargs = self._dynamics_kwargs.copy()
1250+
1251+
# Overload the device and lambda value.
1252+
dynamics_kwargs["device"] = device
1253+
dynamics_kwargs["lambda_value"] = self._lambda_values[index]
1254+
dynamics_kwargs["rest2_scale"] = self._rest2_scale_factors[index]
1255+
1256+
# Create the production dynamics object.
1257+
dynamics = system.dynamics(**dynamics_kwargs)
1258+
1259+
# Reset the GCMC water state. The dynamics object is created from
1260+
# the original Sire system, so the water state in the context does
1261+
# not match the current GCMC water state.
1262+
if gcmc_sampler is not None:
1263+
self._reset_gcmc_sampler(gcmc_sampler, dynamics)
1264+
1265+
# Set the new dynamics object.
1266+
self._dynamics_cache.set(index, dynamics)
1267+
1268+
_logger.info(
1269+
f"Created dynamics object for {_lam_sym} = {self._lambda_values[index]:.5f}"
1270+
)
1271+
11461272
except Exception as e:
11471273
return False, index, e
11481274

0 commit comments

Comments
 (0)