@@ -31,6 +31,38 @@ def zipf_weights(size):
3131 for n in range (1 , size + 1 )
3232 ]
3333
34+ def merge_with_constants (xs : list , constants_at : dict [int , any ]):
35+ """
36+ Merge a list of items with other items that must be placed at certain indices.
37+ :param constants_at: A map of indices to objects that must be placed at
38+ those indices.
39+ :param xs: Items that fill in the gaps left by ``constants_at``.
40+ :return: ``xs`` with ``constants_at`` inserted at the appropriate
41+ points. If there are not enough elements in ``xs`` to fill in the gaps
42+ in ``constants_at``, the elements of ``constants_at`` after the gap
43+ are dropped.
44+ """
45+ outi = 0
46+ xi = 0
47+ constant_count = len (constants_at )
48+ while constant_count != 0 :
49+ if outi in constants_at :
50+ yield constants_at [outi ]
51+ constant_count -= 1
52+ else :
53+ if xi == len (xs ):
54+ return
55+ yield xs [xi ]
56+ xi += 1
57+ outi += 1
58+ for x in xs [xi :]:
59+ yield x
60+
61+
62+ class NothingToGenerateException (Exception ):
63+ def __init__ (self , message ):
64+ super ().__init__ (message )
65+
3466
3567class DistributionGenerator :
3668 root3 = math .sqrt (3 )
@@ -84,6 +116,8 @@ def constant(self, value):
84116
85117 def multivariate_normal_np (self , cov ):
86118 rank = int (cov ["rank" ])
119+ if rank == 0 :
120+ return np .empty (shape = (0 ,))
87121 mean = [
88122 float (cov [f"m{ i } " ])
89123 for i in range (rank )
@@ -97,6 +131,48 @@ def multivariate_normal_np(self, cov):
97131 ]
98132 return self .np_gen .multivariate_normal (mean , covs )
99133
134+ def _select_group (self , alts : list [dict [str , any ]]):
135+ """
136+ Choose one of the ``alts`` weighted by their ``"count"`` elements.
137+ """
138+ total = 0
139+ for alt in alts :
140+ if alt ["count" ] < 0 :
141+ logger .warning ("Alternative count is %d, but should not be negative" , alt ["count" ])
142+ else :
143+ total += alt ["count" ]
144+ if total == 0 :
145+ raise NothingToGenerateException ("No counts in any alternative" )
146+ choice = random .randrange (total )
147+ for alt in alts :
148+ choice -= alt ["count" ]
149+ if choice < 0 :
150+ return alt
151+ raise Exception ("Internal error: ran out of choices in _select_group" )
152+
153+ def _find_constants (self , result : dict [str , any ]):
154+ """
155+ Find all keys ``kN``, returning a dictionary of ``N: kNN``.
156+
157+ This can be passed into ``merge_with_constants`` as the
158+ ``constants_at`` argument.
159+ """
160+ out : dict [int , any ] = {}
161+ for k , v in result .items ():
162+ if k .startswith ("k" ) and k [1 :].isnumeric ():
163+ out [int (k [1 :])] = v
164+ return out
165+
166+ PERMITTED_SUBGENS = {
167+ "multivariate_lognormal" ,
168+ "multivariate_normal" ,
169+ "grouped_multivariate_lognormal" ,
170+ "grouped_multivariate_normal" ,
171+ "constant" ,
172+ "weighted_choice" ,
173+ "with_constants_at" ,
174+ }
175+
100176 def multivariate_normal (self , cov ):
101177 """
102178 Produce a list of values pulled from a multivariate distribution.
@@ -124,6 +200,69 @@ def multivariate_lognormal(self, cov):
124200 """
125201 return np .exp (self .multivariate_normal_np (cov )).tolist ()
126202
203+ def grouped_multivariate_normal (self , covs ):
204+ cov = self ._select_group (covs )
205+ logger .debug ("Multivariate normal group selected: %s" , cov )
206+ constants = self ._find_constants (cov )
207+ nums = self .multivariate_normal (cov )
208+ return list (merge_with_constants (nums , constants ))
209+
210+ def grouped_multivariate_lognormal (self , covs ):
211+ cov = self ._select_group (covs )
212+ logger .debug ("Multivariate lognormal group selected: %s" , cov )
213+ constants = self ._find_constants (cov )
214+ nums = np .exp (self .multivariate_normal_np (cov )).tolist ()
215+ return list (merge_with_constants (nums , constants ))
216+
217+ def _check_generator_name (self , name : str ) -> None :
218+ if name not in self .PERMITTED_SUBGENS :
219+ raise Exception ("%s is not a permitted generator" , name )
220+
221+ def alternatives (self , alternative_configs : list [dict [str , any ]], counts : list [int ] | None ):
222+ """
223+ A generator that picks between other generators.
224+
225+ :param alternative_configs: List of alternative generators.
226+ Each alternative has the following keys: "count" -- a weight for
227+ how often to use this alternative; "name" -- which generator
228+ for this partition, for example "composite"; "params" -- the
229+ parameters for this alternative.
230+ :return: list of values
231+ """
232+ if counts is not None :
233+ while True :
234+ count = self ._select_group (counts )
235+ alt = alternative_configs [count ["index" ]]
236+ name = alt ["name" ]
237+ self ._check_generator_name (name )
238+ try :
239+ return getattr (self , name )(** alt ["params" ])
240+ except NothingToGenerateException :
241+ # Prevent this alternative from being chosen again
242+ count ["count" ] = 0
243+ alt = self ._select_group (alternative_configs )
244+ name = alt ["name" ]
245+ self ._check_generator_name (name )
246+ return getattr (self , name )(** alt ["params" ])
247+
248+ def with_constants_at (self , constants_at : list [int ], subgen : str , params : dict [str , any ]):
249+ if subgen not in self .PERMITTED_SUBGENS :
250+ logger .error (
251+ "subgenerator %s is not a valid name. Valid names are %s." ,
252+ subgen ,
253+ self .PERMITTED_SUBGENS ,
254+ )
255+ subout = getattr (self , subgen )(** params )
256+ logger .debug ("Merging constants %s" , constants_at )
257+ return list (merge_with_constants (subout , constants_at ))
258+
259+ def truncated_string (self , subgen_fn , params , length ):
260+ """ Calls ``subgen_fn(**params)`` and truncates the results to ``length``. """
261+ result = subgen_fn (** params )
262+ if result is None :
263+ return None
264+ return result [:length ]
265+
127266
128267class TableGenerator (ABC ):
129268 """Abstract base class for table generator classes."""
0 commit comments