Skip to content

Commit 0379322

Browse files
tim-bandTim Band
andauthored
Fixed dist_gen.zipf_choice sampled ordering (#53)
Co-authored-by: Tim Band <t.b@ucl>
1 parent 1679826 commit 0379322

4 files changed

Lines changed: 33 additions & 6 deletions

File tree

datafaker/generators.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -667,20 +667,20 @@ def __init__(
667667
self._fit = fit_from_buckets(counts, estimated_counts)
668668
if suppress_count == 0:
669669
if sample_count is None:
670-
self._query = f"SELECT {column_name} AS value FROM {table_name} GROUP BY value ORDER BY COUNT({column_name}) DESC"
670+
self._query = f"SELECT {column_name} AS value FROM {table_name} WHERE {column_name} IS NOT NULL GROUP BY value ORDER BY COUNT({column_name}) DESC"
671671
self._comment = f"All the values that appear in column {column_name} of table {table_name}"
672672
self._annotation = None
673673
else:
674-
self._query = f"SELECT value FROM (SELECT {column_name} AS value FROM {table_name} ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY value"
674+
self._query = f"SELECT value FROM (SELECT {column_name} AS value FROM {table_name} WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY value ORDER BY COUNT(value) DESC"
675675
self._comment = f"The values that appear in column {column_name} of a random sample of {sample_count} rows of table {table_name}"
676676
self._annotation = "sampled"
677677
else:
678678
if sample_count is None:
679-
self._query = f"SELECT value FROM (SELECT {column_name} AS value, COUNT({column_name}) AS count FROM {table_name} GROUP BY value) AS _inner WHERE {suppress_count} < count"
679+
self._query = f"SELECT value FROM (SELECT {column_name} AS value, COUNT({column_name}) AS count FROM {table_name} WHERE {column_name} IS NOT NULL GROUP BY value ORDER BY count DESC) AS _inner WHERE {suppress_count} < count"
680680
self._comment = f"All the values that appear in column {column_name} of table {table_name} more than {suppress_count} times"
681681
self._annotation = "suppressed"
682682
else:
683-
self._query = f"SELECT value FROM (SELECT value, COUNT(value) AS count FROM (SELECT {column_name} AS value FROM {table_name} ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY value) AS _inner WHERE {suppress_count} < count"
683+
self._query = f"SELECT value FROM (SELECT value, COUNT(value) AS count FROM (SELECT {column_name} AS value FROM {table_name} WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY value ORDER BY count DESC) AS _inner WHERE {suppress_count} < count"
684684
self._comment = f"The values that appear more than {suppress_count} times in column {column_name}, out of a random sample of {sample_count} rows of table {table_name}"
685685
self._annotation = "sampled and suppressed"
686686
def nominal_kwargs(self):

tests/examples/instrument.sql

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,7 @@ INSERT INTO public.signature_model VALUES (3, 'Veleno', 2, 2);
9999
INSERT INTO public.signature_model VALUES (4, 'Grifter', NULL, NULL);
100100
INSERT INTO public.signature_model VALUES (5, 'Proton', 3, 1);
101101
INSERT INTO public.signature_model VALUES (6, 'Isabelle', NULL, 3);
102+
INSERT INTO public.signature_model VALUES (7, 'Natalia', NULL, 1);
103+
INSERT INTO public.signature_model VALUES (8, 'Ray', 2, NULL);
104+
INSERT INTO public.signature_model VALUES (9, 'DamageControl', 2, 3);
105+
INSERT INTO public.signature_model VALUES (10, 'Lucy', 3, 1);

tests/test_interactive.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def test_set_generator_choice(self):
481481
self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}__{COLUMN}")
482482
self.assertEqual(
483483
gc.config["src-stats"][0]["query"],
484-
f"SELECT {COLUMN} AS value FROM {TABLE} GROUP BY value ORDER BY COUNT({COLUMN}) DESC",
484+
f"SELECT {COLUMN} AS value FROM {TABLE} WHERE {COLUMN} IS NOT NULL GROUP BY value ORDER BY COUNT({COLUMN}) DESC",
485485
)
486486

487487
def test_old_generators_remain(self):
@@ -834,7 +834,7 @@ def test_create_with_missingness(self):
834834

835835

836836
class GeneratorTests(GeneratesDBTestCase):
837-
""" Testing configure-missing with generation. """
837+
""" Testing configure-generators with generation. """
838838
dump_file_path = "instrument.sql"
839839
database_name = "instrument"
840840
schema_name = "public"
@@ -883,3 +883,25 @@ def test_set_null(self):
883883
self.assertEqual(row.name, "")
884884
self.assertIsNone(row.based_on)
885885
self.assertEqual(count, 3)
886+
887+
def test_dist_gen_sampled_produces_ordered_src_stats(self):
888+
""" Tests that choosing a sampled choice generator produces ordered src stats """
889+
with self._get_cmd({}) as gc:
890+
gc.do_next("signature_model.player_id")
891+
gc.do_set("dist_gen.zipf_choice [sampled]")
892+
gc.do_next("signature_model.based_on")
893+
gc.do_set("dist_gen.zipf_choice [sampled]")
894+
gc.do_quit("")
895+
config = gc.config
896+
self.set_configuration(config)
897+
src_stats = self.get_src_stats(config)
898+
player_ids = [
899+
s["value"]
900+
for s in src_stats["auto__signature_model__player_id"]["results"]
901+
]
902+
self.assertListEqual(player_ids, [2, 3, 1])
903+
based_ons = [
904+
s["value"]
905+
for s in src_stats["auto__signature_model__based_on"]["results"]
906+
]
907+
self.assertListEqual(based_ons, [1, 3, 2])

tests/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def get_src_stats(self, config) -> dict[str, any]:
177177
(self.stats_fd, self.stats_file_path) = mkstemp(".yaml", "src_stats_", text=True)
178178
with os.fdopen(self.stats_fd, "w", encoding="utf-8") as stats_fh:
179179
stats_fh.write(yaml.dump(src_stats))
180+
return src_stats
180181

181182
def create_generators(self, config) -> None:
182183
""" ``create-generators`` with ``src-stats.yaml`` and the rest, producing ``df.py`` """

0 commit comments

Comments
 (0)