Skip to content

Commit 55e4b95

Browse files
authored
Multi column (#56)
* test_prompts added for configure-generators * Refactored GeneratorCmd to allow multi-column generators: Fixes #54 * configure-generators merge and unmerge commands * multivariate normal and lognormal generator * Added (univariate) lognormal generator * Weighted choice generator
1 parent 0379322 commit 55e4b95

6 files changed

Lines changed: 1017 additions & 209 deletions

File tree

datafaker/base.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from dataclasses import dataclass
55
import functools
66
import math
7+
import numpy as np
78
import os
89
from pathlib import Path
910
import random
@@ -34,6 +35,9 @@ def zipf_weights(size):
3435
class DistributionGenerator:
3536
root3 = math.sqrt(3)
3637

38+
def __init__(self):
39+
self.np_gen = np.random.default_rng()
40+
3741
def uniform(self, low, high) -> float:
3842
return random.uniform(float(low), float(high))
3943

@@ -45,6 +49,9 @@ def uniform_ms(self, mean, sd) -> float:
4549
def normal(self, mean, sd) -> float:
4650
return random.normalvariate(float(mean), float(sd))
4751

52+
def lognormal(self, logmean, logsd) -> float:
53+
return random.lognormvariate(float(logmean), float(logsd))
54+
4855
def choice(self, a):
4956
c = random.choice(a)
5057
return c["value"] if type(c) is dict and "value" in c else c
@@ -55,9 +62,68 @@ def zipf_choice(self, a, n=None):
5562
c = random.choices(a, weights=zipf_weights(n))[0]
5663
return c["value"] if type(c) is dict and "value" in c else c
5764

65+
def weighted_choice(self, a: list[dict[str, any]]) -> list[any]:
66+
"""
67+
Choice weighted by the count in the original dataset.
68+
:param a: a list of dicts, each with a ``value`` key
69+
holding the value to be returned and a ``count`` key holding the
70+
number of that value found in the original dataset
71+
"""
72+
vs = []
73+
counts = []
74+
for vc in a:
75+
count = vc.get("count", 0)
76+
if count:
77+
counts.append(count)
78+
vs.append(vc.get("value", None))
79+
c = random.choices(vs, weights=counts)[0]
80+
return c
81+
5882
def constant(self, value):
5983
return value
6084

85+
def multivariate_normal_np(self, cov):
86+
rank = int(cov["rank"])
87+
mean = [
88+
float(cov[f"m{i}"])
89+
for i in range(rank)
90+
]
91+
covs = [
92+
[
93+
float(cov[f"c{i}_{j}"] if i <= j else cov[f"c{j}_{i}"])
94+
for i in range(rank)
95+
]
96+
for j in range(rank)
97+
]
98+
return self.np_gen.multivariate_normal(mean, covs)
99+
100+
def multivariate_normal(self, cov):
101+
"""
102+
Produce a list of values pulled from a multivariate distribution.
103+
104+
:param cov: A dict with various keys: ``rank`` is the number of
105+
output values, ``m0``, ``m1``, ... are the means of the
106+
distributions (``rank`` of them). ``c0_0``, ``c0_1``, ``c1_1``, ...
107+
are the covariates, ``cN_M`` is the covariate of the ``N``th and
108+
``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``.
109+
:return: list of ``rank`` floating point values
110+
"""
111+
return self.multivariate_normal_np(cov).tolist()
112+
113+
def multivariate_lognormal(self, cov):
114+
"""
115+
Produce a list of values pulled from a multivariate distribution.
116+
117+
:param cov: A dict with various keys: ``rank`` is the number of
118+
output values, ``m0``, ``m1``, ... are the means of the
119+
distributions (``rank`` of them). ``c0_0``, ``c0_1``, ``c1_1``, ...
120+
are the covariates, ``cN_M`` is the covariate of the ``N``th and
121+
``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``. These
122+
are all the means and covariants of the logs of the data.
123+
:return: list of ``rank`` floating point values
124+
"""
125+
return np.exp(self.multivariate_normal_np(cov)).tolist()
126+
61127

62128
class TableGenerator(ABC):
63129
"""Abstract base class for table generator classes."""

0 commit comments

Comments
 (0)