Skip to content

Commit 2e47223

Browse files
authored
Merge pull request #142 from OpenBioSim/feature_force_groups
Feature force groups
2 parents 306102c + 8b1e152 commit 2e47223

4 files changed

Lines changed: 97 additions & 75 deletions

File tree

src/somd2/config/_config.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ def __init__(
164164
pert_file=None,
165165
auto_fix_minimise=True,
166166
save_crash_report=False,
167-
save_energy_components=False,
167+
save_energy_components=True,
168+
save_xml=False,
168169
page_size=None,
169170
timeout="300 s",
170171
):
@@ -506,7 +507,11 @@ def __init__(
506507
507508
save_energy_components: bool
508509
Whether to save the energy contribution for each force when checkpointing.
509-
This is useful when debugging crashes.
510+
511+
save_xml: bool
512+
Whether to write an XML file for the OpenMM system to the output
513+
directory on startup. This can be useful for debugging or for
514+
use with other tools that can read OpenMM XML files.
510515
511516
page_size: int
512517
The page size for trajectory handling in megabytes. If None, then Sire
@@ -607,6 +612,7 @@ def __init__(
607612
self.auto_fix_minimise = auto_fix_minimise
608613
self.save_crash_report = save_crash_report
609614
self.save_energy_components = save_energy_components
615+
self.save_xml = save_xml
610616
self.timeout = timeout
611617
self.num_energy_neighbours = num_energy_neighbours
612618
self.null_energy = null_energy
@@ -2419,6 +2425,16 @@ def save_energy_components(self, save_energy_components):
24192425
raise ValueError("'save_energy_components' must be of type 'bool'")
24202426
self._save_energy_components = save_energy_components
24212427

2428+
@property
2429+
def save_xml(self):
2430+
return self._save_xml
2431+
2432+
@save_xml.setter
2433+
def save_xml(self, save_xml):
2434+
if not isinstance(save_xml, bool):
2435+
raise ValueError("'save_xml' must be of type 'bool'")
2436+
self._save_xml = save_xml
2437+
24222438
@property
24232439
def page_size(self):
24242440
return self._page_size

src/somd2/runner/_base.py

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -555,9 +555,6 @@ def __init__(self, system, config):
555555
self._config.checkpoint_frequency / self._config.energy_frequency
556556
)
557557

558-
# Zero the energy sample.
559-
self._nrg_sample = 0
560-
561558
# GCMC specific validation.
562559
if self._config.gcmc:
563560
if self._config.platform not in ["cuda", "opencl"]:
@@ -1197,10 +1194,11 @@ def increment_filename(base_filename, suffix):
11971194
filenames["trajectory"] = str(output_directory / f"traj_{lam}.dcd")
11981195
filenames["trajectory_chunk"] = str(output_directory / f"traj_{lam}_")
11991196
filenames["energy_components"] = str(
1200-
output_directory / f"energy_components_{lam}.txt"
1197+
output_directory / f"energy_components_{lam}.csv"
12011198
)
12021199
filenames["gcmc_ghosts"] = str(output_directory / f"gcmc_ghosts_{lam}.txt")
12031200
filenames["sampler_stats"] = str(output_directory / f"sampler_stats_{lam}.pkl")
1201+
filenames["xml"] = str(output_directory / f"system_{lam}.xml")
12041202
if restart:
12051203
filenames["config"] = str(
12061204
output_directory / increment_filename("config", "yaml")
@@ -2024,9 +2022,10 @@ def _backup_checkpoint(self, index):
20242022

20252023
return index, None
20262024

2027-
def _save_energy_components(self, index, context):
2025+
def _save_energy_components(self, index, context, time_ns):
20282026
"""
2029-
Internal function to save the energy components for each force group to file.
2027+
Internal function to save the energy components for each force group to a
2028+
CSV file.
20302029
20312030
Parameters
20322031
----------
@@ -2036,44 +2035,38 @@ def _save_energy_components(self, index, context):
20362035
20372036
context : openmm.Context
20382037
The current OpenMM context.
2038+
2039+
time_ns : float
2040+
The current simulation time in nanoseconds.
20392041
"""
20402042

2041-
from copy import deepcopy
2043+
import csv as _csv
20422044
import openmm
20432045

2044-
# Get the current context and system.
2045-
system = deepcopy(context.getSystem())
2046-
2047-
# Add each force to a unique group.
2048-
for i, f in enumerate(system.getForces()):
2049-
f.setForceGroup(i)
2050-
2051-
# Create a new context.
2052-
new_context = openmm.Context(system, deepcopy(context.getIntegrator()))
2053-
new_context.setPositions(context.getState(getPositions=True).getPositions())
2054-
2055-
header = f"{'# Sample':>10}"
2056-
record = f"{self._nrg_sample:>10}"
2057-
2058-
# Process the records.
2059-
for i, f in enumerate(system.getForces()):
2060-
state = new_context.getState(getEnergy=True, groups={i})
2061-
name = f.getName()
2062-
name_len = len(name)
2063-
header += f"{f.getName():>{name_len + 2}}"
2064-
record += f"{state.getPotentialEnergy().value_in_unit(openmm.unit.kilocalories_per_mole):>{name_len + 2}.2f}"
2065-
2066-
# Write to file.
2067-
if self._nrg_sample == 0:
2068-
with open(self._filenames[index]["energy_components"], "w") as f:
2069-
f.write(header + "\n")
2070-
f.write(record + "\n")
2071-
else:
2072-
with open(self._filenames[index]["energy_components"], "a") as f:
2073-
f.write(record + "\n")
2046+
filepath = self._filenames[index]["energy_components"]
2047+
file_exists = _Path(filepath).exists()
2048+
2049+
# Use the named force groups already assigned by sire_to_openmm_system,
2050+
# sorted alphabetically for a consistent column order across runs.
2051+
energies = {}
2052+
for name, grp in sorted(context._force_group_map.items()):
2053+
state = context.getState(getEnergy=True, groups=(1 << grp))
2054+
energies[name] = state.getPotentialEnergy().value_in_unit(
2055+
openmm.unit.kilocalories_per_mole
2056+
)
2057+
2058+
columns = ["time"] + list(energies.keys())
2059+
row = {"time": round(time_ns, 6)} | {
2060+
name: round(nrg, 4) for name, nrg in energies.items()
2061+
}
20742062

2075-
# Increment the sample number.
2076-
self._nrg_sample += 1
2063+
with open(filepath, "a", newline="") as f:
2064+
writer = _csv.DictWriter(f, fieldnames=columns)
2065+
if not file_exists:
2066+
# Write a comment line with units before the header.
2067+
f.write("# time: ns, energy: kcal/mol\n")
2068+
writer.writeheader()
2069+
writer.writerow(row)
20772070

20782071
def _restore_backup_files(self):
20792072
"""

src/somd2/runner/_repex.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
gcmc_kwargs=None,
5454
output_directory=None,
5555
perturbed_system=None,
56+
xml_filenames=None,
5657
):
5758
"""
5859
Constructor.
@@ -84,6 +85,10 @@ def __init__(
8485
perturbed_system: :class: `System <sire.system.System>`
8586
The perturbed end-state system used to seed starting coordinates for
8687
lambda > 0.5 replicas. If None, the perturbed state is not used.
88+
89+
xml_filenames: list of str
90+
A list of file paths for the OpenMM XML output, one per replica.
91+
If None, XML files are not written.
8792
"""
8893

8994
# Warn if the number of replicas is not a multiple of the number of GPUs.
@@ -117,6 +122,7 @@ def __init__(
117122
gcmc_kwargs=gcmc_kwargs,
118123
output_directory=output_directory,
119124
perturbed_system=perturbed_system,
125+
xml_filenames=xml_filenames,
120126
)
121127

122128
def __setstate__(self, state):
@@ -168,6 +174,7 @@ def _create_dynamics(
168174
gcmc_kwargs=None,
169175
output_directory=None,
170176
perturbed_system=None,
177+
xml_filenames=None,
171178
):
172179
"""
173180
Create the dynamics objects.
@@ -199,6 +206,10 @@ def _create_dynamics(
199206
perturbed_system: :class: `System <sire.system.System>`
200207
The perturbed end-state system used to seed starting coordinates for
201208
lambda > 0.5 replicas. If None, the perturbed state is not used.
209+
210+
xml_filenames: list of str
211+
A list of file paths for the OpenMM XML output, one per replica.
212+
If None, XML files are not written.
202213
"""
203214

204215
from math import floor
@@ -315,6 +326,13 @@ def _create_dynamics(
315326
# Append the dynamics object.
316327
self._dynamics.append(dynamics)
317328

329+
# Write the OpenMM XML file to the output directory.
330+
if xml_filenames is not None:
331+
_logger.info(
332+
f"Writing OpenMM XML for lambda {lam:.5f} on device {device}"
333+
)
334+
dynamics.to_xml(xml_filenames[i])
335+
318336
# Track memory footprint for this device.
319337
info = device_mem[device]
320338
info["count"] += 1
@@ -740,6 +758,11 @@ def __init__(self, system, config):
740758

741759
# Create the dynamics cache.
742760
if not self._is_restart:
761+
xml_filenames = (
762+
[self._filenames[i]["xml"] for i in range(len(self._lambda_values))]
763+
if self._config.save_xml
764+
else None
765+
)
743766
self._dynamics_cache = DynamicsCache(
744767
self._system,
745768
self._lambda_values,
@@ -749,6 +772,7 @@ def __init__(self, system, config):
749772
gcmc_kwargs=self._gcmc_kwargs,
750773
perturbed_system=self._perturbed_system,
751774
output_directory=self._config.output_directory,
775+
xml_filenames=xml_filenames,
752776
)
753777
else:
754778
_logger.debug("Restarting from file")
@@ -1397,11 +1421,6 @@ def _run_block(
13971421
energies = dynamics._current_energy_array()
13981422

13991423
except Exception as e:
1400-
try:
1401-
# Save the energy components for debugging purposes.
1402-
self._save_energy_components(index, dynamics.context())
1403-
except:
1404-
pass
14051424
return False, index, e
14061425

14071426
# Return the index and the energies.
@@ -1647,11 +1666,6 @@ def _equilibrate(self, index):
16471666
)
16481667

16491668
except Exception as e:
1650-
try:
1651-
# Save the energy components for debugging purposes.
1652-
self._save_energy_components(index, dynamics.context())
1653-
except:
1654-
pass
16551669
return False, index, e
16561670

16571671
return True, index, None
@@ -1767,6 +1781,12 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False):
17671781
# Commit the current system.
17681782
system = dynamics.commit()
17691783

1784+
# Save the energy contribution for each force.
1785+
if self._config.save_energy_components:
1786+
self._save_energy_components(
1787+
index, dynamics.context(), system.time().to("ns")
1788+
)
1789+
17701790
# If performing GCMC, then we need to flag the ghost waters.
17711791
if gcmc_sampler is not None:
17721792
system = gcmc_sampler._flag_ghost_waters(system)

src/somd2/runner/_runner.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -592,10 +592,6 @@ def generate_lam_vals(lambda_base, increment=0.001):
592592
system.set_time(_sr.u("0ps"))
593593

594594
except Exception as e:
595-
try:
596-
self._save_energy_components(index, dynamics.context())
597-
except:
598-
pass
599595
raise RuntimeError(f"Equilibration failed: {e}")
600596

601597
# Work out the lambda values for finite-difference gradient analysis.
@@ -645,6 +641,11 @@ def generate_lam_vals(lambda_base, increment=0.001):
645641
# Create the dynamics object.
646642
dynamics = system.dynamics(**dynamics_kwargs)
647643

644+
# Write the OpenMM XML file to the output directory.
645+
if self._config.save_xml and not is_restart:
646+
_logger.info(f"Writing OpenMM XML for {_lam_sym} = {lambda_value:.5f}")
647+
dynamics.to_xml(self._filenames[index]["xml"])
648+
648649
# Reset the GCMC sampler. This resets the sampling statistics and clears
649650
# the associated OpenMM forces.
650651
if gcmc_sampler is not None:
@@ -914,23 +915,21 @@ def generate_lam_vals(lambda_base, increment=0.001):
914915
save_crash_report=self._config.save_crash_report,
915916
)
916917
except Exception as e:
917-
try:
918-
self._save_energy_components(index, dynamics.context())
919-
except:
920-
pass
921918
raise RuntimeError(
922919
f"Dynamics block {block + 1} for {_lam_sym} = {lambda_value:.5f} failed: {e}"
923920
)
924921

925922
# Checkpoint.
926923
try:
927-
# Save the energy contribution for each force.
928-
if self._config.save_energy_components:
929-
self._save_energy_components(index, dynamics.context())
930-
931924
# Commit the current system.
932925
system = dynamics.commit()
933926

927+
# Save the energy contribution for each force.
928+
if self._config.save_energy_components:
929+
self._save_energy_components(
930+
index, dynamics.context(), system.time().to("ns")
931+
)
932+
934933
# If performing GCMC, then we need to flag the ghost waters.
935934
if gcmc_sampler is not None:
936935
system = gcmc_sampler._flag_ghost_waters(system)
@@ -1056,13 +1055,15 @@ def generate_lam_vals(lambda_base, increment=0.001):
10561055
save_crash_report=self._config.save_crash_report,
10571056
)
10581057

1059-
# Save the energy contribution for each force.
1060-
if self._config.save_energy_components:
1061-
self._save_energy_components(index, dynamics.context())
1062-
10631058
# Commit the current system.
10641059
system = dynamics.commit()
10651060

1061+
# Save the energy contribution for each force.
1062+
if self._config.save_energy_components:
1063+
self._save_energy_components(
1064+
index, dynamics.context(), system.time().to("ns")
1065+
)
1066+
10661067
# Record the end time.
10671068
block_end = _timer()
10681069

@@ -1101,10 +1102,6 @@ def generate_lam_vals(lambda_base, increment=0.001):
11011102
f"{_lam_sym} = {lambda_value:.5f} complete, speed = {speed:.2f} ns day-1"
11021103
)
11031104
except Exception as e:
1104-
try:
1105-
self._save_energy_components(index, dynamics.context())
1106-
except:
1107-
pass
11081105
raise RuntimeError(
11091106
f"Final dynamics block for {lam_sym} = {lambda_value:.5f} failed: {e}"
11101107
)
@@ -1232,10 +1229,6 @@ def generate_lam_vals(lambda_base, increment=0.001):
12321229
save_crash_report=self._config.save_crash_report,
12331230
)
12341231
except Exception as e:
1235-
try:
1236-
self._save_energy_components(index, dynamics.context())
1237-
except:
1238-
pass
12391232
raise RuntimeError(
12401233
f"Dynamics for {_lam_sym} = {lambda_value:.5f} failed: {e}"
12411234
)

0 commit comments

Comments
 (0)