Skip to content

Commit bbc0ac3

Browse files
rhamel3drycwo
authored andcommitted
Add AuthProvider construction to MetafoldClient
1 parent 8cf93e7 commit bbc0ac3

3 files changed

Lines changed: 46 additions & 7 deletions

File tree

metafold/__init__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from metafold.projects import ProjectsEndpoint
33
from metafold.assets import AssetsEndpoint
44
from metafold.jobs import JobsEndpoint
5+
from metafold.auth import AuthProvider
56
from typing import Optional
67

78

@@ -18,8 +19,12 @@ class MetafoldClient(Client):
1819
jobs: JobsEndpoint
1920

2021
def __init__(
21-
self, access_token: str,
22+
self,
23+
access_token: Optional[str] = None,
2224
project_id: Optional[str] = None,
25+
client_id: Optional[str] = None,
26+
client_secret: Optional[str] = None,
27+
auth_domain: str = "metafold3d.us.auth0.com",
2328
base_url: str = "https://api.metafold3d.com",
2429
) -> None:
2530
"""Initialize Metafold API client.
@@ -29,7 +34,17 @@ def __init__(
2934
project_id: ID of the project to make API calls against.
3035
base_url: Metafold API URL. Used for internal testing.
3136
"""
32-
super().__init__(access_token, base_url, project_id=project_id)
37+
# client_id and client_secret have priority
38+
if not any([client_id and client_secret, access_token]):
39+
raise ValueError(
40+
"Expected client_id and client_secret or access_token to be provided"
41+
)
42+
elif client_id and client_secret:
43+
auth = AuthProvider(client_id, client_secret, auth_domain, base_url)
44+
super().__init__(base_url, auth=auth, project_id=project_id)
45+
else:
46+
super().__init__(base_url, access_token=access_token, project_id=project_id)
47+
3348
self.projects = ProjectsEndpoint(self)
3449
self.assets = AssetsEndpoint(self)
3550
self.jobs = JobsEndpoint(self)

metafold/auth.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
from auth0.authentication import GetToken
22
from collections import namedtuple
33
from datetime import datetime, timedelta, timezone
4+
from typing import Optional
45

56
Token = namedtuple("Token", ["access_token", "expires_at"])
67

78

89
class AuthProvider:
9-
def __init__(self, auth_domain: str, client_id: str, client_secret: str):
10+
def __init__(
11+
self,
12+
client_id: str,
13+
client_secret: str,
14+
auth_domain: str,
15+
base_url: str
16+
) -> None:
1017
self._auth_domain = auth_domain
18+
self._base_url = base_url
1119
self._client_id = client_id
1220

1321
self._get_token = GetToken(
@@ -20,7 +28,7 @@ def __init__(self, auth_domain: str, client_id: str, client_secret: str):
2028
def get_token(self) -> str:
2129
now = datetime.now(timezone.utc)
2230
if not self._token or self._token.expires_at - now < timedelta(minutes=1):
23-
token = get_token.client_credentials(self._base_url)
31+
token = self._get_token.client_credentials(self._base_url)
2432
expires_at = now + timedelta(seconds=token["expires_in"])
2533
self._token = Token(token["access_token"], expires_at)
2634
return self._token.access_token

metafold/client.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from metafold.auth import AuthProvider
12
from requests import HTTPError, Response, Session
23
from typing import Any, Callable, Optional
34
from urllib.parse import urljoin
@@ -8,17 +9,29 @@ class Client:
89
"""Base client."""
910

1011
def __init__(
11-
self, access_token: str, base_url: str,
12+
self,
13+
base_url: str,
14+
access_token: Optional[str] = None,
1215
project_id: Optional[str] = None,
16+
auth: Optional[AuthProvider] = None
1317
) -> None:
18+
if bool(auth) == bool(access_token):
19+
raise ValueError(
20+
"Expected AuthProvider or access_token to be provided"
21+
)
22+
self._auth = auth
1423
self._default_project = project_id
1524
self._base_url = base_url
1625
self._session = Session()
1726
self._session.headers.update({
1827
"Accept": "application/json",
19-
"Authorization": f"Bearer {access_token}",
2028
"User-Agent": f"Python/{platform.python_version()}",
2129
})
30+
if access_token:
31+
self._session.headers.update({
32+
"Authorization": f"Bearer {access_token}",
33+
})
34+
2235

2336
def project_id(self, id: Optional[str] = None) -> str:
2437
id = id or self._default_project
@@ -33,7 +46,10 @@ def _request(
3346
*args: Any, **kwargs: Any,
3447
) -> Response:
3548
url = urljoin(self._base_url, url)
36-
r: Response = request(url, *args, **kwargs)
49+
headers = None
50+
if self._auth:
51+
headers = {"Authorization": f"Bearer {self._auth.get_token()}"}
52+
r: Response = request(url, *args, **kwargs, headers=headers)
3753
if not r.ok:
3854
body: dict[str, Any] = r.json()
3955
# Error responses aren't entirely consistent in the Metafold API,

0 commit comments

Comments
 (0)