Skip to content

Commit c1c2858

Browse files
author
Tim Band
committed
A few pre-commit fixes
1 parent edda835 commit c1c2858

7 files changed

Lines changed: 58 additions & 73 deletions

File tree

datafaker/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""Base table generator classes."""
2+
import gzip
3+
import os
4+
import random
25
from abc import ABC, abstractmethod
36
from collections.abc import Callable
47
from dataclasses import dataclass
5-
import gzip
68
from io import TextIOWrapper
7-
import os
89
from pathlib import Path
9-
import random
1010
from typing import Any
11-
import yaml
1211

12+
import yaml
1313
from sqlalchemy import Connection, insert
1414
from sqlalchemy.exc import SQLAlchemyError
1515
from sqlalchemy.schema import MetaData, Table

datafaker/create.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22
from collections import Counter
33
from pathlib import Path
44
from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple
5-
import yaml
65

6+
import typer
7+
import yaml
78
from sqlalchemy import Connection, insert, inspect
89
from sqlalchemy.exc import IntegrityError
910
from sqlalchemy.ext.compiler import compiles
1011
from sqlalchemy.orm import Session
1112
from sqlalchemy.schema import CreateColumn, CreateSchema, CreateTable, MetaData, Table
12-
import typer
1313

1414
from datafaker.base import FileUploader
15-
from datafaker.make import get_generation_info, StoryGeneratorInfo
15+
from datafaker.make import StoryGeneratorInfo, get_generation_info
1616
from datafaker.populate import (
1717
TableGenerator,
1818
call_function,
@@ -255,7 +255,7 @@ def __init__(
255255
self._story = iter([])
256256
self.next()
257257

258-
def _get_next_story(self) -> None:
258+
def _get_next_story(self) -> bool:
259259
"""
260260
Iterate to the next ``_story_infos``.
261261
@@ -265,7 +265,9 @@ def _get_next_story(self) -> None:
265265
sgi = next(self._story_infos)
266266
self._story_counts = sgi.num_stories_per_pass
267267
self._story_function_call = sgi.function_call
268-
logger.info("Generating data for story '%s'", sgi.function_call.function_name)
268+
logger.info(
269+
"Generating data for story '%s'", sgi.function_call.function_name
270+
)
269271
self._story = call_function(sgi.function_call, self._context)
270272
self._final_values = None
271273
except StopIteration:
@@ -289,7 +291,6 @@ def is_ended(self) -> bool:
289291
"""
290292
return self._story_counts == -1
291293

292-
293294
def has_table(self, table_name: str) -> bool:
294295
"""Check if we have a row for table ``table_name``."""
295296
return table_name == self._table_name
@@ -346,7 +347,9 @@ def next(self) -> None:
346347
self._story_counts -= 1
347348
if 0 < self._story_counts:
348349
# Reinitialize the same story again
349-
self._story = call_function(self._story_function_call, self._context)
350+
self._story = call_function(
351+
self._story_function_call, self._context
352+
)
350353
elif not self._get_next_story():
351354
self._story_counts = -1
352355
return

datafaker/dump.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
"""Data dumping functions."""
22
import csv
33
import io
4-
from typing import TYPE_CHECKING
5-
64
from abc import ABC, abstractmethod
75
from pathlib import Path
6+
from typing import TYPE_CHECKING
87

98
import pandas as pd
109
import sqlalchemy

datafaker/main.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,7 @@
2929
update_missingness,
3030
)
3131
from datafaker.interactive.base import DbCmd
32-
from datafaker.make import (
33-
make_src_stats,
34-
make_tables_file,
35-
make_vocabulary_tables,
36-
)
32+
from datafaker.make import make_src_stats, make_tables_file, make_vocabulary_tables
3733
from datafaker.remove import remove_db_data, remove_db_tables, remove_db_vocab
3834
from datafaker.settings import (
3935
SettingsError,

datafaker/make.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ class FunctionCall:
6666
"""Which function to call with what."""
6767

6868
function_name: str
69-
args: list[str]
70-
kwargs: dict[str, str]
69+
args: list[Any]
70+
kwargs: dict[str, Any]
7171

7272

7373
@dataclass
@@ -585,6 +585,7 @@ def make_vocabulary_tables(
585585
@dataclass
586586
class GenerationInfo:
587587
"""Information for the generation of all data."""
588+
588589
provider_imports: list[str]
589590
row_generator_module_name: str | None
590591
story_generator_module_name: str | None
@@ -623,9 +624,7 @@ def get_generation_info(
623624
story_generator_module_name = get_property(
624625
config, "story_generators_module", str | None, None
625626
)
626-
object_instantiation = get_property(
627-
config, "object_instantiation", dict, {}
628-
)
627+
object_instantiation = get_property(config, "object_instantiation", dict, {})
629628
tables_config = get_property(config, "tables", dict, {})
630629

631630
tables: list[TableGeneratorInfo] = []

datafaker/populate.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
22
from pathlib import Path
3+
from typing import Any, Callable
4+
5+
import sqlalchemy
36
from mimesis import Generic
47
from mimesis.locales import Locale
5-
import sqlalchemy
6-
from typing import Any, Callable
78

8-
from datafaker.base import FileUploader, ColumnPresence
9+
from datafaker.base import ColumnPresence, FileUploader
910
from datafaker.make import FunctionCall, TableGeneratorInfo
10-
1111
from datafaker.providers import (
1212
BytesProvider,
1313
ColumnValueProvider,
@@ -20,7 +20,8 @@
2020
)
2121
from datafaker.utils import get_vocabulary_table_names, import_file
2222

23-
def make_generic():
23+
24+
def make_generic() -> Generic:
2425
g = Generic(locale=Locale.EN_GB)
2526
g.add_providers(
2627
BytesProvider,
@@ -38,7 +39,7 @@ def make_generic():
3839
generic = make_generic()
3940

4041

41-
def reset_generic():
42+
def reset_generic() -> None:
4243
"""
4344
Reset all the generators.
4445
@@ -63,10 +64,7 @@ def _eval_structure(config: Any, context: Mapping) -> Any:
6364
except NameError as exc:
6465
raise exc
6566
if isinstance(config, Mapping):
66-
return {
67-
k: _eval_structure(v, context)
68-
for k, v in config.items()
69-
}
67+
return {k: _eval_structure(v, context) for k, v in config.items()}
7068
if isinstance(config, Sequence):
7169
return [_eval_structure(v, context) for v in config]
7270
return config
@@ -96,10 +94,7 @@ def _get_object(class_name: str, context: Mapping) -> Any:
9694

9795

9896
def _call_from_context(
99-
callable_name: str,
100-
args: list[Any],
101-
kwargs: dict[str, Any],
102-
context: Mapping
97+
callable_name: str, args: list[Any], kwargs: dict[str, Any], context: Mapping
10398
) -> Any:
10499
"""
105100
Call a callable from the classes (or functions) in the context.
@@ -111,14 +106,8 @@ def _call_from_context(
111106
cls = _get_object(callable_name, context)
112107
if not isinstance(cls, Callable):
113108
return None
114-
arg_objs = [
115-
_eval_structure(arg, context)
116-
for arg in args
117-
]
118-
kwarg_objs = {
119-
k: _eval_structure(v, context)
120-
for k, v in kwargs.items()
121-
}
109+
arg_objs = [_eval_structure(arg, context) for arg in args]
110+
kwarg_objs = {k: _eval_structure(v, context) for k, v in kwargs.items()}
122111
return cls(*arg_objs, **kwarg_objs)
123112

124113

@@ -191,7 +180,6 @@ def _get_symbols_instantiation(symbols: dict[str, Any], objs: dict[str, Any]) ->
191180

192181

193182
class TableGenerator:
194-
195183
def __init__(
196184
self,
197185
dst_db_conn: sqlalchemy.Connection,
@@ -217,10 +205,9 @@ def __init__(
217205
for constraint in table_data.unique_constraints:
218206
expr = sqlalchemy.select(constraint.columns)
219207
query_result = dst_db_conn.execute(expr).fetchall()
220-
self.existing_constraint_hashes[constraint.name] = set([
221-
hash(tuple(result))
222-
for result in query_result
223-
])
208+
self.existing_constraint_hashes[constraint.name] = set(
209+
[hash(tuple(result)) for result in query_result]
210+
)
224211

225212
@property
226213
def num_rows_per_pass(self):
@@ -242,13 +229,17 @@ def __call__(self, db_conn: sqlalchemy.Connection):
242229
columns_to_generate = set(self.table_data.nonnull_columns)
243230
# Which missingness patterns do we want?
244231
for choice in self.table_data.column_choices:
245-
cols = _call_from_context(choice.function_name, choice.args, choice.kwargs, self.context)
232+
cols = _call_from_context(
233+
choice.function_name, choice.args, choice.kwargs, self.context
234+
)
246235
columns_to_generate.update(cols)
247236

248237
max_tries = self.max_unique_constraint_tries
249238
while columns_to_generate:
250239
if max_tries == 0:
251-
raise RuntimeError(f"Failed to satisfy unique constraints for table {self.table_data.table_name} after {self.max_unique_constraint_tries} attempts.")
240+
raise RuntimeError(
241+
f"Failed to satisfy unique constraints for table {self.table_data.table_name} after {self.max_unique_constraint_tries} attempts."
242+
)
252243
if max_tries is not None:
253244
max_tries -= 1
254245
for row_gen in self.table_data.row_gens:
@@ -264,15 +255,11 @@ def __call__(self, db_conn: sqlalchemy.Connection):
264255
result[variable_name] = values[index]
265256
columns_to_generate = set()
266257
for constraint in self.table_data.unique_constraints:
267-
cf_hash = hash(tuple(
268-
result[col.name] for col in constraint.columns
269-
))
258+
cf_hash = hash(tuple(result[col.name] for col in constraint.columns))
270259
if cf_hash in self.existing_constraint_hashes[constraint.name]:
271260
columns_to_generate.update(c.name for c in constraint.columns)
272261
for constraint in self.table_data.unique_constraints:
273-
cf_hash = hash(tuple(
274-
result[col.name] for col in constraint.columns
275-
))
262+
cf_hash = hash(tuple(result[col.name] for col in constraint.columns))
276263
self.existing_constraint_hashes[constraint.name].add(cf_hash)
277264
return result
278265

tests/test_create.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from sqlalchemy import Connection, Engine, select
1414
from sqlalchemy.schema import MetaData, Table
1515

16-
from datafaker.populate import TableGenerator
1716
from datafaker.create import (
1817
create_db_data_into,
1918
create_db_tables,
@@ -22,7 +21,8 @@
2221
populate,
2322
)
2423
from datafaker.make import FunctionCall, StoryGeneratorInfo
25-
from datafaker.serialize_metadata import metadata_to_dict, dict_to_metadata
24+
from datafaker.populate import TableGenerator
25+
from datafaker.serialize_metadata import dict_to_metadata, metadata_to_dict
2626
from datafaker.utils import sorted_non_vocabulary_tables
2727
from tests.utils import DatafakerTestCase, GeneratesDBTestCase, RequiresDBTestCase
2828

@@ -128,7 +128,7 @@ def story() -> Generator[Tuple[str, dict], None, None]:
128128
{table_name: num_initial_rows} if num_initial_rows > 0 else {}
129129
)
130130

131-
story_generators: list[dict[str, Any]] = (
131+
story_generators: list[StoryGeneratorInfo] = (
132132
[
133133
StoryGeneratorInfo(
134134
"mock_story_gen name",
@@ -345,6 +345,7 @@ def __call__(self, connection: Connection, base_path: Path) -> None:
345345

346346
class CreateDataTestCase(RequiresDBTestCase):
347347
"""Tests for create-data."""
348+
348349
dump_file_path = "empty.sql"
349350
database_name = "empty"
350351
schema_name = "public"
@@ -379,21 +380,23 @@ def test_create_data_minimal(self) -> None:
379380
with self.sync_engine.connect() as connection:
380381
stmt = select(metadata.tables["one"])
381382
rows = connection.execute(stmt).fetchall()
382-
self.assertListEqual(rows, [(1,), (2,), (3,), (4,)])
383-
self.assertListEqual(list(row_counts.keys()), ['one'])
383+
self.assertEqual(rows, [(1,), (2,), (3,), (4,)])
384+
self.assertListEqual(list(row_counts.keys()), ["one"])
384385
self.assertEqual(row_counts["one"], generate_count)
385386

386387
def test_unique_constraint_minimal(self) -> None:
387388
config = {
388389
"tables": {
389390
"one": {
390-
"row_generators": [{
391-
"name": "dist_gen.constant",
392-
"kwargs": {
393-
"value": 123,
394-
},
395-
"columns_assigned": ["tiger"],
396-
}]
391+
"row_generators": [
392+
{
393+
"name": "dist_gen.constant",
394+
"kwargs": {
395+
"value": 123,
396+
},
397+
"columns_assigned": ["tiger"],
398+
}
399+
]
397400
}
398401
},
399402
"max-unique-constraint-tries": 20,
@@ -410,9 +413,7 @@ def test_unique_constraint_minimal(self) -> None:
410413
"type": "INTEGER",
411414
},
412415
},
413-
"unique": [
414-
{"name": "tiger_uniq", "columns": ["tiger"]}
415-
]
416+
"unique": [{"name": "tiger_uniq", "columns": ["tiger"]}],
416417
}
417418
}
418419
}

0 commit comments

Comments
 (0)