Skip to content

Commit b617e42

Browse files
authored
Merge pull request #126 from OpenBioSim/fix_124
Fix issue #124
2 parents e5e6cfd + 0e6a2d8 commit b617e42

7 files changed

Lines changed: 38 additions & 29 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ repos:
1414

1515
# Python formatting and linting
1616
- repo: https://github.com/astral-sh/ruff-pre-commit
17-
rev: v0.8.4
17+
rev: v0.15.4
1818
hooks:
1919
# Run the formatter
2020
- id: ruff-format

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,12 @@ pip install -e .
7171
> [!Note]
7272
> Pixi does not run conda post-link scripts, so the `ocl-icd-system`
7373
> symlink needed for OpenCL won't be created automatically. After
74-
> creating the environment, run the following once to fix this:
74+
> creating the environment (or after a pixi update), run the following
75+
> to fix this:
7576
>
7677
> ```bash
7778
> pixi shell
78-
> ln -s /etc/OpenCL/vendors "${CONDA_PREFIX}/etc/OpenCL/vendors/ocl-icd-system"
79+
> ln -sfn /etc/OpenCL/vendors "${CONDA_PREFIX}/etc/OpenCL/vendors/ocl-icd-system"
7980
> ```
8081
8182
### Testing

src/somd2/config/_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1813,7 +1813,7 @@ def gcmc_radius(self, gcmc_radius):
18131813
gcmc_r = _sr.u(gcmc_radius)
18141814
except:
18151815
raise ValueError(
1816-
"Unable to parse 'gcmc_radius' " f"as a Sire GeneralUnit: {gcmc_radius}"
1816+
f"Unable to parse 'gcmc_radius' as a Sire GeneralUnit: {gcmc_radius}"
18171817
)
18181818

18191819
if not gcmc_r.has_same_units(angstrom):

src/somd2/io/_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def dataframe_to_parquet(df, metadata, filepath=None, filename=None):
7474
table = table.replace_schema_metadata(combined_meta)
7575
if filename is None:
7676
if "lambda" in metadata and "temperature" in metadata:
77-
filename = f"Lam_{metadata['lambda'].replace('.','')[:5]}_T_{metadata['temperature']}.parquet"
77+
filename = f"Lam_{metadata['lambda'].replace('.', '')[:5]}_T_{metadata['temperature']}.parquet"
7878
else:
7979
filename = "output.parquet"
8080
if not filename.endswith(".parquet"):

src/somd2/runner/_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,7 @@ def __init__(self, system, config):
764764
"coulomb_power": self._config.coulomb_power,
765765
"shift_coulomb": str(self._config.shift_coulomb),
766766
"shift_delta": str(self._config.shift_delta),
767+
"rest2_selection": self._config.rest2_selection,
767768
"swap_end_states": self._config.swap_end_states,
768769
"tolerance": self._config.gcmc_tolerance,
769770
"restart": self._is_restart,
@@ -1985,8 +1986,8 @@ def _save_energy_components(self, index, context):
19851986
state = new_context.getState(getEnergy=True, groups={i})
19861987
name = f.getName()
19871988
name_len = len(name)
1988-
header += f"{f.getName():>{name_len+2}}"
1989-
record += f"{state.getPotentialEnergy().value_in_unit(openmm.unit.kilocalories_per_mole):>{name_len+2}.2f}"
1989+
header += f"{f.getName():>{name_len + 2}}"
1990+
record += f"{state.getPotentialEnergy().value_in_unit(openmm.unit.kilocalories_per_mole):>{name_len + 2}.2f}"
19901991

19911992
# Write to file.
19921993
if self._nrg_sample == 0:

src/somd2/runner/_repex.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -904,12 +904,12 @@ def run(self):
904904
frac = 1.0
905905
checkpoint_frequency = self._config.energy_frequency
906906

907-
# Store the number of repex cycles per block.
908-
cycles_per_checkpoint = int(frac)
907+
# Store the number of repex cycles per block (may be fractional).
908+
cycles_per_checkpoint = frac
909909

910910
# Otherwise, we don't checkpoint.
911911
else:
912-
cycles_per_checkpoint = cycles
912+
cycles_per_checkpoint = float(cycles)
913913
num_blocks = 1
914914
rem = 0
915915

@@ -991,9 +991,13 @@ def run(self):
991991
else:
992992
cycles_per_gcmc = cycles + 1
993993

994+
# Initialise the threshold for the next checkpoint cycle. This is a float
995+
# to handle non-integer ratios between the checkpoint and energy frequencies.
996+
next_checkpoint = cycles_per_checkpoint
997+
994998
# Perform the replica exchange simulation.
995999
for i in range(cycles):
996-
_logger.info(f"Running dynamics for cycle {i+1} of {cycles}")
1000+
_logger.info(f"Running dynamics for cycle {i + 1} of {cycles}")
9971001

9981002
# Log the states. This is the replica index for the state (positions
9991003
# and velocities) used to seed each replica for the current cycle.
@@ -1007,14 +1011,15 @@ def run(self):
10071011
# Clear the results list.
10081012
results = []
10091013

1010-
# Whether to checkpoint.
1011-
is_checkpoint = i > 0 and i % cycles_per_checkpoint == 0
1014+
# Whether to checkpoint. Use a float threshold to correctly handle
1015+
# non-integer ratios between the checkpoint and energy frequencies.
1016+
is_checkpoint = (i + 1) >= next_checkpoint - 1e-10
10121017

10131018
# Whether to perform a GCMC move before the dynamics block.
1014-
is_gcmc = i % cycles_per_gcmc == 0
1019+
is_gcmc = (i + 1) % cycles_per_gcmc == 0
10151020

10161021
# Whether a frame is saved at the end of the cycle.
1017-
write_gcmc_ghosts = i > 0 and i % cycles_per_frame == 0
1022+
write_gcmc_ghosts = (i + 1) % cycles_per_frame == 0
10181023

10191024
# Run a dynamics block for each replica, making sure only each GPU is only
10201025
# oversubscribed by a factor of self._config.oversubscription_factor.
@@ -1119,6 +1124,9 @@ def run(self):
11191124
# Update the block number.
11201125
block += 1
11211126

1127+
# Advance the checkpoint threshold.
1128+
next_checkpoint += cycles_per_checkpoint
1129+
11221130
# Guard the repex state and transition matrix saving with a file lock.
11231131
lock = _FileLock(self._lock_file)
11241132
with lock.acquire(timeout=self._config.timeout.to("seconds")):
@@ -1248,6 +1256,14 @@ def _run_block(
12481256
# Remove the PyCUDA context from the stack.
12491257
gcmc_sampler.pop()
12501258

1259+
# A frame was saved at the end of the last cycle, so write
1260+
# the current ghost water residue indices to file. This is
1261+
# done here, immediately after the GCMC move, since the
1262+
# sampler state is only updated during GCMC moves and waters
1263+
# may have moved in/out of the GCMC sphere during dynamics.
1264+
if write_gcmc_ghosts:
1265+
gcmc_sampler.write_ghost_residues()
1266+
12511267
# Run the dynamics.
12521268
dynamics.run(
12531269
self._config.energy_frequency,
@@ -1277,18 +1293,9 @@ def _run_block(
12771293
# Save the GCMC state.
12781294
if gcmc_sampler is not None:
12791295
self._dynamics_cache.save_gcmc_state(index)
1280-
# The frame frequency was hit, so write the indices of the
1281-
# current ghost water residues to file.
1282-
if write_gcmc_ghosts:
1283-
gcmc_sampler.write_ghost_residues()
12841296

12851297
# Get the energy at each lambda value.
1286-
energies = (
1287-
dynamics._d.energy_trajectory()
1288-
.to_pandas(to_alchemlyb=True, energy_unit="kcal/mol")
1289-
.iloc[-1, :]
1290-
.to_numpy()
1291-
)
1298+
energies = dynamics._current_energy_array()
12921299

12931300
except Exception as e:
12941301
try:
@@ -1681,7 +1688,7 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False):
16811688
dynamics._d._sire_mols.delete_all_frames()
16821689

16831690
_logger.info(
1684-
f"Finished block {block+1} of {self._start_block + num_blocks} "
1691+
f"Finished block {block + 1} of {self._start_block + num_blocks} "
16851692
f"for {_lam_sym} = {lam:.5f}"
16861693
)
16871694

src/somd2/runner/_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def generate_lam_vals(lambda_base, increment=0.001):
748748
except:
749749
pass
750750
raise RuntimeError(
751-
f"Dynamics block {block+1} for {_lam_sym} = {lambda_value:.5f} failed: {e}"
751+
f"Dynamics block {block + 1} for {_lam_sym} = {lambda_value:.5f} failed: {e}"
752752
)
753753

754754
# Checkpoint.
@@ -809,7 +809,7 @@ def generate_lam_vals(lambda_base, increment=0.001):
809809
dynamics._d._sire_mols.delete_all_frames()
810810

811811
_logger.info(
812-
f"Finished block {block+1} of {self._start_block + num_blocks} "
812+
f"Finished block {block + 1} of {self._start_block + num_blocks} "
813813
f"for {_lam_sym} = {lambda_value:.5f}"
814814
)
815815

@@ -884,7 +884,7 @@ def generate_lam_vals(lambda_base, increment=0.001):
884884
dynamics._d._sire_mols.delete_all_frames()
885885

886886
_logger.info(
887-
f"Finished block {block+1} of {self._start_block + num_blocks} "
887+
f"Finished block {block + 1} of {self._start_block + num_blocks} "
888888
f"for {_lam_sym} = {lambda_value:.5f}"
889889
)
890890

0 commit comments

Comments
 (0)