Skip to content

Commit edda835

Browse files
author
Tim Band
committed
All tests pass.
1 parent ba9bf59 commit edda835

7 files changed

Lines changed: 143 additions & 99 deletions

File tree

datafaker/create.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,15 @@
99
from sqlalchemy.ext.compiler import compiles
1010
from sqlalchemy.orm import Session
1111
from sqlalchemy.schema import CreateColumn, CreateSchema, CreateTable, MetaData, Table
12+
import typer
1213

1314
from datafaker.base import FileUploader
14-
from datafaker.make import get_generation_info
15+
from datafaker.make import get_generation_info, StoryGeneratorInfo
1516
from datafaker.populate import (
16-
StoryGeneratorInfo,
1717
TableGenerator,
1818
call_function,
1919
get_symbols,
2020
get_table_generator_dict,
21-
get_vocab_dict,
2221
)
2322
from datafaker.settings import get_destination_dsn, get_destination_schema, get_settings
2423
from datafaker.utils import (
@@ -157,8 +156,15 @@ def create_db_data(
157156
) -> RowCounts:
158157
"""Connect to a database and populate it with data."""
159158
if src_stats_filename:
160-
with src_stats_filename.open(encoding="utf-8") as fh:
161-
src_stats = yaml.load(fh, yaml.SafeLoader)
159+
try:
160+
with src_stats_filename.open(encoding="utf-8") as fh:
161+
src_stats = yaml.load(fh, yaml.SafeLoader)
162+
except FileNotFoundError:
163+
logger.error(
164+
"No source stats file '%', this should be the output of the 'make-stats' command",
165+
src_stats_filename,
166+
)
167+
raise typer.Exit(1)
162168
else:
163169
src_stats = None
164170
return create_db_data_into(
@@ -187,10 +193,8 @@ def create_db_data_into(
187193
188194
:param sorted_tables: The table names to populate, sorted so that foreign
189195
keys' targets are populated before the foreign keys themselves.
190-
:param table_generator_dict: A mapping of table names to the generators
191-
used to make data for them.
192-
:param story_generator_list: A list of story generators to be run after the
193-
table generators on each pass.
196+
:param config: The data from the ``config.yaml`` file.
197+
:param src_stats: The data from the ``src-stats.yaml`` file.
194198
:param num_passes: Number of passes to perform.
195199
:param db_dsn: Connection string for the destination database.
196200
:param schema_name: Destination schema name.

datafaker/make.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ def get_generation_info(
656656
story_generators = _get_story_generators(config)
657657

658658
max_unique_constraint_tries = get_property(
659-
config, "max-unique-constraint-tries", str | None, None
659+
config, "max-unique-constraint-tries", int | None, None
660660
)
661661
return GenerationInfo(
662662
provider_imports=PROVIDER_IMPORTS,

datafaker/populate.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Callable
77

88
from datafaker.base import FileUploader, ColumnPresence
9-
from datafaker.make import FunctionCall, TableGeneratorInfo, StoryGeneratorInfo
9+
from datafaker.make import FunctionCall, TableGeneratorInfo
1010

1111
from datafaker.providers import (
1212
BytesProvider,
@@ -18,17 +18,34 @@
1818
TimespanProvider,
1919
WeightedBooleanProvider,
2020
)
21-
from datafaker.utils import logging, get_vocabulary_table_names, import_file
21+
from datafaker.utils import get_vocabulary_table_names, import_file
22+
23+
def make_generic():
24+
g = Generic(locale=Locale.EN_GB)
25+
g.add_providers(
26+
BytesProvider,
27+
ColumnValueProvider,
28+
DistributionProvider,
29+
NullProvider,
30+
SQLGroupByProvider,
31+
TimedeltaProvider,
32+
TimespanProvider,
33+
WeightedBooleanProvider,
34+
)
35+
return g
36+
37+
38+
generic = make_generic()
2239

23-
generic = Generic(locale=Locale.EN_GB)
2440

25-
generic.add_provider(BytesProvider)
26-
generic.add_provider(ColumnValueProvider)
27-
generic.add_provider(NullProvider)
28-
generic.add_provider(SQLGroupByProvider)
29-
generic.add_provider(TimedeltaProvider)
30-
generic.add_provider(TimespanProvider)
31-
generic.add_provider(WeightedBooleanProvider)
41+
def reset_generic():
42+
"""
43+
Reset all the generators.
44+
45+
Only really useful in test code.
46+
"""
47+
global generic
48+
generic = make_generic()
3249

3350

3451
def _eval_structure(config: Any, context: Mapping) -> Any:
@@ -57,10 +74,10 @@ def _eval_structure(config: Any, context: Mapping) -> Any:
5774

5875
def _get_object(class_name: str, context: Mapping) -> Any:
5976
"""
60-
Get an object out of the context.
77+
Fetch an object from the context.
6178
6279
:param class_name: The name of the class, qualified if necessary.
63-
Like "module.MyClass.Nested"
80+
Like "module.MyClass.Nested"
6481
:param context: Mapping of strings to objects with those names.
6582
:return: A value from ``context`` if there are no qualifying names,
6683
otherwise the attribute of the base object.
@@ -288,11 +305,3 @@ def get_table_generator_dict(
288305
)
289306
for table_data in tables_data
290307
}
291-
292-
293-
def get_vocab_dict(config: Mapping, metadata: sqlalchemy.MetaData) -> Mapping[str, FileUploader]:
294-
"""Get a dict of table names to objects that can populate those tables from YAML files."""
295-
return {
296-
name: FileUploader(metadata.tables[name])
297-
for name in get_vocabulary_table_names(config)
298-
}

datafaker/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -546,9 +546,7 @@ def get_row_generators(
546546
:param table_config: The element from the ``tables:`` stanza of ``config.xml``.
547547
:return: Pair of (name, row generator config).
548548
"""
549-
rgs = table_config.get("row_generators", None)
550-
if isinstance(rgs, str) or not hasattr(rgs, "__iter__"):
551-
return
549+
rgs = get_property(table_config, "row_generators", list, [])
552550
for rg in rgs:
553551
name = rg.get("name", None)
554552
if name:

tests/test_create.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
create_db_vocab,
2222
populate,
2323
)
24+
from datafaker.make import FunctionCall, StoryGeneratorInfo
2425
from datafaker.serialize_metadata import metadata_to_dict, dict_to_metadata
2526
from datafaker.utils import sorted_non_vocabulary_tables
2627
from tests.utils import DatafakerTestCase, GeneratesDBTestCase, RequiresDBTestCase
@@ -99,17 +100,16 @@ class TestPopulate(DatafakerTestCase):
99100
"""Test create.populate."""
100101

101102
# pylint: disable=too-many-locals
102-
def test_populate(self) -> None:
103+
@patch("datafaker.populate._get_object")
104+
def test_populate(self, mock_get_object: MagicMock) -> None:
103105
"""Test the populate function."""
104106
table_name = "table_name"
105107

106108
def story() -> Generator[Tuple[str, dict], None, None]:
107109
"""Mock story."""
108110
yield table_name, {}
109111

110-
def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
111-
"""A function that returns mock stories."""
112-
return story()
112+
mock_get_object.return_value = story
113113

114114
for num_stories_per_pass, num_rows_per_pass, num_initial_rows in itt.product(
115115
[0, 2], [0, 3], [0, 17]
@@ -130,11 +130,11 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
130130

131131
story_generators: list[dict[str, Any]] = (
132132
[
133-
{
134-
"function": mock_story_gen,
135-
"num_stories_per_pass": num_stories_per_pass,
136-
"name": "mock_story_gen",
137-
}
133+
StoryGeneratorInfo(
134+
"mock_story_gen name",
135+
FunctionCall("mock_story_gen", [], {}),
136+
num_stories_per_pass,
137+
)
138138
]
139139
if num_stories_per_pass > 0
140140
else []
@@ -304,11 +304,11 @@ def __call__(
304304
create_db_data_into(
305305
[MagicMock()],
306306
MagicMock(),
307+
None,
307308
1,
308309
"duckdb:///:memory:data",
309310
None,
310311
MagicMock(),
311-
MagicMock(),
312312
)
313313
assert mock_populate.side_effect.called
314314

@@ -366,11 +366,12 @@ def test_create_data_minimal(self) -> None:
366366
}
367367
metadata = dict_to_metadata(orm, config)
368368
create_db_tables_into(metadata, self.dsn, self.schema_name)
369+
generate_count = 4
369370
row_counts = create_db_data_into(
370371
sorted_non_vocabulary_tables(metadata, config),
371372
config,
372373
None,
373-
4,
374+
generate_count,
374375
self.dsn,
375376
self.schema_name,
376377
metadata,
@@ -379,3 +380,52 @@ def test_create_data_minimal(self) -> None:
379380
stmt = select(metadata.tables["one"])
380381
rows = connection.execute(stmt).fetchall()
381382
self.assertListEqual(rows, [(1,), (2,), (3,), (4,)])
383+
self.assertListEqual(list(row_counts.keys()), ['one'])
384+
self.assertEqual(row_counts["one"], generate_count)
385+
386+
def test_unique_constraint_minimal(self) -> None:
387+
config = {
388+
"tables": {
389+
"one": {
390+
"row_generators": [{
391+
"name": "dist_gen.constant",
392+
"kwargs": {
393+
"value": 123,
394+
},
395+
"columns_assigned": ["tiger"],
396+
}]
397+
}
398+
},
399+
"max-unique-constraint-tries": 20,
400+
}
401+
orm = {
402+
"tables": {
403+
"one": {
404+
"columns": {
405+
"id": {
406+
"primary": True,
407+
"type": "INTEGER",
408+
},
409+
"tiger": {
410+
"type": "INTEGER",
411+
},
412+
},
413+
"unique": [
414+
{"name": "tiger_uniq", "columns": ["tiger"]}
415+
]
416+
}
417+
}
418+
}
419+
metadata = dict_to_metadata(orm, config)
420+
create_db_tables_into(metadata, self.dsn, self.schema_name)
421+
self.assertRaises(
422+
RuntimeError,
423+
create_db_data_into,
424+
sorted_non_vocabulary_tables(metadata, config),
425+
config,
426+
None,
427+
2,
428+
self.dsn,
429+
self.schema_name,
430+
metadata,
431+
)

tests/test_functional.py

Lines changed: 30 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -272,52 +272,34 @@ def test_workflow_maximal_args(self) -> None:
272272
f"--config-file={self.config_file_path}",
273273
"--num-passes=2",
274274
)
275-
self.assertEqual("", completed_process.stderr)
276-
self.assertEqual(
277-
sorted(
278-
[
279-
"Creating data.",
280-
"Generating data for story 'story_generators.short_story'",
281-
"Generating data for story 'story_generators.short_story'",
282-
"Generating data for story 'story_generators.short_story'",
283-
"Generating data for story 'story_generators.short_story'",
284-
"Generating data for story 'story_generators.short_story'",
285-
"Generating data for story 'story_generators.short_story'",
286-
"Generating data for story 'story_generators.full_row_story'",
287-
"Generating data for story 'story_generators.full_row_story'",
288-
"Generating data for story 'story_generators.long_story'",
289-
"Generating data for story 'story_generators.long_story'",
290-
"Generating data for story 'story_generators.long_story'",
291-
"Generating data for story 'story_generators.long_story'",
292-
"Generating data for table 'data_type_test'",
293-
"Generating data for table 'data_type_test'",
294-
"Generating data for table 'no_pk_test'",
295-
"Generating data for table 'no_pk_test'",
296-
"Generating data for table 'person'",
297-
"Generating data for table 'person'",
298-
"Generating data for table 'strange_type_table'",
299-
"Generating data for table 'strange_type_table'",
300-
"Generating data for table 'unique_constraint_test'",
301-
"Generating data for table 'unique_constraint_test'",
302-
"Generating data for table 'unique_constraint_test2'",
303-
"Generating data for table 'unique_constraint_test2'",
304-
"Generating data for table 'test_entity'",
305-
"Generating data for table 'test_entity'",
306-
"Generating data for table 'hospital_visit'",
307-
"Generating data for table 'hospital_visit'",
308-
"Data created in 2 passes.",
309-
f"person: {2*(3+1+2+2)} rows created.",
310-
f"hospital_visit: {2*(2*2+3)} rows created.",
311-
"data_type_test: 2 rows created.",
312-
"no_pk_test: 2 rows created.",
313-
"strange_type_table: 2 rows created.",
314-
"unique_constraint_test: 2 rows created.",
315-
"unique_constraint_test2: 2 rows created.",
316-
"test_entity: 2 rows created.",
317-
"",
318-
]
319-
),
320-
sorted(completed_process.stdout.split("\n")),
275+
self.assertSetEqual(
276+
{
277+
"Creating data.",
278+
"Generating data for story 'story_generators.short_story'",
279+
"Generating data for story 'story_generators.full_row_story'",
280+
"Generating data for story 'story_generators.full_row_story'",
281+
"Generating data for story 'story_generators.long_story'",
282+
"Generating data for table 'data_type_test'",
283+
"Generating data for table 'no_pk_test'",
284+
"Generating data for table 'person'",
285+
"Generating data for table 'person'",
286+
"Generating data for table 'strange_type_table'",
287+
"Generating data for table 'unique_constraint_test'",
288+
"Generating data for table 'unique_constraint_test2'",
289+
"Generating data for table 'test_entity'",
290+
"Generating data for table 'hospital_visit'",
291+
"Data created in 2 passes.",
292+
f"person: {2*(3+1+2+2)} rows created.",
293+
f"hospital_visit: {2*(2*2+3)} rows created.",
294+
"data_type_test: 2 rows created.",
295+
"no_pk_test: 2 rows created.",
296+
"strange_type_table: 2 rows created.",
297+
"unique_constraint_test: 2 rows created.",
298+
"unique_constraint_test2: 2 rows created.",
299+
"test_entity: 2 rows created.",
300+
"",
301+
},
302+
set(completed_process.stdout.split("\n")),
321303
)
322304

323305
completed_process = self.invoke(
@@ -468,13 +450,9 @@ def test_unique_constraint_fail(self) -> None:
468450
f"--stats-file={self.stats_file_path}",
469451
"--num-passes=1",
470452
)
471-
self.assertEqual("", completed_process.stderr)
472453
self.assertEqual(
473-
"Generating data for story 'story_generators.short_story'\n"
474-
"Generating data for story 'story_generators.short_story'\n"
475454
"Generating data for story 'story_generators.short_story'\n"
476455
"Generating data for story 'story_generators.full_row_story'\n"
477-
"Generating data for story 'story_generators.long_story'\n"
478456
"Generating data for story 'story_generators.long_story'\n",
479457
completed_process.stdout,
480458
)
@@ -483,17 +461,14 @@ def test_unique_constraint_fail(self) -> None:
483461
"create-data",
484462
f"--config-file={self.config_file_path}",
485463
f"--orm-file={self.alt_orm_file_path}",
464+
f"--stats-file={self.stats_file_path}",
486465
"--num-passes=3",
487466
)
488-
self.assertEqual("", completed_process.stderr)
489467
self.assertEqual(
490468
(
491-
"Generating data for story 'story_generators.short_story'\n"
492-
"Generating data for story 'story_generators.short_story'\n"
493469
"Generating data for story 'story_generators.short_story'\n"
494470
"Generating data for story 'story_generators.full_row_story'\n"
495471
"Generating data for story 'story_generators.long_story'\n"
496-
"Generating data for story 'story_generators.long_story'\n"
497472
)
498473
* 3,
499474
completed_process.stdout,
@@ -504,6 +479,7 @@ def test_unique_constraint_fail(self) -> None:
504479
"create-data",
505480
f"--config-file={self.config_file_path}",
506481
f"--orm-file={self.alt_orm_file_path}",
482+
f"--stats-file={self.stats_file_path}",
507483
"--num-passes=1",
508484
expected_error=(
509485
"Failed to satisfy unique constraints for table unique_constraint_test"

0 commit comments

Comments
 (0)