Skip to content

Commit 7ea51c6

Browse files
authored
Add Electrostatics using extensible extras (#532)
1 parent 8e92014 commit 7ea51c6

4 files changed

Lines changed: 509 additions & 9 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,12 @@ test = [
4343
"physical-validation>=1.0.5",
4444
"platformdirs>=4.0.0",
4545
"psutil>=7.0.0",
46-
"pymatgen>=2025.6.14",
4746
"pytest-cov>=6",
4847
"pytest>=8",
4948
"spglib>=2.6",
50-
"vesin[torch]>=0.5.3",
5149
]
5250
vesin = ["vesin[torch]>=0.5.3"]
53-
io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"]
51+
io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2026.3.23"]
5452
symmetry = ["moyopy>=0.7.8"]
5553
mace = ["mace-torch>=0.3.15"]
5654
mattersim = ["mattersim>=1.2.2"]
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""Tests for the electrostatics ModelInterface wrappers."""
2+
3+
import traceback # noqa: I001
4+
5+
import pytest
6+
import torch
7+
from ase.build import bulk
8+
9+
import torch_sim as ts
10+
from tests.conftest import DEVICE, DTYPE
11+
from tests.models.conftest import make_validate_model_outputs_test
12+
13+
try:
14+
from torch_sim.models.electrostatics import DSFCoulombModel, EwaldModel, PMEModel
15+
except (ImportError, OSError, RuntimeError):
16+
pytest.skip(
17+
f"nvalchemiops not installed: {traceback.format_exc()}",
18+
allow_module_level=True,
19+
)
20+
21+
22+
def _make_charged_state(
23+
device: torch.device = DEVICE,
24+
dtype: torch.dtype = DTYPE,
25+
) -> ts.SimState:
26+
"""Build a small NaCl-like state with alternating +1/-1 site charges."""
27+
atoms = bulk("NaCl", crystalstructure="rocksalt", a=5.64, cubic=True)
28+
state = ts.io.atoms_to_state(atoms, device, dtype)
29+
n = state.n_atoms
30+
charges = torch.empty(n, dtype=dtype, device=device)
31+
charges[::2] = 1.0
32+
charges[1::2] = -1.0
33+
state._atom_extras["partial_charges"] = charges # noqa: SLF001
34+
return state
35+
36+
37+
@pytest.fixture
38+
def dsf_model() -> DSFCoulombModel:
39+
return DSFCoulombModel(cutoff=8.0, alpha=0.2, device=DEVICE, dtype=DTYPE)
40+
41+
42+
@pytest.fixture
43+
def ewald_model() -> EwaldModel:
44+
return EwaldModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE)
45+
46+
47+
@pytest.fixture
48+
def pme_model() -> PMEModel:
49+
return PMEModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE)
50+
51+
52+
def _add_partial_charges(state: ts.SimState) -> ts.SimState:
53+
"""Inject alternating +/-0.5 site charges into a state."""
54+
n = state.n_atoms
55+
charges = torch.zeros(n, dtype=state.positions.dtype, device=state.device)
56+
charges[::2] = 0.5
57+
charges[1::2] = -0.5
58+
state._atom_extras["partial_charges"] = charges # noqa: SLF001
59+
return state
60+
61+
62+
test_dsf_model_outputs = make_validate_model_outputs_test(
63+
model_fixture_name="dsf_model",
64+
device=DEVICE,
65+
dtype=DTYPE,
66+
state_modifiers=[_add_partial_charges],
67+
)
68+
test_ewald_model_outputs = make_validate_model_outputs_test(
69+
model_fixture_name="ewald_model",
70+
device=DEVICE,
71+
dtype=DTYPE,
72+
state_modifiers=[_add_partial_charges],
73+
)
74+
test_pme_model_outputs = make_validate_model_outputs_test(
75+
model_fixture_name="pme_model",
76+
device=DEVICE,
77+
dtype=DTYPE,
78+
state_modifiers=[_add_partial_charges],
79+
)
80+
81+
82+
def test_dsf_nonzero_energy() -> None:
83+
"""Charged system should produce nonzero electrostatic energy."""
84+
model = DSFCoulombModel(cutoff=8.0, alpha=0.2, device=DEVICE, dtype=DTYPE)
85+
state = _make_charged_state()
86+
out = model(state)
87+
assert out["energy"].abs().item() > 0
88+
89+
90+
def test_ewald_pme_energy_agreement() -> None:
91+
"""Ewald and PME should give the same converged Coulomb energy."""
92+
state = _make_charged_state()
93+
ewald = EwaldModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE)
94+
pme = PMEModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE)
95+
torch.testing.assert_close(
96+
ewald(state)["energy"], pme(state)["energy"], atol=1e-3, rtol=1e-3
97+
)
98+
99+
100+
def test_sum_model_lj_plus_dsf() -> None:
101+
"""LJ + DSF should be additive through SumModel."""
102+
from torch_sim.models.interface import SumModel
103+
from torch_sim.models.lennard_jones import LennardJonesModel
104+
105+
lj = LennardJonesModel(
106+
sigma=2.8, epsilon=0.01, cutoff=7.0, device=DEVICE, dtype=DTYPE
107+
)
108+
dsf = DSFCoulombModel(cutoff=8.0, alpha=0.2, device=DEVICE, dtype=DTYPE)
109+
combined = SumModel(lj, dsf)
110+
state = _make_charged_state()
111+
lj_out = lj(state)
112+
dsf_out = dsf(state)
113+
sum_out = combined(state)
114+
torch.testing.assert_close(sum_out["energy"], lj_out["energy"] + dsf_out["energy"])
115+
torch.testing.assert_close(sum_out["forces"], lj_out["forces"] + dsf_out["forces"])

torch_sim/_duecredit.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,25 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING, Any, TypeVar
66

77

88
if TYPE_CHECKING:
99
from collections.abc import Callable
1010

11+
_F = TypeVar("_F", bound="Callable[..., Any]")
12+
1113

1214
class InactiveDueCreditCollector:
1315
"""Just a stub at the Collector which would not do anything."""
1416

1517
def _donothing(self, *_args: Any, **_kwargs: Any) -> None:
1618
"""Perform no good and no bad."""
1719

18-
def dcite(
19-
self, *_args: Any, **_kwargs: Any
20-
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
20+
def dcite(self, *_args: Any, **_kwargs: Any) -> Callable[[_F], _F]:
2121
"""If I could cite I would."""
2222

23-
def nondecorating_decorator(func: Callable[..., Any]) -> Callable[..., Any]:
23+
def nondecorating_decorator(func: _F) -> _F:
2424
return func
2525

2626
return nondecorating_decorator
@@ -56,7 +56,7 @@ def _disable_duecredit(exc: Exception) -> None:
5656

5757
def dcite(
5858
doi: str, description: str | None = None, *, path: str | None = None
59-
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
59+
) -> Callable[[_F], _F]:
6060
"""Create a duecredit decorator from a DOI and description."""
6161
kwargs: dict[str, Any] = (
6262
{"description": description} if description is not None else {}

0 commit comments

Comments
 (0)