11import datetime
2+ from http import HTTPStatus
3+ from typing import cast
24
5+ from aiohttp .helpers import BasicAuth
36from pydantic import BaseModel
47
58from healthcheck .context import Context
69from healthcheck .status import Result
10+ from healthcheck .status import status_logger as logger
711from healthcheck .status .requests import query_api
812
913
@@ -16,16 +20,144 @@ class Token(BaseModel):
1620 token_type : str = "Bearer"
1721
1822
23+ def getnow ():
24+ """naive UTC now"""
25+ return datetime .datetime .now (datetime .UTC ).replace (tzinfo = None )
26+
27+
28+ class ClientTokenProvider :
29+ """Client to generate access tokens to authenticate with Zimfarm"""
30+
31+ def __init__ (self ):
32+ self ._access_token : str | None = None
33+ self ._refresh_token : str | None = None
34+ self ._expires_at : datetime .datetime = datetime .datetime .fromtimestamp (
35+ 0
36+ ).replace (tzinfo = None )
37+
38+ async def _generate_oauth_access_token (self ) -> None :
39+ """Generate oauth access token and update expires_at."""
40+
41+ response = await query_api (
42+ f"{ Context .cms_oauth_issuer } /oauth2/token" ,
43+ method = "POST" ,
44+ data = {
45+ "grant_type" : "client_credentials" ,
46+ "audience" : Context .cms_oauth_audience_id ,
47+ },
48+ auth = BasicAuth (
49+ Context .cms_oauth_client_id , Context .cms_oauth_client_secret
50+ ),
51+ timeout = Context .requests_timeout ,
52+ check_name = "zimfarm-api-authentication" ,
53+ )
54+ if response .json :
55+ self ._access_token = cast (str , response .json ["access_token" ])
56+ self ._expires_at = getnow () + datetime .timedelta (
57+ seconds = response .json ["expires_in" ]
58+ )
59+
60+ async def _generate_local_access_token (self ) -> None :
61+ check_name = "zimfarm-api-authentication"
62+ if self ._refresh_token :
63+ response = await query_api (
64+ f"{ Context .cms_api_url } /auth/refresh" ,
65+ method = "POST" ,
66+ payload = {
67+ "refresh_token" : self ._refresh_token ,
68+ },
69+ timeout = Context .requests_timeout ,
70+ check_name = check_name ,
71+ )
72+ else :
73+ response = await query_api (
74+ f"{ Context .cms_api_url } /auth/authorize" ,
75+ method = "POST" ,
76+ payload = {
77+ "username" : Context .cms_username ,
78+ "password" : Context .cms_password ,
79+ },
80+ timeout = Context .requests_timeout ,
81+ check_name = check_name ,
82+ )
83+
84+ if response .json :
85+ self ._access_token = cast (str , response .json ["access_token" ])
86+ self ._refresh_token = cast (str , response .json ["refresh_token" ])
87+ self ._expires_at = datetime .datetime .fromisoformat (
88+ response .json ["expires_time" ]
89+ ).replace (tzinfo = None )
90+
91+ async def get_access_token (self , * , force_refresh : bool = False ) -> str :
92+ """Retrieve or generate access token depending on if token has expired."""
93+ now = getnow ()
94+ if (
95+ force_refresh
96+ or self ._access_token is None
97+ or now >= (self ._expires_at - Context .cms_token_renewal_window )
98+ ):
99+ if Context .auth_mode == "oauth" :
100+ await self ._generate_oauth_access_token ()
101+ elif Context .auth_mode == "local" :
102+ await self ._generate_local_access_token ()
103+ else :
104+ raise ValueError (
105+ f"Unknown authentication mode: { Context .auth_mode } . "
106+ "Allowed values are: 'local', 'oauth'"
107+ )
108+ if self ._access_token is None :
109+ raise ValueError ("Failed to generate access token." )
110+ return self ._access_token
111+
112+ @property
113+ def expires_at (self ) -> datetime .datetime :
114+ return self ._expires_at
115+
116+ @property
117+ def refresh_token (self ) -> str | None :
118+ return self ._refresh_token
119+
120+
121+ _token_provider = ClientTokenProvider ()
122+
123+
19124async def authenticate () -> Result [Token ]:
20- """Check if authentication is sucessful with CMS"""
21- response = await query_api (
22- f"{ Context .cms_api_url } /auth/authorize" ,
23- method = "POST" ,
24- payload = {"username" : Context .cms_username , "password" : Context .cms_password },
25- check_name = "cms-api-authentication" ,
26- )
27- return Result (
28- success = response .success ,
29- status_code = response .status_code ,
30- data = Token .model_validate (response .json ) if response .success else None ,
31- )
125+ """Check if authentication is successful with CMS API"""
126+ try :
127+ access_token = await _token_provider .get_access_token ()
128+ token = Token (
129+ access_token = access_token ,
130+ expires_time = _token_provider .expires_at ,
131+ refresh_token = _token_provider .refresh_token or "" ,
132+ )
133+
134+ response = await query_api (
135+ f"{ Context .cms_api_url } /auth/me" ,
136+ method = "GET" ,
137+ headers = {"Authorization" : f"Bearer { token .access_token } " },
138+ check_name = "zimfarm-api-authentication" ,
139+ )
140+
141+ if response .success :
142+ logger .debug (
143+ f"Authentication successful using { Context .auth_mode } mode" ,
144+ extra = {"checkname" : "zimfarm-api-authentication" },
145+ )
146+
147+ return Result (
148+ success = True ,
149+ status_code = HTTPStatus .OK ,
150+ data = token ,
151+ )
152+ else :
153+ return Result (
154+ success = False ,
155+ status_code = HTTPStatus .UNAUTHORIZED ,
156+ data = None ,
157+ )
158+ except Exception :
159+ return Result (
160+ success = False ,
161+ status_code = HTTPStatus .INTERNAL_SERVER_ERROR ,
162+ data = None ,
163+ )
0 commit comments