Skip to content

Commit c690234

Browse files
Merge pull request mala-project#658 from elcorto/feature-small-improvements
Various small improvements
2 parents c897ceb + e577384 commit c690234

7 files changed

Lines changed: 50 additions & 42 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ cython_debug/
153153
*.out
154154
*.npy
155155
*.pkl
156+
*.pk
156157
*.pth
157158
*.json
158159

examples/advanced/ex01_checkpoint_training.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def initial_setup():
5757
data_handler.output_dimension,
5858
]
5959

60-
test_network = mala.Network(parameters)
61-
test_trainer = mala.Trainer(parameters, test_network, data_handler)
60+
network = mala.Network(parameters)
61+
trainer = mala.Trainer(parameters, network, data_handler)
6262

63-
return parameters, test_network, data_handler, test_trainer
63+
return parameters, network, data_handler, trainer
6464

6565

6666
if mala.Trainer.run_exists("ex01_checkpoint", path="./"):

examples/basic/ex01_train_network.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
100,
102102
data_handler.output_dimension,
103103
]
104-
test_network = mala.Network(parameters)
104+
network = mala.Network(parameters)
105105

106106
####################
107107
# 4. TRAINING THE NETWORK
@@ -113,9 +113,9 @@
113113
# side the model. This makes inference easier.
114114
####################
115115

116-
test_trainer = mala.Trainer(parameters, test_network, data_handler)
117-
test_trainer.train_network()
116+
trainer = mala.Trainer(parameters, network, data_handler)
117+
trainer.train_network()
118118
additional_calculation_data = os.path.join(data_path, "Be_snapshot0.out")
119-
test_trainer.save_run(
119+
trainer.save_run(
120120
"Be_model", additional_calculation_data=additional_calculation_data
121121
)

examples/basic/ex05_run_predictions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from importlib.util import find_spec
23

34
from ase.io import read
45
import mala
@@ -39,7 +40,11 @@
3940
ldos_calculator: mala.LDOS = predictor.target_calculator
4041
ldos_calculator.read_from_array(ldos)
4142
printout("Predicted band energy: ", ldos_calculator.band_energy)
43+
4244
# If the total energy module is installed, the total energy can also be
4345
# calculated.
44-
# parameters.targets.pseudopotential_path = data_path
45-
# printout("Predicted total energy", ldos_calculator.total_energy)
46+
if find_spec("total_energy"):
47+
parameters.targets.pseudopotential_path = data_path
48+
printout("Predicted total energy", ldos_calculator.total_energy)
49+
else:
50+
print("total_energy module not found, skipping total energy calculation")

examples/clean.sh

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,26 @@
11
#!/bin/sh
22

3+
set -u
4+
35
# Remove artifact files that some example scripts write.
46

5-
cd basic
6-
rm -rvf \
7-
*.pth \
8-
*.pkl \
9-
*.db \
10-
*.pw* \
11-
__pycache__ \
12-
*.cube \
13-
ex10_vis \
14-
*.tmp \
15-
*.npy \
16-
*.json \
17-
*.zip \
18-
Be_snapshot* \
19-
lammps*.tmp
20-
cd ..
21-
cd advanced
22-
rm -rvf \
23-
*.pth \
24-
*.pkl \
25-
*.db \
26-
*.pw* \
27-
__pycache__ \
28-
*.cube \
29-
ex10_vis \
30-
*.tmp \
31-
*.npy \
32-
*.json \
33-
*.zip \
34-
Be_snapshot* \
35-
lammps*.tmp
36-
cd ..
7+
here=$(dirname $(readlink -f $0))
8+
for dir in $here $(find $here -mindepth 1 -type d | grep -vE '\.ruff_cache|__pycache__'); do
9+
cd $dir
10+
echo "cleaning: $(pwd)"
11+
rm -rvf \
12+
*.pth \
13+
*.pkl \
14+
*.pk \
15+
*.db \
16+
*.pw* \
17+
*.cube \
18+
ex10_vis \
19+
*.tmp \
20+
*.npy \
21+
*.json \
22+
*.zip \
23+
Be_snapshot* \
24+
lammps*.tmp \
25+
mala_vis
26+
done

pyproject.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
[tool.black]
22
line-length = 79
33

4+
[tool.ruff]
5+
lint.ignore = [
6+
"E731", # Do not assign a lambda expression, use a def
7+
]
8+
9+
[tool.mypy]
10+
ignore_missing_imports = true
11+
implicit_optional = true
12+
follow_imports = "skip"
13+
allow_untyped_calls = true
14+
allow_untyped_defs = true
15+
416
[build-system]
517
requires = ["setuptools"]
618
build-backend = "setuptools.build_meta"

test/clean.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
# Remove artifact files that some example scripts write.
44

5-
rm -rv *.pth *.pkl ex09.db *.pw* __pycache__ *.cube ex10_vis *.tmp *.npy *.json *.h5 *.bp
5+
rm -rv *.pth *.pkl *.db *.pw* __pycache__ *.cube *_vis *.tmp *.npy *.json *.h5 *.bp *.db *.zip

0 commit comments

Comments
 (0)