Skip to content

Commit a4aaca1

Browse files
Added test
1 parent cb704fb commit a4aaca1

1 file changed

Lines changed: 36 additions & 0 deletions

File tree

test/workflow_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import importlib
23
import os
34

@@ -459,6 +460,41 @@ def test_predictions(self):
459460
atol=accuracy_strict,
460461
)
461462

463+
def test_model_copying(self):
464+
parameters, network, data_handler, tester = mala.Tester.load_run(
465+
"Be_model", path=data_path
466+
)
467+
data_handler.add_snapshot(
468+
"Be_snapshot3.in.npy",
469+
data_path,
470+
"Be_snapshot3.out.npy",
471+
data_path,
472+
"te",
473+
)
474+
data_handler.prepare_data(reparametrize_scaler=False)
475+
476+
actual_ldos, predicted_ldos = tester.predict_targets(0)
477+
ldos_calculator = data_handler.target_calculator
478+
ldos_calculator.read_additional_calculation_data(
479+
os.path.join(data_path, "Be_snapshot3.out"), "espresso-out"
480+
)
481+
band_energy_1 = ldos_calculator.get_band_energy(predicted_ldos)
482+
483+
copied_network = copy.deepcopy(network)
484+
tester.network = copied_network
485+
486+
actual_ldos, predicted_ldos = tester.predict_targets(0)
487+
ldos_calculator = data_handler.target_calculator
488+
ldos_calculator.read_additional_calculation_data(
489+
os.path.join(data_path, "Be_snapshot3.out"), "espresso-out"
490+
)
491+
band_energy_2 = ldos_calculator.get_band_energy(predicted_ldos)
492+
assert np.isclose(
493+
band_energy_1,
494+
band_energy_2,
495+
atol=accuracy_strict,
496+
)
497+
462498
@pytest.mark.skipif(
463499
importlib.util.find_spec("total_energy") is None
464500
or importlib.util.find_spec("lammps") is None,

0 commit comments

Comments
 (0)