|
| 1 | +import copy |
1 | 2 | import importlib |
2 | 3 | import os |
3 | 4 |
|
@@ -459,6 +460,41 @@ def test_predictions(self): |
459 | 460 | atol=accuracy_strict, |
460 | 461 | ) |
461 | 462 |
|
| 463 | + def test_model_copying(self): |
| 464 | + parameters, network, data_handler, tester = mala.Tester.load_run( |
| 465 | + "Be_model", path=data_path |
| 466 | + ) |
| 467 | + data_handler.add_snapshot( |
| 468 | + "Be_snapshot3.in.npy", |
| 469 | + data_path, |
| 470 | + "Be_snapshot3.out.npy", |
| 471 | + data_path, |
| 472 | + "te", |
| 473 | + ) |
| 474 | + data_handler.prepare_data(reparametrize_scaler=False) |
| 475 | + |
| 476 | + actual_ldos, predicted_ldos = tester.predict_targets(0) |
| 477 | + ldos_calculator = data_handler.target_calculator |
| 478 | + ldos_calculator.read_additional_calculation_data( |
| 479 | + os.path.join(data_path, "Be_snapshot3.out"), "espresso-out" |
| 480 | + ) |
| 481 | + band_energy_1 = ldos_calculator.get_band_energy(predicted_ldos) |
| 482 | + |
| 483 | + copied_network = copy.deepcopy(network) |
| 484 | + tester.network = copied_network |
| 485 | + |
| 486 | + actual_ldos, predicted_ldos = tester.predict_targets(0) |
| 487 | + ldos_calculator = data_handler.target_calculator |
| 488 | + ldos_calculator.read_additional_calculation_data( |
| 489 | + os.path.join(data_path, "Be_snapshot3.out"), "espresso-out" |
| 490 | + ) |
| 491 | + band_energy_2 = ldos_calculator.get_band_energy(predicted_ldos) |
| 492 | + assert np.isclose( |
| 493 | + band_energy_1, |
| 494 | + band_energy_2, |
| 495 | + atol=accuracy_strict, |
| 496 | + ) |
| 497 | + |
462 | 498 | @pytest.mark.skipif( |
463 | 499 | importlib.util.find_spec("total_energy") is None |
464 | 500 | or importlib.util.find_spec("lammps") is None, |
|
0 commit comments