|
32 | 32 | from jsonschema.validators import validate |
33 | 33 | from sqlalchemy import Connection, Engine, ForeignKey, create_engine, event, select |
34 | 34 | from sqlalchemy.engine.interfaces import DBAPIConnection |
35 | | -from sqlalchemy.exc import IntegrityError, ProgrammingError |
| 35 | +from sqlalchemy.exc import ( |
| 36 | + IntegrityError, |
| 37 | + NoSuchModuleError, |
| 38 | + OperationalError, |
| 39 | + ProgrammingError, |
| 40 | +) |
36 | 41 | from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine |
37 | 42 | from sqlalchemy.orm import Session |
38 | 43 | from sqlalchemy.schema import ( |
|
43 | 48 | MetaData, |
44 | 49 | Table, |
45 | 50 | ) |
| 51 | +from typer import Exit |
| 52 | + |
| 53 | +from datafaker.settings import SettingsError |
46 | 54 |
|
47 | 55 | # Define some types used repeatedly in the code base |
48 | 56 | MaybeAsyncEngine = Union[Engine, AsyncEngine] |
@@ -108,9 +116,15 @@ def import_file(file_path: str) -> ModuleType: |
108 | 116 | """ |
109 | 117 | spec = importlib.util.spec_from_file_location("df", file_path) |
110 | 118 | if spec is None or spec.loader is None: |
111 | | - raise ImportError(f"No loadable module at {file_path}") |
| 119 | + raise SettingsError(f"No loadable module '{file_path}'") |
112 | 120 | module = importlib.util.module_from_spec(spec) |
113 | | - spec.loader.exec_module(module) |
| 121 | + try: |
| 122 | + spec.loader.exec_module(module) |
| 123 | + except ModuleNotFoundError as e: |
| 124 | + logger.error("Failed to load module at %s with error:", file_path) |
| 125 | + logger.error(e) |
| 126 | + except FileNotFoundError as e: |
| 127 | + raise SettingsError(f"No module found '{file_path}'") from e |
114 | 128 | return module |
115 | 129 |
|
116 | 130 |
|
@@ -193,11 +207,19 @@ def create_db_engine( |
193 | 207 | **kwargs: Any, |
194 | 208 | ) -> MaybeAsyncEngine: |
195 | 209 | """Create a SQLAlchemy Engine.""" |
196 | | - if use_asyncio: |
197 | | - async_dsn = db_dsn.replace("postgresql://", "postgresql+asyncpg://") |
198 | | - engine: MaybeAsyncEngine = create_async_engine(async_dsn, **kwargs) |
199 | | - else: |
200 | | - engine = create_engine(db_dsn, **kwargs) |
| 210 | + try: |
| 211 | + if use_asyncio: |
| 212 | + async_dsn = db_dsn.replace("postgresql://", "postgresql+asyncpg://") |
| 213 | + engine: MaybeAsyncEngine = create_async_engine(async_dsn, **kwargs) |
| 214 | + else: |
| 215 | + engine = create_engine(db_dsn, **kwargs) |
| 216 | + except NoSuchModuleError as exc: |
| 217 | + logger.error("Failed to connect to the database: %s", exc) |
| 218 | + logger.error("Perhaps the dialect '%s' is invalid.", db_dsn.split(":")[0]) |
| 219 | + raise Exit(1) from exc |
| 220 | + except ValueError as exc: |
| 221 | + logger.error("DSN %s is malformed: %s", db_dsn, exc) |
| 222 | + raise Exit(1) from exc |
201 | 223 |
|
202 | 224 | settings = {} |
203 | 225 | if schema_name is not None: |
@@ -248,6 +270,17 @@ def create_db_engine_dst( |
248 | 270 | return create_db_engine(db_dsn, schema_name, use_asyncio) |
249 | 271 |
|
250 | 272 |
|
| 273 | +def get_metadata(engine: Engine) -> MetaData: |
| 274 | + """Get the MetaData object associated with the engine passed.""" |
| 275 | + md = MetaData() |
| 276 | + try: |
| 277 | + md.reflect(engine) |
| 278 | + except OperationalError as exc: |
| 279 | + logger.error("Cannot connect to database: %s", exc) |
| 280 | + raise Exit(1) from exc |
| 281 | + return md |
| 282 | + |
| 283 | + |
251 | 284 | def _find_parquet_directories(parquet_dir: Path) -> list[str]: |
252 | 285 | """Find all the directories under ``parquet_dir`` that contain parquet files.""" |
253 | 286 | return [ |
|
0 commit comments