From a7bab9c2fc364e06f3bd241c37056c214a4744cf Mon Sep 17 00:00:00 2001 From: Jaafar Mehrez <32996178+JaafarMehrez@users.noreply.github.com> Date: Wed, 13 May 2026 08:35:10 +0800 Subject: [PATCH] Add GPUMD backend (https://github.com/brucefan1983/GPUMD) --- .trunk/config/ruff.toml | 5 + .trunk/trunk.yaml | 30 ++- examples/gpumd/metad/ethane-metad.py | 215 ++++++++++++++++++ examples/gpumd/metad/model.xyz | 10 + examples/gpumd/metad/run.in | 7 + pysages/backends/core.py | 4 +- pysages/backends/gpumd.py | 327 +++++++++++++++++++++++++++ 7 files changed, 585 insertions(+), 13 deletions(-) create mode 100644 .trunk/config/ruff.toml create mode 100644 examples/gpumd/metad/ethane-metad.py create mode 100644 examples/gpumd/metad/model.xyz create mode 100644 examples/gpumd/metad/run.in create mode 100644 pysages/backends/gpumd.py diff --git a/.trunk/config/ruff.toml b/.trunk/config/ruff.toml new file mode 100644 index 00000000..f5a235cf --- /dev/null +++ b/.trunk/config/ruff.toml @@ -0,0 +1,5 @@ +# Generic, formatter-friendly config. +select = ["B", "D3", "E", "F"] + +# Never enforce `E501` (line length violations). This should be handled by formatters. +ignore = ["E501"] diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml index 8dd81899..eb0d7bd4 100644 --- a/.trunk/trunk.yaml +++ b/.trunk/trunk.yaml @@ -3,7 +3,7 @@ runtimes: enabled: - go@1.21.0 - node@22.16.0 - - python@3.10.8 + - python@3.14.4 actions: enabled: - trunk-announce @@ -16,23 +16,29 @@ cli: plugins: sources: - id: trunk - ref: v1.7.3 + ref: v1.9.0 uri: https://github.com/trunk-io/plugins lint: enabled: - - oxipng@9.1.5 - - yamllint@1.37.1 - - cspell@9.2.2 - - svgo@4.0.0 - - actionlint@1.7.8 - - black@25.9.0 + - bandit@1.9.4 + - checkov@3.2.528 + - grype@0.112.0 + - osv-scanner@2.3.8 + - ruff@0.15.12 + - trufflehog@3.95.3 + - oxipng@10.1.1 + - yamllint@1.38.0 + - cspell@9.7.0 + - svgo@4.0.1 + - actionlint@1.7.12 + - black@26.3.1 - flake8@7.3.0 - git-diff-check@SYSTEM - - gitleaks@8.28.0 + - gitleaks@8.30.1 - hadolint@2.14.0 - - isort@7.0.0 - - markdownlint@0.45.0 - - prettier@3.6.2 + - isort@8.0.1 + - markdownlint@0.48.0 + - prettier@3.8.3 - shellcheck@0.11.0 - shfmt@3.6.0 - taplo@0.10.0 diff --git a/examples/gpumd/metad/ethane-metad.py b/examples/gpumd/metad/ethane-metad.py new file mode 100644 index 00000000..9e4ea03a --- /dev/null +++ b/examples/gpumd/metad/ethane-metad.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 +""" +Well-tempered metadynamics of ethane dihedral angle: PySAGES + GPUMD + +Author: Jaafar Mehrez +(Shanghai Jiao Tong University, Shanghai, China; + HPQC Labs, Waterloo, Canada; + jaafarmehrez@sjtu.edu.cn, jaafar@hpqc.org) + +This script uses *well-tempered* metadynamics to compute the free energy +surface (FES) along the H-C-C-H dihedral angle of ethane. + +For background on ethane conformations, see: +https://chem.libretexts.org/Courses/Athabasca_University/... + +Before running: +1. Build gpumd.so: cd GPUMD/src && make pygpumd +2. Ensure gpumd.so is on PYTHONPATH +3. Have a GPUMD simulation directory with run.in and model.xyz + +Usage: + python ethane-metad.py + +SPDX-License-Identifier: MIT +""" + +import os +import sys +import time + +import numpy as np + +"""Ensure the compiled GPUMD module is importable""" +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_GPUMD_SRC = os.path.join(_SCRIPT_DIR, "GPUMD", "src") +if os.path.isdir(_GPUMD_SRC) and _GPUMD_SRC not in sys.path: + sys.path.insert(0, _GPUMD_SRC) + +import gpumd +import jax + +jax.config.update("jax_enable_x64", True) + +import pysages +from pysages.approxfun import compute_mesh +from pysages.backends.core import SamplingContext +from pysages.colvars import DihedralAngle +from pysages.methods import MetaDLogger, Metadynamics +from pysages.methods.core import Result + +"""Simulation setup""" +SIMULATION_DIR = "/path/to/your/gpumd/simulation" +RUN_IN_PATH = os.path.join(SIMULATION_DIR, "run.in") + +if not os.path.isfile(RUN_IN_PATH): + raise FileNotFoundError( + f"Cannot find {RUN_IN_PATH}. Please create a GPUMD simulation directory first." + ) + + +def generate_simulation(**kwargs): + """Return a GPUMD simulation object (backend context).""" + os.chdir(SIMULATION_DIR) + return gpumd.Simulation(RUN_IN_PATH) + + +""" +Collective variable: H-C-C-H dihedral angle + + From model.xyz (ethane, 8 atoms): + 0 C (carbon 1) + 1 H (hydrogen on C1) + 2 H (hydrogen on C1) + 3 H (hydrogen on C1) + 4 C (carbon 2) + 5 H (hydrogen on C2) + 6 H (hydrogen on C2) + 7 H (hydrogen on C2) + +Dihedral angle: H(1) -- C(0) -- C(4) -- H(5) +""" + +pi = np.pi +cvs = [DihedralAngle([1, 0, 4, 5])] + +"""Well-tempered metadynamics parameters""" + +height = 0.02 # Initial Gaussian height in eV (GPUMD energy unit) +sigma = [0.3] # Gaussian width in radians +stride = 100 # Deposit a hill every 100 steps +timesteps = 1000_000 # Total simulation steps +ngauss = timesteps // stride + 1 +deltaT = 1500.0 # Fictitious temperature in Kelvin (5x 300 K) +kB = 8.617333262e-5 # Boltzmann constant in eV/K + +grid = pysages.Grid( + lower=(-pi,), + upper=(pi,), + shape=(200,), + periodic=True, +) +method = Metadynamics( + cvs, + height, + sigma, + stride, + ngauss, + grid=grid, + deltaT=deltaT, + kB=kB, +) +hills_file = "hills_ethane_wt.dat" +callback = MetaDLogger(hills_file, stride) + +print("Starting ethane dihedral well-tempered metadynamics...") +print(f" CV: dihedral angle H(1)-C(0)-C(4)-H(5)") +print(f" Grid: [{-pi:.3f}, {pi:.3f}] rad (periodic)") +print(f" Hills: height={height} eV, sigma={sigma[0]} rad, stride={stride}") +print(f" Well-tempered: deltaT={deltaT} K, kB={kB:.4e} eV/K") +print(f" Total steps: {timesteps} -> ~{ngauss} hills") + +sim = generate_simulation() +tic = time.perf_counter() + +sampling_context = SamplingContext(method, lambda: sim) +with sampling_context: + sampling_context.run(timesteps) + sampler = sampling_context.sampler +toc = time.perf_counter() +print(f"Completed in {toc - tic:0.1f} seconds.") + +if hasattr(sampler, "print_timings"): + sampler.print_timings() + +run_result = Result( + method, + [sampler.state], + None if sampler.callback is None else [sampler.callback], + [sampler.take_snapshot()], +) + +T = 300.0 +plot_grid = pysages.Grid( + lower=(-pi,), + upper=(pi,), + shape=(400,), + periodic=True, +) +xi = compute_mesh(plot_grid) +result = pysages.analyze(run_result) +metapotential = result["metapotential"] +alpha = 1.0 + T / deltaT +A = -alpha * metapotential(xi) +A = A - A.min() + +output = np.column_stack((xi.flatten() * 180 / pi, A.flatten())) +np.savetxt( + "fes_ethane_wt.dat", + output, + header="dihedral_angle_deg free_energy_eV", + comments="", + fmt="%.6f", +) +print("Free energy surface saved to fes_ethane_wt.dat") + +try: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(8, 5)) + ax.plot(xi.flatten() * 180 / pi, A.flatten(), lw=2, color="steelblue") + ax.axhline(y=0, color="gray", ls="--", lw=0.5) + ax.annotate( + "staggered (min)", + xy=(60, 0), + xytext=(60, 0.05), + ha="center", + fontsize=9, + color="green", + ) + ax.annotate( + "eclipsed (max)", + xy=(0, A.max()), + xytext=(0, A.max() + 0.02), + ha="center", + fontsize=9, + color="red", + ) + + ax.set_xlabel(r"Dihedral angle $\phi$ (degrees)") + ax.set_ylabel(r"Free energy $\Delta G$ (eV)") + ax.set_title("Ethane rotational free energy (well-tempered metadynamics)") + ax.set_xlim(-180, 180) + ax.set_xticks([-180, -120, -60, 0, 60, 120, 180]) + + fig.tight_layout() + fig.savefig("fes_ethane_wt.png", dpi=150) + print("Plot saved to fes_ethane_wt.png") +except ImportError: + print("matplotlib not available; skipping plot.") + +"""Sanity check: barrier height""" +barrier = A.max() - A.min() +print(f"\nFES barrier height: {barrier:.4f} eV") +print(f" (Literature value for ethane: ~0.12 eV = ~12 kJ/mol)") +print(f" Well-tempered scaling factor alpha = {alpha:.3f}") + +if barrier < 0.05: + print("WARNING: barrier seems very low. Try increasing timesteps or") + print(" decreasing hill height/sigma for better resolution.") +elif barrier > 0.5: + print("WARNING: barrier seems very high. Check units or hill parameters.") +else: + print("Barrier height is in a physically reasonable range.") + +print("\nDone.") diff --git a/examples/gpumd/metad/model.xyz b/examples/gpumd/metad/model.xyz new file mode 100644 index 00000000..738d00b5 --- /dev/null +++ b/examples/gpumd/metad/model.xyz @@ -0,0 +1,10 @@ +8 +Lattice="15.0 0.0 0.0 0.0 15.0 0.0 0.0 0.0 15.0" Properties=species:S:1:pos:R:3 +C -2.24100 0.60300 0.00000 +H -1.88500 -0.40500 0.00000 +H -1.88500 1.10800 -0.87400 +H -3.31100 0.60400 0.00000 +C -1.72800 1.32900 1.25700 +H -2.08300 2.33900 1.25600 +H -2.08600 0.82600 2.13100 +H -0.65800 1.32800 1.25800 diff --git a/examples/gpumd/metad/run.in b/examples/gpumd/metad/run.in new file mode 100644 index 00000000..d2889e93 --- /dev/null +++ b/examples/gpumd/metad/run.in @@ -0,0 +1,7 @@ +potential ./nep.txt +velocity 300 + +ensemble nvt_lan 300 300 200 +time_step 0.5 +dump_thermo 1000 +dump_position 5000 diff --git a/pysages/backends/core.py b/pysages/backends/core.py index ad80a776..6df5843c 100644 --- a/pysages/backends/core.py +++ b/pysages/backends/core.py @@ -38,6 +38,8 @@ def __init__( self._backend_name = "lammps" elif module_name.startswith("simtk.openmm") or module_name.startswith("openmm"): self._backend_name = "openmm" + elif module_name.startswith("gpumd"): + self._backend_name = "gpumd" if self._backend_name is None: backends = ", ".join(supported_backends()) @@ -74,4 +76,4 @@ def __exit__(self, exc_type, exc_value, exc_traceback): def supported_backends(): - return ("ase", "hoomd", "jax-md", "lammps", "openmm") + return ("ase", "gpumd", "hoomd", "jax-md", "lammps", "openmm") diff --git a/pysages/backends/gpumd.py b/pysages/backends/gpumd.py new file mode 100644 index 00000000..ff7ccdb7 --- /dev/null +++ b/pysages/backends/gpumd.py @@ -0,0 +1,327 @@ +# SPDX-License-Identifier: MIT +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES +# +# Author: Jaafar Mehrez +# (Shanghai Jiao Tong University, Shanghai, China; +# HPQC Labs, Waterloo, Canada; +# jaafarmehrez@sjtu.edu.cn, jaafar@hpqc.org) + +""" +PySAGES backend for GPUMD. + +This backend assumes that GPUMD has been compiled with a pybind11 Python wrapper +that exposes the simulation data as DLPack capsules and supports a per-step +Python callback. + +The wrapper module is expected to be importable as ``gpumd``. +""" + +import time +from functools import partial + +from jax import jit +from jax import numpy as np +from jax.dlpack import from_dlpack + +from pysages.backends.core import SamplingContext +from pysages.backends.snapshot import ( + Box, + HelperMethods, + Snapshot, + SnapshotMethods, + build_data_querier, +) +from pysages.backends.snapshot import restore as _restore +from pysages.typing import Callable, Optional +from pysages.utils import copy, identity + +# The gpumd package does not exist yet; importing it here will raise a clear +# error when a user tries to use the backend without installing the wrapper. +try: + import gpumd +except ImportError as err: + raise ImportError( + "The gpumd Python package is required for the GPUMD backend. " + "Please build and install the GPUMD pybind11 wrapper." + ) from err + + +class Sampler: + """ + GPUMD sampler that connects a PySAGES sampling method to a GPUMD simulation. + + Parameters + ---------- + simulation: gpumd.Simulation + The wrapped GPUMD simulation instance. + + sampling_method: pysages.methods.SamplingMethod + The enhanced-sampling method to use. + + callback: Optional[Callable] + Optional user callback for logging or analysis. + """ + + def __init__( + self, + simulation, + sampling_method, + callback: Optional[Callable], + ): + self.simulation = simulation + self.callback = callback + + # Build initial snapshot and sampling method + initial_snapshot = self._take_snapshot() + helpers, restore, bias = build_helpers(simulation, sampling_method) + _, initialize, method_update = sampling_method.build(initial_snapshot, helpers) + + self.snapshot = initial_snapshot + self.state = initialize() + self._restore = restore + self._bias = bias + self._method_update = method_update + + self._cached_masses = initial_snapshot.vel_mass[1] + self._cached_ids = initial_snapshot.ids + self._cached_box = initial_snapshot.box + self._cached_dt = initial_snapshot.dt + + self._box_is_constant = simulation.is_box_constant() + + self._timings = { + "snapshot": 0.0, + "update": 0.0, + "bias": 0.0, + "callback": 0.0, + "total": 0.0, + } + self._timing_count = 0 + + simulation.clear_external_bias() + simulation.set_step_callback(self._update) + + def _update(self, timestep: int): + """ + Callback executed by GPUMD every timestep. + """ + t0 = time.perf_counter() + + self.snapshot = self._build_snapshot_with_fresh_arrays(timestep) + t1 = time.perf_counter() + + self.state = self._method_update(self.snapshot, self.state) + t2 = time.perf_counter() + self._bias(self.snapshot, self.state) + t3 = time.perf_counter() + if self.callback: + self.callback(self.snapshot, self.state, timestep) + t4 = time.perf_counter() + + self._timings["snapshot"] += t1 - t0 + self._timings["update"] += t2 - t1 + self._timings["bias"] += t3 - t2 + self._timings["callback"] += t4 - t3 + self._timings["total"] += t4 - t0 + self._timing_count += 1 + + def restore(self, prev_snapshot): + self._restore(self.snapshot, prev_snapshot) + + def take_snapshot(self): + return copy(self.snapshot) + + def print_timings(self): + n = self._timing_count + if n == 0: + print("[gpumd backend] No timing data collected yet.") + return + print("[gpumd backend] Per-step timing breakdown (ms):") + print(f" {'Stage':<12} {'Total (ms)':<14} {'Per-step (ms)':<14} {'% of total':<10}") + print(" " + "-" * 52) + total_ms = self._timings["total"] * 1000.0 + for key in ["snapshot", "update", "bias", "callback"]: + stage_ms = self._timings[key] * 1000.0 + per_ms = stage_ms / n + pct = (stage_ms / total_ms * 100.0) if total_ms > 0 else 0.0 + print(f" {key:<12} {stage_ms:>10.2f} {per_ms:>10.4f} {pct:>6.2f}") + print(" " + "-" * 52) + print(f" {'total':<12} {total_ms:>10.2f} {total_ms / n:>10.4f} {'100.00':>6}") + print(f" Steps counted: {n}") + + def _build_snapshot_with_fresh_arrays(self, timestep: int): + """ + Rebuild snapshot with fresh DLPack for positions/velocities/forces. + + Constant data (ids, dt, masses) is reused from the cache built during + ``__init__``. The simulation box is also cached, but for NPT or + change_box runs it may vary. We query ``sim.is_box_constant()`` once + during ``__init__``; if the box is constant we never call ``get_box()`` + again, otherwise we refresh it every 100 steps. + """ + sim = self.simulation + positions = from_dlpack(sim.get_positions_dlpack()).T + velocities = from_dlpack(sim.get_velocities_dlpack()).T + forces = from_dlpack(sim.get_forces_dlpack()).T + vel_mass = (velocities, self._cached_masses) + + if not self._box_is_constant and timestep % 100 == 0: + h, origin = sim.get_box() + cached_h = self._cached_box.H + h_3x3 = np.asarray(h).reshape(3, 3) + if not np.allclose(cached_h, h_3x3, atol=1e-12): + H = ( + (h[0], h[1], h[2]), + (h[3], h[4], h[5]), + (h[6], h[7], h[8]), + ) + self._cached_box = Box(H, origin) + + return Snapshot( + positions, + vel_mass, + forces, + self._cached_ids, + self._cached_box, + self._cached_dt, + ) + + def _take_snapshot(self): + """ + Construct a full PySAGES Snapshot from the current GPUMD state. + + Called once during ``Sampler.__init__`` to build the initial snapshot + and populate the constant-data cache. During normal MD + ``_build_snapshot_with_fresh_arrays`` is used instead so that JAX + sees updated GPU values each step. + """ + sim = self.simulation + positions = from_dlpack(sim.get_positions_dlpack()).T + velocities = from_dlpack(sim.get_velocities_dlpack()).T + forces = from_dlpack(sim.get_forces_dlpack()).T + masses = from_dlpack(sim.get_masses_dlpack()) + types = from_dlpack(sim.get_types_dlpack()) + vel_mass = (velocities, masses) + ids = np.arange(types.size) + h, origin = sim.get_box() + H = ( + (h[0], h[1], h[2]), + (h[3], h[4], h[5]), + (h[6], h[7], h[8]), + ) + box = Box(H, origin) + dt = sim.get_timestep() + + return Snapshot(positions, vel_mass, forces, ids, box, dt) + + +def build_snapshot_methods(context, sampling_method): + """ + Build methods for retrieving snapshot properties in a format useful for + collective variable calculations. + """ + + if sampling_method.requires_box_unwrapping: + + def positions(snapshot): + pos = snapshot.positions[:, :3] + L = np.diag(snapshot.box.H) + ref = pos[0] + delta = pos - ref + images = np.rint(delta / L) + return pos - L * images + + else: + + def positions(snapshot): + return snapshot.positions + + @jit + def indices(snapshot): + return snapshot.ids + + @jit + def momenta(snapshot): + velocities, masses = snapshot.vel_mass + return (masses * velocities).flatten() + + @jit + def masses(snapshot): + return snapshot.vel_mass[1] + + return SnapshotMethods(jit(positions), indices, momenta, masses) + + +def build_helpers(context, sampling_method): + """ + Builds helper methods used for restoring snapshots and biasing a simulation. + """ + utils = __import__("pysages.backends.utils", fromlist=["cupy_helpers"]) + + sync_forces, view = utils.cupy_helpers() + + def restore_vm(view, snapshot, prev_snapshot): + + velocities = view(snapshot.vel_mass[0]) + masses = view(snapshot.vel_mass[1]) + prev_velocities = view(prev_snapshot.vel_mass[0]) + prev_masses = view(prev_snapshot.vel_mass[1]) + velocities[:] = prev_velocities + masses[:] = prev_masses + + import jax.dlpack + + def bias(snapshot, state): + """ + Adds the computed bias to GPUMD's force_per_atom via a custom CUDA kernel. + + We intentionally skip ``block_until_ready()`` here because the + C++ side calls ``cudaDeviceSynchronize()`` after launching the kernel, + which waits for all GPU streams (including JAX's) to finish. + """ + if state.bias is None: + return + + bias_arr = state.bias + if bias_arr.dtype != np.float64: + bias_arr = bias_arr.astype(np.float64) + + context.add_aos_bias_to_forces(jax.dlpack.to_dlpack(bias_arr.flatten())) + + def dimensionality(): + return 3 + + snapshot_methods = build_snapshot_methods(context, sampling_method) + flags = sampling_method.snapshot_flags + restore = partial(_restore, view, restore_vm=restore_vm) + helpers = HelperMethods(build_data_querier(snapshot_methods, flags), dimensionality) + + return helpers, restore, bias + + +def bind(sampling_context: SamplingContext, callback: Optional[Callable] = None, **kwargs): + """ + Bind a PySAGES sampling method to a GPUMD simulation. + + Parameters + ---------- + sampling_context: pysages.backends.core.SamplingContext + The PySAGES sampling context, whose ``.context`` attribute is a + ``gpumd.Simulation`` instance. + + callback: Optional[Callable] + User callback executed after the bias has been applied each step. + + Returns + ------- + Sampler + The sampler object managing the GPUMD / PySAGES integration. + """ + identity(kwargs) # reserved for future options + + simulation = sampling_context.context + sampling_method = sampling_context.method + sampler = Sampler(simulation, sampling_method, callback) + sampling_context.run = simulation.run + + return sampler