-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathservices.py
More file actions
126 lines (109 loc) · 4.43 KB
/
services.py
File metadata and controls
126 lines (109 loc) · 4.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import functools
import typing as t
from datetime import timedelta
import anyio
import jwt
from ellar.common import serialize_object
from ellar.di import injectable
from jwt import InvalidAlgorithmError, InvalidTokenError, PyJWKClient, PyJWKClientError
from .exceptions import JWTTokenException
from .schemas import JWTConfiguration
from .token import Token
__all__ = ["JWTService"]
@injectable
class JWTService:
def __init__(self, jwt_config: JWTConfiguration) -> None:
self.jwt_config = jwt_config
def get_jwks_client(self, jwt_config: JWTConfiguration) -> t.Optional[PyJWKClient]:
jwks_client = (
PyJWKClient(str(jwt_config.jwk_url)) if jwt_config.jwk_url else None
)
return jwks_client
def get_leeway(self, jwt_config: JWTConfiguration) -> timedelta:
if jwt_config.leeway is None:
return timedelta(seconds=0)
elif isinstance(jwt_config.leeway, (int, float)):
return timedelta(seconds=jwt_config.leeway)
elif isinstance(jwt_config.leeway, timedelta):
return jwt_config.leeway
def get_verifying_key(self, token: t.Any, jwt_config: JWTConfiguration) -> bytes:
if self.jwt_config.algorithm.startswith("HS"):
return jwt_config.signing_secret_key.encode()
jwks_client = self.get_jwks_client(jwt_config)
if jwks_client:
try:
p_jwk = jwks_client.get_signing_key_from_jwt(token)
return p_jwk.key # type:ignore[no-any-return]
except PyJWKClientError as ex:
raise JWTTokenException("Token is invalid or expired") from ex
return jwt_config.verifying_secret_key.encode()
def _merge_configurations(self, **jwt_config: t.Any) -> JWTConfiguration:
jwt_config_default = self.jwt_config.dict()
jwt_config_default.update(jwt_config)
return JWTConfiguration(**jwt_config_default)
def sign(
self,
payload: dict,
headers: t.Optional[t.Dict[str, t.Any]] = None,
**jwt_config: t.Any,
) -> str:
"""
Returns an encoded token for the given payload dictionary.
"""
_jwt_config = self._merge_configurations(**jwt_config)
jwt_payload = Token(jwt_config=_jwt_config).build(
serialize_object(payload.copy())
)
if "sub" in jwt_payload:
jwt_payload["sub"] = str(jwt_payload["sub"])
return jwt.encode(
jwt_payload,
_jwt_config.signing_secret_key,
algorithm=_jwt_config.algorithm,
json_encoder=_jwt_config.json_encoder,
headers=headers,
)
async def sign_async(
self,
payload: dict,
headers: t.Optional[t.Dict[str, t.Any]] = None,
**jwt_config: t.Any,
) -> str:
func = self.sign
if jwt_config:
func = functools.partial(self.sign, **jwt_config)
return await anyio.to_thread.run_sync(func, payload, headers)
def decode(
self, token: str, verify: bool = True, **jwt_config: t.Any
) -> t.Dict[str, t.Any]:
"""
Performs a validation of the given token and returns its payload
dictionary.
Raises a `TokenBackendError` if the token is malformed, if its
signature check fails, or if its 'exp' claim indicates it has expired.
"""
try:
_jwt_config = self._merge_configurations(**jwt_config)
return jwt.decode( # type:ignore[no-any-return]
token,
self.get_verifying_key(token, _jwt_config),
algorithms=[_jwt_config.algorithm],
audience=_jwt_config.audience,
issuer=_jwt_config.issuer,
leeway=self.get_leeway(_jwt_config),
options={
"verify_aud": _jwt_config.audience is not None,
"verify_signature": verify,
},
)
except InvalidAlgorithmError as ex:
raise JWTTokenException("Invalid algorithm specified") from ex
except InvalidTokenError as ex:
raise JWTTokenException("Token is invalid or expired") from ex
async def decode_async(
self, token: str, verify: bool = True, **jwt_config: t.Any
) -> t.Dict[str, t.Any]:
func = self.decode
if jwt_config:
func = functools.partial(self.decode, **jwt_config)
return await anyio.to_thread.run_sync(func, token, verify)