Skip to content

Commit 4b52b85

Browse files
Merge pull request mala-project#653 from franzpoeschel/align-ldos-openpmd
Align LDOS: Read from and write to openPMD
2 parents ab58c79 + 8a2671e commit 4b52b85

3 files changed

Lines changed: 143 additions & 35 deletions

File tree

examples/advanced/ex09_align_ldos.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22

3+
import numpy
4+
35
import mala
46

57
from mala.datahandling.data_repo import data_path
@@ -30,3 +32,31 @@
3032
ldos_aligner.align_ldos_to_ref(
3133
left_truncate=True, right_truncate_value=11, number_of_electrons=4
3234
)
35+
36+
# The same thing works also with openPMD-based data
37+
38+
try:
39+
import openpmd_api
40+
use_openpmd = True
41+
except ImportError:
42+
use_openpmd = False
43+
44+
if use_openpmd:
45+
# initialize and add snapshots to workflow
46+
ldos_aligner = mala.LDOSAligner(parameters)
47+
ldos_aligner.clear_data()
48+
ldos_aligner.add_snapshot("Be_snapshot0.out.h5",
49+
data_path, snapshot_type='openpmd')
50+
ldos_aligner.add_snapshot("Be_snapshot1.out.h5",
51+
data_path, snapshot_type='openpmd')
52+
ldos_aligner.add_snapshot("Be_snapshot2.out.h5",
53+
data_path, snapshot_type='openpmd')
54+
55+
# align and cut the snapshots from the left and right-hand sides
56+
ldos_aligner.align_ldos_to_ref(
57+
left_truncate=True, right_truncate_value=11, number_of_electrons=4
58+
)
59+
60+
# A test that checks for data equivalence between the Numpy-based
61+
# and openPMD-based implementations can be found under
62+
# test/align_ldos_test.py.

mala/datahandling/ldos_aligner.py

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,6 @@ def add_snapshot(
8585
snapshot_type=snapshot_type,
8686
)
8787

88-
if snapshot_type != "numpy":
89-
raise Exception("Snapshot type must be numpy for LDOS alignment")
9088

9189
def align_ldos_to_ref(
9290
self,
@@ -136,6 +134,9 @@ def align_ldos_to_ref(
136134
comm = get_comm()
137135
rank = comm.rank
138136
size = comm.size
137+
import sys
138+
print(
139+
"[mala.LDOSAligner.align_ldos_to_ref] Warning: MPI support is experimental. Use with caution.")
139140
else:
140141
comm = None
141142
rank = 0
@@ -146,13 +147,23 @@ def align_ldos_to_ref(
146147
snapshot_ref = self.parameters.snapshot_directories_list[
147148
reference_index
148149
]
149-
ldos_ref = np.load(
150-
os.path.join(
151-
snapshot_ref.output_npy_directory,
152-
snapshot_ref.output_npy_file,
153-
),
154-
mmap_mode="r",
155-
)
150+
if snapshot_ref.snapshot_type == 'numpy':
151+
ldos_ref = self.target_calculator.read_from_numpy_file(
152+
os.path.join(
153+
snapshot_ref.output_npy_directory,
154+
snapshot_ref.output_npy_file,
155+
)
156+
)
157+
elif snapshot_ref.snapshot_type == 'openpmd':
158+
ldos_ref = self.target_calculator.read_from_openpmd_file(
159+
os.path.join(
160+
snapshot_ref.output_npy_directory,
161+
snapshot_ref.output_npy_file,
162+
)
163+
)
164+
else:
165+
raise Exception(
166+
"Unknown snapshot type '{snapshot_ref.snapshot_type}'.")
156167

157168
# get the mean
158169
n_target = ldos_ref.shape[-1]
@@ -205,13 +216,23 @@ def align_ldos_to_ref(
205216
for idx in local_snapshots:
206217
snapshot = self.parameters.snapshot_directories_list[idx]
207218
print(f"Aligning snapshot {idx+1} of {N_snapshots}")
208-
ldos = np.load(
209-
os.path.join(
210-
snapshot.output_npy_directory,
211-
snapshot.output_npy_file,
212-
),
213-
mmap_mode="r",
214-
)
219+
if snapshot_ref.snapshot_type == 'numpy':
220+
ldos = self.target_calculator.read_from_numpy_file(
221+
os.path.join(
222+
snapshot.output_npy_directory,
223+
snapshot.output_npy_file,
224+
)
225+
)
226+
elif snapshot_ref.snapshot_type == 'openpmd':
227+
ldos = self.target_calculator.read_from_openpmd_file(
228+
os.path.join(
229+
snapshot.output_npy_directory,
230+
snapshot.output_npy_file,
231+
)
232+
)
233+
else:
234+
raise Exception(
235+
"Unknown snapshot type '{snapshot_ref.snapshot_type}'.")
215236

216237
# get the mean
217238
nx = ldos.shape[0]
@@ -279,30 +300,34 @@ def align_ldos_to_ref(
279300
snapshot.output_npy_directory, save_path_ext
280301
)
281302
save_name = snapshot.output_npy_file
282-
283-
stripped_output_file_name = snapshot.output_npy_file.replace(
284-
".out", ""
285-
)
286-
ldos_shift_info_save_name = stripped_output_file_name.replace(
287-
".npy", ".ldos_shift.info.json"
288-
)
289-
303+
target_name = os.path.join(save_path, save_name)
290304
os.makedirs(save_path, exist_ok=True)
291305

292-
if "*" in save_name:
293-
save_name = save_name.replace("*", str(idx))
294-
ldos_shift_info_save_name.replace("*", str(idx))
306+
self.ldos_parameters.ldos_gridsize = ldos_shifted.shape[-1]
295307

296-
target_name = os.path.join(save_path, save_name)
308+
if snapshot.snapshot_type == 'numpy':
309+
stripped_output_file_name = snapshot.output_npy_file.replace(
310+
".out", ""
311+
)
312+
ldos_shift_info_save_name = stripped_output_file_name.replace(
313+
".npy", ".ldos_shift.info.json"
314+
)
297315

298-
self.target_calculator.write_to_numpy_file(
299-
target_name, ldos_shifted
300-
)
316+
if "*" in save_name:
317+
save_name = save_name.replace("*", str(idx))
318+
ldos_shift_info_save_name.replace("*", str(idx))
301319

302-
with open(
303-
os.path.join(save_path, ldos_shift_info_save_name), "w"
304-
) as f:
305-
json.dump(ldos_shift_info, f, indent=2)
320+
self.target_calculator.write_to_numpy_file(
321+
target_name, ldos_shifted
322+
)
323+
324+
with open(
325+
os.path.join(save_path, ldos_shift_info_save_name), "w"
326+
) as f:
327+
json.dump(ldos_shift_info, f, indent=2)
328+
else:
329+
self.target_calculator.write_to_openpmd_file(
330+
target_name, ldos_shifted, additional_attributes=ldos_shift_info)
306331

307332
barrier()
308333

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,56 @@ def test_dos_splitting(self):
8383
)
8484
)
8585
)
86+
87+
# This test is based on examples/advanced/ex09_align_ldos.py, but adds a
88+
# test to check that the aligned data is equivalent between the numpy-based
89+
# and the openPMD-based implementations.
90+
def test_ldos_alignment(self):
91+
parameters = mala.Parameters()
92+
parameters.targets.ldos_gridoffset_ev = -5
93+
parameters.targets.ldos_gridsize = 11
94+
parameters.targets.ldos_gridspacing_ev = 2.5
95+
96+
# initialize and add snapshots to workflow
97+
ldos_aligner = mala.LDOSAligner(parameters)
98+
ldos_aligner.clear_data()
99+
ldos_aligner.add_snapshot("Be_snapshot0.out.npy", data_path)
100+
ldos_aligner.add_snapshot("Be_snapshot1.out.npy", data_path)
101+
ldos_aligner.add_snapshot("Be_snapshot2.out.npy", data_path)
102+
103+
# align and cut the snapshots from the left and right-hand sides
104+
ldos_aligner.align_ldos_to_ref(
105+
left_truncate=True, right_truncate_value=11, number_of_electrons=4
106+
)
107+
108+
try:
109+
import openpmd_api
110+
use_openpmd = True
111+
except ImportError:
112+
use_openpmd = False
113+
114+
if use_openpmd:
115+
# initialize and add snapshots to workflow
116+
ldos_aligner = mala.LDOSAligner(parameters)
117+
ldos_aligner.clear_data()
118+
ldos_aligner.add_snapshot("Be_snapshot0.out.h5",
119+
data_path, snapshot_type='openpmd')
120+
ldos_aligner.add_snapshot("Be_snapshot1.out.h5",
121+
data_path, snapshot_type='openpmd')
122+
ldos_aligner.add_snapshot("Be_snapshot2.out.h5",
123+
data_path, snapshot_type='openpmd')
124+
125+
# align and cut the snapshots from the left and right-hand sides
126+
ldos_aligner.align_ldos_to_ref(
127+
left_truncate=True, right_truncate_value=11, number_of_electrons=4
128+
)
129+
130+
parameters = mala.Parameters()
131+
data_handler = mala.DataHandler(parameters)
132+
for i in range(1, 4):
133+
data_openpmd = data_handler.target_calculator.read_from_openpmd_file(
134+
f"{data_path}/aligned/Be_snapshot0.out.h5")
135+
data_numpy = data_handler.target_calculator.read_from_numpy_file(
136+
f"{data_path}/aligned/Be_snapshot0.out.npy")
137+
if not np.allclose(data_numpy, data_openpmd):
138+
raise Exception("Inconsistency in snapshot", i)

0 commit comments

Comments
 (0)