Skip to content

Commit f092e06

Browse files
Implemented caching for CG/W3j symbols
1 parent 06322f4 commit f092e06

2 files changed

Lines changed: 54 additions & 17 deletions

File tree

mala/common/parameters.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -585,9 +585,6 @@ def __init__(self):
585585

586586
# Other value could be "wigner3j".
587587
self.ace_coupling_coefficients_type = "clebsch_gordan"
588-
589-
# TODO: Implement a check in the ace.py class, so that symbols
590-
# are recomputed if a larger lmax is requested.
591588
self.ace_coupling_coefficients_maximum_l = 12
592589

593590
@property

mala/descriptors/ace.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""ACE descriptor class."""
22

3+
import glob
34
import os
45
import itertools
56
import sys
@@ -807,18 +808,40 @@ def __init_wigner_3j(self, lmax):
807808
Parameters
808809
----------
809810
lmax : int
810-
Maximum l value. ACE_DOCS_MISSING: What is l?
811+
Maximum l value for precomputation.
811812
812813
Returns
813814
-------
814815
wigner : dict
815816
Dictionary of all Wigner 3j coefficients to be used at a given
816817
value of lmax.
817818
"""
819+
# We only recompute the coefficients if there is no file we can use.
820+
# A file we can use can either mean one for the exact same lmax or
821+
# one for a higher lmax.
822+
raw_list = glob.glob(
823+
os.path.join(
824+
os.path.dirname(os.path.abspath(__file__)),
825+
"wig_*.pkl",
826+
)
827+
)
828+
lmaxes = [
829+
int(os.path.basename(x.split("_")[1].split(".")[0]))
830+
for x in raw_list
831+
]
832+
if len(lmaxes) > 0:
833+
loaded_lmax = max(lmax, min(lmaxes))
834+
else:
835+
loaded_lmax = lmax
836+
837+
# We try to load the file and recompute, if necessary.
838+
file_name = os.path.join(
839+
os.path.dirname(os.path.abspath(__file__)),
840+
"wig_" + str(loaded_lmax) + ".pkl",
841+
)
842+
818843
try:
819-
with open(
820-
"%s/wig.pkl" % os.path.dirname(os.path.abspath(__file__)), "rb"
821-
) as readinwig:
844+
with open(file_name, "rb") as readinwig:
822845
cg = pickle.load(readinwig)
823846
except FileNotFoundError:
824847
cg = {}
@@ -839,9 +862,7 @@ def __init_wigner_3j(self, lmax):
839862
cg[key] = self._wigner_3j(
840863
l1, m1, l2, m2, l3, m3
841864
)
842-
with open(
843-
"%s/wig.pkl" % os.path.dirname(os.path.abspath(__file__)), "wb"
844-
) as writewig:
865+
with open(file_name, "wb") as writewig:
845866
pickle.dump(cg, writewig)
846867
return cg
847868

@@ -859,18 +880,39 @@ def __init_clebsch_gordan(self, lmax):
859880
Parameters
860881
----------
861882
lmax : int
862-
Maximum l value. ACE_DOCS_MISSING: What is l?
883+
Maximum l value for precomputation.
863884
864885
Returns
865886
-------
866887
cg : dict
867888
Dictionary of all Clebsch-Gordan coefficients to be used at a given
868889
value of lmax.
869890
"""
891+
# We only recompute the coefficients if there is no file we can use.
892+
# A file we can use can either mean one for the exact same lmax or
893+
# one for a higher lmax.
894+
raw_list = glob.glob(
895+
os.path.join(
896+
os.path.dirname(os.path.abspath(__file__)),
897+
"cg_*.pkl",
898+
)
899+
)
900+
lmaxes = [
901+
int(os.path.basename(x.split("_")[1].split(".")[0]))
902+
for x in raw_list
903+
]
904+
if len(lmaxes) > 0:
905+
loaded_lmax = max(lmax, min(lmaxes))
906+
else:
907+
loaded_lmax = lmax
908+
909+
# We try to load the file and recompute, if necessary.
910+
file_name = os.path.join(
911+
os.path.dirname(os.path.abspath(__file__)),
912+
"cg_" + str(loaded_lmax) + ".pkl",
913+
)
870914
try:
871-
with open(
872-
"%s/cg.pkl" % os.path.dirname(os.path.abspath(__file__)), "rb"
873-
) as readincg:
915+
with open(file_name, "rb") as readincg:
874916
cg = pickle.load(readincg)
875917
except FileNotFoundError:
876918
cg = {}
@@ -891,9 +933,7 @@ def __init_clebsch_gordan(self, lmax):
891933
cg[key] = self._clebsch_gordan(
892934
l1, m1, l2, m2, l3, m3
893935
)
894-
with open(
895-
"%s/cg.pkl" % os.path.dirname(os.path.abspath(__file__)), "wb"
896-
) as writecg:
936+
with open(file_name, "wb") as writecg:
897937
pickle.dump(cg, writecg)
898938
# pickle.dump(cg,'cg.pkl')
899939
return cg

0 commit comments

Comments
 (0)