1515
1616class DatabaseManager :
1717
18- def __init__ (self , db_type : str = "postgresql" ):
18+ def __init__ (self , db_type : str | None = None ):
1919 self .db_type = db_type
2020 self ._engine : Engine | AsyncEngine | None = None
2121 self ._async_session_factory = None
2222 self ._sync_session_factory = None
2323
24- def get_database_url (self , async_driver : bool = False ) -> str :
24+ def get_database_url (
25+ self , async_driver : bool = False , db_type : str | None = None
26+ ) -> str :
27+ db_type = db_type or self .db_type or os .getenv ("TEST_DB" , "postgresql" )
2528
26- if self . db_type == "sqlite" :
29+ if db_type == "sqlite" :
2730 if async_driver :
2831 return (
2932 f"sqlite+aiosqlite:///{ os .getenv ('SQLITE_DB' , 'test_sqlnotify.db' )} "
@@ -44,8 +47,9 @@ def get_database_url(self, async_driver: bool = False) -> str:
4447
4548 return f"postgresql+psycopg://{ user } :{ password } @{ host } :{ port } /{ db_name } "
4649
47- def get_base_connection_url (self ) -> str :
48- if self .db_type == "sqlite" :
50+ def get_base_connection_url (self , db_type : str | None = None ) -> str :
51+ db_type = db_type or self .db_type or os .getenv ("TEST_DB" , "postgresql" )
52+ if db_type == "sqlite" :
4953 return "" # SQLite doesn't need base connection
5054
5155 user = os .getenv ("POSTGRES_USER" , "postgres" )
@@ -55,15 +59,16 @@ def get_base_connection_url(self) -> str:
5559
5660 return f"postgresql+psycopg://{ user } :{ password } @{ host } :{ port } /postgres"
5761
58- def create_test_database (self ) -> None :
62+ def create_test_database (self , db_type : str | None = None ) -> None :
63+ db_type = db_type or self .db_type or os .getenv ("TEST_DB" , "postgresql" )
5964
60- if self . db_type == "sqlite" :
65+ if db_type == "sqlite" :
6166 return
6267
6368 worker_id = os .environ .get ("PYTEST_XDIST_WORKER" , "master" )
6469 db_name = f"{ os .getenv ('POSTGRES_DB' , 'sqlnotify_test' )} _{ worker_id } "
6570
66- base_url = self .get_base_connection_url ()
71+ base_url = self .get_base_connection_url (db_type = db_type )
6772 admin_engine = create_engine (base_url , isolation_level = "AUTOCOMMIT" )
6873
6974 with admin_engine .connect () as conn :
@@ -72,9 +77,10 @@ def create_test_database(self) -> None:
7277
7378 admin_engine .dispose ()
7479
75- def drop_test_database (self ) -> None :
80+ def drop_test_database (self , db_type : str | None = None ) -> None :
81+ db_type = db_type or self .db_type or os .getenv ("TEST_DB" , "postgresql" )
7682
77- if self . db_type == "sqlite" :
83+ if db_type == "sqlite" :
7884 db_file = os .getenv ("SQLITE_DB" , "test_sqlnotify.db" )
7985
8086 if os .path .exists (db_file ):
@@ -85,28 +91,29 @@ def drop_test_database(self) -> None:
8591 worker_id = os .environ .get ("PYTEST_XDIST_WORKER" , "master" )
8692 db_name = f"{ os .getenv ('POSTGRES_DB' , 'sqlnotify_test' )} _{ worker_id } "
8793
88- base_url = self .get_base_connection_url ()
94+ base_url = self .get_base_connection_url (db_type = db_type )
8995 admin_engine = create_engine (base_url , isolation_level = "AUTOCOMMIT" )
9096
9197 with admin_engine .connect () as conn :
9298 conn .execute (text (f"DROP DATABASE IF EXISTS { db_name } WITH (FORCE)" ))
9399
94100 admin_engine .dispose ()
95101
96- def create_async_engine (self ) -> AsyncEngine :
97-
98- url = self .get_database_url (async_driver = True )
102+ def create_async_engine (self , db_type : str | None = None ) -> AsyncEngine :
103+ url = self .get_database_url (async_driver = True , db_type = db_type )
99104 return create_async_engine (url , echo = False )
100105
101- def create_sync_engine (self ) -> Engine :
102-
103- url = self .get_database_url (async_driver = False )
106+ def create_sync_engine (self , db_type : str | None = None ) -> Engine :
107+ url = self .get_database_url (async_driver = False , db_type = db_type )
104108 return create_engine (url , echo = False )
105109
106- async def create_tables_async (self , engine : AsyncEngine ) -> None :
110+ async def create_tables_async (
111+ self , engine : AsyncEngine , db_type : str | None = None
112+ ) -> None :
113+ db_type = db_type or self .db_type or os .getenv ("TEST_DB" , "postgresql" )
107114
108115 async with engine .begin () as conn :
109- if self . db_type == "sqlite" :
116+ if db_type == "sqlite" :
110117
111118 def create_filtered_tables (connection ):
112119 tables_to_create = [
@@ -121,8 +128,9 @@ def create_filtered_tables(connection):
121128 await conn .execute (text ("CREATE SCHEMA IF NOT EXISTS analytics" ))
122129 await conn .run_sync (SQLModel .metadata .create_all )
123130
124- def create_tables_sync (self , engine : Engine ) -> None :
125- if self .db_type == "sqlite" :
131+ def create_tables_sync (self , engine : Engine , db_type : str | None = None ) -> None :
132+ db_type = db_type or self .db_type or os .getenv ("TEST_DB" , "postgresql" )
133+ if db_type == "sqlite" :
126134 tables_to_create = [
127135 table
128136 for table in SQLModel .metadata .sorted_tables
@@ -135,10 +143,13 @@ def create_tables_sync(self, engine: Engine) -> None:
135143
136144 SQLModel .metadata .create_all (engine )
137145
138- async def drop_tables_async (self , engine : AsyncEngine ) -> None :
146+ async def drop_tables_async (
147+ self , engine : AsyncEngine , db_type : str | None = None
148+ ) -> None :
149+ db_type = db_type or self .db_type or os .getenv ("TEST_DB" , "postgresql" )
139150
140151 async with engine .begin () as conn :
141- if self . db_type == "sqlite" :
152+ if db_type == "sqlite" :
142153
143154 def drop_filtered_tables (connection ):
144155 tables_to_drop = [
@@ -152,8 +163,9 @@ def drop_filtered_tables(connection):
152163 else :
153164 await conn .run_sync (SQLModel .metadata .drop_all )
154165
155- def drop_tables_sync (self , engine : Engine ) -> None :
156- if self .db_type == "sqlite" :
166+ def drop_tables_sync (self , engine : Engine , db_type : str | None = None ) -> None :
167+ db_type = db_type or self .db_type or os .getenv ("TEST_DB" , "postgresql" )
168+ if db_type == "sqlite" :
157169 tables_to_drop = [
158170 table
159171 for table in SQLModel .metadata .sorted_tables
0 commit comments