Skip to content

Commit 4f01084

Browse files
committed
fea: add first pass at simulations under influence of electric field per Modern Theory of Polarization
1 parent 8e92014 commit 4f01084

4 files changed

Lines changed: 560 additions & 13 deletions

File tree

tests/models/test_polarization.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""Tests for the polarization electric-field correction model."""
2+
3+
import pytest
4+
import torch
5+
6+
import torch_sim as ts
7+
from tests.conftest import DEVICE, DTYPE
8+
from torch_sim.models.interface import ModelInterface, SerialSumModel
9+
from torch_sim.models.polarization import UniformPolarizationModel
10+
11+
12+
class DummyPolarResponseModel(ModelInterface):
13+
def __init__(
14+
self,
15+
*,
16+
include_born_effective_charges: bool = True,
17+
include_polarizability: bool = True,
18+
include_total_polarization: bool = True,
19+
device: torch.device = DEVICE,
20+
dtype: torch.dtype = DTYPE,
21+
) -> None:
22+
super().__init__()
23+
self.include_born_effective_charges = include_born_effective_charges
24+
self.include_polarizability = include_polarizability
25+
self.include_total_polarization = include_total_polarization
26+
self._device = device
27+
self._dtype = dtype
28+
self._compute_forces = True
29+
self._compute_stress = True
30+
31+
def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
32+
del kwargs
33+
energy = torch.arange(
34+
1, state.n_systems + 1, device=state.device, dtype=state.dtype
35+
)
36+
forces = (
37+
torch.arange(state.n_atoms * 3, device=state.device, dtype=state.dtype)
38+
.reshape(state.n_atoms, 3)
39+
.div(10.0)
40+
)
41+
stress = (
42+
torch.arange(state.n_systems * 9, device=state.device, dtype=state.dtype)
43+
.reshape(state.n_systems, 3, 3)
44+
.div(100.0)
45+
)
46+
polarization = (
47+
torch.arange(state.n_systems * 3, device=state.device, dtype=state.dtype)
48+
.reshape(state.n_systems, 3)
49+
.add(0.5)
50+
)
51+
output: dict[str, torch.Tensor] = {
52+
"energy": energy,
53+
"forces": forces,
54+
"stress": stress,
55+
}
56+
if self.include_total_polarization:
57+
output["total_polarization"] = polarization
58+
if self.include_polarizability:
59+
diag = torch.tensor([1.0, 2.0, 3.0], device=state.device, dtype=state.dtype)
60+
output["polarizability"] = torch.diag_embed(diag.repeat(state.n_systems, 1))
61+
if self.include_born_effective_charges:
62+
born_effective_charges = torch.zeros(
63+
state.n_atoms, 3, 3, device=state.device, dtype=state.dtype
64+
)
65+
born_effective_charges[:, 0, 0] = 1.0
66+
born_effective_charges[:, 1, 1] = 2.0
67+
born_effective_charges[:, 2, 2] = 3.0
68+
output["born_effective_charges"] = born_effective_charges
69+
return output
70+
71+
72+
def test_polarization_model_requires_external_e_field(
73+
si_double_sim_state: ts.SimState,
74+
) -> None:
75+
base_model = DummyPolarResponseModel()
76+
combined_model = SerialSumModel(
77+
base_model,
78+
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
79+
)
80+
81+
with pytest.raises(ValueError, match="external_E_field"):
82+
combined_model(si_double_sim_state)
83+
84+
85+
def test_polarization_model_applies_linear_response_corrections(
86+
si_double_sim_state: ts.SimState,
87+
) -> None:
88+
base_model = DummyPolarResponseModel()
89+
combined_model = SerialSumModel(
90+
base_model,
91+
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
92+
)
93+
field = torch.tensor(
94+
[[0.2, -0.1, 0.05], [-0.3, 0.4, 0.1]],
95+
device=DEVICE,
96+
dtype=DTYPE,
97+
)
98+
state = ts.SimState.from_state(si_double_sim_state, external_E_field=field)
99+
100+
base_output = base_model(state)
101+
combined_output = combined_model(state)
102+
expected_polarization = base_output["total_polarization"] + torch.einsum(
103+
"sij,sj->si", base_output["polarizability"], field
104+
)
105+
expected_energy = base_output["energy"] - torch.einsum(
106+
"si,si->s", field, base_output["total_polarization"]
107+
)
108+
expected_energy = expected_energy - 0.5 * torch.einsum(
109+
"si,sij,sj->s", field, base_output["polarizability"], field
110+
)
111+
expected_forces = base_output["forces"] + torch.einsum(
112+
"imn,im->in",
113+
base_output["born_effective_charges"],
114+
field[state.system_idx],
115+
)
116+
117+
torch.testing.assert_close(combined_output["energy"], expected_energy)
118+
torch.testing.assert_close(combined_output["forces"], expected_forces)
119+
torch.testing.assert_close(
120+
combined_output["total_polarization"], expected_polarization
121+
)
122+
torch.testing.assert_close(combined_output["stress"], base_output["stress"])
123+
124+
125+
def test_polarization_model_returns_additive_total_polarization_delta(
126+
si_double_sim_state: ts.SimState,
127+
) -> None:
128+
base_model = DummyPolarResponseModel()
129+
combined_model = SerialSumModel(
130+
base_model,
131+
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
132+
)
133+
field = torch.tensor([[0.1, 0.0, 0.0], [0.0, -0.2, 0.3]], device=DEVICE, dtype=DTYPE)
134+
state = ts.SimState.from_state(si_double_sim_state, external_E_field=field)
135+
136+
base_output = base_model(state)
137+
combined_output = combined_model(state)
138+
expected_total_polarization = base_output["total_polarization"] + torch.einsum(
139+
"sij,sj->si", base_output["polarizability"], field
140+
)
141+
142+
torch.testing.assert_close(
143+
combined_output["total_polarization"], expected_total_polarization
144+
)
145+
serialized_state = state.clone()
146+
serialized_state.store_model_extras(base_output)
147+
correction_output = UniformPolarizationModel(device=DEVICE, dtype=DTYPE)(
148+
serialized_state
149+
)
150+
torch.testing.assert_close(
151+
correction_output["total_polarization"],
152+
expected_total_polarization,
153+
)
154+
155+
156+
def test_polarization_model_requires_born_effective_charges_for_force_correction(
157+
si_double_sim_state: ts.SimState,
158+
) -> None:
159+
base_model = DummyPolarResponseModel(include_born_effective_charges=False)
160+
combined_model = SerialSumModel(
161+
base_model,
162+
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
163+
)
164+
state = ts.SimState.from_state(
165+
si_double_sim_state,
166+
external_E_field=torch.ones(
167+
si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE
168+
),
169+
)
170+
171+
with pytest.raises(ValueError, match="born_effective_charges"):
172+
combined_model(state)
173+
174+
175+
def test_polarization_model_requires_total_polarization(
176+
si_double_sim_state: ts.SimState,
177+
) -> None:
178+
base_model = DummyPolarResponseModel(include_total_polarization=False)
179+
combined_model = SerialSumModel(
180+
base_model,
181+
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
182+
)
183+
state = ts.SimState.from_state(
184+
si_double_sim_state,
185+
external_E_field=torch.ones(
186+
si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE
187+
),
188+
)
189+
190+
with pytest.raises(ValueError, match="total_polarization"):
191+
combined_model(state)
192+
193+
194+
def test_polarization_model_rejects_non_uniform_field_shape(
195+
si_double_sim_state: ts.SimState,
196+
) -> None:
197+
state = ts.SimState.from_state(
198+
si_double_sim_state,
199+
external_E_field=torch.zeros(
200+
si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE
201+
),
202+
)
203+
state._system_extras["external_E_field"] = torch.zeros( # noqa: SLF001
204+
state.n_atoms, 3, device=DEVICE, dtype=DTYPE
205+
)
206+
model = UniformPolarizationModel(device=DEVICE, dtype=DTYPE)
207+
208+
with pytest.raises(ValueError, match="shape \\(n_systems, 3\\)"):
209+
model(state)

tests/models/test_sum_model.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,79 @@
66
import torch_sim as ts
77
from tests.conftest import DEVICE, DTYPE
88
from tests.models.conftest import make_validate_model_outputs_test
9-
from torch_sim.models.interface import SumModel
9+
from torch_sim.models.interface import ModelInterface, SerialSumModel, SumModel
1010
from torch_sim.models.lennard_jones import LennardJonesModel
1111
from torch_sim.models.morse import MorseModel
1212

1313

14+
class ExtraProducerModel(ModelInterface):
15+
def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None:
16+
super().__init__()
17+
self._device = device
18+
self._dtype = dtype
19+
self._compute_stress = False
20+
self._compute_forces = False
21+
22+
def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
23+
del kwargs
24+
latent = state.positions[:, 0] + 2.0
25+
return {
26+
"energy": torch.ones(state.n_systems, device=state.device, dtype=state.dtype),
27+
"latent": latent,
28+
}
29+
30+
31+
class ExtraConsumerModel(ModelInterface):
32+
seen_latent: torch.Tensor | None
33+
34+
def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None:
35+
super().__init__()
36+
self._device = device
37+
self._dtype = dtype
38+
self._compute_stress = False
39+
self._compute_forces = False
40+
self.seen_latent = None
41+
42+
def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
43+
del kwargs
44+
self.seen_latent = state.latent.clone()
45+
energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype)
46+
energy.scatter_add_(0, state.system_idx, state.latent)
47+
return {"energy": energy}
48+
49+
50+
class OverwriteExtrasModel(ModelInterface):
51+
def __init__(
52+
self,
53+
value: float,
54+
device: torch.device = DEVICE,
55+
dtype: torch.dtype = DTYPE,
56+
) -> None:
57+
super().__init__()
58+
self.value = value
59+
self._device = device
60+
self._dtype = dtype
61+
self._compute_stress = False
62+
self._compute_forces = False
63+
64+
def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
65+
del kwargs
66+
return {
67+
"energy": torch.full(
68+
(state.n_systems,),
69+
self.value,
70+
device=state.device,
71+
dtype=state.dtype,
72+
),
73+
"label": torch.full(
74+
(state.n_systems, 3),
75+
self.value,
76+
device=state.device,
77+
dtype=state.dtype,
78+
),
79+
}
80+
81+
1482
@pytest.fixture
1583
def lj_model_a() -> LennardJonesModel:
1684
return LennardJonesModel(
@@ -43,9 +111,19 @@ def sum_model(lj_model_a: LennardJonesModel, morse_model: MorseModel) -> SumMode
43111
return SumModel(lj_model_a, morse_model)
44112

45113

114+
@pytest.fixture
115+
def serial_sum_model(
116+
lj_model_a: LennardJonesModel, morse_model: MorseModel
117+
) -> SerialSumModel:
118+
return SerialSumModel(lj_model_a, morse_model)
119+
120+
46121
test_sum_model_outputs = make_validate_model_outputs_test(
47122
model_fixture_name="sum_model", device=DEVICE, dtype=DTYPE
48123
)
124+
test_serial_sum_model_outputs = make_validate_model_outputs_test(
125+
model_fixture_name="serial_sum_model", device=DEVICE, dtype=DTYPE
126+
)
49127

50128

51129
def test_sum_model_requires_two_models(lj_model_a: LennardJonesModel) -> None:
@@ -102,3 +180,66 @@ def test_sum_model_retain_graph(
102180
assert lj_model_a.retain_graph is True
103181
assert morse_model.retain_graph is True
104182
assert sm.retain_graph is True
183+
184+
185+
def test_serial_sum_model_matches_parallel_sum_for_independent_models(
186+
lj_model_a: LennardJonesModel,
187+
morse_model: MorseModel,
188+
si_sim_state: ts.SimState,
189+
) -> None:
190+
sum_out = SumModel(lj_model_a, morse_model)(si_sim_state)
191+
serial_out = SerialSumModel(lj_model_a, morse_model)(si_sim_state)
192+
torch.testing.assert_close(serial_out["energy"], sum_out["energy"])
193+
torch.testing.assert_close(serial_out["forces"], sum_out["forces"])
194+
torch.testing.assert_close(serial_out["stress"], sum_out["stress"])
195+
196+
197+
def test_serial_sum_model_exposes_extras_to_later_models(
198+
si_double_sim_state: ts.SimState,
199+
) -> None:
200+
producer = ExtraProducerModel()
201+
consumer = ExtraConsumerModel()
202+
serial_model = SerialSumModel(producer, consumer)
203+
state = si_double_sim_state.clone()
204+
expected_latent = state.positions[:, 0] + 2.0
205+
expected_energy = torch.ones(state.n_systems, device=state.device, dtype=state.dtype)
206+
expected_energy = expected_energy.scatter_add(
207+
0,
208+
state.system_idx,
209+
expected_latent,
210+
)
211+
212+
output = serial_model(state)
213+
214+
assert consumer.seen_latent is not None
215+
torch.testing.assert_close(consumer.seen_latent, expected_latent)
216+
torch.testing.assert_close(output["latent"], expected_latent)
217+
torch.testing.assert_close(output["energy"], expected_energy)
218+
assert not state.has_extras("latent")
219+
220+
221+
def test_serial_sum_model_overwrites_noncanonical_outputs(
222+
si_double_sim_state: ts.SimState,
223+
) -> None:
224+
model = SerialSumModel(OverwriteExtrasModel(1.0), OverwriteExtrasModel(2.0))
225+
226+
output = model(si_double_sim_state)
227+
228+
torch.testing.assert_close(
229+
output["energy"],
230+
torch.full(
231+
(si_double_sim_state.n_systems,),
232+
3.0,
233+
device=si_double_sim_state.device,
234+
dtype=si_double_sim_state.dtype,
235+
),
236+
)
237+
torch.testing.assert_close(
238+
output["label"],
239+
torch.full(
240+
(si_double_sim_state.n_systems, 3),
241+
2.0,
242+
device=si_double_sim_state.device,
243+
dtype=si_double_sim_state.dtype,
244+
),
245+
)

0 commit comments

Comments
 (0)