3434DB_ODBC_DRIVER = "odbc_driver"
3535DB_SESSION = "session"
3636DB_ODBC_UTF8 = "odbc: enable utf-8 encoding"
37+ DB_WAREHOUSE = "warehouse"
38+ DB_SCHEMA = "schema"
39+ DB_ACCOUNT = "account"
3740
3841
3942# [NON ACTION]
@@ -102,6 +105,16 @@ def find_odbc_driver(db_type="postgresql"):
102105
103106def handle_db_exception (sModuleInfo , e ):
104107 import pyodbc
108+
109+ # Handle Snowflake exceptions
110+ try :
111+ import snowflake .connector .errors as snowflake_errors
112+ if isinstance (e , snowflake_errors .Error ):
113+ traceback .print_exc ()
114+ CommonUtil .ExecLog (sModuleInfo , f"Snowflake Error: { e } " , 3 )
115+ return CommonUtil .Exception_Handler (e )
116+ except ImportError :
117+ pass # Snowflake connector not installed
105118
106119 if isinstance (e , pyodbc .DataError ):
107120 traceback .print_exc ()
@@ -243,6 +256,26 @@ def db_get_connection(session_name):
243256 host = db_host ,
244257 port = db_port
245258 )
259+ elif "snowflake" in db_type :
260+ import snowflake .connector
261+
262+ # Get Snowflake-specific parameters
263+ account = db_params .get (DB_ACCOUNT )
264+ if not account :
265+ account = db_host .replace ('.snowflakecomputing.com' , '' ) if '.snowflakecomputing.com' in db_host else db_host
266+ warehouse = db_params .get (DB_WAREHOUSE ) or 'COMPUTE_WH'
267+ schema = db_params .get (DB_SCHEMA ) or 'PUBLIC'
268+
269+ # Connect to Snowflake
270+ db_con = snowflake .connector .connect (
271+ user = db_user_id ,
272+ password = db_password ,
273+ account = account ,
274+ database = db_name ,
275+ warehouse = warehouse ,
276+ schema = schema
277+ )
278+ CommonUtil .ExecLog (sModuleInfo , "Connected to Snowflake." , 1 )
246279 elif "oracle" in db_type :
247280 import cx_Oracle
248281
@@ -303,14 +336,17 @@ def connect_to_db(data_set):
303336 This action just stores the different database specific configs into shared variables for use by other actions.
304337 NOTE: The actual db connection does not happen here, connection to db is made inside the actions which require it.
305338
306- db_type input parameter <type of db, ex: postgres, mysql>
339+ db_type input parameter <type of db, ex: postgres, mysql, snowflake >
307340 db_name input parameter <name of db, ex: zeuz_db>
308341 db_user_id input parameter <user id of the os who have access to the db, ex: postgres>
309342 db_password input parameter <password of db, ex: mydbpass-mY1-t23z>
310343 db_host input parameter <host of db, ex: localhost, 127.0.0.1>
311344 db_port input parameter <port of db, ex: 5432 for postgres by default>
312345 sid optional parameter <sid of db, ex: 15321 for oracle by default>
313346 service_name optional parameter <service_name of db, ex: 'somename' for oracle by default>
347+ warehouse optional parameter <warehouse for Snowflake, ex: COMPUTE_WH>
348+ schema optional parameter <schema for Snowflake, ex: PUBLIC>
349+ account optional parameter <account identifier for Snowflake>
314350 odbc_driver optional parameter <specify the odbc driver, optional, can be found from pyodbc.drivers()>
315351 odbc: enable utf-8 encoding optional parameter true/false - optionally enable utf-8 encoding
316352 connect to db database action Connect to a database
@@ -324,6 +360,7 @@ def connect_to_db(data_set):
324360 try :
325361 # Default values
326362 db_type = db_name = db_user_id = db_password = db_host = db_port = db_sid = db_service_name = db_odbc_driver = db_params = None
363+ db_warehouse = db_schema = db_account = None
327364 db_enable_odbc_utf8 = True
328365 session_name = "default"
329366
@@ -349,6 +386,12 @@ def connect_to_db(data_set):
349386 sr .Set_Shared_Variables (DB_ODBC_DRIVER ,right .strip ())
350387 if left == DB_ODBC_UTF8 :
351388 db_enable_odbc_utf8 = CommonUtil .parse_value_into_object (right .strip ()) == True
389+ if left == DB_WAREHOUSE or left == "warehouse" :
390+ db_warehouse = right .strip ()
391+ if left == DB_SCHEMA :
392+ db_schema = right .strip ()
393+ if left == DB_ACCOUNT :
394+ db_account = right .strip ()
352395 if DB_SESSION in left :
353396 session_name = right .strip ()
354397
@@ -363,6 +406,9 @@ def connect_to_db(data_set):
363406 DB_SERVICE_NAME : db_service_name ,
364407 DB_ODBC_DRIVER : db_odbc_driver ,
365408 DB_ODBC_UTF8 : db_enable_odbc_utf8 ,
409+ DB_WAREHOUSE : db_warehouse ,
410+ DB_SCHEMA : db_schema ,
411+ DB_ACCOUNT : db_account ,
366412 }
367413
368414 if sr .Test_Shared_Variables ('db_sessions' ):
@@ -426,6 +472,10 @@ def db_select(data_set):
426472
427473 # Get db_cursor and execute
428474 db_con = db_get_connection (session_name )
475+ if db_con == "zeuz_failed" :
476+ CommonUtil .ExecLog (sModuleInfo , "Failed to get database connection" , 3 )
477+ return "zeuz_failed"
478+
429479 with db_con :
430480 with db_con .cursor () as db_cursor :
431481 db_cursor .execute (query )
@@ -539,6 +589,10 @@ def select_from_db(data_set):
539589
540590 # Get db_cursor and execute
541591 db_con = db_get_connection (session_name )
592+ if db_con == "zeuz_failed" :
593+ CommonUtil .ExecLog (sModuleInfo , "Failed to get database connection" , 3 )
594+ return "zeuz_failed"
595+
542596 with db_con :
543597 with db_con .cursor () as db_cursor :
544598 db_cursor .execute (query )
@@ -627,6 +681,10 @@ def insert_into_db(data_set):
627681 CommonUtil .ExecLog (sModuleInfo , "Executing query:\n %s." % query , 1 )
628682
629683 db_con = db_get_connection (session_name )
684+ if db_con == "zeuz_failed" :
685+ CommonUtil .ExecLog (sModuleInfo , "Failed to get database connection" , 3 )
686+ return "zeuz_failed"
687+
630688 with db_con :
631689 with db_con .cursor () as db_cursor :
632690 db_cursor .execute (query )
@@ -699,6 +757,10 @@ def delete_from_db(data_set):
699757
700758 # Get db_cursor and execute
701759 db_con = db_get_connection (session_name )
760+ if db_con == "zeuz_failed" :
761+ CommonUtil .ExecLog (sModuleInfo , "Failed to get database connection" , 3 )
762+ return "zeuz_failed"
763+
702764 with db_con :
703765 with db_con .cursor () as db_cursor :
704766 db_cursor .execute (query )
@@ -784,6 +846,10 @@ def update_into_db(data_set):
784846
785847 # Get db_cursor and execute
786848 db_con = db_get_connection (session_name )
849+ if db_con == "zeuz_failed" :
850+ CommonUtil .ExecLog (sModuleInfo , "Failed to get database connection" , 3 )
851+ return "zeuz_failed"
852+
787853 with db_con :
788854 with db_con .cursor () as db_cursor :
789855 db_cursor .execute (query )
@@ -849,6 +915,10 @@ def db_non_query(data_set):
849915
850916 # Get db_cursor and execute
851917 db_con = db_get_connection (session_name )
918+ if db_con == "zeuz_failed" :
919+ CommonUtil .ExecLog (sModuleInfo , "Failed to get database connection" , 3 )
920+ return "zeuz_failed"
921+
852922 with db_con :
853923 with db_con .cursor () as db_cursor :
854924 db_cursor .execute (query )
0 commit comments