Skip to content

Commit f733134

Browse files
committed
feat(torchsim): remove metatrain dependency
- Remove support for loading .ckpt files directly - Remove 'pet-mad' shortcut (requires metatrain) - Update documentation to point users to upet for model export - Remove metatrain from optional dependencies Metatrain checkpoints should now be exported using upet before using with torchsim: pip install upet && upet export model.ckpt Addresses review comment from @Luthaf on PR metatensor#167
1 parent d7019ab commit f733134

4 files changed

Lines changed: 16 additions & 42 deletions

File tree

python/metatomic_torchsim/README.md

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,6 @@ use in TorchSim molecular dynamics and other simulation workflows.
1111
pip install metatomic-torchsim
1212
```
1313

14-
To use metatrain checkpoints (`.ckpt` files):
15-
16-
```bash
17-
pip install metatomic-torchsim[metatrain]
18-
```
19-
2014
For universal potential models, see
2115
[upet](https://github.com/lab-cosmo/upet).
2216

@@ -25,12 +19,9 @@ For universal potential models, see
2519
```python
2620
from metatomic_torchsim import MetatomicModel
2721

28-
# From a saved .pt model
22+
# From a saved .pt model (exported with upet)
2923
model = MetatomicModel("model.pt", device="cuda")
3024

31-
# From a metatrain checkpoint (requires metatrain extra)
32-
model = MetatomicModel("model.ckpt", device="cuda")
33-
3425
# Use with TorchSim
3526
output = model(sim_state)
3627
energy = output["energy"]
@@ -40,3 +31,4 @@ stress = output["stress"]
4031

4132
For full documentation, see the
4233
[torch-sim engine page](https://docs.metatensor.org/metatomic/latest/engines/torch-sim.html).
34+

python/metatomic_torchsim/metatomic_torchsim/_model.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,9 @@ def __init__(
7979
"""Initialize the metatomic model wrapper.
8080
8181
:param model: Model to use. Accepts a file path to a ``.pt`` saved
82-
model, a ``.ckpt`` metatrain checkpoint (requires ``metatrain``), the
83-
string ``"pet-mad"`` (shortcut for the PET-MAD model, requires
84-
``metatrain``), a Python :py:class:`AtomisticModel` instance, or a
85-
TorchScript :py:class:`torch.jit.RecursiveScriptModule`.
82+
metatomic model (exported with ``upet``), a Python
83+
:py:class:`AtomisticModel` instance, or a TorchScript
84+
:py:class:`torch.jit.RecursiveScriptModule`.
8685
:param extensions_directory: Directory containing compiled TorchScript
8786
extensions required by the model, if any.
8887
:param device: Torch device for evaluation. When ``None``, the best
@@ -96,22 +95,19 @@ def __init__(
9695

9796
self._check_consistency = check_consistency
9897

99-
# Load the model, following the same patterns as ase_calculator.py
100-
if isinstance(model, str) and model == "pet-mad":
101-
model = self._load_metatrain_model(
102-
"https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/"
103-
"models/pet-mad-v1.1.0.ckpt"
104-
)
105-
elif isinstance(model, (str, bytes, pathlib.PurePath)):
98+
# Load the model from a file path or AtomisticModel instance
99+
if isinstance(model, (str, bytes, pathlib.PurePath)):
106100
model_path = str(model)
107101
if model_path.endswith(".ckpt"):
108-
model = self._load_metatrain_model(model_path)
109-
else:
110-
if not os.path.exists(model_path):
111-
raise ValueError(f"given model path '{model_path}' does not exist")
112-
model = load_atomistic_model(
113-
model_path, extensions_directory=extensions_directory
102+
raise ValueError(
103+
".ckpt files are not supported. Please export your metatrain "
104+
"model using upet: pip install upet && upet export model.ckpt"
114105
)
106+
if not os.path.exists(model_path):
107+
raise ValueError(f"given model path '{model_path}' does not exist")
108+
model = load_atomistic_model(
109+
model_path, extensions_directory=extensions_directory
110+
)
115111
elif isinstance(model, torch.jit.RecursiveScriptModule):
116112
if model.original_name != "AtomisticModel":
117113
raise TypeError(
@@ -162,18 +158,6 @@ def __init__(
162158
},
163159
)
164160

165-
@staticmethod
166-
def _load_metatrain_model(path: str) -> AtomisticModel:
167-
"""Load a metatrain checkpoint and export it as an AtomisticModel."""
168-
try:
169-
from metatrain.utils.io import load_model
170-
except ImportError as exc:
171-
raise ImportError(
172-
"metatrain is required to load .ckpt files or use the 'pet-mad' "
173-
"shortcut: pip install metatrain"
174-
) from exc
175-
176-
return load_model(path).export()
177161

178162
def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]:
179163
"""Compute energies, forces, and stresses for the given simulation state.

python/metatomic_torchsim/pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ dependencies = [
3636
]
3737

3838
[project.optional-dependencies]
39-
metatrain = ["metatrain"]
40-
4139
[project.urls]
4240
homepage = "https://docs.metatensor.org/metatomic/latest/engines/torch-sim.html"
4341
documentation = "https://docs.metatensor.org/metatomic/latest/engines/torch-sim.html"

python/metatomic_torchsim/tests/test_torchsim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tests for the MetatomicModel TorchSim wrapper.
22
3-
Uses the metatomic-lj-test model so that tests run without needing metatrain or
3+
Uses the metatomic-lj-test model so that tests run without needing
44
downloading large model files.
55
"""
66

0 commit comments

Comments
 (0)