1- from typing import Any , Dict , Optional , Sequence
1+ from typing import Any , Dict , Optional , Sequence , List , Union , Protocol
22
33import jwt
44from jwt .exceptions import InvalidIssuerError , InvalidTokenError
@@ -33,7 +33,29 @@ def get_kid(token: str) -> Optional[str]:
3333 return headers .get ("kid" )
3434
3535
36- class JWTValidator :
36+ class JWTValidatorProtocol (Protocol ):
37+ """Protocol defining the interface for JWT validators"""
38+
39+ async def validate_jwt (self , access_token : str ) -> Dict [str , Any ]: ...
40+
41+
42+ class AbstractJWTValidator :
43+ """Base class for JWT validators with common functionality"""
44+
45+ def __init__ (
46+ self ,
47+ * ,
48+ valid_issuers : Sequence [str ],
49+ valid_audiences : Sequence [str ],
50+ algorithms : Sequence [str ],
51+ ) -> None :
52+ self ._valid_issuers = list (valid_issuers )
53+ self ._valid_audiences = list (valid_audiences )
54+ self ._algorithms = list (algorithms )
55+ self .logger = get_logger ()
56+
57+
58+ class AsymmetricJWTValidator (AbstractJWTValidator ):
3759 def __init__ (
3860 self ,
3961 * ,
@@ -48,7 +70,7 @@ def __init__(
4870 refresh_time : float = 120 ,
4971 ) -> None :
5072 """
51- Creates a new instance of JWTValidator . This class only supports validating
73+ Creates a new instance of AsymmetricJWTValidator . This class supports validating
5274 access tokens signed using asymmetric keys and handling JWKs of RSA type.
5375
5476 Parameters
@@ -83,6 +105,12 @@ def __init__(
83105 JWKS were last fetched more than `refresh_time` seconds ago (by default
84106 120 seconds)
85107 """
108+ super ().__init__ (
109+ valid_issuers = valid_issuers ,
110+ valid_audiences = valid_audiences ,
111+ algorithms = algorithms ,
112+ )
113+
86114 if keys_provider :
87115 pass
88116 elif authority :
@@ -96,14 +124,10 @@ def __init__(
96124 "`authority`, or `keys_provider`."
97125 )
98126
99- keys_provider = CachingKeysProvider (keys_provider , cache_time , refresh_time )
100-
101- self ._valid_issuers = list (valid_issuers )
102- self ._valid_audiences = list (valid_audiences )
103- self ._algorithms = list (algorithms )
104- self ._keys_provider = keys_provider
127+ self ._keys_provider = CachingKeysProvider (
128+ keys_provider , cache_time , refresh_time
129+ )
105130 self .require_kid = require_kid
106- self .logger = get_logger ()
107131
108132 async def get_jwks (self ) -> JWKS :
109133 return await self ._keys_provider .get_keys ()
@@ -170,3 +194,110 @@ async def validate_jwt(self, access_token: str) -> Dict[str, Any]:
170194 return data
171195
172196 raise InvalidAccessToken ()
197+
198+
199+ class SymmetricJWTValidator (AbstractJWTValidator ):
200+ def __init__ (
201+ self ,
202+ * ,
203+ valid_issuers : Sequence [str ],
204+ valid_audiences : Sequence [str ],
205+ secret_key : Union [str , bytes ],
206+ algorithms : Sequence [str ] = ["HS256" ],
207+ ) -> None :
208+ """
209+ Creates a new instance of SymmetricJWTValidator. This class supports validating
210+ access tokens signed using symmetric keys (HMAC).
211+
212+ Parameters
213+ ----------
214+ valid_issuers : Sequence[str]
215+ Sequence of acceptable issuers (iss).
216+ valid_audiences : Sequence[str]
217+ Sequence of acceptable audiences (aud).
218+ secret_key : Union[str, bytes]
219+ The secret key used for symmetric validation.
220+ algorithms : Sequence[str], optional
221+ Sequence of acceptable algorithms, by default ["HS256"].
222+ Supported algorithms: HS256, HS384, HS512
223+ """
224+ super ().__init__ (
225+ valid_issuers = valid_issuers ,
226+ valid_audiences = valid_audiences ,
227+ algorithms = algorithms ,
228+ )
229+
230+ supported_algorithms = ["HS256" , "HS384" , "HS512" ]
231+ for algorithm in algorithms :
232+ if algorithm not in supported_algorithms :
233+ raise ValueError (
234+ f"Algorithm '{ algorithm } ' is not supported for symmetric validation. "
235+ f"Use one of: { ', ' .join (supported_algorithms )} "
236+ )
237+
238+ self ._secret_key = secret_key
239+
240+ async def validate_jwt (self , access_token : str ) -> Dict [str , Any ]:
241+ """
242+ Validates the given JWT using symmetric key and returns its payload.
243+ This method throws exception if the JWT is not valid.
244+ """
245+ for issuer in self ._valid_issuers :
246+ try :
247+ return jwt .decode (
248+ access_token ,
249+ self ._secret_key ,
250+ verify = True ,
251+ algorithms = self ._algorithms ,
252+ audience = self ._valid_audiences ,
253+ issuer = issuer ,
254+ )
255+ except InvalidIssuerError :
256+ # Try the next issuer
257+ pass
258+ except InvalidTokenError as exc :
259+ self .logger .debug ("Invalid access token: " , exc_info = exc )
260+
261+ # If we've tried all issuers and none worked
262+ raise InvalidAccessToken ()
263+
264+
265+ class CompositeJWTValidator (AbstractJWTValidator ):
266+ def __init__ (self , validators : List [JWTValidatorProtocol ]) -> None :
267+ """
268+ Creates a composite validator that tries multiple validation strategies.
269+ Useful when you need to support both symmetric and asymmetric validation.
270+
271+ Parameters
272+ ----------
273+ validators : List[JWTValidatorProtocol]
274+ List of validators to try in sequence
275+ """
276+ self ._validators = validators
277+ self .logger = get_logger ()
278+
279+ async def validate_jwt (self , access_token : str ) -> Dict [str , Any ]:
280+ """
281+ Attempts to validate the JWT using each validator in sequence.
282+ Returns the first successful validation result or raises InvalidAccessToken
283+ if all validators fail.
284+ """
285+ exceptions = []
286+
287+ for validator in self ._validators :
288+ try :
289+ return await validator .validate_jwt (access_token )
290+ except InvalidAccessToken as exc :
291+ exceptions .append (exc )
292+ # Continue to the next validator
293+
294+ # If we get here, all validators failed
295+ if exceptions :
296+ self .logger .debug (f"All validators failed: { exceptions } " )
297+ raise InvalidAccessToken (
298+ "Token validation failed with all configured validators"
299+ )
300+
301+
302+ # For backward compatibility, keep the original name
303+ JWTValidator = AsymmetricJWTValidator
0 commit comments