Skip to content

Commit 18bb0bf

Browse files
authored
Destination DuckDB should not have access to parquet files (#81)
Destination DuckDB should not have access to parquet files
1 parent e06ad52 commit 18bb0bf

8 files changed

Lines changed: 261 additions & 48 deletions

File tree

datafaker/create.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from datafaker.base import FileUploader, TableGenerator
1414
from datafaker.settings import get_destination_dsn, get_destination_schema
1515
from datafaker.utils import (
16-
create_db_engine,
16+
create_db_engine_dst,
1717
get_sync_engine,
1818
get_vocabulary_table_names,
1919
logger,
@@ -61,9 +61,15 @@ def remove_on_delete_cascade(element: CreateTable, compiler: Any, **kw: Any) ->
6161
def create_db_tables(metadata: MetaData) -> None:
6262
"""Create tables described by the sqlalchemy metadata object."""
6363
dst_dsn = get_destination_dsn()
64-
engine = get_sync_engine(create_db_engine(dst_dsn))
65-
schema_name = get_destination_schema()
64+
assert dst_dsn != "", "Missing DST_DSN setting."
65+
create_db_tables_into(metadata, dst_dsn, get_destination_schema())
6666

67+
68+
def create_db_tables_into(
69+
metadata: MetaData, dst_dsn: str, schema_name: str | None = None
70+
) -> None:
71+
"""Create tables described by the sqlalchemy metadata object with explicit DSN."""
72+
engine = get_sync_engine(create_db_engine_dst(dst_dsn))
6773
# Create schema, if necessary.
6874
if schema_name is not None:
6975
with engine.connect() as connection:
@@ -75,9 +81,11 @@ def create_db_tables(metadata: MetaData) -> None:
7581
connection.commit()
7682

7783
# Recreate the engine, this time with a schema specified
78-
engine = get_sync_engine(create_db_engine(dst_dsn, schema_name=schema_name))
84+
engine.dispose()
85+
engine = get_sync_engine(create_db_engine_dst(dst_dsn, schema_name=schema_name))
7986

8087
metadata.create_all(engine)
88+
engine.dispose()
8189

8290

8391
def create_db_vocab(
@@ -95,7 +103,7 @@ def create_db_vocab(
95103
:return: List of table names loaded.
96104
"""
97105
dst_engine = get_sync_engine(
98-
create_db_engine(
106+
create_db_engine_dst(
99107
get_destination_dsn(),
100108
schema_name=get_destination_schema(),
101109
)
@@ -165,7 +173,7 @@ def create_db_data_into(
165173
:param db_dsn: Connection string for the destination database.
166174
:param schema_name: Destination schema name.
167175
"""
168-
dst_engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name))
176+
dst_engine = get_sync_engine(create_db_engine_dst(db_dsn, schema_name=schema_name))
169177

170178
row_counts: Counter[str] = Counter()
171179
with dst_engine.connect() as dst_conn:
@@ -177,6 +185,7 @@ def create_db_data_into(
177185
df_module.story_generator_list,
178186
metadata,
179187
)
188+
dst_engine.dispose()
180189
return row_counts
181190

182191

datafaker/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,35 @@ def connect(dbapi_connection: DBAPIConnection, _: Any) -> None:
207207
return engine
208208

209209

210+
def create_db_engine_dst(
211+
db_dsn: str,
212+
schema_name: Optional[str] = None,
213+
use_asyncio: bool = False,
214+
) -> MaybeAsyncEngine:
215+
"""
216+
Create a SQLAlchemy Engine suitable for output.
217+
218+
This prevents DuckDB from reading any parquet files avoiding any
219+
possible leakage from existing source files into the destination database.
220+
:param db_dsn: The database connection string.
221+
:param schema_name: The name of the schema within the database to use.
222+
:param use_asyncio: True if an asynchronous connection is required.
223+
:return: The ``Engine`` or ``AsyncEngine``.
224+
"""
225+
if db_dsn.startswith("duckdb:"):
226+
return create_db_engine(
227+
db_dsn,
228+
schema_name,
229+
use_asyncio,
230+
connect_args={
231+
"config": {
232+
"enable_external_access": False,
233+
}
234+
},
235+
)
236+
return create_db_engine(db_dsn, schema_name, use_asyncio)
237+
238+
210239
def set_search_path(connection: DBAPIConnection, schema: str) -> None:
211240
"""Set the SEARCH_PATH for a PostgreSQL connection."""
212241
# https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#remote-schema-table-introspection-and-postgresql-search-path

tests/test_create.py

Lines changed: 163 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,24 @@
22
import itertools as itt
33
import os
44
import random
5+
import tempfile
56
from collections import Counter
67
from pathlib import Path
78
from typing import Any, Generator, Mapping, Tuple
89
from unittest.mock import MagicMock, call, patch
910

10-
from sqlalchemy import Connection, select
11+
import duckdb
12+
import pandas as pd
13+
from sqlalchemy import Connection, Engine, select
1114
from sqlalchemy.schema import MetaData, Table
1215

1316
from datafaker.base import TableGenerator
14-
from datafaker.create import create_db_vocab, populate
15-
from datafaker.remove import remove_db_vocab
17+
from datafaker.create import (
18+
create_db_data_into,
19+
create_db_tables,
20+
create_db_vocab,
21+
populate,
22+
)
1623
from datafaker.serialize_metadata import metadata_to_dict
1724
from tests.utils import DatafakerTestCase, GeneratesDBTestCase
1825

@@ -28,7 +35,7 @@ def test_create_vocab(self) -> None:
2835
"""Test the create_db_vocab function."""
2936
with patch.dict(
3037
os.environ,
31-
{"DST_DSN": self.dsn, "DST_SCHEMA": self.schema_name},
38+
{"DST_DSN": self.dst_dsn},
3239
clear=True,
3340
):
3441
config = {
@@ -42,10 +49,9 @@ def test_create_vocab(self) -> None:
4249
meta_dict = metadata_to_dict(
4350
self.metadata, self.schema_name, self.sync_engine
4451
)
45-
self.remove_data(config)
46-
remove_db_vocab(self.metadata, meta_dict, config)
52+
create_db_tables(self.metadata)
4753
create_db_vocab(self.metadata, meta_dict, config, Path("./tests/examples"))
48-
with self.sync_engine.connect() as conn:
54+
with self.dst_sync_engine.connect() as conn:
4955
stmt = select(self.metadata.tables["player"])
5056
rows = list(conn.execute(stmt).mappings().fetchall())
5157
self.assertEqual(len(rows), 3)
@@ -64,7 +70,7 @@ def test_make_table_generators(self) -> None:
6470
random.seed(56)
6571
config: Mapping[str, Any] = {}
6672
self.generate_data(config, num_passes=2)
67-
with self.sync_engine.connect() as conn:
73+
with self.dst_sync_engine.connect() as conn:
6874
stmt = select(self.metadata.tables["string"])
6975
rows = list(conn.execute(stmt).mappings().fetchall())
7076
a = rows[0]
@@ -183,3 +189,152 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None:
183189

184190
mock_gen_two.assert_called_once()
185191
mock_gen_three.assert_called_once()
192+
193+
194+
class MockFunctionUsingConnection:
195+
"""Base mock callable that should not be permitted to read parquet files."""
196+
197+
@classmethod
198+
def is_parquet_permitted(cls, connection: Any) -> bool:
199+
"""Test if a normal DuckDB can access the ``fruit.parquet`` file."""
200+
try:
201+
connection.execute("SELECT * FROM fruit.parquet")
202+
except duckdb.PermissionException:
203+
return False
204+
return True
205+
206+
def __init__(self) -> None:
207+
"""Initialize as uncalled."""
208+
self.called = False
209+
210+
def do_call(self, connection: Any) -> None:
211+
"""Test for parquet access not being permitted."""
212+
assert not self.is_parquet_permitted(connection)
213+
self.called = True
214+
215+
216+
class CreateReadsNoParquetTestCase(DatafakerTestCase):
217+
"""
218+
Output to the database should not have access to parquet files.
219+
220+
Otherwise there is a risk of leakage of source data.
221+
"""
222+
223+
examples_dir = Path("tests/examples/duckdb")
224+
parquet_name = "fruit.parquet"
225+
226+
def setUp(self) -> None:
227+
"""Go to the directory where there are parquet files."""
228+
super().setUp()
229+
self.start_dir = os.getcwd()
230+
self.parquet_dir = Path(tempfile.mkdtemp("parq"))
231+
os.chdir(self.parquet_dir)
232+
self.write_parquet()
233+
assert MockFunctionUsingConnection.is_parquet_permitted(duckdb.connect())
234+
235+
def tearDown(self) -> None:
236+
"""Return to the start directory."""
237+
os.chdir(self.start_dir)
238+
return super().tearDown()
239+
240+
def write_parquet(self) -> None:
241+
"""Write a parquet file into the current directory."""
242+
fruit: dict[str, list[Any]] = {
243+
"id": [1, 2, 3],
244+
"orange": [True, True, False],
245+
"banana": ["one", "two", "three"],
246+
}
247+
pd.DataFrame.from_dict(fruit).to_parquet(self.parquet_name)
248+
249+
class MockCreateAll(MockFunctionUsingConnection):
250+
"""Mock for the MetaData.create_all function."""
251+
252+
def __call__(self, engine: Engine) -> None:
253+
self.do_call(engine.raw_connection())
254+
255+
def test_create_db_tables_cannot_access_parquet(self) -> None:
256+
"""Test the database connection cannot access parquet file."""
257+
meta_data = MagicMock()
258+
meta_data.create_all = self.MockCreateAll()
259+
with patch.dict(
260+
os.environ,
261+
{"DST_DSN": "duckdb:///:memory:tables"},
262+
clear=True,
263+
):
264+
create_db_tables(meta_data)
265+
assert meta_data.create_all.called
266+
267+
def test_create_db_tables_cannot_access_parquet_with_schema(self) -> None:
268+
"""
269+
Test the database connection cannot access parquet file.
270+
271+
We use a schema because this activates a different code path.
272+
"""
273+
meta_data = MagicMock()
274+
meta_data.create_all = self.MockCreateAll()
275+
testdb = duckdb.connect("./test.db")
276+
testdb.execute("CREATE SCHEMA fruity")
277+
testdb.close()
278+
with patch.dict(
279+
os.environ,
280+
{"DST_SCHEMA": "fruity", "DST_DSN": "duckdb:///./test.db"},
281+
clear=True,
282+
):
283+
create_db_tables(meta_data)
284+
assert meta_data.create_all.called
285+
286+
@patch("datafaker.create.populate")
287+
def test_create_db_data_cannot_access_parquet(
288+
self, mock_populate: MagicMock
289+
) -> None:
290+
"""Test the database connection cannot access parquet file while creating data."""
291+
292+
class MockPopulate(MockFunctionUsingConnection):
293+
"""Mock ``populate`` function."""
294+
295+
def __call__(
296+
self, connection: Connection, _a2: Any, _a3: Any, _a4: Any, _a5: Any
297+
) -> dict[str, Any]:
298+
super().do_call(connection.connection.dbapi_connection)
299+
return {"vocab1": 1}
300+
301+
mock_populate.side_effect = MockPopulate()
302+
create_db_data_into(
303+
[MagicMock()],
304+
MagicMock(),
305+
1,
306+
"duckdb:///:memory:data",
307+
None,
308+
MagicMock(),
309+
)
310+
assert mock_populate.side_effect.called
311+
312+
@patch("datafaker.create.FileUploader")
313+
def test_create_db_vocab_cannot_access_parquet(
314+
self, file_uploader: MagicMock
315+
) -> None:
316+
"""Test we cannot access parquet file while populating vocabulary tables."""
317+
318+
class MockLoader(MockFunctionUsingConnection):
319+
"""Mock ``FileUploader.load`` function."""
320+
321+
def __call__(self, connection: Connection, base_path: Path) -> None:
322+
assert str(base_path) == "base"
323+
super().do_call(connection.connection.dbapi_connection)
324+
325+
file_uploader.return_value.load = MockLoader()
326+
assert not file_uploader.return_value.load.called
327+
meta_data = MetaData()
328+
Table("table1", meta_data)
329+
with patch.dict(
330+
os.environ,
331+
{"DST_DSN": "duckdb:///:memory:vocab"},
332+
clear=True,
333+
):
334+
create_db_vocab(
335+
meta_data,
336+
{"tables": {"table1": {"columns": {}}}},
337+
{"tables": {"table1": {"vocabulary_table": True}}},
338+
base_path=Path("base"),
339+
)
340+
assert file_uploader.return_value.load.called

tests/test_dump.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def test_end_to_end_parquet(self) -> None:
146146
# Dump the fake tables
147147
outdir = Path(tempfile.mkdtemp("dump"))
148148
result = runner.invoke(app, ["dump-data", "--output", str(outdir), "--parquet"])
149+
print(result)
149150
self.assertSuccess(result)
150151

151152
# Check the dumped files

tests/test_interactive_generators.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def test_create_with_sampled_choice(self) -> None:
636636
gc.do_quit("")
637637
self.generate_data(gc.config, num_passes=200)
638638
# all generation possibilities should be present
639-
with self.sync_engine.connect() as conn:
639+
with self.dst_sync_engine.connect() as conn:
640640
stats = ChoiceMeasurementTableStats(self.metadata, conn)
641641
self.assertSetEqual(stats.ones, {1, 4})
642642
self.assertSetEqual(stats.twos, {2, 3})
@@ -654,7 +654,7 @@ def test_create_with_choice(self) -> None:
654654
gc.do_set(str(proposals["dist_gen.zipf_choice"][0]))
655655
gc.do_quit("")
656656
self.generate_data(gc.config, num_passes=200)
657-
with self.sync_engine.connect() as conn:
657+
with self.dst_sync_engine.connect() as conn:
658658
stmt = select(self.metadata.tables[table_name])
659659
rows = conn.execute(stmt).fetchall()
660660
ones = set()
@@ -733,13 +733,12 @@ def test_create_with_weighted_choice(self) -> None:
733733
gc.do_set(str(prop[0]))
734734
gc.do_quit("")
735735
self.generate_data(gc.config, num_passes=200)
736-
with self.sync_engine.connect() as conn:
737-
with self.sync_engine.connect() as conn:
738-
stats = ChoiceMeasurementTableStats(self.metadata, conn)
739-
# all generation possibilities should be present
740-
self.assertSetEqual(stats.ones, {1, 4})
741-
self.assertSetEqual(stats.twos, {1, 2, 3, 4, 5})
742-
self.assertSetEqual(stats.threes, {1, 2, 3, 4, 5})
736+
with self.dst_sync_engine.connect() as conn:
737+
stats = ChoiceMeasurementTableStats(self.metadata, conn)
738+
# all generation possibilities should be present
739+
self.assertSetEqual(stats.ones, {1, 4})
740+
self.assertSetEqual(stats.twos, {1, 2, 3, 4, 5})
741+
self.assertSetEqual(stats.threes, {1, 2, 3, 4, 5})
743742

744743

745744
class GeneratorsOutputTestsDuckDb(GeneratorsOutputTests):
@@ -780,7 +779,7 @@ def test_set_null(self) -> None:
780779
config = gc.config
781780
self.generate_data(config, num_passes=3)
782781
# Test that each missingness pattern is present in the database
783-
with self.sync_engine.connect() as conn:
782+
with self.dst_sync_engine.connect() as conn:
784783
# select(self.metadata.tables["string"].c["position", "frequency"]) would be nicer
785784
# but mypy doesn't like it
786785
stmt = select(
@@ -866,7 +865,7 @@ def test_varchar_ns_are_truncated(self) -> None:
866865
gc.do_quit("")
867866
config = gc.config
868867
self.generate_data(config, num_passes=15)
869-
with self.sync_engine.connect() as conn:
868+
with self.dst_sync_engine.connect() as conn:
870869
stmt = select(self.metadata.tables[table].c[column])
871870
rows = conn.execute(stmt).scalars().fetchall()
872871
self.assert_are_truncated_to(rows, 20)

0 commit comments

Comments
 (0)