Skip to content

Commit 0a342e9

Browse files
committed
Share rest client
1 parent 6485a18 commit 0a342e9

5 files changed

Lines changed: 40 additions & 145 deletions

File tree

auth0/v3/authentication/base.py

Lines changed: 8 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import requests
77

8+
from auth0.v3.rest import RestClient, RestClientOptions
9+
810
from ..exceptions import Auth0Error, RateLimitError
911

1012
UNKNOWN_ERROR = "a0.sdk.internal.unknown"
@@ -24,132 +26,14 @@ class AuthenticationBase(object):
2426

2527
def __init__(self, domain, telemetry=True, timeout=5.0, protocol="https"):
2628
self.domain = domain
27-
self.timeout = timeout
2829
self.protocol = protocol
29-
self.base_headers = {"Content-Type": "application/json"}
30-
31-
if telemetry:
32-
py_version = platform.python_version()
33-
version = sys.modules["auth0"].__version__
34-
35-
auth0_client = json.dumps(
36-
{
37-
"name": "auth0-python",
38-
"version": version,
39-
"env": {
40-
"python": py_version,
41-
},
42-
}
43-
).encode("utf-8")
44-
45-
self.base_headers.update(
46-
{
47-
"User-Agent": "Python/{}".format(py_version),
48-
"Auth0-Client": base64.b64encode(auth0_client),
49-
}
50-
)
30+
self.client = RestClient(
31+
jwt=None,
32+
options=RestClientOptions(telemetry=telemetry, timeout=timeout, retries=0),
33+
)
5134

5235
def post(self, url, data=None, headers=None):
53-
request_headers = self.base_headers.copy()
54-
request_headers.update(headers or {})
55-
response = requests.post(
56-
url=url, json=data, headers=request_headers, timeout=self.timeout
57-
)
58-
return self._process_response(response)
36+
return self.client.post(url, data, headers)
5937

6038
def get(self, url, params=None, headers=None):
61-
request_headers = self.base_headers.copy()
62-
request_headers.update(headers or {})
63-
response = requests.get(
64-
url=url, params=params, headers=request_headers, timeout=self.timeout
65-
)
66-
return self._process_response(response)
67-
68-
def _process_response(self, response):
69-
return self._parse(response).content()
70-
71-
def _parse(self, response):
72-
if not response.text:
73-
return EmptyResponse(response.status_code)
74-
try:
75-
return JsonResponse(response)
76-
except ValueError:
77-
return PlainResponse(response)
78-
79-
80-
class Response(object):
81-
def __init__(self, status_code, content, headers):
82-
self._status_code = status_code
83-
self._content = content
84-
self._headers = headers
85-
86-
def content(self):
87-
if not self._is_error():
88-
return self._content
89-
90-
if self._status_code == 429:
91-
reset_at = int(self._headers.get("x-ratelimit-reset", "-1"))
92-
raise RateLimitError(
93-
error_code=self._error_code(),
94-
message=self._error_message(),
95-
reset_at=reset_at,
96-
)
97-
98-
raise Auth0Error(
99-
status_code=self._status_code,
100-
error_code=self._error_code(),
101-
message=self._error_message(),
102-
)
103-
104-
def _is_error(self):
105-
return self._status_code is None or self._status_code >= 400
106-
107-
# Adding these methods to force implementation in subclasses because they are references in this parent class
108-
def _error_code(self):
109-
raise NotImplementedError
110-
111-
def _error_message(self):
112-
raise NotImplementedError
113-
114-
115-
class JsonResponse(Response):
116-
def __init__(self, response):
117-
content = json.loads(response.text)
118-
super(JsonResponse, self).__init__(
119-
response.status_code, content, response.headers
120-
)
121-
122-
def _error_code(self):
123-
if "error" in self._content:
124-
return self._content.get("error")
125-
elif "code" in self._content:
126-
return self._content.get("code")
127-
else:
128-
return UNKNOWN_ERROR
129-
130-
def _error_message(self):
131-
return self._content.get("error_description", "")
132-
133-
134-
class PlainResponse(Response):
135-
def __init__(self, response):
136-
super(PlainResponse, self).__init__(
137-
response.status_code, response.text, response.headers
138-
)
139-
140-
def _error_code(self):
141-
return UNKNOWN_ERROR
142-
143-
def _error_message(self):
144-
return self._content
145-
146-
147-
class EmptyResponse(Response):
148-
def __init__(self, status_code):
149-
super(EmptyResponse, self).__init__(status_code, "", {})
150-
151-
def _error_code(self):
152-
return UNKNOWN_ERROR
153-
154-
def _error_message(self):
155-
return ""
39+
return self.client.get(url, params, headers)

auth0/v3/rest.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,12 @@ def __init__(self, jwt, telemetry=True, timeout=5.0, options=None):
7575
self._skip_sleep = False
7676

7777
self.base_headers = {
78-
"Authorization": "Bearer {}".format(self.jwt),
7978
"Content-Type": "application/json",
8079
}
8180

81+
if jwt is not None:
82+
self.base_headers["Authorization"] = "Bearer {}".format(self.jwt)
83+
8284
if options.telemetry:
8385
py_version = platform.python_version()
8486
version = sys.modules["auth0"].__version__
@@ -124,8 +126,9 @@ def MAX_REQUEST_RETRY_DELAY(self):
124126
def MIN_REQUEST_RETRY_DELAY(self):
125127
return 100
126128

127-
def get(self, url, params=None):
128-
headers = self.base_headers.copy()
129+
def get(self, url, params=None, headers=None):
130+
request_headers = self.base_headers.copy()
131+
request_headers.update(headers or {})
129132

130133
# Track the API request attempt number
131134
attempt = 0
@@ -139,7 +142,10 @@ def get(self, url, params=None):
139142

140143
# Issue the request
141144
response = requests.get(
142-
url, params=params, headers=headers, timeout=self.options.timeout
145+
url,
146+
params=params,
147+
headers=request_headers,
148+
timeout=self.options.timeout,
143149
)
144150

145151
# If the response did not have a 429 header, or the attempt number is greater than the configured retries, break
@@ -156,11 +162,12 @@ def get(self, url, params=None):
156162
# Return the final Response
157163
return self._process_response(response)
158164

159-
def post(self, url, data=None):
160-
headers = self.base_headers.copy()
165+
def post(self, url, data=None, headers=None):
166+
request_headers = self.base_headers.copy()
167+
request_headers.update(headers or {})
161168

162169
response = requests.post(
163-
url, json=data, headers=headers, timeout=self.options.timeout
170+
url, json=data, headers=request_headers, timeout=self.options.timeout
164171
)
165172
return self._process_response(response)
166173

@@ -281,10 +288,14 @@ def _error_code(self):
281288
return self._content.get("errorCode")
282289
elif "error" in self._content:
283290
return self._content.get("error")
291+
elif "code" in self._content:
292+
return self._content.get("code")
284293
else:
285294
return UNKNOWN_ERROR
286295

287296
def _error_message(self):
297+
if "error_description" in self._content:
298+
return self._content.get("error_description")
288299
message = self._content.get("message", "")
289300
if message is not None and message != "":
290301
return message

auth0/v3/test/authentication/test_base.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
class TestBase(unittest.TestCase):
1414
def test_telemetry_enabled_by_default(self):
1515
ab = AuthenticationBase("auth0.com")
16+
base_headers = ab.client.base_headers
1617

17-
user_agent = ab.base_headers["User-Agent"]
18-
auth0_client_bytes = base64.b64decode(ab.base_headers["Auth0-Client"])
18+
user_agent = base_headers["User-Agent"]
19+
auth0_client_bytes = base64.b64decode(base_headers["Auth0-Client"])
1920
auth0_client_json = auth0_client_bytes.decode("utf-8")
2021
auth0_client = json.loads(auth0_client_json)
21-
content_type = ab.base_headers["Content-Type"]
22+
content_type = base_headers["Content-Type"]
2223

2324
from auth0 import __version__ as auth0_version
2425

@@ -39,7 +40,7 @@ def test_telemetry_enabled_by_default(self):
3940
def test_telemetry_disabled(self):
4041
ab = AuthenticationBase("auth0.com", telemetry=False)
4142

42-
self.assertEqual(ab.base_headers, {"Content-Type": "application/json"})
43+
self.assertEqual(ab.client.base_headers, {"Content-Type": "application/json"})
4344

4445
@mock.patch("requests.post")
4546
def test_post(self, mock_post):
@@ -51,7 +52,7 @@ def test_post(self, mock_post):
5152
data = ab.post("the-url", data={"a": "b"}, headers={"c": "d"})
5253

5354
mock_post.assert_called_with(
54-
url="the-url",
55+
"the-url",
5556
json={"a": "b"},
5657
headers={"c": "d", "Content-Type": "application/json"},
5758
timeout=(10, 2),
@@ -70,7 +71,7 @@ def test_post_with_defaults(self, mock_post):
7071
data = ab.post("the-url")
7172

7273
mock_post.assert_called_with(
73-
url="the-url",
74+
"the-url",
7475
json=None,
7576
headers={"Content-Type": "application/json"},
7677
timeout=5.0,
@@ -88,8 +89,8 @@ def test_post_includes_telemetry(self, mock_post):
8889
data = ab.post("the-url", data={"a": "b"}, headers={"c": "d"})
8990

9091
self.assertEqual(mock_post.call_count, 1)
91-
call_kwargs = mock_post.call_args[1]
92-
self.assertEqual(call_kwargs["url"], "the-url")
92+
call_args, call_kwargs = mock_post.call_args
93+
self.assertEqual(call_args[0], "the-url")
9394
self.assertEqual(call_kwargs["json"], {"a": "b"})
9495
headers = call_kwargs["headers"]
9596
self.assertEqual(headers["c"], "d")
@@ -228,7 +229,7 @@ def test_get(self, mock_get):
228229
data = ab.get("the-url", params={"a": "b"}, headers={"c": "d"})
229230

230231
mock_get.assert_called_with(
231-
url="the-url",
232+
"the-url",
232233
params={"a": "b"},
233234
headers={"c": "d", "Content-Type": "application/json"},
234235
timeout=(10, 2),
@@ -247,7 +248,7 @@ def test_get_with_defaults(self, mock_get):
247248
data = ab.get("the-url")
248249

249250
mock_get.assert_called_with(
250-
url="the-url",
251+
"the-url",
251252
params=None,
252253
headers={"Content-Type": "application/json"},
253254
timeout=5.0,
@@ -265,8 +266,8 @@ def test_get_includes_telemetry(self, mock_get):
265266
data = ab.get("the-url", params={"a": "b"}, headers={"c": "d"})
266267

267268
self.assertEqual(mock_get.call_count, 1)
268-
call_kwargs = mock_get.call_args[1]
269-
self.assertEqual(call_kwargs["url"], "the-url")
269+
call_args, call_kwargs = mock_get.call_args
270+
self.assertEqual(call_args[0], "the-url")
270271
self.assertEqual(call_kwargs["params"], {"a": "b"})
271272
headers = call_kwargs["headers"]
272273
self.assertEqual(headers["c"], "d")

auth0/v3/test_async/test_asyncify.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import base64
32
import json
43
import platform
@@ -12,8 +11,8 @@
1211
from callee import Attrs
1312
from mock import ANY, MagicMock
1413

14+
from auth0.v3.asyncify import asyncify
1515
from auth0.v3.management import Clients, Guardian, Jobs
16-
from auth0.v3.management.asyncify import asyncify
1716

1817
clients = re.compile(r"^https://example\.com/api/v2/clients.*")
1918
factors = re.compile(r"^https://example\.com/api/v2/guardian/factors.*")

0 commit comments

Comments
 (0)