Skip to content

Commit d4f4f69

Browse files
appleboyclaude
andcommitted
refactor(sdk): deduplicate shared logic and fix type annotations
- Extract _parse_error_response into oauth/_parsing.py shared by sync and async clients - Extract _validate_and_enrich and _copy_metadata helpers in discovery to eliminate duplication - Replace copy.deepcopy with dataclasses.replace for Metadata copying - Extract _is_token_valid module function in clientcreds/token_source.py - Extract _handle_poll_error helper in authflow/device.py for polling error classification - Add in-memory token cache to authflow/token_source.py to avoid redundant store reads - Fix _with_file_lock type annotation from object to Callable[[], None] - Remove unused request parameter from Django middleware _ensure_configured Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a8b7e68 commit d4f4f69

10 files changed

Lines changed: 130 additions & 123 deletions

File tree

src/authgate/authflow/device.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,23 @@ def run_device_flow(
4747
return _poll_device_code(client, auth)
4848

4949

50+
def _handle_poll_error(exc: OAuthError) -> str:
51+
"""Classify a device-code polling error.
52+
53+
Returns ``"continue"`` or ``"slow_down"`` for retriable errors.
54+
Raises the appropriate exception for terminal errors.
55+
"""
56+
if exc.code == "authorization_pending":
57+
return "continue"
58+
if exc.code == "slow_down":
59+
return "slow_down"
60+
if exc.code == "expired_token":
61+
raise TokenExpiredError("device code expired") from exc
62+
if exc.code == "access_denied":
63+
raise AccessDeniedError("access denied by user") from exc
64+
raise AuthFlowError(f"exchange device code: {exc}") from exc
65+
66+
5067
def _poll_device_code(client: OAuthClient, auth: DeviceAuth) -> Token:
5168
"""Poll the token endpoint until the user authorizes or the code expires."""
5269
interval = max(auth.interval, 5)
@@ -61,16 +78,9 @@ def _poll_device_code(client: OAuthClient, auth: DeviceAuth) -> Token:
6178
try:
6279
return client.exchange_device_code(auth.device_code)
6380
except OAuthError as exc:
64-
if exc.code == "authorization_pending":
65-
continue
66-
if exc.code == "slow_down":
81+
signal = _handle_poll_error(exc)
82+
if signal == "slow_down":
6783
interval += 5
68-
continue
69-
if exc.code == "expired_token":
70-
raise TokenExpiredError("device code expired") from exc
71-
if exc.code == "access_denied":
72-
raise AccessDeniedError("access denied by user") from exc
73-
raise AuthFlowError(f"exchange device code: {exc}") from exc
7484

7585

7686
async def async_run_device_flow(
@@ -108,13 +118,6 @@ async def _async_poll_device_code(client: AsyncOAuthClient, auth: DeviceAuth) ->
108118
try:
109119
return await client.exchange_device_code(auth.device_code)
110120
except OAuthError as exc:
111-
if exc.code == "authorization_pending":
112-
continue
113-
if exc.code == "slow_down":
121+
signal = _handle_poll_error(exc)
122+
if signal == "slow_down":
114123
interval += 5
115-
continue
116-
if exc.code == "expired_token":
117-
raise TokenExpiredError("device code expired") from exc
118-
if exc.code == "access_denied":
119-
raise AccessDeniedError("access denied by user") from exc
120-
raise AuthFlowError(f"exchange device code: {exc}") from exc

src/authgate/authflow/token_source.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,25 @@ def __init__(
4646
self._client = client
4747
self._store = store
4848
self._lock = threading.RLock()
49+
self._cached: Token | None = None
4950
self._inflight: threading.Event | None = None
5051
self._inflight_result: Token | None = None
5152
self._inflight_error: Exception | None = None
5253

5354
def token(self) -> Token:
5455
"""Return a valid token, refreshing from store or server as needed."""
55-
# Fast path: check for valid cached token
56+
# Fast path: check in-memory cache first
57+
if self._cached is not None and not self._cached.is_expired():
58+
return self._cached
59+
60+
# Check persistent store
5661
if self._store is not None:
5762
try:
5863
stored = self._store.load(self._client.client_id)
5964
if stored.is_valid():
60-
return _credstore_to_oauth(stored)
65+
tok = _credstore_to_oauth(stored)
66+
self._cached = tok
67+
return tok
6168
except NotFoundError:
6269
pass
6370

@@ -100,12 +107,15 @@ def _do_refresh(self) -> Token:
100107
try:
101108
stored = self._store.load(self._client.client_id)
102109
if stored.is_valid():
103-
return _credstore_to_oauth(stored)
110+
tok = _credstore_to_oauth(stored)
111+
self._cached = tok
112+
return tok
104113

105114
# Try refresh if we have a refresh token
106115
if stored.refresh_token:
107116
refreshed = self._client.refresh_token(stored.refresh_token)
108117
self._save_token(refreshed)
118+
self._cached = refreshed
109119
return refreshed
110120
except NotFoundError:
111121
pass
@@ -115,6 +125,7 @@ def _do_refresh(self) -> Token:
115125
def save_token(self, token: Token) -> None:
116126
"""Persist a token to the store (if configured)."""
117127
with self._lock:
128+
self._cached = token
118129
self._save_token(token)
119130

120131
def _save_token(self, token: Token) -> None:

src/authgate/clientcreds/token_source.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@
1313
_DEFAULT_EXPIRY_DELTA = 30.0 # seconds
1414

1515

16+
def _is_token_valid(token: Token | None, expiry_delta: float) -> bool:
17+
"""Check whether a token is present and not expired."""
18+
if token is None or not token.access_token:
19+
return False
20+
if token.expires_at == 0:
21+
return True
22+
return (time.time() + expiry_delta) < token.expires_at
23+
24+
1625
class TokenSource:
1726
"""Thread-safe, auto-caching token source for client credentials (sync).
1827
@@ -39,23 +48,16 @@ def token(self) -> Token:
3948
"""Return a valid access token, fetching a new one if expired."""
4049
# Fast path
4150
with self._lock:
42-
if self._token is not None and self._is_valid():
43-
return self._token
51+
if _is_token_valid(self._token, self._expiry_delta):
52+
return self._token # type: ignore[return-value]
4453

4554
return self._slow_path()
4655

47-
def _is_valid(self) -> bool:
48-
if self._token is None or not self._token.access_token:
49-
return False
50-
if self._token.expires_at == 0:
51-
return True
52-
return (time.time() + self._expiry_delta) < self._token.expires_at
53-
5456
def _slow_path(self) -> Token:
5557
with self._lock:
5658
# Re-check after acquiring lock
57-
if self._token is not None and self._is_valid():
58-
return self._token
59+
if _is_token_valid(self._token, self._expiry_delta):
60+
return self._token # type: ignore[return-value]
5961

6062
if self._inflight is not None:
6163
event = self._inflight
@@ -104,18 +106,11 @@ def __init__(
104106

105107
async def token(self) -> Token:
106108
"""Return a valid access token, fetching a new one if expired."""
107-
if self._token is not None and self._is_valid():
108-
return self._token
109+
if _is_token_valid(self._token, self._expiry_delta):
110+
return self._token # type: ignore[return-value]
109111

110112
async with self._lock:
111-
if self._token is not None and self._is_valid():
112-
return self._token
113+
if _is_token_valid(self._token, self._expiry_delta):
114+
return self._token # type: ignore[return-value]
113115
self._token = await self._client.client_credentials(self._scopes)
114116
return self._token
115-
116-
def _is_valid(self) -> bool:
117-
if self._token is None or not self._token.access_token:
118-
return False
119-
if self._token.expires_at == 0:
120-
return True
121-
return (time.time() + self._expiry_delta) < self._token.expires_at

src/authgate/credstore/file_store.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import os
88
import time
9+
from collections.abc import Callable
910
from typing import Generic, TypeVar
1011

1112
from authgate.credstore.protocols import Codec
@@ -143,10 +144,10 @@ def _ensure_dir(self) -> None:
143144
if parent:
144145
os.makedirs(parent, mode=0o700, exist_ok=True)
145146

146-
def _with_file_lock(self, fn: object) -> None:
147+
def _with_file_lock(self, fn: Callable[[], None]) -> None:
147148
lock = _FileLock(self._file_path + ".lock")
148149
lock.acquire()
149150
try:
150-
fn() # type: ignore[operator]
151+
fn()
151152
finally:
152153
lock.release()

src/authgate/discovery/async_client.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,20 @@
33
from __future__ import annotations
44

55
import asyncio
6-
import copy
76
import time
87

98
import httpx
109

11-
from authgate.discovery.client import _parse_metadata
10+
from authgate.discovery.client import (
11+
_DEFAULT_CACHE_TTL,
12+
_WELL_KNOWN_PATH,
13+
_copy_metadata,
14+
_parse_metadata,
15+
_validate_and_enrich,
16+
)
1217
from authgate.discovery.models import Metadata
1318
from authgate.exceptions import DiscoveryError
1419

15-
_WELL_KNOWN_PATH = "/.well-known/openid-configuration"
16-
_DEFAULT_CACHE_TTL = 3600.0
17-
1820

1921
class AsyncDiscoveryClient:
2022
"""OIDC discovery client with caching (async)."""
@@ -36,13 +38,13 @@ def __init__(
3638
async def fetch(self) -> Metadata:
3739
"""Retrieve the OIDC provider metadata, using the cache if still valid."""
3840
if self._cached is not None and (time.time() - self._fetched_at) < self._cache_ttl:
39-
return copy.deepcopy(self._cached)
41+
return _copy_metadata(self._cached)
4042
return await self._refresh()
4143

4244
async def _refresh(self) -> Metadata:
4345
async with self._lock:
4446
if self._cached is not None and (time.time() - self._fetched_at) < self._cache_ttl:
45-
return copy.deepcopy(self._cached)
47+
return _copy_metadata(self._cached)
4648

4749
url = self._issuer_url + _WELL_KNOWN_PATH
4850
resp = await self._http.get(url)
@@ -51,20 +53,8 @@ async def _refresh(self) -> Metadata:
5153

5254
body = resp.json()
5355
meta = _parse_metadata(body)
54-
55-
issuer = meta.issuer.rstrip("/")
56-
if issuer != self._issuer_url:
57-
raise DiscoveryError(
58-
f"discovery: issuer mismatch: got {meta.issuer!r},"
59-
f" expected {self._issuer_url!r}"
60-
)
61-
62-
if not meta.device_authorization_endpoint:
63-
meta.device_authorization_endpoint = issuer + "/oauth/device/code"
64-
65-
if not meta.introspection_endpoint:
66-
meta.introspection_endpoint = issuer + "/oauth/introspect"
56+
_validate_and_enrich(meta, self._issuer_url)
6757

6858
self._cached = meta
6959
self._fetched_at = time.time()
70-
return copy.deepcopy(meta)
60+
return _copy_metadata(meta)

src/authgate/discovery/client.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
import copy
5+
import dataclasses
66
import threading
77
import time
88

@@ -15,6 +15,35 @@
1515
_DEFAULT_CACHE_TTL = 3600.0 # 1 hour
1616

1717

18+
def _copy_metadata(meta: Metadata) -> Metadata:
19+
"""Return a shallow copy of Metadata with independent list fields."""
20+
return dataclasses.replace(
21+
meta,
22+
response_types_supported=list(meta.response_types_supported),
23+
subject_types_supported=list(meta.subject_types_supported),
24+
id_token_signing_alg_values_supported=list(meta.id_token_signing_alg_values_supported),
25+
scopes_supported=list(meta.scopes_supported),
26+
token_endpoint_auth_methods_supported=list(meta.token_endpoint_auth_methods_supported),
27+
grant_types_supported=list(meta.grant_types_supported),
28+
claims_supported=list(meta.claims_supported),
29+
code_challenge_methods_supported=list(meta.code_challenge_methods_supported),
30+
)
31+
32+
33+
def _validate_and_enrich(meta: Metadata, expected_issuer: str) -> Metadata:
34+
"""Validate issuer and fill in default endpoint URLs."""
35+
issuer = meta.issuer.rstrip("/")
36+
if issuer != expected_issuer:
37+
raise DiscoveryError(
38+
f"discovery: issuer mismatch: got {meta.issuer!r}, expected {expected_issuer!r}"
39+
)
40+
if not meta.device_authorization_endpoint:
41+
meta.device_authorization_endpoint = issuer + "/oauth/device/code"
42+
if not meta.introspection_endpoint:
43+
meta.introspection_endpoint = issuer + "/oauth/introspect"
44+
return meta
45+
46+
1847
class DiscoveryClient:
1948
"""OIDC discovery client with caching (sync)."""
2049

@@ -39,13 +68,13 @@ def fetch(self) -> Metadata:
3968
"""
4069
with self._lock:
4170
if self._cached is not None and (time.time() - self._fetched_at) < self._cache_ttl:
42-
return copy.deepcopy(self._cached)
71+
return _copy_metadata(self._cached)
4372
return self._refresh()
4473

4574
def _refresh(self) -> Metadata:
4675
with self._lock:
4776
if self._cached is not None and (time.time() - self._fetched_at) < self._cache_ttl:
48-
return copy.deepcopy(self._cached)
77+
return _copy_metadata(self._cached)
4978

5079
url = self._issuer_url + _WELL_KNOWN_PATH
5180
resp = self._http.get(url)
@@ -54,23 +83,11 @@ def _refresh(self) -> Metadata:
5483

5584
body = resp.json()
5685
meta = _parse_metadata(body)
57-
58-
issuer = meta.issuer.rstrip("/")
59-
if issuer != self._issuer_url:
60-
raise DiscoveryError(
61-
f"discovery: issuer mismatch: got {meta.issuer!r},"
62-
f" expected {self._issuer_url!r}"
63-
)
64-
65-
if not meta.device_authorization_endpoint:
66-
meta.device_authorization_endpoint = issuer + "/oauth/device/code"
67-
68-
if not meta.introspection_endpoint:
69-
meta.introspection_endpoint = issuer + "/oauth/introspect"
86+
_validate_and_enrich(meta, self._issuer_url)
7087

7188
self._cached = meta
7289
self._fetched_at = time.time()
73-
return copy.deepcopy(meta)
90+
return _copy_metadata(meta)
7491

7592

7693
def _get_str(body: dict[str, object], key: str) -> str:

src/authgate/middleware/django.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
4040
self._mode = ValidationMode.TOKEN_INFO
4141
self._required_scopes: list[str] = []
4242

43-
def _ensure_configured(self, request: HttpRequest) -> None:
43+
def _ensure_configured(self) -> None:
4444
if self._client is not None:
4545
return
4646
from django.conf import settings
@@ -50,7 +50,7 @@ def _ensure_configured(self, request: HttpRequest) -> None:
5050
self._required_scopes = getattr(settings, "AUTHGATE_REQUIRED_SCOPES", [])
5151

5252
def __call__(self, request: HttpRequest) -> HttpResponse:
53-
self._ensure_configured(request)
53+
self._ensure_configured()
5454

5555
if self._client is None:
5656
return self.get_response(request)

src/authgate/oauth/_parsing.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Shared OAuth response parsing utilities."""
2+
3+
from __future__ import annotations
4+
5+
import httpx
6+
7+
from authgate.exceptions import OAuthError
8+
9+
10+
def _parse_error_response(resp: httpx.Response) -> OAuthError:
11+
"""Parse an OAuth error response body."""
12+
try:
13+
body = resp.json()
14+
if isinstance(body, dict) and body.get("error"):
15+
return OAuthError(
16+
code=body["error"],
17+
description=body.get("error_description", ""),
18+
status_code=resp.status_code,
19+
)
20+
except Exception:
21+
pass
22+
return OAuthError(
23+
code=resp.reason_phrase or "server_error",
24+
description=resp.text,
25+
status_code=resp.status_code,
26+
)

0 commit comments

Comments
 (0)