Skip to content

Commit 176ed63

Browse files
author
Tim Band
committed
Merge branch 'main' into remove-generator-file
2 parents f914804 + cf80f78 commit 176ed63

8 files changed

Lines changed: 247 additions & 21 deletions

File tree

datafaker/main.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _check_file_non_existence(file_path: Path) -> None:
7878
"""Check that a given file does not exist. Exit with an error message if it does."""
7979
if file_path.exists():
8080
logger.error("%s should not already exist. Exiting...", file_path)
81-
sys.exit(1)
81+
raise Exit(1)
8282

8383

8484
def load_metadata_config(
@@ -202,6 +202,8 @@ def create_data(
202202
return
203203
except RuntimeError as e:
204204
logger.error(e.args[0])
205+
except SettingsError as e:
206+
logger.error(str(e))
205207
raise Exit(1)
206208

207209

@@ -554,7 +556,7 @@ def convert_table_names_to_tables(
554556
"%s is not the name of a table in the destination database", name
555557
)
556558
if failed_count:
557-
sys.exit(1)
559+
raise Exit(1)
558560
return results
559561

560562

@@ -643,7 +645,7 @@ def dump_data(
643645
"Must specify exactly one table if the output name is"
644646
" specified, or specify an existing directory"
645647
)
646-
sys.exit(1)
648+
raise Exit(1)
647649
dst_dsn = get_destination_dsn()
648650
schema_name = get_destination_schema()
649651
config = read_config_file(config_file) if config_file is not None else {}
@@ -676,7 +678,7 @@ def validate_config(
676678
validate(config, schema_config)
677679
except ValidationError as e:
678680
logger.error(e)
679-
sys.exit(1)
681+
raise Exit(1) from e
680682
logger.debug("Config file is valid.")
681683

682684

@@ -772,7 +774,7 @@ def remove_tables(
772774
except InternalError as exc:
773775
logger.error("Failed to drop tables: %s", exc)
774776
logger.error("Please try again using the --all option.")
775-
sys.exit(1)
777+
raise Exit(1) from exc
776778
logger.debug("Tables dropped.")
777779
else:
778780
logger.info("Would remove tables if called with --yes.")

datafaker/make.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import pandas as pd
1313
import snsql
14+
import typer
1415
import yaml
1516
from black import FileMode, format_str
1617
from jinja2 import Environment, FileSystemLoader, Template
@@ -38,6 +39,7 @@
3839
create_db_engine,
3940
download_table,
4041
get_columns_assigned,
42+
get_metadata,
4143
get_property,
4244
get_property_or_none,
4345
get_related_table_names,
@@ -605,9 +607,21 @@ def get_generation_info(
605607
row_generator_module_name = get_property_or_none(
606608
config, "row_generators_module", str
607609
)
610+
if row_generator_module_name and "-" in row_generator_module_name:
611+
logger.error(
612+
"Row generator name %s should not contain a hyphen",
613+
row_generator_module_name,
614+
)
615+
raise typer.Exit(1)
608616
story_generator_module_name = get_property_or_none(
609617
config, "story_generators_module", str
610618
)
619+
if story_generator_module_name and "-" in story_generator_module_name:
620+
logger.error(
621+
"Story generator name %s should not contain a hyphen",
622+
story_generator_module_name,
623+
)
624+
raise typer.Exit(1)
611625
object_instantiation: dict[str, Any] = get_property(
612626
config, "object_instantiation", {}
613627
)
@@ -703,8 +717,7 @@ def make_tables_file(
703717
"""Construct the YAML file representing the schema."""
704718
engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name))
705719

706-
metadata = MetaData()
707-
metadata.reflect(engine)
720+
metadata = get_metadata(engine)
708721
meta_dict = metadata_to_dict(metadata, schema_name, engine, parquet_dir)
709722

710723
if parquet_dir is not None:

datafaker/remove.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from datafaker.settings import get_destination_dsn, get_destination_schema
77
from datafaker.utils import (
88
create_db_engine,
9+
get_metadata,
910
get_sync_engine,
1011
get_vocabulary_table_names,
1112
logger,
@@ -67,6 +68,5 @@ def remove_db_tables(metadata: Optional[MetaData]) -> None:
6768
)
6869
)
6970
if metadata is None:
70-
metadata = MetaData()
71-
metadata.reflect(dst_engine)
71+
metadata = get_metadata(dst_engine)
7272
metadata.drop_all(dst_engine)

datafaker/utils.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
from jsonschema.validators import validate
3333
from sqlalchemy import Connection, Engine, ForeignKey, create_engine, event, select
3434
from sqlalchemy.engine.interfaces import DBAPIConnection
35-
from sqlalchemy.exc import IntegrityError, ProgrammingError
35+
from sqlalchemy.exc import (
36+
IntegrityError,
37+
NoSuchModuleError,
38+
OperationalError,
39+
ProgrammingError,
40+
)
3641
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
3742
from sqlalchemy.orm import Session
3843
from sqlalchemy.schema import (
@@ -43,6 +48,9 @@
4348
MetaData,
4449
Table,
4550
)
51+
from typer import Exit
52+
53+
from datafaker.settings import SettingsError
4654

4755
# Define some types used repeatedly in the code base
4856
MaybeAsyncEngine = Union[Engine, AsyncEngine]
@@ -108,9 +116,15 @@ def import_file(file_path: str) -> ModuleType:
108116
"""
109117
spec = importlib.util.spec_from_file_location("df", file_path)
110118
if spec is None or spec.loader is None:
111-
raise ImportError(f"No loadable module at {file_path}")
119+
raise SettingsError(f"No loadable module '{file_path}'")
112120
module = importlib.util.module_from_spec(spec)
113-
spec.loader.exec_module(module)
121+
try:
122+
spec.loader.exec_module(module)
123+
except ModuleNotFoundError as e:
124+
logger.error("Failed to load module at %s with error:", file_path)
125+
logger.error(e)
126+
except FileNotFoundError as e:
127+
raise SettingsError(f"No module found '{file_path}'") from e
114128
return module
115129

116130

@@ -193,11 +207,19 @@ def create_db_engine(
193207
**kwargs: Any,
194208
) -> MaybeAsyncEngine:
195209
"""Create a SQLAlchemy Engine."""
196-
if use_asyncio:
197-
async_dsn = db_dsn.replace("postgresql://", "postgresql+asyncpg://")
198-
engine: MaybeAsyncEngine = create_async_engine(async_dsn, **kwargs)
199-
else:
200-
engine = create_engine(db_dsn, **kwargs)
210+
try:
211+
if use_asyncio:
212+
async_dsn = db_dsn.replace("postgresql://", "postgresql+asyncpg://")
213+
engine: MaybeAsyncEngine = create_async_engine(async_dsn, **kwargs)
214+
else:
215+
engine = create_engine(db_dsn, **kwargs)
216+
except NoSuchModuleError as exc:
217+
logger.error("Failed to connect to the database: %s", exc)
218+
logger.error("Perhaps the dialect '%s' is invalid.", db_dsn.split(":")[0])
219+
raise Exit(1) from exc
220+
except ValueError as exc:
221+
logger.error("DSN %s is malformed: %s", db_dsn, exc)
222+
raise Exit(1) from exc
201223

202224
settings = {}
203225
if schema_name is not None:
@@ -248,6 +270,17 @@ def create_db_engine_dst(
248270
return create_db_engine(db_dsn, schema_name, use_asyncio)
249271

250272

273+
def get_metadata(engine: Engine) -> MetaData:
274+
"""Get the MetaData object associated with the engine passed."""
275+
md = MetaData()
276+
try:
277+
md.reflect(engine)
278+
except OperationalError as exc:
279+
logger.error("Cannot connect to database: %s", exc)
280+
raise Exit(1) from exc
281+
return md
282+
283+
251284
def _find_parquet_directories(parquet_dir: Path) -> list[str]:
252285
"""Find all the directories under ``parquet_dir`` that contain parquet files."""
253286
return [

tests/test_create.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from datafaker.make import FunctionCall, StoryGeneratorInfo
2424
from datafaker.populate import TableGenerator
2525
from datafaker.serialize_metadata import dict_to_metadata, metadata_to_dict
26+
from datafaker.settings import SettingsError
2627
from datafaker.utils import sorted_non_vocabulary_tables
2728
from tests.utils import DatafakerTestCase, GeneratesDBTestCase, RequiresDBTestCase
2829

@@ -431,3 +432,34 @@ def test_unique_constraint_minimal(self) -> None:
431432
self.schema_name,
432433
metadata,
433434
)
435+
436+
def test_story_incorrect_name_minimal(self) -> None:
437+
"""Test we get a proper error message if the story generator module does not exist."""
438+
config = {
439+
"story_generators_module": "incorrect_module",
440+
}
441+
orm = {
442+
"tables": {
443+
"one": {
444+
"columns": {
445+
"id": {
446+
"primary": True,
447+
"type": "INTEGER",
448+
}
449+
}
450+
}
451+
}
452+
}
453+
metadata = dict_to_metadata(orm, config)
454+
create_db_tables_into(metadata, self.dsn, self.schema_name)
455+
self.assertRaises(
456+
SettingsError,
457+
create_db_data_into,
458+
sorted_non_vocabulary_tables(metadata, config),
459+
config,
460+
None,
461+
1,
462+
self.dsn,
463+
self.schema_name,
464+
metadata,
465+
)

tests/test_functional.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66
from typing import Any, Mapping
77

8+
import yaml
89
from sqlalchemy import create_engine, inspect
910
from typer.testing import CliRunner, Result
1011

@@ -514,6 +515,71 @@ def test_create_schema(self) -> None:
514515
inspector = inspect(engine)
515516
self.assertTrue(inspector.has_schema(env["dst_schema"]))
516517

518+
def test_story_incorrect_name(self) -> None:
519+
"""Test we get a proper error message if the story generator module does not exist."""
520+
config_file = "config_story_incorrect.yaml"
521+
config = {
522+
"story_generators_module": "incorrect_module",
523+
}
524+
with Path(config_file).open("w", encoding="utf-8") as fh:
525+
fh.write(yaml.dump(config))
526+
self.invoke(
527+
"make-tables",
528+
"--force",
529+
)
530+
self.invoke(
531+
"create-tables",
532+
"--config-file",
533+
config_file,
534+
)
535+
completed_process = self.invoke(
536+
"create-data",
537+
"--config-file",
538+
config_file,
539+
expected_error="No module found 'incorrect_module",
540+
)
541+
self.assertReturnCode(completed_process, 1)
542+
543+
def test_story_hyphens_in_name(self) -> None:
544+
"""Test hyphens in story generator names cause an error to be emitted."""
545+
config_file = "config_story_hyphens.yaml"
546+
config = {
547+
"story_generators_module": "story-generators",
548+
}
549+
with Path(config_file).open("w", encoding="utf-8") as fh:
550+
fh.write(yaml.dump(config))
551+
self.invoke(
552+
"make-tables",
553+
"--force",
554+
)
555+
completed_process = self.invoke(
556+
"create-data",
557+
"--config-file",
558+
config_file,
559+
expected_error="hyphen",
560+
)
561+
self.assertReturnCode(completed_process, 1)
562+
563+
def test_row_hyphens_in_name(self) -> None:
564+
"""Test hyphens in row generator names cause an error to be emitted."""
565+
config_file = "config_row_hyphens.yaml"
566+
config = {
567+
"row_generators_module": "row-generators",
568+
}
569+
with Path(config_file).open("w", encoding="utf-8") as fh:
570+
fh.write(yaml.dump(config))
571+
self.invoke(
572+
"make-tables",
573+
"--force",
574+
)
575+
completed_process = self.invoke(
576+
"create-data",
577+
"--config-file",
578+
config_file,
579+
expected_error="hyphen",
580+
)
581+
self.assertReturnCode(completed_process, 1)
582+
517583

518584
class DuckDbFunctionalTestCase(DBFunctionalTestCaseBase):
519585
"""End-to-end tests for the DuckDB workflow."""

0 commit comments

Comments
 (0)