Skip to content

Commit 8b7973e

Browse files
committed
feat(project): allow setting auth provider
1 parent 9965124 commit 8b7973e

10 files changed

Lines changed: 438 additions & 190 deletions

File tree

libsimba/auth/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from abc import ABC, abstractmethod
2424
from typing import Any, Dict, Optional
2525

26-
from libsimba.schemas import AuthProviderName, AuthToken, ConnectionConfig
26+
from libsimba.schemas import AuthProviderName, AuthToken, ConnectionConfig, Login
2727

2828

2929
logger = logging.getLogger(__name__)
@@ -37,8 +37,7 @@ def provider(self) -> AuthProviderName:
3737
@abstractmethod
3838
async def login(
3939
self,
40-
client_id: str,
41-
client_secret: str,
40+
login: Login,
4241
headers: Dict[str, Any],
4342
config: ConnectionConfig = None,
4443
) -> Optional[AuthToken]:
@@ -47,8 +46,7 @@ async def login(
4746
@abstractmethod
4847
def login_sync(
4948
self,
50-
client_id: str,
51-
client_secret: str,
49+
login: Login,
5250
headers: Dict[str, Any],
5351
config: ConnectionConfig = None,
5452
) -> Optional[AuthToken]:

libsimba/auth/apikey.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

libsimba/auth/client_credentials.py

Lines changed: 104 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
import logging
2222
import os
2323

24-
from datetime import datetime, timedelta
24+
from datetime import datetime, timedelta, timezone
2525
from typing import Any, Dict, Optional, Tuple
2626

2727
from httpx import BasicAuth
2828
from libsimba.auth import AuthProvider
2929
from libsimba.config import settings
30-
from libsimba.schemas import AuthProviderName, AuthToken, ConnectionConfig
30+
from libsimba.schemas import AuthProviderName, AuthToken, ConnectionConfig, Login
3131
from libsimba.utils import Path, async_http_client, build_url, http_client
3232

3333

@@ -44,14 +44,21 @@ def __init__(self, do_init: bool = True):
4444
self.registry: Dict[AuthProviderName, AuthProvider] = {}
4545
ad = BlocksAuthProvider()
4646
kc = KcAuthProvider()
47+
pc = PlatformAuthProvider()
4748
self.registry[ad.provider()] = ad
4849
self.registry[kc.provider()] = kc
50+
self.registry[pc.provider()] = pc
4951

50-
def do_login(self, client_id: str) -> Tuple[Optional[AuthToken], AuthProvider]:
51-
provider = self.registry.get(settings().AUTH_PROVIDER)
52+
def do_login(
53+
self,
54+
client_id: str,
55+
auth_provider: Optional[AuthProviderName] = None
56+
) -> Tuple[Optional[AuthToken], AuthProvider]:
57+
provider_name = auth_provider or settings().AUTH_PROVIDER
58+
provider = self.registry.get(provider_name)
5259
if not provider:
5360
raise ValueError(
54-
f"No provider found for provider type: {settings().AUTH_PROVIDER}"
61+
f"No provider found for provider type: {provider}"
5562
)
5663
token = self.get_cached_token(client_id=client_id)
5764
return token, provider
@@ -61,42 +68,38 @@ def add_header(self, token: AuthToken, headers: Dict[str, Any]) -> None:
6168

6269
def login_sync(
6370
self,
64-
client_id: str,
65-
client_secret: str,
71+
login: Login,
6672
headers: Dict[str, Any],
6773
config: ConnectionConfig = None,
6874
) -> Optional[AuthToken]:
6975
if not headers.get(self.header):
70-
token, provider = self.do_login(client_id=client_id)
76+
token, provider = self.do_login(client_id=login.client_id, auth_provider=login.provider)
7177
if not token:
7278
token = provider.login_sync(
73-
client_id=client_id,
74-
client_secret=client_secret,
79+
login=login,
7580
headers=headers,
7681
config=config,
7782
)
7883
self.add_header(token=token, headers=headers)
79-
self.cache_token(client_id=client_id, token=token)
84+
self.cache_token(client_id=login.client_id, token=token)
8085
return token
8186

8287
async def login(
8388
self,
84-
client_id: str,
85-
client_secret: str,
89+
login: Login,
8690
headers: Dict[str, Any],
8791
config: ConnectionConfig = None,
8892
) -> Optional[AuthToken]:
8993
if not headers.get(self.header):
90-
token, provider = self.do_login(client_id=client_id)
94+
token, provider = self.do_login(client_id=login.client_id, auth_provider=login.provider)
9195
if not token:
9296
token = await provider.login(
93-
client_id=client_id,
94-
client_secret=client_secret,
97+
login=login,
9598
headers=headers,
9699
config=config,
97100
)
98101
self.add_header(token=token, headers=headers)
99-
self.cache_token(client_id=client_id, token=token)
102+
self.cache_token(client_id=login.client_id, token=token)
100103
return token
101104

102105
def token_expired(self, token: AuthToken, offset: int = 60) -> bool:
@@ -109,7 +112,7 @@ def token_expired(self, token: AuthToken, offset: int = 60) -> bool:
109112
:return:
110113
"""
111114

112-
now_w_offset = datetime.utcnow() + timedelta(seconds=offset)
115+
now_w_offset = datetime.now(tz=timezone.utc) + timedelta(seconds=offset)
113116
expiry = token.expires
114117
if now_w_offset >= expiry:
115118
logger.debug(
@@ -231,19 +234,18 @@ def provider(self) -> AuthProviderName:
231234

232235
async def login_sync(
233236
self,
234-
client_id: str,
235-
client_secret: str,
237+
login: Login,
236238
headers: Dict[str, Any],
237239
config: ConnectionConfig = None,
238240
) -> Optional[AuthToken]:
239241
data = {
240-
"client_id": client_id,
241-
"client_secret": client_secret,
242+
"client_id": login.client_id,
243+
"client_secret": login.client_secret,
242244
"grant_type": "client_credentials",
243245
"scope": settings().AUTH_SCOPE or "email profile roles web-origins",
244246
}
245247
sso_host = "{}/auth/realms/{}/protocol/openid-connect/token".format(
246-
settings().BASE_AUTH_URL, settings().AUTH_REALM
248+
settings().AUTH_BASE_URL, settings().AUTH_REALM
247249
)
248250
with http_client(config=config) as client:
249251
r = client.post(
@@ -258,7 +260,7 @@ async def login_sync(
258260
"token": resp["access_token"],
259261
"type": resp["token_type"],
260262
"expires": (
261-
datetime.utcnow() + timedelta(seconds=int(resp["expires_in"]))
263+
datetime.now(tz=timezone.utc) + timedelta(seconds=int(resp["expires_in"]))
262264
),
263265
}
264266
return AuthToken(**data)
@@ -268,20 +270,19 @@ async def login_sync(
268270

269271
async def login(
270272
self,
271-
client_id: str,
272-
client_secret: str,
273+
login: Login,
273274
headers: Dict[str, Any],
274275
config: ConnectionConfig = None,
275276
) -> Optional[AuthToken]:
276277
data = {
277-
"client_id": client_id,
278-
"client_secret": client_secret,
278+
"client_id": login.client_id,
279+
"client_secret": login.client_secret,
279280
"grant_type": "client_credentials",
280281
"scope": settings().SCOPE or "email profile roles web-origins",
281282
}
282283
try:
283284
sso_host = "{}/auth/realms/{}/protocol/openid-connect/token".format(
284-
settings().AUTH_BASE_URL, settings().AUTH_REALM_ID
285+
settings().AUTH_BASE_URL, settings().AUTH_REALM
285286
)
286287
async with async_http_client(config=config) as client:
287288
r = await client.post(
@@ -295,7 +296,7 @@ async def login(
295296
"token": resp["access_token"],
296297
"type": resp["token_type"],
297298
"expires": (
298-
datetime.utcnow() + timedelta(seconds=int(resp["expires_in"]))
299+
datetime.now(tz=timezone.utc) + timedelta(seconds=int(resp["expires_in"]))
299300
),
300301
}
301302
return AuthToken(**data)
@@ -314,13 +315,12 @@ def provider(self) -> AuthProviderName:
314315

315316
async def login(
316317
self,
317-
client_id: str,
318-
client_secret: str,
318+
login: Login,
319319
headers: Dict[str, Any],
320320
config: ConnectionConfig = None,
321321
) -> Optional[AuthToken]:
322322
try:
323-
auth = BasicAuth(client_id, client_secret)
323+
auth = BasicAuth(login.client_id, login.client_secret)
324324
data = {"grant_type": "client_credentials"}
325325
async with async_http_client(config=config) as client:
326326
token_response = await client.post(
@@ -334,7 +334,7 @@ async def login(
334334
"token": resp["access_token"],
335335
"type": resp["token_type"],
336336
"expires": (
337-
datetime.utcnow() + timedelta(seconds=int(resp["expires_in"]))
337+
datetime.now(tz=timezone.utc) + timedelta(seconds=int(resp["expires_in"]))
338338
),
339339
}
340340
return AuthToken(**data)
@@ -344,13 +344,12 @@ async def login(
344344

345345
def login_sync(
346346
self,
347-
client_id: str,
348-
client_secret: str,
347+
login: Login,
349348
headers: Dict[str, Any],
350349
config: ConnectionConfig = None,
351350
) -> Optional[AuthToken]:
352351
try:
353-
auth = BasicAuth(client_id, client_secret)
352+
auth = BasicAuth(login.client_id, login.client_secret)
354353
data = {"grant_type": "client_credentials"}
355354
with http_client(config=config) as client:
356355
token_response = client.post(
@@ -364,10 +363,77 @@ def login_sync(
364363
"token": resp["access_token"],
365364
"type": resp["token_type"],
366365
"expires": (
367-
datetime.utcnow() + timedelta(seconds=int(resp["expires_in"]))
366+
datetime.now(tz=timezone.utc) + timedelta(seconds=int(resp["expires_in"]))
368367
),
369368
}
370369
return AuthToken(**data)
371370
except Exception as e:
372371
logger.warning("[BlocksAuthProvider] :: Error fetching token: {}".format(e))
373372
raise e
373+
374+
375+
class PlatformAuthProvider(ClientCredentials):
376+
377+
def __init__(self):
378+
super().__init__(do_init=False)
379+
380+
def provider(self) -> AuthProviderName:
381+
return AuthProviderName.PLAT
382+
383+
async def login(
384+
self,
385+
login: Login,
386+
headers: Dict[str, Any],
387+
config: ConnectionConfig = None,
388+
) -> Optional[AuthToken]:
389+
try:
390+
data = {
391+
"grant_type": "client_credentials",
392+
"client_id": login.client_id,
393+
"client_secret": login.client_secret
394+
}
395+
async with async_http_client(config=config) as client:
396+
token_response = await client.post(
397+
f"{settings().AUTH_BASE_URL}/oauth/token",
398+
data=data,
399+
)
400+
token_response.raise_for_status()
401+
resp = token_response.json()
402+
data = {
403+
"token": resp["access_token"],
404+
"type": resp["token_type"],
405+
"expires": datetime.fromtimestamp(resp["expires_at"], tz=timezone.utc),
406+
}
407+
return AuthToken(**data)
408+
except Exception as e:
409+
logger.warning("[PlatformAuthProvider] :: Error fetching token: {}".format(e))
410+
raise e
411+
412+
def login_sync(
413+
self,
414+
login: Login,
415+
headers: Dict[str, Any],
416+
config: ConnectionConfig = None,
417+
) -> Optional[AuthToken]:
418+
try:
419+
data = {
420+
"grant_type": "client_credentials",
421+
"client_id": login.client_id,
422+
"client_secret": login.client_secret
423+
}
424+
with http_client(config=config) as client:
425+
token_response = client.post(
426+
f"{settings().AUTH_BASE_URL}/oauth/token",
427+
data=data,
428+
)
429+
token_response.raise_for_status()
430+
resp = token_response.json()
431+
data = {
432+
"token": resp["access_token"],
433+
"type": resp["token_type"],
434+
"expires": datetime.fromtimestamp(resp["expires_at"], tz=timezone.utc),
435+
}
436+
return AuthToken(**data)
437+
except Exception as e:
438+
logger.warning("[PlatformAuthProvider] :: Error fetching token: {}".format(e))
439+
raise e

libsimba/schemas.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import mimetypes
2222

23-
from datetime import datetime
23+
from datetime import datetime, timezone
2424
from enum import Enum
2525
from pathlib import Path
2626
from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Union
@@ -48,8 +48,8 @@ class AuthToken(BaseModel):
4848
@field_validator("expires")
4949
def do_datetime(cls, v: Union[datetime, str]):
5050
if isinstance(v, str):
51-
return datetime.fromisoformat(v)
52-
return v
51+
v = datetime.fromisoformat(v)
52+
return v.replace(tzinfo=timezone.utc)
5353

5454

5555
class ConnectionConfig(BaseModel):
@@ -64,7 +64,8 @@ class ConnectionConfig(BaseModel):
6464
class Login(BaseModel):
6565
auth_flow: AuthFlow
6666
client_id: str
67-
client_secret: Optional[str]
67+
client_secret: Optional[str] = None
68+
provider: Optional[AuthProviderName] = None
6869

6970
@field_validator("client_secret")
7071
def set_secret(cls, v: Optional[str], info: FieldValidationInfo) -> str:

0 commit comments

Comments
 (0)