88import numpy as np
99import numpy .typing as npt
1010import pandas as pd
11+ from lmfit import Parameter
1112from numpy .lib .scimath import sqrt
1213
1314from .. import dispersions
@@ -44,6 +45,14 @@ def _guard_invalid_params(params1, params2):
4445 missing_param_strings = ", " .join (f"{ p } " for p in missing_params )
4546 raise InvalidParameters (f"Invalid parameter(s): { missing_param_strings } " )
4647
48+ @staticmethod
49+ def _hash_params (params : dict | list [dict ]) -> int :
50+ """Creates an single_params_dict or the repeating_params_list."""
51+ if isinstance (params , list ):
52+ return hash (tuple ([self ._hash_params (dictionary ) for dictionary in params ]))
53+ else :
54+ return hash (tuple ([item for _ , item in params .items ()]))
55+
4756 @staticmethod
4857 def _fill_params_dict (template : dict , * args , ** kwargs ) -> dict :
4958 BaseDispersion ._guard_invalid_params (list (kwargs .keys ()), list (template .keys ()))
@@ -56,6 +65,8 @@ def _fill_params_dict(template: dict, *args, **kwargs) -> dict:
5665
5766 for i , val in enumerate (args ):
5867 key = list (template .keys ())[i ]
68+ if isinstance (val , Parameter ):
69+ val = val .value
5970 params [key ] = val
6071 pos_arguments .add (key )
6172
@@ -64,6 +75,8 @@ def _fill_params_dict(template: dict, *args, **kwargs) -> dict:
6475 raise InvalidParameters (
6576 f"Parameter { key } already set by positional argument"
6677 )
78+ if isinstance (value , Parameter ):
79+ value = value .value
6780 params [key ] = value
6881
6982 return params
@@ -80,6 +93,10 @@ def __init__(self, *args, **kwargs):
8093 if self .single_params [param ] is None :
8194 raise InvalidParameters (f"Please specify parameter { param } " )
8295
96+ self .last_lbda = None
97+ self .hash_single_params = None
98+ self .hash_rep_params = None
99+
83100 @abstractmethod
84101 def dielectric_function (self , lbda : npt .ArrayLike ) -> npt .NDArray :
85102 """Calculates the dielectric function in a given wavelength window.
@@ -114,6 +131,39 @@ def get_dielectric(self, lbda: Optional[npt.ArrayLike] = None) -> npt.NDArray:
114131 """Returns the dielectric constant for wavelength 'lbda' default unit (nm)
115132 in the convention ε1 + iε2."""
116133 lbda = self .default_lbda_range if lbda is None else lbda
134+
135+ from .table_epsilon import TableEpsilon
136+ from .table_index import Table
137+
138+ if not isinstance (self , (DispersionSum , IndexDispersionSum )):
139+ if isinstance (self , (TableEpsilon , Table )):
140+ if self .last_lbda is lbda :
141+ return self .cached_diel
142+ else :
143+ self .last_lbda = lbda
144+ self .cached_diel = np .asarray (
145+ self .dielectric_function (lbda ), dtype = np .complex128
146+ )
147+ return self .cached_diel
148+ else :
149+ new_single_hash = self ._hash_params (self .single_params )
150+ new_rep_hash = self ._hash_params (self .rep_params )
151+
152+ if (
153+ self .last_lbda is lbda
154+ and self .hash_single_params == new_single_hash
155+ and self .hash_rep_params == new_rep_hash
156+ ):
157+ return self .cached_diel
158+ else :
159+ self .last_lbda = lbda
160+ self .hash_single_params = new_single_hash
161+ self .hash_rep_params = new_rep_hash
162+ self .cached_diel = np .asarray (
163+ self .dielectric_function (lbda ), dtype = np .complex128
164+ )
165+ return self .cached_diel
166+
117167 return np .asarray (self .dielectric_function (lbda ), dtype = np .complex128 )
118168
119169 def get_refractive_index (self , lbda : Optional [npt .ArrayLike ] = None ) -> npt .NDArray :
0 commit comments