3131 HAS_NVALCHEMIOPS = False
3232
3333
34- def _compute_requested_neighbors (
35- systems : List [System ],
36- requested_options : List [NeighborListOptions ],
37- check_consistency = False ,
38- ) -> List [System ]:
39- """
40- Compute all neighbor lists requested by ``model`` and store them inside the systems.
41- """
42- can_use_nvalchemi = HAS_NVALCHEMIOPS and all (
43- system .device .type == "cuda" for system in systems
44- )
45-
46- if can_use_nvalchemi :
47- full_nl_options = []
48- half_nl_options = []
49- for options in requested_options :
50- if options .full_list :
51- full_nl_options .append (options )
52- else :
53- half_nl_options .append (options )
54-
55- # Do the full neighbor lists with nvalchemi, and the rest with vesin
56- systems = _compute_requested_neighbors_nvalchemi (
57- systems = systems ,
58- requested_options = full_nl_options ,
59- )
60- systems = _compute_requested_neighbors_vesin (
61- systems = systems ,
62- requested_options = half_nl_options ,
63- check_consistency = check_consistency ,
34+ class AllNeighborsCalculator :
35+ def __init__ (
36+ self ,
37+ requested_options : List [NeighborListOptions ],
38+ check_consistency = False ,
39+ ):
40+ self .check_consistency = check_consistency
41+ self ._full_nl_options = [
42+ options for options in requested_options if options .full_list
43+ ]
44+ self ._full_vesin_calculators = [
45+ vesin .metatomic .NeighborList (
46+ options = options ,
47+ length_unit = "angstrom" ,
48+ check_consistency = check_consistency ,
49+ )
50+ for options in requested_options
51+ if options .full_list
52+ ]
53+ self ._half_vesin_calculators = [
54+ vesin .metatomic .NeighborList (
55+ options = options ,
56+ length_unit = "angstrom" ,
57+ check_consistency = check_consistency ,
58+ )
59+ for options in requested_options
60+ if not options .full_list
61+ ]
62+
63+ def compute (self , systems : List [System ]) -> List [System ]:
64+ assert isinstance (systems , list )
65+ assert isinstance (systems [0 ], torch .ScriptObject )
66+
67+ can_use_nvalchemi = HAS_NVALCHEMIOPS and all (
68+ system .device .type == "cuda" for system in systems
6469 )
65- else :
70+
71+ if can_use_nvalchemi :
72+ # Do the full neighbor lists with nvalchemi
73+ systems = _compute_requested_neighbors_nvalchemi (
74+ systems = systems ,
75+ requested_options = self ._full_nl_options ,
76+ )
77+ else :
78+ systems = _compute_requested_neighbors_vesin (
79+ systems = systems ,
80+ calculators = self ._full_vesin_calculators ,
81+ )
82+
83+ # always compute the half neighbor lists with vesin
6684 systems = _compute_requested_neighbors_vesin (
6785 systems = systems ,
68- requested_options = requested_options ,
69- check_consistency = check_consistency ,
86+ calculators = self ._half_vesin_calculators ,
7087 )
7188
72- return systems
89+ return systems
7390
7491
7592def _compute_requested_neighbors_vesin (
7693 systems : List [System ],
77- requested_options : List [NeighborListOptions ],
78- check_consistency = False ,
94+ calculators : List [vesin .metatomic .NeighborList ],
7995) -> List [System ]:
80- """
81- Compute all neighbor lists requested by ``model`` and store them inside the systems,
82- using vesin.
83- """
84-
8596 system_devices = []
8697 moved_systems = []
8798 for system in systems :
@@ -91,12 +102,13 @@ def _compute_requested_neighbors_vesin(
91102 else :
92103 moved_systems .append (system )
93104
94- vesin .metatomic .compute_requested_neighbors_from_options (
95- systems = moved_systems ,
96- system_length_unit = "angstrom" ,
97- options = requested_options ,
98- check_consistency = check_consistency ,
99- )
105+ for calculator in calculators :
106+ calculator .add_neighbor_list (
107+ systems = moved_systems ,
108+ # if we have more than one system, we can no keep the data as a reference
109+ # to memory allocated in the calculator and we need to make a copy
110+ copy = len (systems ) > 1 ,
111+ )
100112
101113 systems = []
102114 for system , device in zip (moved_systems , system_devices , strict = True ):
@@ -142,6 +154,7 @@ def _compute_requested_neighbors_nvalchemi(systems, requested_options):
142154 "cell_shift_c" ,
143155 ],
144156 values = torch .hstack ([P , S ]),
157+ assume_unique = True ,
145158 ),
146159 components = [
147160 Labels ("xyz" , torch .tensor ([[0 ], [1 ], [2 ]], device = system .device ))
0 commit comments