22from collections import Counter
33from pathlib import Path
44from typing import Any , Generator , Iterable , Iterator , Mapping , Sequence , Tuple
5+ import yaml
56
67from sqlalchemy import Connection , insert , inspect
78from sqlalchemy .exc import IntegrityError
1213from datafaker .base import FileUploader
1314from datafaker .make import get_generation_info
1415from datafaker .populate import (
16+ StoryGeneratorInfo ,
1517 TableGenerator ,
18+ call_function ,
1619 get_symbols ,
1720 get_table_generator_dict ,
18- get_story_generator_list ,
1921 get_vocab_dict ,
2022)
2123from datafaker .settings import get_destination_dsn , get_destination_schema , get_settings
@@ -154,10 +156,15 @@ def create_db_data(
154156 metadata : MetaData ,
155157) -> RowCounts :
156158 """Connect to a database and populate it with data."""
159+ if src_stats_filename :
160+ with src_stats_filename .open (encoding = "utf-8" ) as fh :
161+ src_stats = yaml .load (fh , yaml .SafeLoader )
162+ else :
163+ src_stats = None
157164 return create_db_data_into (
158165 sorted_tables ,
159166 config ,
160- src_stats_filename ,
167+ src_stats ,
161168 num_passes ,
162169 get_destination_dsn (),
163170 get_destination_schema (),
@@ -169,7 +176,7 @@ def create_db_data(
169176def create_db_data_into (
170177 sorted_tables : Sequence [Table ],
171178 config : Mapping [str , Any ],
172- src_stats_filename : Path | None ,
179+ src_stats : dict [ str , dict [ str , Any ]] | None ,
173180 num_passes : int ,
174181 db_dsn : str ,
175182 schema_name : str | None ,
@@ -190,17 +197,17 @@ def create_db_data_into(
190197 :param metadata: Destination database metadata.
191198 """
192199 dst_engine = get_sync_engine (create_db_engine_dst (db_dsn , schema_name = schema_name ))
193- gen_info = get_generation_info (metadata , config , Path ("orm.blah" ), Path ("config.blah" ), src_stats_filename )
200+ gen_info = get_generation_info (metadata , config )
201+ context = get_symbols (
202+ gen_info .row_generator_module_name ,
203+ gen_info .story_generator_module_name ,
204+ get_property (config , "object_instantiation" , dict , {}),
205+ src_stats ,
206+ metadata ,
207+ )
194208 row_counts : Counter [str ] = Counter ()
195209 with dst_engine .connect () as dst_conn :
196- context = get_symbols (
197- gen_info .row_generator_module_name ,
198- gen_info .story_generator_module_name ,
199- get_property (config , "object_instantiation" , dict , {}),
200- gen_info .src_stats_filename ,
201- dst_conn ,
202- metadata ,
203- )
210+ context ["dst_db_conn" ] = dst_conn
204211 for _ in range (num_passes ):
205212 row_counts += populate (
206213 dst_conn ,
@@ -211,8 +218,8 @@ def create_db_data_into(
211218 gen_info .max_unique_constraint_tries ,
212219 context ,
213220 ),
214- get_story_generator_list ( gen_info .story_generators , context ) ,
215- metadata ,
221+ gen_info .story_generators ,
222+ context ,
216223 )
217224 dst_engine .dispose ()
218225 return row_counts
@@ -224,32 +231,60 @@ class StoryIterator:
224231
225232 def __init__ (
226233 self ,
227- stories : Iterable [tuple [ str , Story ] ],
234+ stories : Iterable [StoryGeneratorInfo ],
228235 table_dict : Mapping [str , Table ],
229236 table_generator_dict : Mapping [str , TableGenerator ],
230237 dst_conn : Connection ,
238+ context : Mapping ,
231239 ):
232240 """Initialise a Story Iterator."""
233- self ._stories : Iterator [tuple [ str , Story ] ] = iter (stories )
241+ self ._story_infos : Iterator [StoryGeneratorInfo ] = iter (stories )
234242 self ._table_dict : Mapping [str , Table ] = table_dict
235243 self ._table_generator_dict : Mapping [str , TableGenerator ] = table_generator_dict
236244 self ._dst_conn : Connection = dst_conn
237- self ._table_name : str | None
245+ self ._table_name : str | None = None
238246 self ._final_values : dict [str , Any ] | None = None
247+ # Number of times the current story should be run
248+ self ._story_counts = 1
249+ self ._story_function_call = None
250+ self ._context = context
251+ self ._story = iter ([])
252+ self .next ()
253+
254+ def _get_next_story (self ) -> None :
255+ """
256+ Iterate to the next ``_story_infos``.
257+
258+ :return: False if there are no more.
259+ """
239260 try :
240- name , self ._story = next (self ._stories )
241- logger .info ("Generating data for story '%s'" , name )
242- self ._table_name , self ._provided_values = next (self ._story )
261+ sgi = next (self ._story_infos )
262+ self ._story_counts = sgi .num_stories_per_pass
263+ self ._story_function_call = sgi .function_call
264+ logger .info ("Generating data for story '%s'" , sgi .function_call .function_name )
265+ self ._story = call_function (sgi .function_call , self ._context )
266+ self ._final_values = None
243267 except StopIteration :
244268 self ._table_name = None
269+ return False
270+ return True
271+
272+ def _get_values (self ) -> None :
273+ if self ._final_values is None :
274+ self ._table_name , self ._provided_values = next (self ._story )
275+ else :
276+ self ._table_name , self ._provided_values = self ._story .send (
277+ self ._final_values
278+ )
245279
246280 def is_ended (self ) -> bool :
247281 """
248282 Check if we have another row to process.
249283
250284 If so, insert() can be called.
251285 """
252- return self ._table_name is None
286+ return self ._story_counts == - 1
287+
253288
254289 def has_table (self , table_name : str ) -> bool :
255290 """Check if we have a row for table ``table_name``."""
@@ -264,7 +299,7 @@ def table_name(self) -> str | None:
264299 """
265300 return self ._table_name
266301
267- def insert (self , metadata : MetaData ) -> None :
302+ def insert (self ) -> None :
268303 """
269304 Put the row in the table.
270305
@@ -276,7 +311,7 @@ def insert(self, metadata: MetaData) -> None:
276311 table = self ._table_dict [self ._table_name ]
277312 if table .name in self ._table_generator_dict :
278313 table_generator = self ._table_generator_dict [table .name ]
279- default_values = table_generator (self ._dst_conn , metadata )
314+ default_values = table_generator (self ._dst_conn )
280315 else :
281316 default_values = {}
282317 insert_values = {** default_values , ** self ._provided_values }
@@ -300,54 +335,46 @@ def next(self) -> None:
300335 """Advance to the next row."""
301336 while True :
302337 try :
303- if self ._final_values is None :
304- self ._table_name , self ._provided_values = next (self ._story )
305- return
306- self ._table_name , self ._provided_values = self ._story .send (
307- self ._final_values
308- )
338+ self ._get_values ()
309339 return
310- except StopIteration :
311- try :
312- name , self ._story = next (self ._stories )
313- logger .info ("Generating data for story '%s'" , name )
314- self ._final_values = None
315- except StopIteration :
316- self ._table_name = None
340+ except StopIteration as exc :
341+ self ._final_values = None
342+ self ._story_counts -= 1
343+ if 0 < self ._story_counts :
344+ # Reinitialize the same story again
345+ self ._story = call_function (self ._story_function_call , self ._context )
346+ elif not self ._get_next_story ():
347+ self ._story_counts = - 1
317348 return
318349
319350
320351def populate (
321352 dst_conn : Connection ,
322353 tables : Sequence [Table ],
323354 table_generator_dict : Mapping [str , TableGenerator ],
324- story_generator_list : Sequence [Mapping [ str , Any ] ],
325- metadata : MetaData ,
355+ story_generator_infos : Sequence [StoryGeneratorInfo ],
356+ context : Mapping ,
326357) -> RowCounts :
327358 """Populate a database schema with synthetic data."""
328359 row_counts : Counter [str ] = Counter ()
329360 table_dict = {table .name : table for table in tables }
330361 # Generate stories
331362 # Each story generator returns a python generator (an unfortunate naming clash with
332363 # what we call generators). Iterating over it yields individual rows for the
333- # database. First, collect all of the python generators into a single list.
334- stories : list [tuple [str , Story ]] = sum (
335- [
336- [
337- (sg ["name" ], sg ["function" ](dst_conn ))
338- for _ in range (sg ["num_stories_per_pass" ])
339- ]
340- for sg in story_generator_list
341- ],
342- [],
364+ # database.
365+ story_iterator = StoryIterator (
366+ story_generator_infos ,
367+ table_dict ,
368+ table_generator_dict ,
369+ dst_conn ,
370+ context ,
343371 )
344- story_iterator = StoryIterator (stories , table_dict , table_generator_dict , dst_conn )
345372
346373 # Generate individual rows, table by table.
347374 for table in tables :
348375 # Do we have a story row to enter into this table?
349376 if story_iterator .has_table (table .name ):
350- story_iterator .insert (metadata )
377+ story_iterator .insert ()
351378 row_counts [table .name ] = row_counts .get (table .name , 0 ) + 1
352379 story_iterator .next ()
353380 if table .name not in table_generator_dict :
@@ -358,20 +385,20 @@ def populate(
358385 continue
359386 logger .debug ("Generating data for table '%s'" , table .name )
360387 # Run all the inserts for one table in a transaction
361- try :
362- with dst_conn . begin () :
388+ with dst_conn . begin () :
389+ try :
363390 for _ in range (table_generator .num_rows_per_pass ):
364391 stmt = insert (table ).values (table_generator (dst_conn ))
365392 dst_conn .execute (stmt )
366393 row_counts [table .name ] = row_counts .get (table .name , 0 ) + 1
367394 dst_conn .commit ()
368- except :
369- dst_conn .rollback ()
370- raise
395+ except :
396+ dst_conn .rollback ()
397+ raise
371398
372399 # Insert any remaining stories
373400 while not story_iterator .is_ended ():
374- story_iterator .insert (metadata )
401+ story_iterator .insert ()
375402 t = story_iterator .table_name ()
376403 if t is None :
377404 raise AssertionError (
0 commit comments