Skip to content

Commit 9d8e829

Browse files
Don't let .netrc take precedence in ApiClient (#526)
1 parent e2338b3 commit 9d8e829

2 files changed

Lines changed: 69 additions & 0 deletions

File tree

databricks_cli/sdk/api_client.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from . import version
4040

4141
from requests.adapters import HTTPAdapter
42+
from requests.utils import get_netrc_auth
43+
from requests.auth import HTTPBasicAuth
4244
from six.moves.urllib.parse import urlparse
4345

4446
try:
@@ -63,6 +65,30 @@ class TlsV1HttpAdapter(HTTPAdapter):
6365
def init_poolmanager(self, connections, maxsize, block=False):
6466
self.poolmanager = PoolManager(num_pools=connections, maxsize=maxsize, block=block, ssl_version=ssl.PROTOCOL_TLSv1_2)
6567

68+
# https://github.com/psf/requests/issues/2773#issuecomment-174312831
69+
class FallbackNetrcAuth(requests.auth.AuthBase):
70+
'''Force requests to ignore the ``.netrc`` if other authentication
71+
methods have been setup. Fallback to ``.netrc`` if not.
72+
73+
Use with::
74+
75+
requests.get(url, auth=FallbackNetrcAuth())
76+
77+
s = requests.Session()
78+
s.auth = FallbackNetrcAuth()
79+
'''
80+
81+
def __call__(self, r):
82+
if "Authorization" in r.headers:
83+
return r
84+
85+
netrc_tuple = get_netrc_auth(r.url)
86+
87+
if netrc_tuple is None or not any(netrc_tuple):
88+
return r
89+
90+
return HTTPBasicAuth(*netrc_tuple)(r)
91+
6692
class ApiClient(object):
6793
"""
6894
A partial Python implementation of dbc rest api
@@ -82,6 +108,7 @@ def __init__(self, user=None, password=None, host=None, token=None,
82108
raise_on_status=False # return original response when retries have been exhausted
83109
)
84110
self.session = requests.Session()
111+
self.session.auth = FallbackNetrcAuth()
85112
self.session.mount('https://', TlsV1HttpAdapter(max_retries=retries))
86113

87114
parsed_url = urlparse(host)

tests/sdk/test_api_client.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
# See the License for the specific language governing permissions and
2222
# limitations under the License.
2323
import json
24+
import os
2425

2526
import pytest
2627
import requests
@@ -44,6 +45,12 @@ def m():
4445
with requests_mock.Mocker() as m:
4546
yield m
4647

48+
@pytest.fixture(autouse=False)
49+
def netrc_file(tmp_path, monkeypatch):
50+
netrc_file_path = os.path.join(str(tmp_path), ".netrc")
51+
monkeypatch.setenv("NETRC", netrc_file_path)
52+
return netrc_file_path
53+
4754
def test_simple_request(m):
4855
data = {'cucumber': 'dade'}
4956
m.get('https://databricks.com/api/2.0/endpoint', text=json.dumps(data))
@@ -130,3 +137,38 @@ def test_api_client_url_parsing():
130137
# databricks_cli.configure.cli
131138
client = ApiClient(host='http://databricks.com')
132139
assert client.get_url('') == 'http://databricks.com/api/2.0'
140+
141+
def test_api_client_auth_netrc_and_user_password(m, netrc_file):
142+
with open(netrc_file, "w+") as netrc:
143+
#generates header Authorization: 'Basic bmV0cmM6cGFzc3dvcmQ='
144+
netrc.write("machine databricks.com login netrc password password")
145+
146+
m.get('https://databricks.com/api/2.0/endpoint', text=json.dumps({}))
147+
client = ApiClient(user="apple", password="banana", host="https://databricks.com")
148+
client.perform_query("GET", "/endpoint")
149+
assert m.request_history[0].headers['Authorization'] == "Basic YXBwbGU6YmFuYW5h"
150+
151+
def test_api_client_auth_only_valid_netrc(m, netrc_file):
152+
with open(netrc_file, "w+") as netrc:
153+
#generates header Authorization: 'Basic bmV0cmM6cGFzc3dvcmQ='
154+
netrc.write("machine databricks.com login netrc password password")
155+
156+
m.get('https://databricks.com/api/2.0/endpoint', text=json.dumps({}))
157+
client = ApiClient(host="https://databricks.com")
158+
client.perform_query("GET", "/endpoint")
159+
assert m.request_history[0].headers['Authorization'] == "Basic bmV0cmM6cGFzc3dvcmQ="
160+
161+
def test_api_client_auth_invalid_netrc(m, netrc_file):
162+
with open(netrc_file, "w+") as netrc:
163+
netrc.write("garbage")
164+
165+
m.get('https://databricks.com/api/2.0/endpoint', text=json.dumps({}))
166+
client = ApiClient(host="https://databricks.com")
167+
client.perform_query("GET", "/endpoint")
168+
assert "Authorization" not in m.request_history[0].headers
169+
170+
def test_api_client_auth_no_netrc(m):
171+
m.get('https://databricks.com/api/2.0/endpoint', text=json.dumps({}))
172+
client = ApiClient(host="https://databricks.com")
173+
client.perform_query("GET", "/endpoint")
174+
assert "Authorization" not in m.request_history[0].headers

0 commit comments

Comments
 (0)