|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import base64 |
3 | 4 | import hashlib |
| 5 | +import hmac |
| 6 | +import json |
4 | 7 | import time |
| 8 | +from collections.abc import Mapping |
5 | 9 | from dataclasses import dataclass |
6 | 10 | from enum import Enum |
7 | 11 |
|
|
10 | 14 |
|
11 | 15 | class IdentityProviderType(str, Enum): |
12 | 16 | LOCAL = "local" |
| 17 | + LOCAL_IDP = "local_idp" |
13 | 18 | OIDC = "oidc" |
14 | 19 | ENTRA = "entra" |
15 | 20 | OKTA = "okta" |
@@ -39,6 +44,14 @@ class EntraBridgeConfig: |
39 | 44 | token_ttl_seconds: int = 300 |
40 | 45 |
|
41 | 46 |
|
| 47 | +@dataclass(frozen=True) |
| 48 | +class LocalIdPBridgeConfig: |
| 49 | + issuer: str = "http://localhost/predicate-local-idp" |
| 50 | + audience: str = "api://predicate-authority" |
| 51 | + signing_key: str = "predicate-local-idp-dev-key" |
| 52 | + token_ttl_seconds: int = 300 |
| 53 | + |
| 54 | + |
42 | 55 | class IdentityBridge: |
43 | 56 | """Local bridge implementation for development/local-only mode.""" |
44 | 57 |
|
@@ -120,3 +133,84 @@ def exchange_token( |
120 | 133 | token_type=result.token_type, |
121 | 134 | provider=IdentityProviderType.ENTRA, |
122 | 135 | ) |
| 136 | + |
| 137 | + |
| 138 | +class LocalIdPBridge: |
| 139 | + """Local IdP emulator for dev/offline/air-gapped workflows.""" |
| 140 | + |
| 141 | + def __init__(self, config: LocalIdPBridgeConfig) -> None: |
| 142 | + self._config = config |
| 143 | + |
| 144 | + def exchange_token( |
| 145 | + self, subject: PrincipalRef, state_evidence: StateEvidence |
| 146 | + ) -> TokenExchangeResult: |
| 147 | + expires_at = int(time.time()) + self._config.token_ttl_seconds |
| 148 | + token = self._mint_token( |
| 149 | + subject=subject, |
| 150 | + state_evidence=state_evidence, |
| 151 | + expires_at_epoch_s=expires_at, |
| 152 | + grant_kind="access", |
| 153 | + refresh_token=None, |
| 154 | + ) |
| 155 | + return TokenExchangeResult( |
| 156 | + access_token=token, |
| 157 | + expires_at_epoch_s=expires_at, |
| 158 | + provider=IdentityProviderType.LOCAL_IDP, |
| 159 | + ) |
| 160 | + |
| 161 | + def refresh_token( |
| 162 | + self, refresh_token: str, subject: PrincipalRef, state_evidence: StateEvidence |
| 163 | + ) -> TokenExchangeResult: |
| 164 | + expires_at = int(time.time()) + self._config.token_ttl_seconds |
| 165 | + token = self._mint_token( |
| 166 | + subject=subject, |
| 167 | + state_evidence=state_evidence, |
| 168 | + expires_at_epoch_s=expires_at, |
| 169 | + grant_kind="refresh_access", |
| 170 | + refresh_token=refresh_token, |
| 171 | + ) |
| 172 | + return TokenExchangeResult( |
| 173 | + access_token=token, |
| 174 | + expires_at_epoch_s=expires_at, |
| 175 | + provider=IdentityProviderType.LOCAL_IDP, |
| 176 | + ) |
| 177 | + |
| 178 | + def _mint_token( |
| 179 | + self, |
| 180 | + subject: PrincipalRef, |
| 181 | + state_evidence: StateEvidence, |
| 182 | + expires_at_epoch_s: int, |
| 183 | + grant_kind: str, |
| 184 | + refresh_token: str | None, |
| 185 | + ) -> str: |
| 186 | + header = {"alg": "HS256", "typ": "JWT", "kid": "predicate-local-idp-dev"} |
| 187 | + payload: dict[str, str | int | None] = { |
| 188 | + "iss": self._config.issuer, |
| 189 | + "aud": self._config.audience, |
| 190 | + "sub": subject.principal_id, |
| 191 | + "state_hash": state_evidence.state_hash, |
| 192 | + "state_source": state_evidence.source, |
| 193 | + "token_kind": grant_kind, |
| 194 | + "exp": expires_at_epoch_s, |
| 195 | + "iat": int(time.time()), |
| 196 | + "tenant_id": subject.tenant_id, |
| 197 | + "session_id": subject.session_id, |
| 198 | + "refresh_token_hash": ( |
| 199 | + hashlib.sha256(refresh_token.encode("utf-8")).hexdigest() |
| 200 | + if refresh_token is not None |
| 201 | + else None |
| 202 | + ), |
| 203 | + } |
| 204 | + header_b64 = _b64url_json(header) |
| 205 | + payload_b64 = _b64url_json(payload) |
| 206 | + signing_input = f"{header_b64}.{payload_b64}".encode() |
| 207 | + signature = hmac.new( |
| 208 | + self._config.signing_key.encode("utf-8"), signing_input, hashlib.sha256 |
| 209 | + ).digest() |
| 210 | + signature_b64 = base64.urlsafe_b64encode(signature).rstrip(b"=").decode("utf-8") |
| 211 | + return f"{header_b64}.{payload_b64}.{signature_b64}" |
| 212 | + |
| 213 | + |
| 214 | +def _b64url_json(value: Mapping[str, str | int | None]) -> str: |
| 215 | + encoded = json.dumps(value, separators=(",", ":")).encode("utf-8") |
| 216 | + return base64.urlsafe_b64encode(encoded).rstrip(b"=").decode("utf-8") |
0 commit comments