@@ -53,6 +53,7 @@ def __init__(
5353 gcmc_kwargs = None ,
5454 output_directory = None ,
5555 perturbed_system = None ,
56+ xml_filenames = None ,
5657 ):
5758 """
5859 Constructor.
@@ -84,6 +85,10 @@ def __init__(
8485 perturbed_system: :class: `System <sire.system.System>`
8586 The perturbed end-state system used to seed starting coordinates for
8687 lambda > 0.5 replicas. If None, the perturbed state is not used.
88+
89+ xml_filenames: list of str
90+ A list of file paths for the OpenMM XML output, one per replica.
91+ If None, XML files are not written.
8792 """
8893
8994 # Warn if the number of replicas is not a multiple of the number of GPUs.
@@ -117,6 +122,7 @@ def __init__(
117122 gcmc_kwargs = gcmc_kwargs ,
118123 output_directory = output_directory ,
119124 perturbed_system = perturbed_system ,
125+ xml_filenames = xml_filenames ,
120126 )
121127
122128 def __setstate__ (self , state ):
@@ -168,6 +174,7 @@ def _create_dynamics(
168174 gcmc_kwargs = None ,
169175 output_directory = None ,
170176 perturbed_system = None ,
177+ xml_filenames = None ,
171178 ):
172179 """
173180 Create the dynamics objects.
@@ -199,6 +206,10 @@ def _create_dynamics(
199206 perturbed_system: :class: `System <sire.system.System>`
200207 The perturbed end-state system used to seed starting coordinates for
201208 lambda > 0.5 replicas. If None, the perturbed state is not used.
209+
210+ xml_filenames: list of str
211+ A list of file paths for the OpenMM XML output, one per replica.
212+ If None, XML files are not written.
202213 """
203214
204215 from math import floor
@@ -315,6 +326,13 @@ def _create_dynamics(
315326 # Append the dynamics object.
316327 self ._dynamics .append (dynamics )
317328
329+ # Write the OpenMM XML file to the output directory.
330+ if xml_filenames is not None :
331+ _logger .info (
332+ f"Writing OpenMM XML for lambda { lam :.5f} on device { device } "
333+ )
334+ dynamics .to_xml (xml_filenames [i ])
335+
318336 # Track memory footprint for this device.
319337 info = device_mem [device ]
320338 info ["count" ] += 1
@@ -740,6 +758,11 @@ def __init__(self, system, config):
740758
741759 # Create the dynamics cache.
742760 if not self ._is_restart :
761+ xml_filenames = (
762+ [self ._filenames [i ]["xml" ] for i in range (len (self ._lambda_values ))]
763+ if self ._config .save_xml
764+ else None
765+ )
743766 self ._dynamics_cache = DynamicsCache (
744767 self ._system ,
745768 self ._lambda_values ,
@@ -749,6 +772,7 @@ def __init__(self, system, config):
749772 gcmc_kwargs = self ._gcmc_kwargs ,
750773 perturbed_system = self ._perturbed_system ,
751774 output_directory = self ._config .output_directory ,
775+ xml_filenames = xml_filenames ,
752776 )
753777 else :
754778 _logger .debug ("Restarting from file" )
0 commit comments