Skip to content

Commit ba9bf59

Browse files
author
Tim Band
committed
test_workflow_minimal_args passes
1 parent 00df596 commit ba9bf59

9 files changed

Lines changed: 121 additions & 344 deletions

File tree

datafaker/create.py

Lines changed: 83 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections import Counter
33
from pathlib import Path
44
from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple
5+
import yaml
56

67
from sqlalchemy import Connection, insert, inspect
78
from sqlalchemy.exc import IntegrityError
@@ -12,10 +13,11 @@
1213
from datafaker.base import FileUploader
1314
from datafaker.make import get_generation_info
1415
from datafaker.populate import (
16+
StoryGeneratorInfo,
1517
TableGenerator,
18+
call_function,
1619
get_symbols,
1720
get_table_generator_dict,
18-
get_story_generator_list,
1921
get_vocab_dict,
2022
)
2123
from datafaker.settings import get_destination_dsn, get_destination_schema, get_settings
@@ -154,10 +156,15 @@ def create_db_data(
154156
metadata: MetaData,
155157
) -> RowCounts:
156158
"""Connect to a database and populate it with data."""
159+
if src_stats_filename:
160+
with src_stats_filename.open(encoding="utf-8") as fh:
161+
src_stats = yaml.load(fh, yaml.SafeLoader)
162+
else:
163+
src_stats = None
157164
return create_db_data_into(
158165
sorted_tables,
159166
config,
160-
src_stats_filename,
167+
src_stats,
161168
num_passes,
162169
get_destination_dsn(),
163170
get_destination_schema(),
@@ -169,7 +176,7 @@ def create_db_data(
169176
def create_db_data_into(
170177
sorted_tables: Sequence[Table],
171178
config: Mapping[str, Any],
172-
src_stats_filename: Path | None,
179+
src_stats: dict[str, dict[str, Any]] | None,
173180
num_passes: int,
174181
db_dsn: str,
175182
schema_name: str | None,
@@ -190,17 +197,17 @@ def create_db_data_into(
190197
:param metadata: Destination database metadata.
191198
"""
192199
dst_engine = get_sync_engine(create_db_engine_dst(db_dsn, schema_name=schema_name))
193-
gen_info = get_generation_info(metadata, config, Path("orm.blah"), Path("config.blah"), src_stats_filename)
200+
gen_info = get_generation_info(metadata, config)
201+
context = get_symbols(
202+
gen_info.row_generator_module_name,
203+
gen_info.story_generator_module_name,
204+
get_property(config, "object_instantiation", dict, {}),
205+
src_stats,
206+
metadata,
207+
)
194208
row_counts: Counter[str] = Counter()
195209
with dst_engine.connect() as dst_conn:
196-
context = get_symbols(
197-
gen_info.row_generator_module_name,
198-
gen_info.story_generator_module_name,
199-
get_property(config, "object_instantiation", dict, {}),
200-
gen_info.src_stats_filename,
201-
dst_conn,
202-
metadata,
203-
)
210+
context["dst_db_conn"] = dst_conn
204211
for _ in range(num_passes):
205212
row_counts += populate(
206213
dst_conn,
@@ -211,8 +218,8 @@ def create_db_data_into(
211218
gen_info.max_unique_constraint_tries,
212219
context,
213220
),
214-
get_story_generator_list(gen_info.story_generators, context),
215-
metadata,
221+
gen_info.story_generators,
222+
context,
216223
)
217224
dst_engine.dispose()
218225
return row_counts
@@ -224,32 +231,60 @@ class StoryIterator:
224231

225232
def __init__(
226233
self,
227-
stories: Iterable[tuple[str, Story]],
234+
stories: Iterable[StoryGeneratorInfo],
228235
table_dict: Mapping[str, Table],
229236
table_generator_dict: Mapping[str, TableGenerator],
230237
dst_conn: Connection,
238+
context: Mapping,
231239
):
232240
"""Initialise a Story Iterator."""
233-
self._stories: Iterator[tuple[str, Story]] = iter(stories)
241+
self._story_infos: Iterator[StoryGeneratorInfo] = iter(stories)
234242
self._table_dict: Mapping[str, Table] = table_dict
235243
self._table_generator_dict: Mapping[str, TableGenerator] = table_generator_dict
236244
self._dst_conn: Connection = dst_conn
237-
self._table_name: str | None
245+
self._table_name: str | None = None
238246
self._final_values: dict[str, Any] | None = None
247+
# Number of times the current story should be run
248+
self._story_counts = 1
249+
self._story_function_call = None
250+
self._context = context
251+
self._story = iter([])
252+
self.next()
253+
254+
def _get_next_story(self) -> None:
255+
"""
256+
Iterate to the next ``_story_infos``.
257+
258+
:return: False if there are no more.
259+
"""
239260
try:
240-
name, self._story = next(self._stories)
241-
logger.info("Generating data for story '%s'", name)
242-
self._table_name, self._provided_values = next(self._story)
261+
sgi = next(self._story_infos)
262+
self._story_counts = sgi.num_stories_per_pass
263+
self._story_function_call = sgi.function_call
264+
logger.info("Generating data for story '%s'", sgi.function_call.function_name)
265+
self._story = call_function(sgi.function_call, self._context)
266+
self._final_values = None
243267
except StopIteration:
244268
self._table_name = None
269+
return False
270+
return True
271+
272+
def _get_values(self) -> None:
273+
if self._final_values is None:
274+
self._table_name, self._provided_values = next(self._story)
275+
else:
276+
self._table_name, self._provided_values = self._story.send(
277+
self._final_values
278+
)
245279

246280
def is_ended(self) -> bool:
247281
"""
248282
Check if we have another row to process.
249283
250284
If so, insert() can be called.
251285
"""
252-
return self._table_name is None
286+
return self._story_counts == -1
287+
253288

254289
def has_table(self, table_name: str) -> bool:
255290
"""Check if we have a row for table ``table_name``."""
@@ -264,7 +299,7 @@ def table_name(self) -> str | None:
264299
"""
265300
return self._table_name
266301

267-
def insert(self, metadata: MetaData) -> None:
302+
def insert(self) -> None:
268303
"""
269304
Put the row in the table.
270305
@@ -276,7 +311,7 @@ def insert(self, metadata: MetaData) -> None:
276311
table = self._table_dict[self._table_name]
277312
if table.name in self._table_generator_dict:
278313
table_generator = self._table_generator_dict[table.name]
279-
default_values = table_generator(self._dst_conn, metadata)
314+
default_values = table_generator(self._dst_conn)
280315
else:
281316
default_values = {}
282317
insert_values = {**default_values, **self._provided_values}
@@ -300,54 +335,46 @@ def next(self) -> None:
300335
"""Advance to the next row."""
301336
while True:
302337
try:
303-
if self._final_values is None:
304-
self._table_name, self._provided_values = next(self._story)
305-
return
306-
self._table_name, self._provided_values = self._story.send(
307-
self._final_values
308-
)
338+
self._get_values()
309339
return
310-
except StopIteration:
311-
try:
312-
name, self._story = next(self._stories)
313-
logger.info("Generating data for story '%s'", name)
314-
self._final_values = None
315-
except StopIteration:
316-
self._table_name = None
340+
except StopIteration as exc:
341+
self._final_values = None
342+
self._story_counts -= 1
343+
if 0 < self._story_counts:
344+
# Reinitialize the same story again
345+
self._story = call_function(self._story_function_call, self._context)
346+
elif not self._get_next_story():
347+
self._story_counts = -1
317348
return
318349

319350

320351
def populate(
321352
dst_conn: Connection,
322353
tables: Sequence[Table],
323354
table_generator_dict: Mapping[str, TableGenerator],
324-
story_generator_list: Sequence[Mapping[str, Any]],
325-
metadata: MetaData,
355+
story_generator_infos: Sequence[StoryGeneratorInfo],
356+
context: Mapping,
326357
) -> RowCounts:
327358
"""Populate a database schema with synthetic data."""
328359
row_counts: Counter[str] = Counter()
329360
table_dict = {table.name: table for table in tables}
330361
# Generate stories
331362
# Each story generator returns a python generator (an unfortunate naming clash with
332363
# what we call generators). Iterating over it yields individual rows for the
333-
# database. First, collect all of the python generators into a single list.
334-
stories: list[tuple[str, Story]] = sum(
335-
[
336-
[
337-
(sg["name"], sg["function"](dst_conn))
338-
for _ in range(sg["num_stories_per_pass"])
339-
]
340-
for sg in story_generator_list
341-
],
342-
[],
364+
# database.
365+
story_iterator = StoryIterator(
366+
story_generator_infos,
367+
table_dict,
368+
table_generator_dict,
369+
dst_conn,
370+
context,
343371
)
344-
story_iterator = StoryIterator(stories, table_dict, table_generator_dict, dst_conn)
345372

346373
# Generate individual rows, table by table.
347374
for table in tables:
348375
# Do we have a story row to enter into this table?
349376
if story_iterator.has_table(table.name):
350-
story_iterator.insert(metadata)
377+
story_iterator.insert()
351378
row_counts[table.name] = row_counts.get(table.name, 0) + 1
352379
story_iterator.next()
353380
if table.name not in table_generator_dict:
@@ -358,20 +385,20 @@ def populate(
358385
continue
359386
logger.debug("Generating data for table '%s'", table.name)
360387
# Run all the inserts for one table in a transaction
361-
try:
362-
with dst_conn.begin():
388+
with dst_conn.begin():
389+
try:
363390
for _ in range(table_generator.num_rows_per_pass):
364391
stmt = insert(table).values(table_generator(dst_conn))
365392
dst_conn.execute(stmt)
366393
row_counts[table.name] = row_counts.get(table.name, 0) + 1
367394
dst_conn.commit()
368-
except:
369-
dst_conn.rollback()
370-
raise
395+
except:
396+
dst_conn.rollback()
397+
raise
371398

372399
# Insert any remaining stories
373400
while not story_iterator.is_ended():
374-
story_iterator.insert(metadata)
401+
story_iterator.insert()
375402
t = story_iterator.table_name()
376403
if t is None:
377404
raise AssertionError(

datafaker/main.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from datafaker.interactive.base import DbCmd
3232
from datafaker.make import (
3333
make_src_stats,
34-
make_table_generators,
3534
make_tables_file,
3635
make_vocabulary_tables,
3736
)
@@ -182,6 +181,8 @@ def create_data(
182181
"""
183182
logger.debug("Creating data.")
184183
config = read_config_file(config_file) if config_file is not None else {}
184+
if stats_file is None and generators_require_stats(config):
185+
stats_file = Path(STATS_FILENAME)
185186
orm_metadata = load_metadata_for_output(orm_file, config)
186187
try:
187188
row_counts = create_db_data(
@@ -294,34 +295,8 @@ def create_generators(
294295
False, "--force", "-f", help="Overwrite any existing Python generators file."
295296
),
296297
) -> None:
297-
"""Make a datafaker file of generator classes.
298-
299-
This CLI command takes an object relation model output by sqlcodegen and
300-
returns a set of synthetic data generators for each attribute
301-
302-
Example:
303-
$ datafaker create-generators
304-
"""
305-
logger.debug("Making %s.", df_file)
306-
307-
if not force:
308-
_check_file_non_existence(df_file)
309-
310-
generator_config = read_config_file(config_file) if config_file is not None else {}
311-
if stats_file is None and generators_require_stats(generator_config):
312-
stats_file = Path(STATS_FILENAME)
313-
orm_metadata = load_metadata_for_output(orm_file, generator_config)
314-
result: str = make_table_generators(
315-
orm_metadata,
316-
generator_config,
317-
orm_file,
318-
config_file,
319-
stats_file,
320-
)
321-
322-
df_file.write_text(result, encoding="utf-8")
323-
324-
logger.debug("%s created.", df_file)
298+
"""Obsolete command."""
299+
logger.error("This command is deprecated; it does nothing.")
325300

326301

327302
@app.command()

0 commit comments

Comments
 (0)