@@ -85,8 +85,6 @@ def add_snapshot(
8585 snapshot_type = snapshot_type ,
8686 )
8787
88- if snapshot_type != "numpy" :
89- raise Exception ("Snapshot type must be numpy for LDOS alignment" )
9088
9189 def align_ldos_to_ref (
9290 self ,
@@ -136,6 +134,9 @@ def align_ldos_to_ref(
136134 comm = get_comm ()
137135 rank = comm .rank
138136 size = comm .size
137+ import sys
138+ print (
139+ "[mala.LDOSAligner.align_ldos_to_ref] Warning: MPI support is experimental. Use with caution." )
139140 else :
140141 comm = None
141142 rank = 0
@@ -146,13 +147,23 @@ def align_ldos_to_ref(
146147 snapshot_ref = self .parameters .snapshot_directories_list [
147148 reference_index
148149 ]
149- ldos_ref = np .load (
150- os .path .join (
151- snapshot_ref .output_npy_directory ,
152- snapshot_ref .output_npy_file ,
153- ),
154- mmap_mode = "r" ,
155- )
150+ if snapshot_ref .snapshot_type == 'numpy' :
151+ ldos_ref = self .target_calculator .read_from_numpy_file (
152+ os .path .join (
153+ snapshot_ref .output_npy_directory ,
154+ snapshot_ref .output_npy_file ,
155+ )
156+ )
157+ elif snapshot_ref .snapshot_type == 'openpmd' :
158+ ldos_ref = self .target_calculator .read_from_openpmd_file (
159+ os .path .join (
160+ snapshot_ref .output_npy_directory ,
161+ snapshot_ref .output_npy_file ,
162+ )
163+ )
164+ else :
165+ raise Exception (
166+ "Unknown snapshot type '{snapshot_ref.snapshot_type}'." )
156167
157168 # get the mean
158169 n_target = ldos_ref .shape [- 1 ]
@@ -205,13 +216,23 @@ def align_ldos_to_ref(
205216 for idx in local_snapshots :
206217 snapshot = self .parameters .snapshot_directories_list [idx ]
207218 print (f"Aligning snapshot { idx + 1 } of { N_snapshots } " )
208- ldos = np .load (
209- os .path .join (
210- snapshot .output_npy_directory ,
211- snapshot .output_npy_file ,
212- ),
213- mmap_mode = "r" ,
214- )
219+ if snapshot_ref .snapshot_type == 'numpy' :
220+ ldos = self .target_calculator .read_from_numpy_file (
221+ os .path .join (
222+ snapshot .output_npy_directory ,
223+ snapshot .output_npy_file ,
224+ )
225+ )
226+ elif snapshot_ref .snapshot_type == 'openpmd' :
227+ ldos = self .target_calculator .read_from_openpmd_file (
228+ os .path .join (
229+ snapshot .output_npy_directory ,
230+ snapshot .output_npy_file ,
231+ )
232+ )
233+ else :
234+ raise Exception (
235+ "Unknown snapshot type '{snapshot_ref.snapshot_type}'." )
215236
216237 # get the mean
217238 nx = ldos .shape [0 ]
@@ -279,30 +300,34 @@ def align_ldos_to_ref(
279300 snapshot .output_npy_directory , save_path_ext
280301 )
281302 save_name = snapshot .output_npy_file
282-
283- stripped_output_file_name = snapshot .output_npy_file .replace (
284- ".out" , ""
285- )
286- ldos_shift_info_save_name = stripped_output_file_name .replace (
287- ".npy" , ".ldos_shift.info.json"
288- )
289-
303+ target_name = os .path .join (save_path , save_name )
290304 os .makedirs (save_path , exist_ok = True )
291305
292- if "*" in save_name :
293- save_name = save_name .replace ("*" , str (idx ))
294- ldos_shift_info_save_name .replace ("*" , str (idx ))
306+ self .ldos_parameters .ldos_gridsize = ldos_shifted .shape [- 1 ]
295307
296- target_name = os .path .join (save_path , save_name )
308+ if snapshot .snapshot_type == 'numpy' :
309+ stripped_output_file_name = snapshot .output_npy_file .replace (
310+ ".out" , ""
311+ )
312+ ldos_shift_info_save_name = stripped_output_file_name .replace (
313+ ".npy" , ".ldos_shift.info.json"
314+ )
297315
298- self . target_calculator . write_to_numpy_file (
299- target_name , ldos_shifted
300- )
316+ if "*" in save_name :
317+ save_name = save_name . replace ( "*" , str ( idx ))
318+ ldos_shift_info_save_name . replace ( "*" , str ( idx ) )
301319
302- with open (
303- os .path .join (save_path , ldos_shift_info_save_name ), "w"
304- ) as f :
305- json .dump (ldos_shift_info , f , indent = 2 )
320+ self .target_calculator .write_to_numpy_file (
321+ target_name , ldos_shifted
322+ )
323+
324+ with open (
325+ os .path .join (save_path , ldos_shift_info_save_name ), "w"
326+ ) as f :
327+ json .dump (ldos_shift_info , f , indent = 2 )
328+ else :
329+ self .target_calculator .write_to_openpmd_file (
330+ target_name , ldos_shifted , additional_attributes = ldos_shift_info )
306331
307332 barrier ()
308333
0 commit comments