Skip to content

Commit ad8b676

Browse files
Allow connecting via private key or password
1 parent fb65e44 commit ad8b676

2 files changed

Lines changed: 81 additions & 24 deletions

File tree

src/nypl_py_utils/classes/snowflake_client.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,71 @@
66
class SnowflakeClient:
77
"""Client for managing connections to Snowflake"""
88

9-
def __init__(self, account, user, password, warehouse=None):
9+
def __init__(self, account, user, private_key=None, password=None):
1010
self.logger = create_log('snowflake_client')
11+
if (password is None) == (private_key is None):
12+
raise SnowflakeClientError(
13+
'Either password or private key must be set (but not both)',
14+
self.logger
15+
) from None
16+
1117
self.conn = None
1218
self.account = account
1319
self.user = user
20+
self.private_key = private_key
1421
self.password = password
15-
self.warehouse = warehouse
1622

17-
def connect(self, **kwargs):
23+
def connect(self, mfa_code=None, **kwargs):
1824
"""
19-
Connects to a Snowflake database using the given credentials. If
20-
warehouse parameter is None, uses the default warehouse for the user.
25+
Connects to Snowflake using the given credentials. If you're connecting
26+
locally, you should be using the password and mfa_code. If the
27+
connection is for production code, a private_key should be set up.
2128
2229
Parameters
2330
----------
31+
mfa_code: str, optional
32+
The six-digit MFA code. Only necessary for connecting as a human
33+
user.
2434
kwargs:
25-
All possible arguments (such as timeouts) can be found here:
35+
All possible arguments (such as which warehouse to use or how
36+
long to wait before timing out) can be found here:
2637
https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api#connect
2738
"""
2839
self.logger.info('Connecting to Snowflake')
29-
try:
30-
self.conn = sc.connect(
31-
account=self.account,
32-
user=self.user,
33-
password=self.password,
34-
warehouse=self.warehouse,
35-
**kwargs)
36-
except Exception as e:
37-
raise SnowflakeClientError(
38-
f'Error connecting to Snowflake: {e}') from None
40+
if self.private_key is not None:
41+
try:
42+
self.conn = sc.connect(
43+
account=self.account,
44+
user=self.user,
45+
private_key=self.private_key,
46+
**kwargs)
47+
except Exception as e:
48+
raise SnowflakeClientError(
49+
f'Error connecting to Snowflake: {e}', self.logger
50+
) from None
51+
else:
52+
if mfa_code is None:
53+
raise SnowflakeClientError(
54+
'When using a password, an MFA code must also be provided',
55+
self.logger
56+
) from None
57+
58+
pw = self.password + mfa_code
59+
try:
60+
self.conn = sc.connect(
61+
account=self.account,
62+
user=self.user,
63+
password=pw,
64+
passcode_in_password=True,
65+
**kwargs)
66+
except Exception as e:
67+
raise SnowflakeClientError(
68+
f'Error connecting to Snowflake: {e}', self.logger
69+
) from None
3970

4071
def execute_query(self, query, **kwargs):
4172
"""
42-
Executes an arbitrary query against the given database connection.
73+
Executes an arbitrary query against the given connection.
4374
4475
Note that:
4576
1) All results will be fetched by default, so this method is not
@@ -53,6 +84,8 @@ def execute_query(self, query, **kwargs):
5384
5485
Parameters
5586
----------
87+
query: str
88+
The SQL query to execute
5689
kwargs:
5790
All possible arguments (such as timeouts) can be found here:
5891
https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api#execute
@@ -62,7 +95,7 @@ def execute_query(self, query, **kwargs):
6295
sequence
6396
A list of tuples
6497
"""
65-
self.logger.info('Querying database')
98+
self.logger.info('Querying Snowflake')
6699
cursor = self.conn.cursor()
67100
try:
68101
try:
@@ -73,12 +106,12 @@ def execute_query(self, query, **kwargs):
73106
finally:
74107
cursor.close()
75108
except Exception as e:
76-
# If there was an error, also close the database connection
109+
# If there was an error, also close the connection
77110
self.close_connection()
78111

79112
short_q = str(query)
80113
if len(short_q) > 2500:
81-
short_q = short_q[:2497] + "..."
114+
short_q = short_q[:2497] + '...'
82115
raise SnowflakeClientError(
83116
f'Error executing Snowflake query {short_q}: {e}', self.logger
84117
) from None

tests/test_snowflake_client.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,39 @@ def mock_snowflake_conn(self, mocker):
1313

1414
@pytest.fixture
1515
def test_instance(self):
16-
return SnowflakeClient('test_account', 'test_user', 'test_password')
16+
return SnowflakeClient(
17+
'test_account', 'test_user', private_key='test_pk')
1718

18-
def test_connect(self, mock_snowflake_conn, test_instance):
19+
def test_init_no_pw(self):
20+
with pytest.raises(SnowflakeClientError):
21+
SnowflakeClient('test_account', 'test_user')
22+
23+
def test_init_multiple_pw(self):
24+
with pytest.raises(SnowflakeClientError):
25+
SnowflakeClient('test_account', 'test_user', 'test_pk', 'test_pw')
26+
27+
def test_connect_with_pk(self, mock_snowflake_conn, test_instance):
1928
test_instance.connect()
2029
mock_snowflake_conn.assert_called_once_with(
2130
account='test_account',
2231
user='test_user',
23-
password='test_password',
24-
warehouse=None)
32+
private_key='test_pk')
33+
34+
def test_connect_with_pw(self, mock_snowflake_conn):
35+
test_instance = SnowflakeClient(
36+
'test_account', 'test_user', password='test_pw')
37+
test_instance.connect('123456')
38+
mock_snowflake_conn.assert_called_once_with(
39+
account='test_account',
40+
user='test_user',
41+
password='test_pw123456',
42+
passcode_in_password=True)
43+
44+
def test_connect_no_mfa(self, mock_snowflake_conn):
45+
test_instance = SnowflakeClient(
46+
'test_account', 'test_user', password='test_pw')
47+
with pytest.raises(SnowflakeClientError):
48+
test_instance.connect()
2549

2650
def test_execute_query(
2751
self, mock_snowflake_conn, test_instance, mocker):

0 commit comments

Comments
 (0)