Skip to content

Commit 1679826

Browse files
tim-bandTim Band
andauthored
dist_gen.constant generates NULL, 0 or "" as appropriate (#52)
Co-authored-by: Tim Band <t.b@ucl>
1 parent ed1415f commit 1679826

2 files changed

Lines changed: 67 additions & 6 deletions

File tree

datafaker/generators.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -828,17 +828,19 @@ def get_generators(self, column: Column, engine: Engine):
828828
return generators
829829

830830

831-
class NullGenerator(Generator):
832-
def __init__(self):
831+
class ConstantGenerator(Generator):
832+
def __init__(self, value):
833833
super().__init__()
834+
self.value = value
835+
self.repr = repr(value)
834836
def function_name(self) -> str:
835837
return "dist_gen.constant"
836838
def nominal_kwargs(self) -> dict[str, str]:
837-
return {"value": "None"}
839+
return {"value": self.repr}
838840
def actual_kwargs(self) -> dict[str, any]:
839-
return {"value": None}
841+
return {"value": self.value}
840842
def generate_data(self, count) -> list[any]:
841-
return [None for _ in range(count)]
843+
return [self.value for _ in range(count)]
842844

843845

844846
class ConstantGeneratorFactory(GeneratorFactory):
@@ -847,7 +849,14 @@ class ConstantGeneratorFactory(GeneratorFactory):
847849
"""
848850
def get_generators(self, column: Column, _engine: Engine):
849851
if column.nullable:
850-
return [NullGenerator()]
852+
return [ConstantGenerator(None)]
853+
c_type = get_column_type(column)
854+
if isinstance(c_type, String):
855+
return [ConstantGenerator("")]
856+
if isinstance(c_type, Numeric):
857+
return [ConstantGenerator(0.0)]
858+
if isinstance(c_type, Integer):
859+
return [ConstantGenerator(0)]
851860
return []
852861

853862

tests/test_interactive.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,3 +831,55 @@ def test_create_with_missingness(self):
831831
patterns.add(p + b)
832832
# all pattern possibilities should be present
833833
self.assertSetEqual(patterns, {0, 1, 2, 3})
834+
835+
836+
class GeneratorTests(GeneratesDBTestCase):
837+
""" Testing configure-missing with generation. """
838+
dump_file_path = "instrument.sql"
839+
database_name = "instrument"
840+
schema_name = "public"
841+
842+
def _get_cmd(self, config) -> TestGeneratorCmd:
843+
return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config)
844+
845+
def test_set_null(self):
846+
""" Test that we can sample real missingness and reproduce it. """
847+
with self._get_cmd({}) as gc:
848+
gc.do_next("string.position")
849+
gc.do_set("dist_gen.constant")
850+
self.assertListEqual(gc.messages, [])
851+
gc.reset()
852+
gc.do_next("string.frequency")
853+
gc.do_set("dist_gen.constant")
854+
self.assertListEqual(gc.messages, [])
855+
gc.reset()
856+
gc.do_next("signature_model.name")
857+
gc.do_set("dist_gen.constant")
858+
self.assertListEqual(gc.messages, [])
859+
gc.reset()
860+
gc.do_next("signature_model.based_on")
861+
gc.do_set("dist_gen.constant")
862+
# we have got to the end of the columns, but shouldn't have any errors
863+
self.assertListEqual(gc.messages, [("No more tables", (), {})])
864+
gc.reset()
865+
gc.do_quit("")
866+
config = gc.config
867+
self.generate_data(config, num_passes=3)
868+
# Test that each missingness pattern is present in the database
869+
with self.engine.connect() as conn:
870+
stmt = select(self.metadata.tables["string"].c["position", "frequency"])
871+
rows = conn.execute(stmt).fetchall()
872+
count = 0
873+
for row in rows:
874+
count += 1
875+
self.assertEqual(row.position, 0)
876+
self.assertEqual(row.frequency, 0.0)
877+
self.assertEqual(count, 3)
878+
stmt = select(self.metadata.tables["signature_model"].c["name", "based_on"])
879+
rows = conn.execute(stmt).fetchall()
880+
count = 0
881+
for row in rows:
882+
count += 1
883+
self.assertEqual(row.name, "")
884+
self.assertIsNone(row.based_on)
885+
self.assertEqual(count, 3)

0 commit comments

Comments
 (0)