1212import base64
1313import hashlib
1414import json
15+ import jwt
1516import os
1617import urllib
1718import time
18- from datetime import datetime
19+ import uuid
20+ from datetime import datetime , timezone , timedelta
1921from abc import ABC , abstractmethod
2022from typing import Optional
2123
@@ -428,6 +430,7 @@ def __init__(self, access_token, token_type='Bearer', expires_in=None,
428430 def authenticate (self ) -> AuthContext :
429431 return self ._context
430432
433+
431434class RefreshTokenAuthManager (AuthManager ):
432435 def __init__ (self , client_id , refresh_token , host , scope = "openid" , requests_hooks = None ):
433436 super ().__init__ (host , client_id , requests_hooks = requests_hooks )
@@ -448,3 +451,47 @@ def authenticate(self):
448451
449452 response = self ._post_token (** data )
450453 return AuthContext (** response .json ())
454+
455+
456+ class ServicePrincipalAuthManager (AuthManager ):
457+ def __init__ (self , host , principal_name , key , kid , algorithm = "ES256" , ** kwargs ):
458+ """
459+ Creates an AuthManager that uses Service Principals to authenticate.
460+
461+ principal_name is the principal_name of the authenticating service principal
462+ key is the PEM formatted private key registered with the service principal
463+ kid is the key_id of `key`
464+ algorithm is the algorithm that generated `key`
465+ """
466+ super ().__init__ (host = host , client_id = None , ** kwargs )
467+ self ._principal_name = principal_name
468+ self ._key = key
469+ self ._kid = kid
470+ self ._algorithm = algorithm
471+
472+ def authenticate (self ):
473+ """Authenticate using the "client assertion" flow."""
474+ if not self ._principal_name :
475+ raise ValueError ("missing principal_name" )
476+ if not self ._key :
477+ raise ValueError ("missing key" )
478+ if not self ._kid :
479+ raise ValueError ("missing kid" )
480+ if not self ._algorithm :
481+ raise ValueError ("missing algorithm" )
482+
483+ # Client assertion expires in 10 minutes
484+ ten_minutes_from_now = datetime .now (timezone .utc ) + timedelta (minutes = 10 )
485+ jwt_payload = {"sub" : self ._principal_name , "iss" : self ._principal_name , "jti" : str (uuid .uuid4 ()),
486+ "exp" : int (ten_minutes_from_now .timestamp ()), "aud" : [self ._url (PATH_TOKEN )]}
487+
488+ client_assertion = jwt .encode (payload = jwt_payload , key = self ._key , algorithm = self ._algorithm ,
489+ headers = {"kid" : self ._kid })
490+
491+ data = {"grant_type" : "client_credentials" , "client_assertion" : client_assertion .decode ("utf-8" ),
492+ "client_assertion_type" : "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" }
493+
494+ response = self ._post_token (** data )
495+ if response .status_code != 200 :
496+ raise AuthnError ("Unable to authenticate. Check credentials." , response )
497+ return AuthContext (** response .json ())
0 commit comments