@@ -45,15 +45,18 @@ def _fetch_key(self, key_id=None):
4545 """
4646 raise NotImplementedError
4747
48- def verify_signature (self , token ):
49- """Verifies the signature of the given JSON web token.
48+ def _get_kid (self , token ):
49+ """Gets the key id from the kid claim of the header of the token
5050
5151 Args:
52- token (str): The JWT to get its signature verified .
52+ token (str): The JWT to get the header from .
5353
5454 Raises:
5555 TokenValidationError: if the token cannot be decoded, the algorithm is invalid
5656 or the token's signature doesn't match the calculated one.
57+
58+ Returns:
59+ the key id or None
5760 """
5861 try :
5962 header = jwt .get_unverified_header (token )
@@ -67,9 +70,19 @@ def verify_signature(self, token):
6770 'to be signed with "{}"' .format (alg , self ._algorithm )
6871 )
6972
70- kid = header .get ("kid" , None )
71- secret_or_certificate = self ._fetch_key (key_id = kid )
73+ return header .get ("kid" , None )
74+
75+ def _decode_jwt (self , token , secret_or_certificate ):
76+ """Verifies the signature of the given JSON web token.
77+
78+ Args:
79+ token (str): The JWT to get its signature verified.
80+ secret_or_certificate (str): The public key or shared secret.
7281
82+ Raises:
83+ TokenValidationError: if the token cannot be decoded, the algorithm is invalid
84+ or the token's signature doesn't match the calculated one.
85+ """
7386 try :
7487 decoded = jwt .decode (
7588 jwt = token ,
@@ -81,6 +94,21 @@ def verify_signature(self, token):
8194 raise TokenValidationError ("Invalid token signature." )
8295 return decoded
8396
97+ def verify_signature (self , token ):
98+ """Verifies the signature of the given JSON web token.
99+
100+ Args:
101+ token (str): The JWT to get its signature verified.
102+
103+ Raises:
104+ TokenValidationError: if the token cannot be decoded, the algorithm is invalid
105+ or the token's signature doesn't match the calculated one.
106+ """
107+ kid = self ._get_kid (token )
108+ secret_or_certificate = self ._fetch_key (key_id = kid )
109+
110+ return self ._decode_jwt (token , secret_or_certificate )
111+
84112
85113class SymmetricSignatureVerifier (SignatureVerifier ):
86114 """Verifier for HMAC signatures, which rely on shared secrets.
@@ -136,6 +164,24 @@ def _init_cache(self, cache_ttl):
136164 self ._cache_ttl = cache_ttl
137165 self ._cache_is_fresh = False
138166
167+ def _cache_expired (self ):
168+ """Checks if the cache is expired
169+
170+ Returns:
171+ True if it should use the cache.
172+ """
173+ return self ._cache_date + self ._cache_ttl < time .time ()
174+
175+ def _cache_jwks (self , jwks ):
176+ """Cache the response of the JWKS request
177+
178+ Args:
179+ jwks (dict): The JWKS
180+ """
181+ self ._cache_value = self ._parse_jwks (jwks )
182+ self ._cache_is_fresh = True
183+ self ._cache_date = time .time ()
184+
139185 def _fetch_jwks (self , force = False ):
140186 """Attempts to obtain the JWK set from the cache, as long as it's still valid.
141187 When not, it will perform a network request to the jwks_url to obtain a fresh result
@@ -144,23 +190,15 @@ def _fetch_jwks(self, force=False):
144190 Args:
145191 force (bool, optional): whether to ignore the cache and force a network request or not. Defaults to False.
146192 """
147- has_expired = self ._cache_date + self ._cache_ttl < time .time ()
148-
149- if not force and not has_expired :
150- # Return from cache
151- self ._cache_is_fresh = False
193+ if force or self ._cache_expired ():
194+ self ._cache_value = {}
195+ response = requests .get (self ._jwks_url )
196+ if response .ok :
197+ jwks = response .json ()
198+ self ._cache_jwks (jwks )
152199 return self ._cache_value
153200
154- # Invalidate cache and fetch fresh data
155- self ._cache_value = {}
156- response = requests .get (self ._jwks_url )
157-
158- if response .ok :
159- # Update cache
160- jwks = response .json ()
161- self ._cache_value = self ._parse_jwks (jwks )
162- self ._cache_is_fresh = True
163- self ._cache_date = time .time ()
201+ self ._cache_is_fresh = False
164202 return self ._cache_value
165203
166204 @staticmethod
0 commit comments