Skip to content

Commit 5ccec12

Browse files
tim-bandTim Band
andauthored
Entity-Attribute Values (#58)
* null-partitioned grouped lognormal plus sampled and suppressed * VARCHAR(N) generators truncate results * Updated health_data documentation * #59 Foreign Keys to ignored tables supported Co-authored-by: Tim Band <t.b@ucl>
1 parent 55e4b95 commit 5ccec12

15 files changed

Lines changed: 1706 additions & 731 deletions

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM python:3.13.3-alpine3.21
1+
FROM python:3.13.3-alpine3.22
22
RUN apk add bash poetry
33
WORKDIR /app
44
ADD . /app
@@ -11,4 +11,4 @@ SHELL ["/bin/bash", "-c"]
1111
# The redirect to /dev/null seems to help shellingham detect bash!
1212
RUN poetry run datafaker --install-completion > /dev/null
1313
WORKDIR /data
14-
CMD ["poetry", "--directory=/app", "shell"]
14+
CMD ["bash", "-c", "source $(poetry -C /app env info --path)/bin/activate;bash"]

datafaker/base.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,38 @@ def zipf_weights(size):
3131
for n in range(1, size + 1)
3232
]
3333

34+
def merge_with_constants(xs: list, constants_at: dict[int, any]):
35+
"""
36+
Merge a list of items with other items that must be placed at certain indices.
37+
:param constants_at: A map of indices to objects that must be placed at
38+
those indices.
39+
:param xs: Items that fill in the gaps left by ``constants_at``.
40+
:return: ``xs`` with ``constants_at`` inserted at the appropriate
41+
points. If there are not enough elements in ``xs`` to fill in the gaps
42+
in ``constants_at``, the elements of ``constants_at`` after the gap
43+
are dropped.
44+
"""
45+
outi = 0
46+
xi = 0
47+
constant_count = len(constants_at)
48+
while constant_count != 0:
49+
if outi in constants_at:
50+
yield constants_at[outi]
51+
constant_count -= 1
52+
else:
53+
if xi == len(xs):
54+
return
55+
yield xs[xi]
56+
xi += 1
57+
outi += 1
58+
for x in xs[xi:]:
59+
yield x
60+
61+
62+
class NothingToGenerateException(Exception):
63+
def __init__(self, message):
64+
super().__init__(message)
65+
3466

3567
class DistributionGenerator:
3668
root3 = math.sqrt(3)
@@ -84,6 +116,8 @@ def constant(self, value):
84116

85117
def multivariate_normal_np(self, cov):
86118
rank = int(cov["rank"])
119+
if rank == 0:
120+
return np.empty(shape=(0,))
87121
mean = [
88122
float(cov[f"m{i}"])
89123
for i in range(rank)
@@ -97,6 +131,48 @@ def multivariate_normal_np(self, cov):
97131
]
98132
return self.np_gen.multivariate_normal(mean, covs)
99133

134+
def _select_group(self, alts: list[dict[str, any]]):
135+
"""
136+
Choose one of the ``alts`` weighted by their ``"count"`` elements.
137+
"""
138+
total = 0
139+
for alt in alts:
140+
if alt["count"] < 0:
141+
logger.warning("Alternative count is %d, but should not be negative", alt["count"])
142+
else:
143+
total += alt["count"]
144+
if total == 0:
145+
raise NothingToGenerateException("No counts in any alternative")
146+
choice = random.randrange(total)
147+
for alt in alts:
148+
choice -= alt["count"]
149+
if choice < 0:
150+
return alt
151+
raise Exception("Internal error: ran out of choices in _select_group")
152+
153+
def _find_constants(self, result: dict[str, any]):
154+
"""
155+
Find all keys ``kN``, returning a dictionary of ``N: kNN``.
156+
157+
This can be passed into ``merge_with_constants`` as the
158+
``constants_at`` argument.
159+
"""
160+
out: dict[int, any] = {}
161+
for k, v in result.items():
162+
if k.startswith("k") and k[1:].isnumeric():
163+
out[int(k[1:])] = v
164+
return out
165+
166+
PERMITTED_SUBGENS = {
167+
"multivariate_lognormal",
168+
"multivariate_normal",
169+
"grouped_multivariate_lognormal",
170+
"grouped_multivariate_normal",
171+
"constant",
172+
"weighted_choice",
173+
"with_constants_at",
174+
}
175+
100176
def multivariate_normal(self, cov):
101177
"""
102178
Produce a list of values pulled from a multivariate distribution.
@@ -124,6 +200,69 @@ def multivariate_lognormal(self, cov):
124200
"""
125201
return np.exp(self.multivariate_normal_np(cov)).tolist()
126202

203+
def grouped_multivariate_normal(self, covs):
204+
cov = self._select_group(covs)
205+
logger.debug("Multivariate normal group selected: %s", cov)
206+
constants = self._find_constants(cov)
207+
nums = self.multivariate_normal(cov)
208+
return list(merge_with_constants(nums, constants))
209+
210+
def grouped_multivariate_lognormal(self, covs):
211+
cov = self._select_group(covs)
212+
logger.debug("Multivariate lognormal group selected: %s", cov)
213+
constants = self._find_constants(cov)
214+
nums = np.exp(self.multivariate_normal_np(cov)).tolist()
215+
return list(merge_with_constants(nums, constants))
216+
217+
def _check_generator_name(self, name: str) -> None:
218+
if name not in self.PERMITTED_SUBGENS:
219+
raise Exception("%s is not a permitted generator", name)
220+
221+
def alternatives(self, alternative_configs: list[dict[str, any]], counts: list[int] | None):
222+
"""
223+
A generator that picks between other generators.
224+
225+
:param alternative_configs: List of alternative generators.
226+
Each alternative has the following keys: "count" -- a weight for
227+
how often to use this alternative; "name" -- which generator
228+
for this partition, for example "composite"; "params" -- the
229+
parameters for this alternative.
230+
:return: list of values
231+
"""
232+
if counts is not None:
233+
while True:
234+
count = self._select_group(counts)
235+
alt = alternative_configs[count["index"]]
236+
name = alt["name"]
237+
self._check_generator_name(name)
238+
try:
239+
return getattr(self, name)(**alt["params"])
240+
except NothingToGenerateException:
241+
# Prevent this alternative from being chosen again
242+
count["count"] = 0
243+
alt = self._select_group(alternative_configs)
244+
name = alt["name"]
245+
self._check_generator_name(name)
246+
return getattr(self, name)(**alt["params"])
247+
248+
def with_constants_at(self, constants_at: list[int], subgen: str, params: dict[str, any]):
249+
if subgen not in self.PERMITTED_SUBGENS:
250+
logger.error(
251+
"subgenerator %s is not a valid name. Valid names are %s.",
252+
subgen,
253+
self.PERMITTED_SUBGENS,
254+
)
255+
subout = getattr(self, subgen)(**params)
256+
logger.debug("Merging constants %s", constants_at)
257+
return list(merge_with_constants(subout, constants_at))
258+
259+
def truncated_string(self, subgen_fn, params, length):
260+
""" Calls ``subgen_fn(**params)`` and truncates the results to ``length``. """
261+
result = subgen_fn(**params)
262+
if result is None:
263+
return None
264+
return result[:length]
265+
127266

128267
class TableGenerator(ABC):
129268
"""Abstract base class for table generator classes."""

0 commit comments

Comments
 (0)