44from dataclasses import dataclass
55import functools
66import math
7+ import numpy as np
78import os
89from pathlib import Path
910import random
@@ -34,6 +35,9 @@ def zipf_weights(size):
3435class DistributionGenerator :
3536 root3 = math .sqrt (3 )
3637
38+ def __init__ (self ):
39+ self .np_gen = np .random .default_rng ()
40+
3741 def uniform (self , low , high ) -> float :
3842 return random .uniform (float (low ), float (high ))
3943
@@ -45,6 +49,9 @@ def uniform_ms(self, mean, sd) -> float:
4549 def normal (self , mean , sd ) -> float :
4650 return random .normalvariate (float (mean ), float (sd ))
4751
52+ def lognormal (self , logmean , logsd ) -> float :
53+ return random .lognormvariate (float (logmean ), float (logsd ))
54+
4855 def choice (self , a ):
4956 c = random .choice (a )
5057 return c ["value" ] if type (c ) is dict and "value" in c else c
@@ -55,9 +62,68 @@ def zipf_choice(self, a, n=None):
5562 c = random .choices (a , weights = zipf_weights (n ))[0 ]
5663 return c ["value" ] if type (c ) is dict and "value" in c else c
5764
65+ def weighted_choice (self , a : list [dict [str , any ]]) -> list [any ]:
66+ """
67+ Choice weighted by the count in the original dataset.
68+ :param a: a list of dicts, each with a ``value`` key
69+ holding the value to be returned and a ``count`` key holding the
70+ number of that value found in the original dataset
71+ """
72+ vs = []
73+ counts = []
74+ for vc in a :
75+ count = vc .get ("count" , 0 )
76+ if count :
77+ counts .append (count )
78+ vs .append (vc .get ("value" , None ))
79+ c = random .choices (vs , weights = counts )[0 ]
80+ return c
81+
5882 def constant (self , value ):
5983 return value
6084
85+ def multivariate_normal_np (self , cov ):
86+ rank = int (cov ["rank" ])
87+ mean = [
88+ float (cov [f"m{ i } " ])
89+ for i in range (rank )
90+ ]
91+ covs = [
92+ [
93+ float (cov [f"c{ i } _{ j } " ] if i <= j else cov [f"c{ j } _{ i } " ])
94+ for i in range (rank )
95+ ]
96+ for j in range (rank )
97+ ]
98+ return self .np_gen .multivariate_normal (mean , covs )
99+
100+ def multivariate_normal (self , cov ):
101+ """
102+ Produce a list of values pulled from a multivariate distribution.
103+
104+ :param cov: A dict with various keys: ``rank`` is the number of
105+ output values, ``m0``, ``m1``, ... are the means of the
106+ distributions (``rank`` of them). ``c0_0``, ``c0_1``, ``c1_1``, ...
107+ are the covariates, ``cN_M`` is the covariate of the ``N``th and
108+ ``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``.
109+ :return: list of ``rank`` floating point values
110+ """
111+ return self .multivariate_normal_np (cov ).tolist ()
112+
113+ def multivariate_lognormal (self , cov ):
114+ """
115+ Produce a list of values pulled from a multivariate distribution.
116+
117+ :param cov: A dict with various keys: ``rank`` is the number of
118+ output values, ``m0``, ``m1``, ... are the means of the
119+ distributions (``rank`` of them). ``c0_0``, ``c0_1``, ``c1_1``, ...
120+ are the covariates, ``cN_M`` is the covariate of the ``N``th and
121+ ``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``. These
122+ are all the means and covariants of the logs of the data.
123+ :return: list of ``rank`` floating point values
124+ """
125+ return np .exp (self .multivariate_normal_np (cov )).tolist ()
126+
61127
62128class TableGenerator (ABC ):
63129 """Abstract base class for table generator classes."""
0 commit comments