Skip to content

Commit 840b72b

Browse files
tim-bandTim Band
andauthored
sampled and suppressed choice generators (#48)
* sampled and suppressed choice generators * Fixed problems found trying this out for real. --------- Co-authored-by: Tim Band <t.b@ucl>
1 parent 9ba03f5 commit 840b72b

5 files changed

Lines changed: 229 additions & 39 deletions

File tree

datafaker/generators.py

Lines changed: 114 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from abc import ABC, abstractmethod
66
from collections.abc import Mapping
77
import decimal
8-
import logging
98
import math
109
import mimesis
1110
import mimesis.locales
@@ -42,6 +41,15 @@ class Generator(ABC):
4241
def function_name(self) -> str:
4342
""" The name of the generator function to put into df.py. """
4443

44+
def name(self) -> str:
45+
"""
46+
The name of the generator.
47+
48+
Usually the same as the function name, but can be different to distinguish
49+
between generators that have the same function but different queries.
50+
"""
51+
return self.function_name()
52+
4553
@abstractmethod
4654
def nominal_kwargs(self) -> dict[str, str]:
4755
"""
@@ -642,30 +650,59 @@ def zipf_distribution(total, bins):
642650

643651

644652
class ChoiceGenerator(Generator):
645-
def __init__(self, table_name, column_name, values, counts):
653+
def __init__(
654+
self,
655+
table_name,
656+
column_name,
657+
values,
658+
counts,
659+
sample_count = None,
660+
suppress_count = 0,
661+
):
646662
super().__init__()
647663
self.table_name = table_name
648664
self.column_name = column_name
649665
self.values = values
650666
estimated_counts = self.get_estimated_counts(counts)
651667
self._fit = fit_from_buckets(counts, estimated_counts)
668+
if suppress_count == 0:
669+
if sample_count is None:
670+
self._query = f"SELECT {column_name} AS value FROM {table_name} GROUP BY value ORDER BY COUNT({column_name}) DESC"
671+
self._comment = f"All the values that appear in column {column_name} of table {table_name}"
672+
self._annotation = None
673+
else:
674+
self._query = f"SELECT value FROM (SELECT {column_name} AS value FROM {table_name} ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY value"
675+
self._comment = f"The values that appear in column {column_name} of a random sample of {sample_count} rows of table {table_name}"
676+
self._annotation = "sampled"
677+
else:
678+
if sample_count is None:
679+
self._query = f"SELECT value FROM (SELECT {column_name} AS value, COUNT({column_name}) AS count FROM {table_name} GROUP BY value) AS _inner WHERE {suppress_count} < count"
680+
self._comment = f"All the values that appear in column {column_name} of table {table_name} more than {suppress_count} times"
681+
self._annotation = "suppressed"
682+
else:
683+
self._query = f"SELECT value FROM (SELECT value, COUNT(value) AS count FROM (SELECT {column_name} AS value FROM {table_name} ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY value) AS _inner WHERE {suppress_count} < count"
684+
self._comment = f"The values that appear more than {suppress_count} times in column {column_name}, out of a random sample of {sample_count} rows of table {table_name}"
685+
self._annotation = "sampled and suppressed"
652686
def nominal_kwargs(self):
653687
return {
654688
"a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]',
655689
}
690+
def name(self):
691+
n = super().name()
692+
if self._annotation is None:
693+
return n
694+
return f"{n} [{self._annotation}]"
656695
def actual_kwargs(self):
657696
return {
658697
"a": self.values,
659698
}
660699
def custom_queries(self) -> dict[str, dict[str, str]]:
661700
qs = super().custom_queries()
662-
t = self.table_name
663-
c = self.column_name
664701
return {
665702
**qs,
666-
f"auto__{t}__{c}": {
667-
"query": f"SELECT {c} AS value FROM {t} GROUP BY value ORDER BY COUNT({c}) DESC",
668-
"comment": f"All the values that appear in column {c} of table {t}",
703+
f"auto__{self.table_name}__{self.column_name}": {
704+
"query": self._query,
705+
"comment": self._comment,
669706
}
670707
}
671708
def fit(self, default=None):
@@ -708,9 +745,12 @@ class ChoiceGeneratorFactory(GeneratorFactory):
708745
"""
709746
All generators that want an average and standard deviation.
710747
"""
748+
SAMPLE_COUNT = MAXIMUM_CHOICES
749+
SUPPRESS_COUNT = 5
711750
def get_generators(self, column: Column, engine: Engine):
712751
column_name = column.name
713752
table_name = column.table.name
753+
generators = []
714754
with engine.connect() as connection:
715755
results = connection.execute(
716756
text("SELECT {column} AS v, COUNT({column}) AS f FROM {table} GROUP BY v ORDER BY f DESC LIMIT {limit}".format(
@@ -719,26 +759,73 @@ def get_generators(self, column: Column, engine: Engine):
719759
limit=MAXIMUM_CHOICES+1,
720760
))
721761
)
722-
if results is None or MAXIMUM_CHOICES < results.rowcount:
723-
return []
724-
values = [] # The values found
725-
counts = [] # The number or each value
726-
total = 0 # total number of non-NULL results
727-
for result in results:
728-
c = result.f
729-
if c != 0:
730-
total += c
731-
counts.append(c)
732-
v = result.v
733-
if type(v) is decimal.Decimal:
734-
v = float(v)
735-
values.append(v)
736-
if not counts:
737-
return []
738-
return [
739-
ZipfChoiceGenerator(table_name, column_name, values, counts),
740-
UniformChoiceGenerator(table_name, column_name, values, counts),
741-
]
762+
if results is not None and results.rowcount <= MAXIMUM_CHOICES:
763+
values = [] # The values found
764+
counts = [] # The number or each value
765+
for result in results:
766+
c = result.f
767+
if c != 0:
768+
counts.append(c)
769+
v = result.v
770+
if type(v) is decimal.Decimal:
771+
v = float(v)
772+
values.append(v)
773+
if counts:
774+
generators += [
775+
ZipfChoiceGenerator(table_name, column_name, values, counts),
776+
UniformChoiceGenerator(table_name, column_name, values, counts),
777+
]
778+
results = connection.execute(
779+
text("SELECT v, COUNT(v) AS f FROM (SELECT {column} as v FROM {table} ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY v ORDER BY f DESC".format(
780+
table=table_name,
781+
column=column_name,
782+
sample_count=self.SAMPLE_COUNT,
783+
))
784+
)
785+
if results is not None:
786+
values = [] # All values found
787+
counts = [] # The number or each value
788+
values_not_suppressed = [] # All values found more than SUPPRESS_COUNT times
789+
counts_not_suppressed = [] # The number for each value not suppressed
790+
for result in results:
791+
c = result.f
792+
if c != 0:
793+
counts.append(c)
794+
v = result.v
795+
if type(v) is decimal.Decimal:
796+
v = float(v)
797+
values.append(v)
798+
if self.SUPPRESS_COUNT < c:
799+
counts_not_suppressed.append(c)
800+
v = result.v
801+
if type(v) is decimal.Decimal:
802+
v = float(v)
803+
values_not_suppressed.append(v)
804+
if counts:
805+
generators += [
806+
ZipfChoiceGenerator(table_name, column_name, values, counts, sample_count=self.SAMPLE_COUNT),
807+
UniformChoiceGenerator(table_name, column_name, values, counts, sample_count=self.SAMPLE_COUNT),
808+
]
809+
if counts_not_suppressed:
810+
generators += [
811+
ZipfChoiceGenerator(
812+
table_name,
813+
column_name,
814+
values_not_suppressed,
815+
counts_not_suppressed,
816+
sample_count=self.SAMPLE_COUNT,
817+
suppress_count=self.SUPPRESS_COUNT,
818+
),
819+
UniformChoiceGenerator(
820+
table_name,
821+
column_name,
822+
values_not_suppressed,
823+
counts_not_suppressed,
824+
sample_count=self.SAMPLE_COUNT,
825+
suppress_count=self.SUPPRESS_COUNT,
826+
),
827+
]
828+
return generators
742829

743830

744831
class NullGenerator(Generator):

datafaker/interactive.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -967,7 +967,7 @@ def set_prompt(self):
967967
self.prompt = "({table}.{column} ({generator})) ".format(
968968
table=table_name,
969969
column=column,
970-
generator=gen_info.new_gen.function_name(),
970+
generator=gen_info.new_gen.name(),
971971
)
972972
else:
973973
self.prompt = "({table}.{column}) ".format(
@@ -1037,8 +1037,8 @@ def do_quit(self, _arg):
10371037
self.print(
10381038
"...changing {0} from {1} to {2}",
10391039
gen.column,
1040-
gen.old_gen.function_name() if gen.old_gen else "nothing",
1041-
gen.new_gen.function_name() if gen.new_gen else "nothing",
1040+
gen.old_gen.name() if gen.old_gen else "nothing",
1041+
gen.new_gen.name() if gen.new_gen else "nothing",
10421042
)
10431043
if count == 0:
10441044
self.print("You have made no changes.")
@@ -1063,15 +1063,15 @@ def do_list(self, arg):
10631063
self.print("Error: no table {0}", self.table_index)
10641064
return
10651065
for gen in self.table_entries[self.table_index].generators:
1066-
old = "" if gen.old_gen is None else gen.old_gen.function_name()
1066+
old = "" if gen.old_gen is None else gen.old_gen.name()
10671067
if gen.old_gen == gen.new_gen:
10681068
becomes = ""
10691069
if old == "":
10701070
old = "(not set)"
10711071
elif gen.new_gen is None:
10721072
becomes = "(delete)"
10731073
else:
1074-
becomes = f"->{gen.new_gen.function_name()}"
1074+
becomes = f"->{gen.new_gen.name()}"
10751075
primary = "[primary-key]" if gen.is_primary_key else ""
10761076
self.print("{0}{1}{2} {3}", old, becomes, primary, gen.column)
10771077

@@ -1239,7 +1239,7 @@ def do_compare(self, arg: str):
12391239
n = int(argument)
12401240
if 0 < n and n <= len(gens):
12411241
gen = gens[n - 1]
1242-
comparison[f"{n}. {gen.function_name()}"] = gen.generate_data(limit)
1242+
comparison[f"{n}. {gen.name()}"] = gen.generate_data(limit)
12431243
self._print_values_queried(table_name, n, gen)
12441244
self.print_table_by_columns(comparison)
12451245

@@ -1255,13 +1255,13 @@ def _print_values_queried(self, table_name: str, n: int, gen: Generator):
12551255
self.print(
12561256
"{0}. {1} requires no data from the source database.",
12571257
n,
1258-
gen.function_name(),
1258+
gen.name(),
12591259
)
12601260
else:
12611261
self.print(
12621262
"{0}. {1} requires the following data from the source database:",
12631263
n,
1264-
gen.function_name(),
1264+
gen.name(),
12651265
)
12661266
self._print_select_aggregate_query(table_name, gen)
12671267
self._print_custom_queries(gen)
@@ -1314,9 +1314,9 @@ def _print_select_aggregate_query(self, table_name, gen: Generator) -> None:
13141314
if ak in kwa:
13151315
vals.append(kwa[ak])
13161316
else:
1317-
logger.warning("actual_kwargs for %s does not report %s", gen.function_name(), ak)
1317+
logger.warning("actual_kwargs for %s does not report %s", gen.name(), ak)
13181318
else:
1319-
logger.warning('nominal_kwargs for %s does not have a value SRC_STATS["auto__%s"]["results"][0]["%s"]', gen.function_name(), table_name, n)
1319+
logger.warning('nominal_kwargs for %s does not have a value SRC_STATS["auto__%s"]["results"][0]["%s"]', gen.name(), table_name, n)
13201320
select_q = self._get_aggregate_query([gen], table_name)
13211321
self.print("{0}; providing the following values: {1}", select_q, vals)
13221322

@@ -1365,7 +1365,7 @@ def do_propose(self, _arg):
13651365
self.print(
13661366
self.PROPOSE_GENERATOR_SAMPLE_TEXT,
13671367
index=index + 1,
1368-
name=gen.function_name(),
1368+
name=gen.name(),
13691369
fit=fit_s,
13701370
sample=", ".join(map(repr, gen.generate_data(limit)))
13711371
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "datafaker"
3-
version = "0.2.1"
3+
version = "0.2.2"
44
description = "Generates fake SQL data"
55
authors = ["Tim Band <3266052+tim-band@users.noreply.github.com>"]
66
license = "MIT"

tests/examples/choice.sql

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
-- DROP DATABASE IF EXISTS instrument WITH (FORCE);
2+
CREATE DATABASE numbers WITH TEMPLATE template0 ENCODING = 'UTF8' LOCALE = 'en_US.utf8';
3+
ALTER DATABASE numbers OWNER TO postgres;
4+
5+
\connect numbers
6+
7+
CREATE TABLE public.number_table (
8+
id INTEGER NOT NULL,
9+
one INTEGER NOT NULL,
10+
two INTEGER NOT NULL,
11+
three INTEGER NOT NULL
12+
);
13+
14+
ALTER TABLE ONLY public.number_table ADD CONSTRAINT number_table_pkey PRIMARY KEY (id);
15+
16+
ALTER TABLE public.number_table OWNER TO postgres;
17+
18+
INSERT INTO public.number_table VALUES (1, 1, 1, 1);
19+
INSERT INTO public.number_table VALUES (2, 2, 2, 2);
20+
INSERT INTO public.number_table VALUES (3, 3, 3, 3);
21+
INSERT INTO public.number_table VALUES (4, 4, 4, 4);
22+
INSERT INTO public.number_table VALUES (5, 5, 5, 5);
23+
INSERT INTO public.number_table VALUES (6, 1, 1, 1);
24+
INSERT INTO public.number_table VALUES (7, 1, 2, 2);
25+
INSERT INTO public.number_table VALUES (8, 1, 3, 3);
26+
INSERT INTO public.number_table VALUES (9, 1, 3, 4);
27+
INSERT INTO public.number_table VALUES (10, 1, 3, 5);
28+
INSERT INTO public.number_table VALUES (11, 1, 2, 1);
29+
INSERT INTO public.number_table VALUES (12, 1, 2, 2);
30+
INSERT INTO public.number_table VALUES (13, 4, 1, 3);
31+
INSERT INTO public.number_table VALUES (14, 4, 3, 4);
32+
INSERT INTO public.number_table VALUES (15, 1, 3, 5);
33+
INSERT INTO public.number_table VALUES (16, 1, 2, 1);
34+
INSERT INTO public.number_table VALUES (17, 4, 3, 2);
35+
INSERT INTO public.number_table VALUES (18, 4, 2, 3);
36+
INSERT INTO public.number_table VALUES (19, 4, 3, 4);
37+
INSERT INTO public.number_table VALUES (20, 4, 1, 5);

tests/test_interactive.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,72 @@ def test_empty_tables_are_not_configured(self):
676676
self.assertNotIn("string", table_names)
677677

678678

679+
class GeneratorsOutputTests(GeneratesDBTestCase):
680+
""" Testing configure-missing with generation. """
681+
dump_file_path = "choice.sql"
682+
database_name = "numbers"
683+
schema_name = "public"
684+
685+
def _get_cmd(self, config) -> TestGeneratorCmd:
686+
return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config)
687+
688+
def test_create_with_sampled_choice(self):
689+
""" Test that we can sample real missingness and reproduce it. """
690+
table_name = "number_table"
691+
with self._get_cmd({}) as gc:
692+
gc.do_next("number_table.one")
693+
gc.reset()
694+
gc.do_propose("")
695+
proposals = gc.get_proposals()
696+
self.assertIn("dist_gen.choice", proposals)
697+
self.assertIn("dist_gen.zipf_choice", proposals)
698+
self.assertIn("dist_gen.choice [sampled]", proposals)
699+
self.assertIn("dist_gen.zipf_choice [sampled]", proposals)
700+
self.assertIn("dist_gen.choice [sampled and suppressed]", proposals)
701+
self.assertIn("dist_gen.zipf_choice [sampled and suppressed]", proposals)
702+
gc.do_set(str(proposals["dist_gen.choice [sampled and suppressed]"][0]))
703+
gc.do_next("number_table.two")
704+
gc.reset()
705+
gc.do_propose("")
706+
proposals = gc.get_proposals()
707+
self.assertIn("dist_gen.choice", proposals)
708+
self.assertIn("dist_gen.zipf_choice", proposals)
709+
self.assertIn("dist_gen.choice [sampled]", proposals)
710+
self.assertIn("dist_gen.zipf_choice [sampled]", proposals)
711+
self.assertIn("dist_gen.choice [sampled and suppressed]", proposals)
712+
self.assertIn("dist_gen.zipf_choice [sampled and suppressed]", proposals)
713+
gc.do_set(str(proposals["dist_gen.zipf_choice [sampled and suppressed]"][0]))
714+
gc.do_next("number_table.three")
715+
gc.reset()
716+
gc.do_propose("")
717+
proposals = gc.get_proposals()
718+
self.assertIn("dist_gen.choice", proposals)
719+
self.assertIn("dist_gen.zipf_choice", proposals)
720+
self.assertIn("dist_gen.choice [sampled]", proposals)
721+
self.assertIn("dist_gen.zipf_choice [sampled]", proposals)
722+
self.assertNotIn("dist_gen.choice [sampled and suppressed]", proposals)
723+
self.assertNotIn("dist_gen.zipf_choice [sampled and suppressed]", proposals)
724+
gc.do_set(str(proposals["dist_gen.choice [sampled]"][0]))
725+
gc.do_quit("")
726+
config = gc.config
727+
self.generate_data(config, num_passes=200)
728+
# Test that each missingness pattern is present in the database
729+
with self.engine.connect() as conn:
730+
stmt = select(self.metadata.tables[table_name])
731+
rows = conn.execute(stmt).fetchall()
732+
ones = set()
733+
twos = set()
734+
threes = set()
735+
for row in rows:
736+
ones.add(row.one)
737+
twos.add(row.two)
738+
threes.add(row.three)
739+
# all pattern possibilities should be present
740+
self.assertSetEqual(ones, {1, 4})
741+
self.assertSetEqual(twos, {2, 3})
742+
self.assertSetEqual(threes, {1, 2, 3, 4, 5})
743+
744+
679745
class TestMissingnessCmd(MissingnessCmd, TestDbCmdMixin):
680746
""" MissingnessCmd but mocked """
681747

0 commit comments

Comments
 (0)