Skip to content

Commit 903ca67

Browse files
committed
Update to the latest vesin release
1 parent e020f5f commit 903ca67

7 files changed

Lines changed: 155 additions & 132 deletions

File tree

python/metatomic_ase/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def create_version_number(version):
122122
# No dependency on ASE itself until this package is no longer a direct
123123
# dependency of metatomic-torch
124124
# "ase >=3.22.0",
125-
"vesin >=0.5.2,<0.6",
125+
"vesin >=0.5.5,<0.6",
126126
]
127127

128128
# when packaging a sdist for release, we should never use local dependencies

python/metatomic_ase/src/metatomic_ase/_calculator.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
pick_output,
2222
)
2323

24-
from ._neighbors import _compute_requested_neighbors
24+
from ._neighbors import AllNeighborsCalculator
2525

2626

2727
import ase # isort: skip
@@ -317,6 +317,11 @@ def __init__(
317317
"be positive"
318318
)
319319

320+
self._nl_calculators = AllNeighborsCalculator(
321+
requested_options=self._model.requested_neighbor_lists(),
322+
check_consistency=check_consistency,
323+
)
324+
320325
# We do our own check to verify if a property is implemented in `calculate()`,
321326
# so we pretend to be able to compute all properties ASE knows about.
322327
self.implemented_properties = ALL_ASE_PROPERTIES
@@ -398,11 +403,7 @@ def run_model(
398403
systems.append(system)
399404

400405
# Compute the neighbors lists requested by the model
401-
input_systems = _compute_requested_neighbors(
402-
systems=systems,
403-
requested_options=self._model.requested_neighbor_lists(),
404-
check_consistency=self.parameters["check_consistency"],
405-
)
406+
input_systems = self._nl_calculators.compute(systems=systems)
406407

407408
available_outputs = self._model.capabilities().outputs
408409
for key in outputs:
@@ -538,11 +539,7 @@ def calculate(
538539
with record_function("MetatomicCalculator::compute_neighbors"):
539540
# convert from ase.Atoms to metatomic.torch.System
540541
system = System(types, positions, cell, pbc)
541-
input_system = _compute_requested_neighbors(
542-
systems=[system],
543-
requested_options=self._model.requested_neighbor_lists(),
544-
check_consistency=self.parameters["check_consistency"],
545-
)[0]
542+
input_system = self._nl_calculators.compute(systems=[system])[0]
546543

547544
with record_function("MetatomicCalculator::get_model_inputs"):
548545
for name, option in self._model.requested_inputs().items():
@@ -721,11 +718,7 @@ def compute_energy(
721718
systems.append(system)
722719

723720
# Compute the neighbors lists requested by the model
724-
input_systems = _compute_requested_neighbors(
725-
systems=systems,
726-
requested_options=self._model.requested_neighbor_lists(),
727-
check_consistency=self.parameters["check_consistency"],
728-
)
721+
input_systems = self._nl_calculators.compute(systems=systems)
729722

730723
predictions = self._model(
731724
systems=input_systems,

python/metatomic_ase/src/metatomic_ase/_neighbors.py

Lines changed: 60 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -31,57 +31,68 @@
3131
HAS_NVALCHEMIOPS = False
3232

3333

34-
def _compute_requested_neighbors(
35-
systems: List[System],
36-
requested_options: List[NeighborListOptions],
37-
check_consistency=False,
38-
) -> List[System]:
39-
"""
40-
Compute all neighbor lists requested by ``model`` and store them inside the systems.
41-
"""
42-
can_use_nvalchemi = HAS_NVALCHEMIOPS and all(
43-
system.device.type == "cuda" for system in systems
44-
)
45-
46-
if can_use_nvalchemi:
47-
full_nl_options = []
48-
half_nl_options = []
49-
for options in requested_options:
50-
if options.full_list:
51-
full_nl_options.append(options)
52-
else:
53-
half_nl_options.append(options)
54-
55-
# Do the full neighbor lists with nvalchemi, and the rest with vesin
56-
systems = _compute_requested_neighbors_nvalchemi(
57-
systems=systems,
58-
requested_options=full_nl_options,
59-
)
60-
systems = _compute_requested_neighbors_vesin(
61-
systems=systems,
62-
requested_options=half_nl_options,
63-
check_consistency=check_consistency,
34+
class AllNeighborsCalculator:
35+
def __init__(
36+
self,
37+
requested_options: List[NeighborListOptions],
38+
check_consistency=False,
39+
):
40+
self.check_consistency = check_consistency
41+
self._full_nl_options = [
42+
options for options in requested_options if options.full_list
43+
]
44+
self._full_vesin_calculators = [
45+
vesin.metatomic.NeighborList(
46+
options=options,
47+
length_unit="angstrom",
48+
check_consistency=check_consistency,
49+
)
50+
for options in requested_options
51+
if options.full_list
52+
]
53+
self._half_vesin_calculators = [
54+
vesin.metatomic.NeighborList(
55+
options=options,
56+
length_unit="angstrom",
57+
check_consistency=check_consistency,
58+
)
59+
for options in requested_options
60+
if not options.full_list
61+
]
62+
63+
def compute(self, systems: List[System]) -> List[System]:
64+
assert isinstance(systems, list)
65+
assert isinstance(systems[0], torch.ScriptObject)
66+
67+
can_use_nvalchemi = HAS_NVALCHEMIOPS and all(
68+
system.device.type == "cuda" for system in systems
6469
)
65-
else:
70+
71+
if can_use_nvalchemi:
72+
# Do the full neighbor lists with nvalchemi
73+
systems = _compute_requested_neighbors_nvalchemi(
74+
systems=systems,
75+
requested_options=self._full_nl_options,
76+
)
77+
else:
78+
systems = _compute_requested_neighbors_vesin(
79+
systems=systems,
80+
calculators=self._full_vesin_calculators,
81+
)
82+
83+
# always compute the half neighbor lists with vesin
6684
systems = _compute_requested_neighbors_vesin(
6785
systems=systems,
68-
requested_options=requested_options,
69-
check_consistency=check_consistency,
86+
calculators=self._half_vesin_calculators,
7087
)
7188

72-
return systems
89+
return systems
7390

7491

7592
def _compute_requested_neighbors_vesin(
7693
systems: List[System],
77-
requested_options: List[NeighborListOptions],
78-
check_consistency=False,
94+
calculators: List[vesin.metatomic.NeighborList],
7995
) -> List[System]:
80-
"""
81-
Compute all neighbor lists requested by ``model`` and store them inside the systems,
82-
using vesin.
83-
"""
84-
8596
system_devices = []
8697
moved_systems = []
8798
for system in systems:
@@ -91,12 +102,13 @@ def _compute_requested_neighbors_vesin(
91102
else:
92103
moved_systems.append(system)
93104

94-
vesin.metatomic.compute_requested_neighbors_from_options(
95-
systems=moved_systems,
96-
system_length_unit="angstrom",
97-
options=requested_options,
98-
check_consistency=check_consistency,
99-
)
105+
for calculator in calculators:
106+
calculator.add_neighbor_list(
107+
systems=moved_systems,
108+
# if we have more than one system, we can no keep the data as a reference
109+
# to memory allocated in the calculator and we need to make a copy
110+
copy=len(systems) > 1,
111+
)
100112

101113
systems = []
102114
for system, device in zip(moved_systems, system_devices, strict=True):
@@ -142,6 +154,7 @@ def _compute_requested_neighbors_nvalchemi(systems, requested_options):
142154
"cell_shift_c",
143155
],
144156
values=torch.hstack([P, S]),
157+
assume_unique=True,
145158
),
146159
components=[
147160
Labels("xyz", torch.tensor([[0], [1], [2]], device=system.device))

python/metatomic_torchsim/metatomic_torchsim/_model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
pick_output,
2929
)
3030

31-
from ._neighbors import _compute_requested_neighbors
31+
from ._neighbors import AllNeighborsCalculator
3232

3333

3434
try:
@@ -249,7 +249,6 @@ def __init__(
249249
"be positive"
250250
)
251251

252-
self._requested_neighbor_lists = self._model.requested_neighbor_lists()
253252
self._requested_inputs = self._model.requested_inputs()
254253
if len(self._requested_inputs) != 0:
255254
raise ValueError(
@@ -283,6 +282,11 @@ def __init__(
283282
outputs=run_outputs,
284283
)
285284

285+
self._nl_calculators = AllNeighborsCalculator(
286+
requested_options=self._model.requested_neighbor_lists(),
287+
check_consistency=check_consistency,
288+
)
289+
286290
self.additional_outputs: Dict[str, TensorMap] = {}
287291
"""
288292
Additional outputs computed by :py:meth:`forward` are stored here.
@@ -355,11 +359,7 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]:
355359
)
356360

357361
# Compute neighbor lists
358-
systems = _compute_requested_neighbors(
359-
systems=systems,
360-
requested_options=self._requested_neighbor_lists,
361-
check_consistency=self._check_consistency,
362-
)
362+
systems = self._nl_calculators.compute(systems=systems)
363363

364364
# Run the model (evaluation options precomputed in __init__)
365365
model_outputs = self._model(

0 commit comments

Comments
 (0)