Skip to content

Commit 1631e9f

Browse files
authored
Merge pull request #218 from openzim/healthcheck-optional-auth
perform authentication healthcheck with oauth token
2 parents 750d6a4 + 90f6836 commit 1631e9f

4 files changed

Lines changed: 166 additions & 16 deletions

File tree

healthcheck/src/healthcheck/context.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
import os
23
from typing import Any
34

@@ -29,8 +30,18 @@ class Context:
2930

3031
cms_api_url = get_mandatory_env("CMS_API_URL")
3132
cms_frontend_url = get_mandatory_env("CMS_FRONTEND_URL")
32-
cms_username = get_mandatory_env("CMS_USERNAME")
33-
cms_password = get_mandatory_env("CMS_PASSWORD")
33+
auth_mode = os.getenv("AUTH_MODE", default="local")
34+
cms_username = os.getenv("CMS_USERNAME", default="")
35+
cms_password = os.getenv("CMS_PASSWORD")
36+
cms_oauth_issuer = os.getenv(
37+
"CMS_OAUTH_ISSUER", default="https://ory.login.kiwix.org"
38+
)
39+
cms_oauth_client_id = os.getenv("CMS_OAUTH_CLIENT_ID", default="")
40+
cms_oauth_client_secret = os.getenv("CMS_OAUTH_CLIENT_SECRET", default="")
41+
cms_oauth_audience_id = os.getenv("CMS_OAUTH_AUDIENCE_ID", default="")
42+
cms_token_renewal_window = datetime.timedelta(
43+
seconds=parse_timespan(os.getenv("CMS_TOKEN_RENEWAL_WINDOW", default="5m"))
44+
)
3445
cms_database_url = get_mandatory_env("CMS_DATABASE_URL")
3546
catalog_generation_timeout = parse_timespan(
3647
os.getenv("CATALOG_GENERATION_TIMEOUT", default="10s")
Lines changed: 144 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import datetime
2+
from http import HTTPStatus
3+
from typing import cast
24

5+
from aiohttp.helpers import BasicAuth
36
from pydantic import BaseModel
47

58
from healthcheck.context import Context
69
from healthcheck.status import Result
10+
from healthcheck.status import status_logger as logger
711
from healthcheck.status.requests import query_api
812

913

@@ -16,16 +20,144 @@ class Token(BaseModel):
1620
token_type: str = "Bearer"
1721

1822

23+
def getnow():
24+
"""naive UTC now"""
25+
return datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
26+
27+
28+
class ClientTokenProvider:
29+
"""Client to generate access tokens to authenticate with Zimfarm"""
30+
31+
def __init__(self):
32+
self._access_token: str | None = None
33+
self._refresh_token: str | None = None
34+
self._expires_at: datetime.datetime = datetime.datetime.fromtimestamp(
35+
0
36+
).replace(tzinfo=None)
37+
38+
async def _generate_oauth_access_token(self) -> None:
39+
"""Generate oauth access token and update expires_at."""
40+
41+
response = await query_api(
42+
f"{Context.cms_oauth_issuer}/oauth2/token",
43+
method="POST",
44+
data={
45+
"grant_type": "client_credentials",
46+
"audience": Context.cms_oauth_audience_id,
47+
},
48+
auth=BasicAuth(
49+
Context.cms_oauth_client_id, Context.cms_oauth_client_secret
50+
),
51+
timeout=Context.requests_timeout,
52+
check_name="zimfarm-api-authentication",
53+
)
54+
if response.json:
55+
self._access_token = cast(str, response.json["access_token"])
56+
self._expires_at = getnow() + datetime.timedelta(
57+
seconds=response.json["expires_in"]
58+
)
59+
60+
async def _generate_local_access_token(self) -> None:
61+
check_name = "zimfarm-api-authentication"
62+
if self._refresh_token:
63+
response = await query_api(
64+
f"{Context.cms_api_url}/auth/refresh",
65+
method="POST",
66+
payload={
67+
"refresh_token": self._refresh_token,
68+
},
69+
timeout=Context.requests_timeout,
70+
check_name=check_name,
71+
)
72+
else:
73+
response = await query_api(
74+
f"{Context.cms_api_url}/auth/authorize",
75+
method="POST",
76+
payload={
77+
"username": Context.cms_username,
78+
"password": Context.cms_password,
79+
},
80+
timeout=Context.requests_timeout,
81+
check_name=check_name,
82+
)
83+
84+
if response.json:
85+
self._access_token = cast(str, response.json["access_token"])
86+
self._refresh_token = cast(str, response.json["refresh_token"])
87+
self._expires_at = datetime.datetime.fromisoformat(
88+
response.json["expires_time"]
89+
).replace(tzinfo=None)
90+
91+
async def get_access_token(self, *, force_refresh: bool = False) -> str:
92+
"""Retrieve or generate access token depending on if token has expired."""
93+
now = getnow()
94+
if (
95+
force_refresh
96+
or self._access_token is None
97+
or now >= (self._expires_at - Context.cms_token_renewal_window)
98+
):
99+
if Context.auth_mode == "oauth":
100+
await self._generate_oauth_access_token()
101+
elif Context.auth_mode == "local":
102+
await self._generate_local_access_token()
103+
else:
104+
raise ValueError(
105+
f"Unknown authentication mode: {Context.auth_mode}. "
106+
"Allowed values are: 'local', 'oauth'"
107+
)
108+
if self._access_token is None:
109+
raise ValueError("Failed to generate access token.")
110+
return self._access_token
111+
112+
@property
113+
def expires_at(self) -> datetime.datetime:
114+
return self._expires_at
115+
116+
@property
117+
def refresh_token(self) -> str | None:
118+
return self._refresh_token
119+
120+
121+
_token_provider = ClientTokenProvider()
122+
123+
19124
async def authenticate() -> Result[Token]:
20-
"""Check if authentication is sucessful with CMS"""
21-
response = await query_api(
22-
f"{Context.cms_api_url}/auth/authorize",
23-
method="POST",
24-
payload={"username": Context.cms_username, "password": Context.cms_password},
25-
check_name="cms-api-authentication",
26-
)
27-
return Result(
28-
success=response.success,
29-
status_code=response.status_code,
30-
data=Token.model_validate(response.json) if response.success else None,
31-
)
125+
"""Check if authentication is successful with CMS API"""
126+
try:
127+
access_token = await _token_provider.get_access_token()
128+
token = Token(
129+
access_token=access_token,
130+
expires_time=_token_provider.expires_at,
131+
refresh_token=_token_provider.refresh_token or "",
132+
)
133+
134+
response = await query_api(
135+
f"{Context.cms_api_url}/auth/me",
136+
method="GET",
137+
headers={"Authorization": f"Bearer {token.access_token}"},
138+
check_name="zimfarm-api-authentication",
139+
)
140+
141+
if response.success:
142+
logger.debug(
143+
f"Authentication successful using {Context.auth_mode} mode",
144+
extra={"checkname": "zimfarm-api-authentication"},
145+
)
146+
147+
return Result(
148+
success=True,
149+
status_code=HTTPStatus.OK,
150+
data=token,
151+
)
152+
else:
153+
return Result(
154+
success=False,
155+
status_code=HTTPStatus.UNAUTHORIZED,
156+
data=None,
157+
)
158+
except Exception:
159+
return Result(
160+
success=False,
161+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
162+
data=None,
163+
)

healthcheck/src/healthcheck/status/collections.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ async def check_catalog_generation() -> Result[CatalogStatus]:
7878
if failures:
7979
logger.error(
8080
f"Failed to generate catalogs for the following collections: "
81-
f"{','.join([failure.name for failure in failures])}"
81+
f"{','.join([failure.name for failure in failures])}",
82+
extra={"checkname": "cms-check-catalog-generation"},
8283
)
8384

8485
return Result(

healthcheck/src/healthcheck/status/requests.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any
55

66
import aiohttp
7+
from aiohttp.helpers import BasicAuth
78

89
from healthcheck.context import Context
910
from healthcheck.status import status_logger as logger
@@ -28,14 +29,17 @@ async def query_api(
2829
payload: dict[str, Any] | None = None,
2930
params: dict[str, Any] | None = None,
3031
timeout: float = Context.requests_timeout,
32+
auth: BasicAuth | None = None,
33+
data: dict[str, Any] | None = None,
3134
) -> Response:
3235
req_headers: dict[str, Any] = {}
3336
req_headers.update(headers if headers else {})
3437

3538
# Log request details
3639
logger.debug(
3740
f"Sending request: method={method.upper()}, url={url}, "
38-
f"headers={req_headers}, params={params}, body={payload}",
41+
f"headers={req_headers}, params={params}, body={payload}, "
42+
f"data={data}",
3943
extra={"checkname": check_name},
4044
)
4145

@@ -48,6 +52,8 @@ async def query_api(
4852
json=payload,
4953
params=params,
5054
timeout=aiohttp.ClientTimeout(total=timeout),
55+
auth=auth,
56+
data=data,
5157
) as resp:
5258
try:
5359
text = await resp.text()

0 commit comments

Comments
 (0)