Skip to content

Commit f6ee44b

Browse files
Fixed wrong imports
1 parent 65ebdb8 commit f6ee44b

2 files changed

Lines changed: 102 additions & 80 deletions

File tree

test/align_ldos_test.py

Lines changed: 90 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import numpy as np
44

55
import mala
6-
from mala.datahandling.data_repo import data_path_be
6+
from mala.datahandling.data_repo import data_path_be, data_path_bao
7+
8+
accuracy_band_energy = 1
9+
accuracy_strict = 1e-16
710

811

912
class TestSplitting:
@@ -96,9 +99,9 @@ def test_ldos_alignment(self):
9699
# initialize and add snapshots to workflow
97100
ldos_aligner = mala.LDOSAligner(parameters)
98101
ldos_aligner.clear_data()
99-
ldos_aligner.add_snapshot("Be_snapshot0.out.npy", data_path)
100-
ldos_aligner.add_snapshot("Be_snapshot1.out.npy", data_path)
101-
ldos_aligner.add_snapshot("Be_snapshot2.out.npy", data_path)
102+
ldos_aligner.add_snapshot("Be_snapshot0.out.npy", data_path_be)
103+
ldos_aligner.add_snapshot("Be_snapshot1.out.npy", data_path_be)
104+
ldos_aligner.add_snapshot("Be_snapshot2.out.npy", data_path_be)
102105

103106
# align and cut the snapshots from the left and right-hand sides
104107
ldos_aligner.align_ldos_to_ref(
@@ -107,6 +110,7 @@ def test_ldos_alignment(self):
107110

108111
try:
109112
import openpmd_api
113+
110114
use_openpmd = True
111115
except ImportError:
112116
use_openpmd = False
@@ -115,24 +119,95 @@ def test_ldos_alignment(self):
115119
# initialize and add snapshots to workflow
116120
ldos_aligner = mala.LDOSAligner(parameters)
117121
ldos_aligner.clear_data()
118-
ldos_aligner.add_snapshot("Be_snapshot0.out.h5",
119-
data_path, snapshot_type='openpmd')
120-
ldos_aligner.add_snapshot("Be_snapshot1.out.h5",
121-
data_path, snapshot_type='openpmd')
122-
ldos_aligner.add_snapshot("Be_snapshot2.out.h5",
123-
data_path, snapshot_type='openpmd')
122+
ldos_aligner.add_snapshot(
123+
"Be_snapshot0.out.h5", data_path_be, snapshot_type="openpmd"
124+
)
125+
ldos_aligner.add_snapshot(
126+
"Be_snapshot1.out.h5", data_path_be, snapshot_type="openpmd"
127+
)
128+
ldos_aligner.add_snapshot(
129+
"Be_snapshot2.out.h5", data_path_be, snapshot_type="openpmd"
130+
)
124131

125132
# align and cut the snapshots from the left and right-hand sides
126133
ldos_aligner.align_ldos_to_ref(
127-
left_truncate=True, right_truncate_value=11, number_of_electrons=4
134+
left_truncate=True,
135+
right_truncate_value=11,
136+
number_of_electrons=4,
128137
)
129138

130139
parameters = mala.Parameters()
131140
data_handler = mala.DataHandler(parameters)
132141
for i in range(1, 4):
133-
data_openpmd = data_handler.target_calculator.read_from_openpmd_file(
134-
f"{data_path}/aligned/Be_snapshot0.out.h5")
135-
data_numpy = data_handler.target_calculator.read_from_numpy_file(
136-
f"{data_path}/aligned/Be_snapshot0.out.npy")
142+
data_openpmd = (
143+
data_handler.target_calculator.read_from_openpmd_file(
144+
f"{data_path_be}/aligned/Be_snapshot0.out.h5"
145+
)
146+
)
147+
data_numpy = (
148+
data_handler.target_calculator.read_from_numpy_file(
149+
f"{data_path_be}/aligned/Be_snapshot0.out.npy"
150+
)
151+
)
137152
if not np.allclose(data_numpy, data_openpmd):
138153
raise Exception("Inconsistency in snapshot", i)
154+
155+
def test_ldos_splitting_multiple_elements(self):
156+
"""
157+
Test that the LDOS splitting works both on LDOS and DOS level.
158+
159+
We compute the band energy with splitted and unsplitted DOS and
160+
compare both to splitted LDOS band energy.
161+
"""
162+
params = mala.Parameters()
163+
164+
params.targets.ldos_gridsize = [12, 13, 14, 28]
165+
params.targets.ldos_gridspacing_ev = [0.5, 0.5, 0.5, 0.5]
166+
params.targets.ldos_gridoffset_ev = [-19, -10.5, -4.5, 3.5]
167+
params.targets.pseudopotential_path = "."
168+
169+
dos_calculator = mala.DOS(params)
170+
dos_calculator.read_additional_calculation_data(
171+
os.path.join(data_path_bao, "BaO_snapshot0.out")
172+
)
173+
dos_calculator.read_from_qe_out(
174+
os.path.join(data_path_bao, "BaO_snapshot0.out"),
175+
smearing_factor=[2, 2, 2, 2],
176+
)
177+
178+
params2 = mala.Parameters()
179+
params2.targets.ldos_gridsize = 73
180+
params2.targets.ldos_gridspacing_ev = 0.5
181+
params2.targets.ldos_gridoffset_ev = -19
182+
params2.targets.pseudopotential_path = "."
183+
dos_calculator_unsplitted = mala.DOS(params2)
184+
dos_calculator_unsplitted.read_additional_calculation_data(
185+
os.path.join(data_path_bao, "BaO_snapshot0.out")
186+
)
187+
188+
dos_calculator_unsplitted.read_from_qe_out(
189+
os.path.join(data_path_bao, "BaO_snapshot0.out"),
190+
smearing_factor=2,
191+
)
192+
193+
params3 = mala.Parameters()
194+
params3.targets.ldos_gridsize = [12, 13, 14, 28]
195+
params3.targets.ldos_gridspacing_ev = [0.5, 0.5, 0.5, 0.5]
196+
params3.targets.ldos_gridoffset_ev = [-19, -10.5, -4.5, 3.5]
197+
198+
ldos_calculator = mala.LDOS.from_numpy_file(
199+
params3, os.path.join(data_path_bao, "BaO_snapshot0.out.npy")
200+
)
201+
ldos_calculator.read_additional_calculation_data(
202+
os.path.join(data_path_bao, "BaO_snapshot0.info.json")
203+
)
204+
assert np.isclose(
205+
dos_calculator.band_energy,
206+
ldos_calculator.band_energy,
207+
atol=accuracy_strict,
208+
)
209+
assert np.isclose(
210+
dos_calculator.band_energy,
211+
dos_calculator_unsplitted.band_energy,
212+
atol=accuracy_band_energy,
213+
)

test/workflow_test.py

Lines changed: 12 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ def test_preprocessing(self):
120120
"correct_output_shape": (48, 48, 48, 64),
121121
}
122122

123+
# Multiple elements only work with LAMMPS.
124+
systems = (
125+
["Be", "BaO"]
126+
if importlib.util.find_spec("lammps") is not None
127+
else ["Be"]
128+
)
129+
123130
for system in ["Be", "BaO"]:
124131
configuration = (
125132
configuration_be if system == "Be" else configuration_bao
@@ -314,66 +321,6 @@ def test_postprocessing_from_dos(self):
314321
atol=accuracy_band_energy,
315322
)
316323

317-
def test_ldos_splitting(self):
318-
"""
319-
Test that the LDOS splitting works both on LDOS and DOS level.
320-
321-
We compute the band energy with splitted and unsplitted DOS and
322-
compare both to splitted LDOS band energy.
323-
"""
324-
params = mala.Parameters()
325-
326-
params.targets.ldos_gridsize = [12, 13, 14, 28]
327-
params.targets.ldos_gridspacing_ev = [0.5, 0.5, 0.5, 0.5]
328-
params.targets.ldos_gridoffset_ev = [-19, -10.5, -4.5, 3.5]
329-
params.targets.pseudopotential_path = "."
330-
331-
dos_calculator = mala.DOS(params)
332-
dos_calculator.read_additional_calculation_data(
333-
os.path.join(data_path_bao, "BaO_snapshot0.out")
334-
)
335-
dos_calculator.read_from_qe_out(
336-
os.path.join(data_path_bao, "BaO_snapshot0.out"),
337-
smearing_factor=[2, 2, 2, 2],
338-
)
339-
340-
params2 = mala.Parameters()
341-
params2.targets.ldos_gridsize = 73
342-
params2.targets.ldos_gridspacing_ev = 0.5
343-
params2.targets.ldos_gridoffset_ev = -19
344-
params2.targets.pseudopotential_path = "."
345-
dos_calculator_unsplitted = mala.DOS(params2)
346-
dos_calculator_unsplitted.read_additional_calculation_data(
347-
os.path.join(data_path_bao, "BaO_snapshot0.out")
348-
)
349-
350-
dos_calculator_unsplitted.read_from_qe_out(
351-
os.path.join(data_path_bao, "BaO_snapshot0.out"),
352-
smearing_factor=2,
353-
)
354-
355-
params3 = mala.Parameters()
356-
params3.targets.ldos_gridsize = [12, 13, 14, 28]
357-
params3.targets.ldos_gridspacing_ev = [0.5, 0.5, 0.5, 0.5]
358-
params3.targets.ldos_gridoffset_ev = [-19, -10.5, -4.5, 3.5]
359-
360-
ldos_calculator = mala.LDOS.from_numpy_file(
361-
params3, os.path.join(data_path_bao, "BaO_snapshot0.out.npy")
362-
)
363-
ldos_calculator.read_additional_calculation_data(
364-
os.path.join(data_path_bao, "BaO_snapshot0.info.json")
365-
)
366-
assert np.isclose(
367-
dos_calculator.band_energy,
368-
ldos_calculator.band_energy,
369-
atol=accuracy_strict,
370-
)
371-
assert np.isclose(
372-
dos_calculator.band_energy,
373-
dos_calculator_unsplitted.band_energy,
374-
atol=accuracy_band_energy,
375-
)
376-
377324
def test_postprocessing(self):
378325
"""
379326
Test whether MALA can postprocess data (from LDOS)
@@ -671,13 +618,13 @@ def test_predictions(self):
671618

672619
def test_model_copying(self):
673620
parameters, network, data_handler, tester = mala.Tester.load_run(
674-
"Be_model", path=data_path
621+
"Be_model", path=data_path_be
675622
)
676623
data_handler.add_snapshot(
677624
"Be_snapshot3.in.npy",
678-
data_path,
625+
data_path_be,
679626
"Be_snapshot3.out.npy",
680-
data_path,
627+
data_path_be,
681628
"te",
682629
)
683630
parameters.manual_seed = 123456252
@@ -686,7 +633,7 @@ def test_model_copying(self):
686633
actual_ldos, predicted_ldos = tester.predict_targets(0)
687634
ldos_calculator = data_handler.target_calculator
688635
ldos_calculator.read_additional_calculation_data(
689-
os.path.join(data_path, "Be_snapshot3.out"), "espresso-out"
636+
os.path.join(data_path_be, "Be_snapshot3.out"), "espresso-out"
690637
)
691638
band_energy_1 = ldos_calculator.get_band_energy(predicted_ldos)
692639

@@ -696,7 +643,7 @@ def test_model_copying(self):
696643
actual_ldos, predicted_ldos = tester.predict_targets(0)
697644
ldos_calculator = data_handler.target_calculator
698645
ldos_calculator.read_additional_calculation_data(
699-
os.path.join(data_path, "Be_snapshot3.out"), "espresso-out"
646+
os.path.join(data_path_be, "Be_snapshot3.out"), "espresso-out"
700647
)
701648
band_energy_2 = ldos_calculator.get_band_energy(predicted_ldos)
702649
print(

0 commit comments

Comments
 (0)