22import itertools as itt
33import os
44import random
5+ import tempfile
56from collections import Counter
67from pathlib import Path
78from typing import Any , Generator , Mapping , Tuple
89from unittest .mock import MagicMock , call , patch
910
10- from sqlalchemy import Connection , select
11+ import duckdb
12+ import pandas as pd
13+ from sqlalchemy import Connection , Engine , select
1114from sqlalchemy .schema import MetaData , Table
1215
1316from datafaker .base import TableGenerator
14- from datafaker .create import create_db_vocab , populate
15- from datafaker .remove import remove_db_vocab
17+ from datafaker .create import (
18+ create_db_data_into ,
19+ create_db_tables ,
20+ create_db_vocab ,
21+ populate ,
22+ )
1623from datafaker .serialize_metadata import metadata_to_dict
1724from tests .utils import DatafakerTestCase , GeneratesDBTestCase
1825
@@ -28,7 +35,7 @@ def test_create_vocab(self) -> None:
2835 """Test the create_db_vocab function."""
2936 with patch .dict (
3037 os .environ ,
31- {"DST_DSN" : self .dsn , "DST_SCHEMA" : self . schema_name },
38+ {"DST_DSN" : self .dst_dsn },
3239 clear = True ,
3340 ):
3441 config = {
@@ -42,10 +49,9 @@ def test_create_vocab(self) -> None:
4249 meta_dict = metadata_to_dict (
4350 self .metadata , self .schema_name , self .sync_engine
4451 )
45- self .remove_data (config )
46- remove_db_vocab (self .metadata , meta_dict , config )
52+ create_db_tables (self .metadata )
4753 create_db_vocab (self .metadata , meta_dict , config , Path ("./tests/examples" ))
48- with self .sync_engine .connect () as conn :
54+ with self .dst_sync_engine .connect () as conn :
4955 stmt = select (self .metadata .tables ["player" ])
5056 rows = list (conn .execute (stmt ).mappings ().fetchall ())
5157 self .assertEqual (len (rows ), 3 )
@@ -64,7 +70,7 @@ def test_make_table_generators(self) -> None:
6470 random .seed (56 )
6571 config : Mapping [str , Any ] = {}
6672 self .generate_data (config , num_passes = 2 )
67- with self .sync_engine .connect () as conn :
73+ with self .dst_sync_engine .connect () as conn :
6874 stmt = select (self .metadata .tables ["string" ])
6975 rows = list (conn .execute (stmt ).mappings ().fetchall ())
7076 a = rows [0 ]
@@ -183,3 +189,152 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None:
183189
184190 mock_gen_two .assert_called_once ()
185191 mock_gen_three .assert_called_once ()
192+
193+
194+ class MockFunctionUsingConnection :
195+ """Base mock callable that should not be permitted to read parquet files."""
196+
197+ @classmethod
198+ def is_parquet_permitted (cls , connection : Any ) -> bool :
199+ """Test if a normal DuckDB can access the ``fruit.parquet`` file."""
200+ try :
201+ connection .execute ("SELECT * FROM fruit.parquet" )
202+ except duckdb .PermissionException :
203+ return False
204+ return True
205+
206+ def __init__ (self ) -> None :
207+ """Initialize as uncalled."""
208+ self .called = False
209+
210+ def do_call (self , connection : Any ) -> None :
211+ """Test for parquet access not being permitted."""
212+ assert not self .is_parquet_permitted (connection )
213+ self .called = True
214+
215+
216+ class CreateReadsNoParquetTestCase (DatafakerTestCase ):
217+ """
218+ Output to the database should not have access to parquet files.
219+
220+ Otherwise there is a risk of leakage of source data.
221+ """
222+
223+ examples_dir = Path ("tests/examples/duckdb" )
224+ parquet_name = "fruit.parquet"
225+
226+ def setUp (self ) -> None :
227+ """Go to the directory where there are parquet files."""
228+ super ().setUp ()
229+ self .start_dir = os .getcwd ()
230+ self .parquet_dir = Path (tempfile .mkdtemp ("parq" ))
231+ os .chdir (self .parquet_dir )
232+ self .write_parquet ()
233+ assert MockFunctionUsingConnection .is_parquet_permitted (duckdb .connect ())
234+
235+ def tearDown (self ) -> None :
236+ """Return to the start directory."""
237+ os .chdir (self .start_dir )
238+ return super ().tearDown ()
239+
240+ def write_parquet (self ) -> None :
241+ """Write a parquet file into the current directory."""
242+ fruit : dict [str , list [Any ]] = {
243+ "id" : [1 , 2 , 3 ],
244+ "orange" : [True , True , False ],
245+ "banana" : ["one" , "two" , "three" ],
246+ }
247+ pd .DataFrame .from_dict (fruit ).to_parquet (self .parquet_name )
248+
249+ class MockCreateAll (MockFunctionUsingConnection ):
250+ """Mock for the MetaData.create_all function."""
251+
252+ def __call__ (self , engine : Engine ) -> None :
253+ self .do_call (engine .raw_connection ())
254+
255+ def test_create_db_tables_cannot_access_parquet (self ) -> None :
256+ """Test the database connection cannot access parquet file."""
257+ meta_data = MagicMock ()
258+ meta_data .create_all = self .MockCreateAll ()
259+ with patch .dict (
260+ os .environ ,
261+ {"DST_DSN" : "duckdb:///:memory:tables" },
262+ clear = True ,
263+ ):
264+ create_db_tables (meta_data )
265+ assert meta_data .create_all .called
266+
267+ def test_create_db_tables_cannot_access_parquet_with_schema (self ) -> None :
268+ """
269+ Test the database connection cannot access parquet file.
270+
271+ We use a schema because this activates a different code path.
272+ """
273+ meta_data = MagicMock ()
274+ meta_data .create_all = self .MockCreateAll ()
275+ testdb = duckdb .connect ("./test.db" )
276+ testdb .execute ("CREATE SCHEMA fruity" )
277+ testdb .close ()
278+ with patch .dict (
279+ os .environ ,
280+ {"DST_SCHEMA" : "fruity" , "DST_DSN" : "duckdb:///./test.db" },
281+ clear = True ,
282+ ):
283+ create_db_tables (meta_data )
284+ assert meta_data .create_all .called
285+
286+ @patch ("datafaker.create.populate" )
287+ def test_create_db_data_cannot_access_parquet (
288+ self , mock_populate : MagicMock
289+ ) -> None :
290+ """Test the database connection cannot access parquet file while creating data."""
291+
292+ class MockPopulate (MockFunctionUsingConnection ):
293+ """Mock ``populate`` function."""
294+
295+ def __call__ (
296+ self , connection : Connection , _a2 : Any , _a3 : Any , _a4 : Any , _a5 : Any
297+ ) -> dict [str , Any ]:
298+ super ().do_call (connection .connection .dbapi_connection )
299+ return {"vocab1" : 1 }
300+
301+ mock_populate .side_effect = MockPopulate ()
302+ create_db_data_into (
303+ [MagicMock ()],
304+ MagicMock (),
305+ 1 ,
306+ "duckdb:///:memory:data" ,
307+ None ,
308+ MagicMock (),
309+ )
310+ assert mock_populate .side_effect .called
311+
312+ @patch ("datafaker.create.FileUploader" )
313+ def test_create_db_vocab_cannot_access_parquet (
314+ self , file_uploader : MagicMock
315+ ) -> None :
316+ """Test we cannot access parquet file while populating vocabulary tables."""
317+
318+ class MockLoader (MockFunctionUsingConnection ):
319+ """Mock ``FileUploader.load`` function."""
320+
321+ def __call__ (self , connection : Connection , base_path : Path ) -> None :
322+ assert str (base_path ) == "base"
323+ super ().do_call (connection .connection .dbapi_connection )
324+
325+ file_uploader .return_value .load = MockLoader ()
326+ assert not file_uploader .return_value .load .called
327+ meta_data = MetaData ()
328+ Table ("table1" , meta_data )
329+ with patch .dict (
330+ os .environ ,
331+ {"DST_DSN" : "duckdb:///:memory:vocab" },
332+ clear = True ,
333+ ):
334+ create_db_vocab (
335+ meta_data ,
336+ {"tables" : {"table1" : {"columns" : {}}}},
337+ {"tables" : {"table1" : {"vocabulary_table" : True }}},
338+ base_path = Path ("base" ),
339+ )
340+ assert file_uploader .return_value .load .called
0 commit comments