Skip to content

Commit 167bfca

Browse files
authored
Merge pull request #621 from AutomationSolutionz/task-1433-snowflake-db-integration
Snowflake DB support in zeuz node
2 parents 5172332 + 84948e0 commit 167bfca

3 files changed

Lines changed: 319 additions & 96 deletions

File tree

Framework/Built_In_Automation/Database/BuiltInFunctions.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
DB_ODBC_DRIVER = "odbc_driver"
3535
DB_SESSION = "session"
3636
DB_ODBC_UTF8 = "odbc: enable utf-8 encoding"
37+
DB_WAREHOUSE = "warehouse"
38+
DB_SCHEMA = "schema"
39+
DB_ACCOUNT = "account"
3740

3841

3942
# [NON ACTION]
@@ -102,6 +105,16 @@ def find_odbc_driver(db_type="postgresql"):
102105

103106
def handle_db_exception(sModuleInfo, e):
104107
import pyodbc
108+
109+
# Handle Snowflake exceptions
110+
try:
111+
import snowflake.connector.errors as snowflake_errors
112+
if isinstance(e, snowflake_errors.Error):
113+
traceback.print_exc()
114+
CommonUtil.ExecLog(sModuleInfo, f"Snowflake Error: {e}", 3)
115+
return CommonUtil.Exception_Handler(e)
116+
except ImportError:
117+
pass # Snowflake connector not installed
105118

106119
if isinstance(e, pyodbc.DataError):
107120
traceback.print_exc()
@@ -243,6 +256,26 @@ def db_get_connection(session_name):
243256
host=db_host,
244257
port=db_port
245258
)
259+
elif "snowflake" in db_type:
260+
import snowflake.connector
261+
262+
# Get Snowflake-specific parameters
263+
account = db_params.get(DB_ACCOUNT)
264+
if not account:
265+
account = db_host.replace('.snowflakecomputing.com', '') if '.snowflakecomputing.com' in db_host else db_host
266+
warehouse = db_params.get(DB_WAREHOUSE) or 'COMPUTE_WH'
267+
schema = db_params.get(DB_SCHEMA) or 'PUBLIC'
268+
269+
# Connect to Snowflake
270+
db_con = snowflake.connector.connect(
271+
user=db_user_id,
272+
password=db_password,
273+
account=account,
274+
database=db_name,
275+
warehouse=warehouse,
276+
schema=schema
277+
)
278+
CommonUtil.ExecLog(sModuleInfo, "Connected to Snowflake.", 1)
246279
elif "oracle" in db_type:
247280
import cx_Oracle
248281

@@ -303,14 +336,17 @@ def connect_to_db(data_set):
303336
This action just stores the different database specific configs into shared variables for use by other actions.
304337
NOTE: The actual db connection does not happen here, connection to db is made inside the actions which require it.
305338
306-
db_type input parameter <type of db, ex: postgres, mysql>
339+
db_type input parameter <type of db, ex: postgres, mysql, snowflake>
307340
db_name input parameter <name of db, ex: zeuz_db>
308341
db_user_id input parameter <user id of the os who have access to the db, ex: postgres>
309342
db_password input parameter <password of db, ex: mydbpass-mY1-t23z>
310343
db_host input parameter <host of db, ex: localhost, 127.0.0.1>
311344
db_port input parameter <port of db, ex: 5432 for postgres by default>
312345
sid optional parameter <sid of db, ex: 15321 for oracle by default>
313346
service_name optional parameter <service_name of db, ex: 'somename' for oracle by default>
347+
warehouse optional parameter <warehouse for Snowflake, ex: COMPUTE_WH>
348+
schema optional parameter <schema for Snowflake, ex: PUBLIC>
349+
account optional parameter <account identifier for Snowflake>
314350
odbc_driver optional parameter <specify the odbc driver, optional, can be found from pyodbc.drivers()>
315351
odbc: enable utf-8 encoding optional parameter true/false - optionally enable utf-8 encoding
316352
connect to db database action Connect to a database
@@ -324,6 +360,7 @@ def connect_to_db(data_set):
324360
try:
325361
# Default values
326362
db_type = db_name = db_user_id = db_password = db_host = db_port = db_sid = db_service_name = db_odbc_driver = db_params = None
363+
db_warehouse = db_schema = db_account = None
327364
db_enable_odbc_utf8 = True
328365
session_name = "default"
329366

@@ -349,6 +386,12 @@ def connect_to_db(data_set):
349386
sr.Set_Shared_Variables(DB_ODBC_DRIVER,right.strip())
350387
if left == DB_ODBC_UTF8:
351388
db_enable_odbc_utf8 = CommonUtil.parse_value_into_object(right.strip()) == True
389+
if left == DB_WAREHOUSE or left == "warehouse":
390+
db_warehouse = right.strip()
391+
if left == DB_SCHEMA:
392+
db_schema = right.strip()
393+
if left == DB_ACCOUNT:
394+
db_account = right.strip()
352395
if DB_SESSION in left:
353396
session_name = right.strip()
354397

@@ -363,6 +406,9 @@ def connect_to_db(data_set):
363406
DB_SERVICE_NAME: db_service_name,
364407
DB_ODBC_DRIVER: db_odbc_driver,
365408
DB_ODBC_UTF8: db_enable_odbc_utf8,
409+
DB_WAREHOUSE: db_warehouse,
410+
DB_SCHEMA: db_schema,
411+
DB_ACCOUNT: db_account,
366412
}
367413

368414
if sr.Test_Shared_Variables('db_sessions'):
@@ -426,6 +472,10 @@ def db_select(data_set):
426472

427473
# Get db_cursor and execute
428474
db_con = db_get_connection(session_name)
475+
if db_con == "zeuz_failed":
476+
CommonUtil.ExecLog(sModuleInfo, "Failed to get database connection", 3)
477+
return "zeuz_failed"
478+
429479
with db_con:
430480
with db_con.cursor() as db_cursor:
431481
db_cursor.execute(query)
@@ -539,6 +589,10 @@ def select_from_db(data_set):
539589

540590
# Get db_cursor and execute
541591
db_con = db_get_connection(session_name)
592+
if db_con == "zeuz_failed":
593+
CommonUtil.ExecLog(sModuleInfo, "Failed to get database connection", 3)
594+
return "zeuz_failed"
595+
542596
with db_con:
543597
with db_con.cursor() as db_cursor:
544598
db_cursor.execute(query)
@@ -627,6 +681,10 @@ def insert_into_db(data_set):
627681
CommonUtil.ExecLog(sModuleInfo, "Executing query:\n%s." % query, 1)
628682

629683
db_con = db_get_connection(session_name)
684+
if db_con == "zeuz_failed":
685+
CommonUtil.ExecLog(sModuleInfo, "Failed to get database connection", 3)
686+
return "zeuz_failed"
687+
630688
with db_con:
631689
with db_con.cursor() as db_cursor:
632690
db_cursor.execute(query)
@@ -699,6 +757,10 @@ def delete_from_db(data_set):
699757

700758
# Get db_cursor and execute
701759
db_con = db_get_connection(session_name)
760+
if db_con == "zeuz_failed":
761+
CommonUtil.ExecLog(sModuleInfo, "Failed to get database connection", 3)
762+
return "zeuz_failed"
763+
702764
with db_con:
703765
with db_con.cursor() as db_cursor:
704766
db_cursor.execute(query)
@@ -784,6 +846,10 @@ def update_into_db(data_set):
784846

785847
# Get db_cursor and execute
786848
db_con = db_get_connection(session_name)
849+
if db_con == "zeuz_failed":
850+
CommonUtil.ExecLog(sModuleInfo, "Failed to get database connection", 3)
851+
return "zeuz_failed"
852+
787853
with db_con:
788854
with db_con.cursor() as db_cursor:
789855
db_cursor.execute(query)
@@ -849,6 +915,10 @@ def db_non_query(data_set):
849915

850916
# Get db_cursor and execute
851917
db_con = db_get_connection(session_name)
918+
if db_con == "zeuz_failed":
919+
CommonUtil.ExecLog(sModuleInfo, "Failed to get database connection", 3)
920+
return "zeuz_failed"
921+
852922
with db_con:
853923
with db_con.cursor() as db_cursor:
854924
db_cursor.execute(query)

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ dependencies = [
9494
"xvfbwrapper>=0.2.9 ; sys_platform == 'linux'",
9595
"pyodbc>=5.2.0",
9696
"psycopg2-binary>=2.9.10",
97-
"cryptography==42.0.8",
97+
"cryptography>=42.0.8",
98+
"snowflake-connector-python>=3.12.0",
99+
"pyopenssl>=23.0.0",
98100
"pipdeptree>=2.26.1",
99101
"axe-selenium-python>=2.1.6",
100102
"filelock>=3.20.0",

0 commit comments

Comments
 (0)