Skip to content

Commit cf80f78

Browse files
tim-bandTim Band
andauthored
Some bugfixes (#84)
* Fixed #65 Stack trace if DSN is malformed. * Fix for #68 incorrect story_generators_module error * Fixed #69 hyphens in generator modules * Fixes some more of #65 --------- Co-authored-by: Tim Band <t.b@ucl>
1 parent d8e84d7 commit cf80f78

7 files changed

Lines changed: 220 additions & 20 deletions

File tree

datafaker/main.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _check_file_non_existence(file_path: Path) -> None:
8484
"""Check that a given file does not exist. Exit with an error message if it does."""
8585
if file_path.exists():
8686
logger.error("%s should not already exist. Exiting...", file_path)
87-
sys.exit(1)
87+
raise Exit(1)
8888

8989

9090
def load_metadata_config(
@@ -580,7 +580,7 @@ def convert_table_names_to_tables(
580580
"%s is not the name of a table in the destination database", name
581581
)
582582
if failed_count:
583-
sys.exit(1)
583+
raise Exit(1)
584584
return results
585585

586586

@@ -669,7 +669,7 @@ def dump_data(
669669
"Must specify exactly one table if the output name is"
670670
" specified, or specify an existing directory"
671671
)
672-
sys.exit(1)
672+
raise Exit(1)
673673
dst_dsn = get_destination_dsn()
674674
schema_name = get_destination_schema()
675675
config = read_config_file(config_file) if config_file is not None else {}
@@ -702,7 +702,7 @@ def validate_config(
702702
validate(config, schema_config)
703703
except ValidationError as e:
704704
logger.error(e)
705-
sys.exit(1)
705+
raise Exit(1) from e
706706
logger.debug("Config file is valid.")
707707

708708

@@ -798,7 +798,7 @@ def remove_tables(
798798
except InternalError as exc:
799799
logger.error("Failed to drop tables: %s", exc)
800800
logger.error("Please try again using the --all option.")
801-
sys.exit(1)
801+
raise Exit(1) from exc
802802
logger.debug("Tables dropped.")
803803
else:
804804
logger.info("Would remove tables if called with --yes.")

datafaker/make.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import pandas as pd
1313
import snsql
14+
import typer
1415
import yaml
1516
from black import FileMode, format_str
1617
from jinja2 import Environment, FileSystemLoader, Template
@@ -31,6 +32,7 @@
3132
create_db_engine,
3233
download_table,
3334
get_columns_assigned,
35+
get_metadata,
3436
get_property,
3537
get_related_table_names,
3638
get_row_generators,
@@ -606,7 +608,21 @@ def make_table_generators( # pylint: disable=too-many-locals
606608
:return: A string that is a valid Python module, once written to file.
607609
"""
608610
row_generator_module_name: str = config.get("row_generators_module", None)
611+
if row_generator_module_name and "-" in row_generator_module_name:
612+
logger.error(
613+
"Row generator name %s specified in %s should not contain a hyphen",
614+
row_generator_module_name,
615+
config_filename,
616+
)
617+
raise typer.Exit(1)
609618
story_generator_module_name = config.get("story_generators_module", None)
619+
if story_generator_module_name and "-" in story_generator_module_name:
620+
logger.error(
621+
"Story generator name %s specified in %s should not contain a hyphen",
622+
story_generator_module_name,
623+
config_filename,
624+
)
625+
raise typer.Exit(1)
610626
object_instantiation: dict[str, dict] = config.get("object_instantiation", {})
611627
tables_config = config.get("tables", {})
612628

@@ -703,8 +719,7 @@ def make_tables_file(
703719
"""Construct the YAML file representing the schema."""
704720
engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name))
705721

706-
metadata = MetaData()
707-
metadata.reflect(engine)
722+
metadata = get_metadata(engine)
708723
meta_dict = metadata_to_dict(metadata, schema_name, engine, parquet_dir)
709724

710725
if parquet_dir is not None:

datafaker/remove.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from datafaker.settings import get_destination_dsn, get_destination_schema
77
from datafaker.utils import (
88
create_db_engine,
9+
get_metadata,
910
get_sync_engine,
1011
get_vocabulary_table_names,
1112
logger,
@@ -67,6 +68,5 @@ def remove_db_tables(metadata: Optional[MetaData]) -> None:
6768
)
6869
)
6970
if metadata is None:
70-
metadata = MetaData()
71-
metadata.reflect(dst_engine)
71+
metadata = get_metadata(dst_engine)
7272
metadata.drop_all(dst_engine)

datafaker/utils.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@
3333
from jsonschema.validators import validate
3434
from sqlalchemy import Connection, Engine, ForeignKey, create_engine, event, select
3535
from sqlalchemy.engine.interfaces import DBAPIConnection
36-
from sqlalchemy.exc import IntegrityError, ProgrammingError
36+
from sqlalchemy.exc import (
37+
IntegrityError,
38+
NoSuchModuleError,
39+
OperationalError,
40+
ProgrammingError,
41+
)
3742
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
3843
from sqlalchemy.orm import Session
3944
from sqlalchemy.schema import (
@@ -43,6 +48,7 @@
4348
MetaData,
4449
Table,
4550
)
51+
from typer import Exit
4652

4753
# Define some types used repeatedly in the code base
4854
MaybeAsyncEngine = Union[Engine, AsyncEngine]
@@ -110,7 +116,11 @@ def import_file(file_path: str) -> ModuleType:
110116
if spec is None or spec.loader is None:
111117
raise ImportError(f"No loadable module at {file_path}")
112118
module = importlib.util.module_from_spec(spec)
113-
spec.loader.exec_module(module)
119+
try:
120+
spec.loader.exec_module(module)
121+
except ModuleNotFoundError as e:
122+
logger.error("Failed to load module at %s with error:", file_path)
123+
logger.error(e)
114124
return module
115125

116126

@@ -193,11 +203,19 @@ def create_db_engine(
193203
**kwargs: Any,
194204
) -> MaybeAsyncEngine:
195205
"""Create a SQLAlchemy Engine."""
196-
if use_asyncio:
197-
async_dsn = db_dsn.replace("postgresql://", "postgresql+asyncpg://")
198-
engine: MaybeAsyncEngine = create_async_engine(async_dsn, **kwargs)
199-
else:
200-
engine = create_engine(db_dsn, **kwargs)
206+
try:
207+
if use_asyncio:
208+
async_dsn = db_dsn.replace("postgresql://", "postgresql+asyncpg://")
209+
engine: MaybeAsyncEngine = create_async_engine(async_dsn, **kwargs)
210+
else:
211+
engine = create_engine(db_dsn, **kwargs)
212+
except NoSuchModuleError as exc:
213+
logger.error("Failed to connect to the database: %s", exc)
214+
logger.error("Perhaps the dialect '%s' is invalid.", db_dsn.split(":")[0])
215+
raise Exit(1) from exc
216+
except ValueError as exc:
217+
logger.error("DSN %s is malformed: %s", db_dsn, exc)
218+
raise Exit(1) from exc
201219

202220
settings = {}
203221
if schema_name is not None:
@@ -248,6 +266,17 @@ def create_db_engine_dst(
248266
return create_db_engine(db_dsn, schema_name, use_asyncio)
249267

250268

269+
def get_metadata(engine: Engine) -> MetaData:
270+
"""Get the MetaData object associated with the engine passed."""
271+
md = MetaData()
272+
try:
273+
md.reflect(engine)
274+
except OperationalError as exc:
275+
logger.error("Cannot connect to database: %s", exc)
276+
raise Exit(1) from exc
277+
return md
278+
279+
251280
def _find_parquet_directories(parquet_dir: Path) -> list[str]:
252281
"""Find all the directories under ``parquet_dir`` that contain parquet files."""
253282
return [

tests/test_functional.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66
from typing import Any, Mapping
77

8+
import yaml
89
from sqlalchemy import create_engine, inspect
910
from typer.testing import CliRunner, Result
1011

@@ -606,6 +607,81 @@ def test_create_schema(self) -> None:
606607
inspector = inspect(engine)
607608
self.assertTrue(inspector.has_schema(env["dst_schema"]))
608609

610+
def test_story_incorrect_name(self) -> None:
611+
"""Test we get a proper error message if the story generator module does not exist."""
612+
config_file = "config_story_incorrect.yaml"
613+
config = {
614+
"story_generators_module": "incorrect_module",
615+
}
616+
with Path(config_file).open("w", encoding="utf-8") as fh:
617+
fh.write(yaml.dump(config))
618+
self.invoke(
619+
"make-tables",
620+
"--force",
621+
)
622+
completed_process = self.invoke(
623+
"create-generators",
624+
"--force",
625+
"--config-file",
626+
config_file,
627+
)
628+
self.assertSuccess(completed_process)
629+
self.invoke(
630+
"create-tables",
631+
"--config-file",
632+
config_file,
633+
)
634+
self.assertSuccess(completed_process)
635+
completed_process = self.invoke(
636+
"create-data",
637+
"--config-file",
638+
config_file,
639+
expected_error="No module named 'incorrect_module'",
640+
)
641+
self.assertReturnCode(completed_process, 1)
642+
643+
def test_story_hyphens_in_name(self) -> None:
644+
"""Test hyphens in story generator names cause an error to be emitted."""
645+
config_file = "config_story_hyphens.yaml"
646+
config = {
647+
"story_generators_module": "story-generators",
648+
}
649+
with Path(config_file).open("w", encoding="utf-8") as fh:
650+
fh.write(yaml.dump(config))
651+
self.invoke(
652+
"make-tables",
653+
"--force",
654+
)
655+
completed_process = self.invoke(
656+
"create-generators",
657+
"--force",
658+
"--config-file",
659+
config_file,
660+
expected_error="hyphen",
661+
)
662+
self.assertReturnCode(completed_process, 1)
663+
664+
def test_row_hyphens_in_name(self) -> None:
665+
"""Test hyphens in row generator names cause an error to be emitted."""
666+
config_file = "config_row_hyphens.yaml"
667+
config = {
668+
"row_generators_module": "row-generators",
669+
}
670+
with Path(config_file).open("w", encoding="utf-8") as fh:
671+
fh.write(yaml.dump(config))
672+
self.invoke(
673+
"make-tables",
674+
"--force",
675+
)
676+
completed_process = self.invoke(
677+
"create-generators",
678+
"--force",
679+
"--config-file",
680+
config_file,
681+
expected_error="hyphen",
682+
)
683+
self.assertReturnCode(completed_process, 1)
684+
609685

610686
class DuckDbFunctionalTestCase(DBFunctionalTestCaseBase):
611687
"""End-to-end tests for the DuckDB workflow."""

tests/test_main.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,90 @@ def test_make_tables_with_force_enabled(
341341
mock_make_tables.reset_mock()
342342
mock_path.reset_mock()
343343

344+
@patch("datafaker.main.Path")
345+
@patch("datafaker.settings.get_settings")
346+
def test_incorrect_dialect_causes_nice_error_message(
347+
self,
348+
mock_get_settings: MagicMock,
349+
mock_path: MagicMock,
350+
) -> None:
351+
"""Test the make-tables sub-command, when the force option is activated."""
352+
mock_get_settings.return_value = Settings(
353+
# postgres: not postgresql: will cause sqlalchemy to fail to connect
354+
src_dsn="postgres://suser:spassword@shost:5432/sdbname",
355+
dst_dsn="postgresql://duser:dpassword@dhost:5432/ddbname",
356+
# To stop any local .env files influencing the test
357+
# The mypy ignore can be removed once we upgrade to pydantic 2.
358+
_env_file=None, # type: ignore[call-arg]
359+
)
360+
mock_path.return_value.exists.return_value = True
361+
362+
result = runner.invoke(
363+
app,
364+
[
365+
"make-tables",
366+
"--force",
367+
"--orm-file=tests/examples/example_orm.yaml",
368+
],
369+
)
370+
self.assertIs(type(result.exception), SystemExit)
371+
372+
@patch("datafaker.main.Path")
373+
@patch("datafaker.settings.get_settings")
374+
def test_invalid_host_causes_nice_error_message(
375+
self,
376+
mock_get_settings: MagicMock,
377+
mock_path: MagicMock,
378+
) -> None:
379+
"""Test the make-tables sub-command, when the force option is activated."""
380+
mock_get_settings.return_value = Settings(
381+
# postgres: not postgresql: will cause sqlalchemy to fail to connect
382+
src_dsn="postgresql://suser:spassword@invalid_host:5432/sdbname",
383+
dst_dsn="postgresql://duser:dpassword@dhost:5432/ddbname",
384+
# To stop any local .env files influencing the test
385+
# The mypy ignore can be removed once we upgrade to pydantic 2.
386+
_env_file=None, # type: ignore[call-arg]
387+
)
388+
mock_path.return_value.exists.return_value = True
389+
390+
result = runner.invoke(
391+
app,
392+
[
393+
"make-tables",
394+
"--force",
395+
"--orm-file=tests/examples/example_orm.yaml",
396+
],
397+
)
398+
self.assertIs(type(result.exception), SystemExit)
399+
400+
@patch("datafaker.main.Path")
401+
@patch("datafaker.settings.get_settings")
402+
def test_incorrect_dsn_causes_nice_error_message(
403+
self,
404+
mock_get_settings: MagicMock,
405+
mock_path: MagicMock,
406+
) -> None:
407+
"""Test the make-tables sub-command, when the force option is activated."""
408+
mock_get_settings.return_value = Settings(
409+
# postgres: not postgresql: will cause sqlalchemy to fail to connect
410+
src_dsn="postgresql://suser:spassword:localhost:5432/sdbname",
411+
dst_dsn="postgresql://duser:dpassword@dhost:5432/ddbname",
412+
# To stop any local .env files influencing the test
413+
# The mypy ignore can be removed once we upgrade to pydantic 2.
414+
_env_file=None, # type: ignore[call-arg]
415+
)
416+
mock_path.return_value.exists.return_value = True
417+
418+
result = runner.invoke(
419+
app,
420+
[
421+
"make-tables",
422+
"--force",
423+
"--orm-file=tests/examples/example_orm.yaml",
424+
],
425+
)
426+
self.assertIs(type(result.exception), SystemExit)
427+
344428
def test_validate_config(self) -> None:
345429
"""Test the validate-config sub-command."""
346430
result = runner.invoke(

tests/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@
3838
)
3939

4040

41-
class SysExit(Exception):
42-
"""To force the function to exit as sys.exit() would."""
43-
44-
4541
@lru_cache(1)
4642
def get_test_settings() -> settings.Settings:
4743
"""Get a Settings object that ignores .env files and environment variables."""

0 commit comments

Comments
 (0)