diff --git a/.github/actions/conformance/client.py b/.github/actions/conformance/client.py index 58f684f01..4c0f1d841 100644 --- a/.github/actions/conformance/client.py +++ b/.github/actions/conformance/client.py @@ -5,7 +5,7 @@ Contract: - MCP_CONFORMANCE_SCENARIO env var -> scenario name - - MCP_CONFORMANCE_CONTEXT env var -> optional JSON (for client-credentials scenarios) + - MCP_CONFORMANCE_CONTEXT env var -> optional JSON (for auth scenarios) - Server URL as last CLI argument (sys.argv[1]) - Must exit 0 within 30 seconds @@ -16,7 +16,16 @@ elicitation-sep1034-client-defaults - Elicitation with default accept callback auth/client-credentials-jwt - Client credentials with private_key_jwt auth/client-credentials-basic - Client credentials with client_secret_basic + auth/cross-app-access-complete-flow - Enterprise managed OAuth (SEP-990) - v0.1.14+ auth/* - Authorization code flow (default for auth scenarios) + +Enterprise Auth (SEP-990): + The conformance package v0.1.14+ (https://github.com/modelcontextprotocol/conformance/pull/110) + provides the scenario 'auth/cross-app-access-complete-flow' which tests the complete + enterprise managed OAuth flow: IDP ID token → ID-JAG → access token. + + The client receives test context (idp_id_token, idp_token_endpoint, etc.) via + MCP_CONFORMANCE_CONTEXT environment variable and performs the token exchange flows automatically. """ import asyncio @@ -314,9 +323,100 @@ async def run_auth_code_client(server_url: str) -> None: await _run_auth_session(server_url, oauth_auth) +@register("auth/cross-app-access-complete-flow") +async def run_cross_app_access_complete_flow(server_url: str) -> None: + """Enterprise managed auth: Complete SEP-990 flow (OIDC ID token → ID-JAG → access token). + + This scenario is provided by @modelcontextprotocol/conformance@0.1.14+ (PR #110). + It tests the complete enterprise managed OAuth flow using token exchange (RFC 8693) + and JWT bearer grant (RFC 7523). + """ + from mcp.client.auth.extensions.enterprise_managed_auth import ( + EnterpriseAuthOAuthClientProvider, + TokenExchangeParameters, + ) + + context = get_conformance_context() + # The conformance package provides these fields + idp_id_token = context.get("idp_id_token") + idp_token_endpoint = context.get("idp_token_endpoint") + idp_issuer = context.get("idp_issuer") + + # For cross-app access, we need to determine the MCP server's resource ID and auth issuer + # The conformance package sets up the auth server, and the MCP server URL is passed to us + + if not idp_id_token: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_id_token'") + if not idp_token_endpoint: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_token_endpoint'") + if not idp_issuer: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_issuer'") + + # Extract base URL by stripping trailing /mcp path (Python 3.9+). + # The conformance harness always serves MCP at /mcp, so stripping + # the suffix gives us the auth-server base URL for fallback defaults. + base_url = server_url.removesuffix("/mcp") + auth_issuer = context.get("auth_issuer", base_url) + resource_id = context.get("resource_id", server_url) + + logger.debug("Cross-app access flow:") + logger.debug(f" IDP Issuer: {idp_issuer}") + logger.debug(f" IDP Token Endpoint: {idp_token_endpoint}") + logger.debug(f" Auth Issuer: {auth_issuer}") + logger.debug(f" Resource ID: {resource_id}") + + # Create token exchange parameters from IDP ID token + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=idp_id_token, + mcp_server_auth_issuer=auth_issuer, + mcp_server_resource_id=resource_id, + scope=context.get("scope"), + ) + + # Get pre-configured client credentials from context (if provided) + client_id = context.get("client_id") + client_secret = context.get("client_secret") + + storage = InMemoryTokenStorage() + + # Create enterprise auth provider + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url=server_url, + client_metadata=OAuthClientMetadata( + client_name="conformance-cross-app-client", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + response_types=["token"], + ), + storage=storage, + idp_token_endpoint=idp_token_endpoint, + token_exchange_params=token_exchange_params, + ) + + # If client credentials are provided in context, use them instead of dynamic registration + if client_id and client_secret: + from mcp.shared.auth import OAuthClientInformationFull + + logger.debug(f"Using pre-configured client credentials: {client_id}") + client_info = OAuthClientInformationFull( + client_id=client_id, + client_secret=client_secret, + token_endpoint_auth_method="client_secret_basic", + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + response_types=["token"], + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + ) + enterprise_auth.context.client_info = client_info + await storage.set_client_info(client_info) + + await _run_auth_session(server_url, enterprise_auth) + + async def _run_auth_session(server_url: str, oauth_auth: OAuthClientProvider) -> None: """Common session logic for all OAuth flows.""" - client = httpx.AsyncClient(auth=oauth_auth, timeout=30.0) + # Allow timeout to be configured via environment variable for different test scenarios + timeout = float(os.environ.get("MCP_CONFORMANCE_TIMEOUT", "30.0")) + client = httpx.AsyncClient(auth=oauth_auth, timeout=timeout) async with streamable_http_client(url=server_url, http_client=client) as (read_stream, write_stream): async with ClientSession( read_stream, write_stream, elicitation_callback=default_elicitation_callback diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index d876da00b..7af927ddf 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -42,4 +42,4 @@ jobs: with: node-version: 24 - run: uv sync --frozen --all-extras --package mcp - - run: npx @modelcontextprotocol/conformance@0.1.13 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all + - run: npx @modelcontextprotocol/conformance@0.1.14 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all diff --git a/README.v2.md b/README.v2.md index 55d867586..b0b0b7286 100644 --- a/README.v2.md +++ b/README.v2.md @@ -70,6 +70,7 @@ - [Writing MCP Clients](#writing-mcp-clients) - [Client Display Utilities](#client-display-utilities) - [OAuth Authentication for Clients](#oauth-authentication-for-clients) + - [Enterprise Managed Authorization](#enterprise-managed-authorization) - [Parsing Tool Results](#parsing-tool-results) - [MCP Primitives](#mcp-primitives) - [Server Capabilities](#server-capabilities) @@ -2395,6 +2396,328 @@ _Full example: [examples/snippets/clients/oauth_client.py](https://github.com/mo For a complete working example, see [`examples/clients/simple-auth-client/`](examples/clients/simple-auth-client/). +#### Enterprise Managed Authorization + +The SDK includes support for Enterprise Managed Authorization (SEP-990), which enables MCP clients to connect to protected servers using enterprise Single Sign-On (SSO) systems. This implementation supports: + +- **RFC 8693**: OAuth 2.0 Token Exchange (ID Token -> ID-JAG) +- **RFC 7523**: JSON Web Token (JWT) Profile for OAuth 2.0 Authorization Grants (ID-JAG -> Access Token) +- Integration with enterprise identity providers (Okta, Azure AD, etc.) + +**Key Components:** + +The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provider to implement the enterprise authorization flow: + +**Token Exchange Flow:** + +1. **Obtain ID Token** from your enterprise IdP (e.g., Okta, Azure AD) +2. **Exchange ID Token for ID-JAG** using RFC 8693 Token Exchange +3. **Exchange ID-JAG for Access Token** using RFC 7523 JWT Bearer Grant +4. **Use Access Token** to call protected MCP server tools + +**Using the Access Token with MCP Server:** + +1. Once you have obtained the access token, you can use it to authenticate requests to the MCP server +2. The access token is automatically included in all subsequent requests to the MCP server, allowing you to access protected tools and resources based on your enterprise identity and permissions. + +**Handling Token Expiration and Refresh:** + +Access tokens have a limited lifetime and will expire. When tokens expire: + +- **Check Token Expiration**: Use the `expires_in` field to determine when the token expires +- **Refresh Flow**: When expired, repeat the token exchange flow with a fresh ID token from your IdP +- **Automatic Refresh**: Implement automatic token refresh before expiration (recommended for production) +- **Error Handling**: Catch authentication errors and retry with refreshed tokens + +**Important Notes:** + +- **ID Token Expiration**: If the ID token from your IdP expires, you must re-authenticate with the IdP to obtain a new ID token before performing token exchange +- **Token Storage**: Store tokens securely and implement the `TokenStorage` interface to persist tokens between application restarts +- **Scope Changes**: If you need different scopes, you must obtain a new ID token from the IdP with the required scopes +- **Security**: Never log or expose access tokens or ID tokens in production environments + +**Example Usage:** + + +```python +import asyncio + +import httpx +from pydantic import AnyUrl + +from mcp import ClientSession +from mcp.client.auth import TokenStorage +from mcp.client.auth.extensions import ( + EnterpriseAuthOAuthClientProvider, + TokenExchangeParameters, +) +from mcp.client.streamable_http import streamable_http_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken + + +# Placeholder function for IdP authentication +async def get_id_token_from_idp() -> str: + """Placeholder function to get ID token from your IdP. + + In production, implement actual IdP authentication flow. + """ + raise NotImplementedError("Implement your IdP authentication flow here") + + +# Define token storage implementation +class SimpleTokenStorage(TokenStorage): + def __init__(self) -> None: + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info + + +async def discover_mcp_server_metadata(server_url: str) -> tuple[str, str]: + """Discover MCP server's OAuth metadata and resource identifier. + + Returns: + Tuple of (auth_issuer, resource_id) + """ + from mcp.client.auth.utils import ( + build_oauth_authorization_server_metadata_discovery_urls, + build_protected_resource_metadata_discovery_urls, + handle_auth_metadata_response, + handle_protected_resource_response, + ) + + async with httpx.AsyncClient() as client: + # Step 1: Discover Protected Resource Metadata (PRM) + prm_urls = build_protected_resource_metadata_discovery_urls(None, server_url) + + prm = None + for url in prm_urls: + response = await client.get(url) + prm = await handle_protected_resource_response(response) + if prm: + break + + if not prm: + raise ValueError("Could not discover Protected Resource Metadata") + + # Extract resource identifier and authorization server URL + resource_id = str(prm.resource) + auth_server_url = str(prm.authorization_servers[0]) if prm.authorization_servers else None + + # Step 2: Discover OAuth Authorization Server Metadata + oauth_urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url) + + oauth_metadata = None + for url in oauth_urls: + response = await client.get(url) + ok, asm = await handle_auth_metadata_response(response) + if ok and asm: + oauth_metadata = asm + break + + if not oauth_metadata or not oauth_metadata.issuer: + raise ValueError("Could not discover OAuth metadata or issuer") + + auth_issuer = str(oauth_metadata.issuer) + + return auth_issuer, resource_id + + +async def main() -> None: + """Example demonstrating enterprise managed authorization with MCP.""" + server_url = "https://mcp-server.example.com" + + # Step 1: Get ID token from your IdP (e.g., Okta, Azure AD) + id_token = await get_id_token_from_idp() + + # Step 2: Discover MCP server's OAuth metadata and resource identifier + # This replaces hardcoding these values + mcp_server_auth_issuer, mcp_server_resource_id = await discover_mcp_server_metadata(server_url) + print(f"Discovered auth issuer: {mcp_server_auth_issuer}") + print(f"Discovered resource ID: {mcp_server_resource_id}") + + # Step 3: Configure token exchange parameters using discovered values + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=id_token, + mcp_server_auth_issuer=mcp_server_auth_issuer, + mcp_server_resource_id=mcp_server_resource_id, + scope="mcp:tools mcp:resources", # Optional scopes + ) + + # Step 4: Create enterprise auth provider + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url=server_url, + client_metadata=OAuthClientMetadata( + client_name="Enterprise MCP Client", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + response_types=["token"], + ), + storage=SimpleTokenStorage(), + idp_token_endpoint="https://your-idp.com/oauth2/v1/token", # Your IdP's token endpoint + token_exchange_params=token_exchange_params, + # Optional: IdP client credentials if your IdP requires client authentication for token exchange + # idp_client_id="your-idp-client-id", + # idp_client_secret="your-idp-client-secret", + ) + + # Step 5: Create authenticated HTTP client + # The auth provider automatically handles the two-step token exchange: + # 1. ID Token -> ID-JAG (via IDP) + # 2. ID-JAG -> Access Token (via MCP server) + client = httpx.AsyncClient(auth=enterprise_auth, timeout=30.0) + + # Step 6: Connect to MCP server with authenticated client + async with streamable_http_client(url=server_url, http_client=client) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # List available tools + tools_result = await session.list_tools() + print(f"Available tools: {[t.name for t in tools_result.tools]}") + + # Call a tool - auth tokens are automatically managed + if tools_result.tools: + tool_name = tools_result.tools[0].name + result = await session.call_tool(tool_name, {}) + print(f"Tool result: {result.content}") + + # List available resources + resources = await session.list_resources() + for resource in resources.resources: + print(f"Resource: {resource.uri}") + + +async def advanced_manual_flow() -> None: + """Advanced example showing manual token exchange. + + Use cases for manual token exchange: + - Testing and debugging: Inspect ID-JAG claims before exchanging for access token + - Token caching: Store and reuse ID-JAG across multiple MCP server connections + - Custom error handling: Implement specific retry logic for each token exchange step + - Monitoring: Log token exchange metrics and performance + - Token introspection: Validate ID-JAG structure before sending to MCP server + """ + id_token = await get_id_token_from_idp() + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example.com", + mcp_server_resource_id="https://mcp-server.example.com", + ) + + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example.com", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + ), + storage=SimpleTokenStorage(), + idp_token_endpoint="https://your-idp.com/oauth2/v1/token", + token_exchange_params=token_exchange_params, + ) + + # Manual token exchange (useful for debugging, caching, custom error handling, etc.) + async with httpx.AsyncClient() as client: + # Step 1: Exchange ID token for ID-JAG + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + # WARNING: Only log tokens in development/testing environments + # In production, NEVER log tokens or token fragments as they are sensitive credentials + print(f"Obtained ID-JAG: {id_jag[:50]}...") + + # Step 2: Build JWT bearer grant request + jwt_bearer_request = await enterprise_auth.exchange_id_jag_for_access_token(id_jag) + print(f"Built JWT bearer grant request to: {jwt_bearer_request.url}") + + # Step 3: Execute the request to get access token + response = await client.send(jwt_bearer_request) + response.raise_for_status() + token_data = response.json() + + access_token = OAuthToken( + access_token=token_data["access_token"], + token_type=token_data["token_type"], + expires_in=token_data.get("expires_in"), + ) + # WARNING: In production, do not log token expiry or any token information + print(f"Access token obtained, expires in: {access_token.expires_in}s") + + # Use the access token for API calls + _ = {"Authorization": f"Bearer {access_token.access_token}"} + # ... make authenticated requests with headers + + +async def token_refresh_example() -> None: + """Example showing how to refresh tokens when they expire. + + When your access token expires, you need to obtain a fresh ID token + from your enterprise IdP and use the refresh helper method. + """ + # Initial setup + id_token = await get_id_token_from_idp() + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example.com", + mcp_server_resource_id="https://mcp-server.example.com", + ) + + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example.com", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + ), + storage=SimpleTokenStorage(), + idp_token_endpoint="https://your-idp.com/oauth2/v1/token", + token_exchange_params=token_exchange_params, + ) + + _ = httpx.AsyncClient(auth=enterprise_auth, timeout=30.0) + + # Use the client for MCP operations... + # ... time passes and token expires ... + + # When token expires, get a fresh ID token from your IdP + new_id_token = await get_id_token_from_idp() + + # Refresh the authentication using the helper method + await enterprise_auth.refresh_with_new_id_token(new_id_token) + + # Next API call will automatically use the refreshed tokens + # No need to recreate the client or session + + +if __name__ == "__main__": + asyncio.run(main()) +``` + +_Full example: [examples/snippets/clients/enterprise_managed_auth_client.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/enterprise_managed_auth_client.py)_ + + +**Working with SAML Assertions:** + +If your enterprise uses SAML instead of OIDC, you can exchange SAML assertions: + +```python +token_exchange_params = TokenExchangeParameters.from_saml_assertion( + saml_assertion=saml_assertion_string, + mcp_server_auth_issuer="https://your-idp.com", + mcp_server_resource_id="https://mcp-server.example.com", + scope="mcp:tools", +) +``` + +For more details on the enterprise authorization flow, see the [MCP Enterprise Authorization specification](https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization). + ### Parsing Tool Results When calling tools through MCP, the `CallToolResult` object contains the tool's response in a structured format. Understanding how to parse this result is essential for properly handling tool outputs. diff --git a/examples/snippets/clients/enterprise_managed_auth_client.py b/examples/snippets/clients/enterprise_managed_auth_client.py new file mode 100644 index 000000000..3888fdc21 --- /dev/null +++ b/examples/snippets/clients/enterprise_managed_auth_client.py @@ -0,0 +1,258 @@ +import asyncio + +import httpx +from pydantic import AnyUrl + +from mcp import ClientSession +from mcp.client.auth import TokenStorage +from mcp.client.auth.extensions import ( + EnterpriseAuthOAuthClientProvider, + TokenExchangeParameters, +) +from mcp.client.streamable_http import streamable_http_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken + + +# Placeholder function for IdP authentication +async def get_id_token_from_idp() -> str: + """Placeholder function to get ID token from your IdP. + + In production, implement actual IdP authentication flow. + """ + raise NotImplementedError("Implement your IdP authentication flow here") + + +# Define token storage implementation +class SimpleTokenStorage(TokenStorage): + def __init__(self) -> None: + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info + + +async def discover_mcp_server_metadata(server_url: str) -> tuple[str, str]: + """Discover MCP server's OAuth metadata and resource identifier. + + Returns: + Tuple of (auth_issuer, resource_id) + """ + from mcp.client.auth.utils import ( + build_oauth_authorization_server_metadata_discovery_urls, + build_protected_resource_metadata_discovery_urls, + handle_auth_metadata_response, + handle_protected_resource_response, + ) + + async with httpx.AsyncClient() as client: + # Step 1: Discover Protected Resource Metadata (PRM) + prm_urls = build_protected_resource_metadata_discovery_urls(None, server_url) + + prm = None + for url in prm_urls: + response = await client.get(url) + prm = await handle_protected_resource_response(response) + if prm: + break + + if not prm: + raise ValueError("Could not discover Protected Resource Metadata") + + # Extract resource identifier and authorization server URL + resource_id = str(prm.resource) + auth_server_url = str(prm.authorization_servers[0]) if prm.authorization_servers else None + + # Step 2: Discover OAuth Authorization Server Metadata + oauth_urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url) + + oauth_metadata = None + for url in oauth_urls: + response = await client.get(url) + ok, asm = await handle_auth_metadata_response(response) + if ok and asm: + oauth_metadata = asm + break + + if not oauth_metadata or not oauth_metadata.issuer: + raise ValueError("Could not discover OAuth metadata or issuer") + + auth_issuer = str(oauth_metadata.issuer) + + return auth_issuer, resource_id + + +async def main() -> None: + """Example demonstrating enterprise managed authorization with MCP.""" + server_url = "https://mcp-server.example.com" + + # Step 1: Get ID token from your IdP (e.g., Okta, Azure AD) + id_token = await get_id_token_from_idp() + + # Step 2: Discover MCP server's OAuth metadata and resource identifier + # This replaces hardcoding these values + mcp_server_auth_issuer, mcp_server_resource_id = await discover_mcp_server_metadata(server_url) + print(f"Discovered auth issuer: {mcp_server_auth_issuer}") + print(f"Discovered resource ID: {mcp_server_resource_id}") + + # Step 3: Configure token exchange parameters using discovered values + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=id_token, + mcp_server_auth_issuer=mcp_server_auth_issuer, + mcp_server_resource_id=mcp_server_resource_id, + scope="mcp:tools mcp:resources", # Optional scopes + ) + + # Step 4: Create enterprise auth provider + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url=server_url, + client_metadata=OAuthClientMetadata( + client_name="Enterprise MCP Client", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + response_types=["token"], + ), + storage=SimpleTokenStorage(), + idp_token_endpoint="https://your-idp.com/oauth2/v1/token", # Your IdP's token endpoint + token_exchange_params=token_exchange_params, + # Optional: IdP client credentials if your IdP requires client authentication for token exchange + # idp_client_id="your-idp-client-id", + # idp_client_secret="your-idp-client-secret", + ) + + # Step 5: Create authenticated HTTP client + # The auth provider automatically handles the two-step token exchange: + # 1. ID Token -> ID-JAG (via IDP) + # 2. ID-JAG -> Access Token (via MCP server) + client = httpx.AsyncClient(auth=enterprise_auth, timeout=30.0) + + # Step 6: Connect to MCP server with authenticated client + async with streamable_http_client(url=server_url, http_client=client) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # List available tools + tools_result = await session.list_tools() + print(f"Available tools: {[t.name for t in tools_result.tools]}") + + # Call a tool - auth tokens are automatically managed + if tools_result.tools: + tool_name = tools_result.tools[0].name + result = await session.call_tool(tool_name, {}) + print(f"Tool result: {result.content}") + + # List available resources + resources = await session.list_resources() + for resource in resources.resources: + print(f"Resource: {resource.uri}") + + +async def advanced_manual_flow() -> None: + """Advanced example showing manual token exchange. + + Use cases for manual token exchange: + - Testing and debugging: Inspect ID-JAG claims before exchanging for access token + - Token caching: Store and reuse ID-JAG across multiple MCP server connections + - Custom error handling: Implement specific retry logic for each token exchange step + - Monitoring: Log token exchange metrics and performance + - Token introspection: Validate ID-JAG structure before sending to MCP server + """ + id_token = await get_id_token_from_idp() + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example.com", + mcp_server_resource_id="https://mcp-server.example.com", + ) + + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example.com", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + ), + storage=SimpleTokenStorage(), + idp_token_endpoint="https://your-idp.com/oauth2/v1/token", + token_exchange_params=token_exchange_params, + ) + + # Manual token exchange (useful for debugging, caching, custom error handling, etc.) + async with httpx.AsyncClient() as client: + # Step 1: Exchange ID token for ID-JAG + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + # WARNING: Only log tokens in development/testing environments + # In production, NEVER log tokens or token fragments as they are sensitive credentials + print(f"Obtained ID-JAG: {id_jag[:50]}...") + + # Step 2: Build JWT bearer grant request + jwt_bearer_request = await enterprise_auth.exchange_id_jag_for_access_token(id_jag) + print(f"Built JWT bearer grant request to: {jwt_bearer_request.url}") + + # Step 3: Execute the request to get access token + response = await client.send(jwt_bearer_request) + response.raise_for_status() + token_data = response.json() + + access_token = OAuthToken( + access_token=token_data["access_token"], + token_type=token_data["token_type"], + expires_in=token_data.get("expires_in"), + ) + # WARNING: In production, do not log token expiry or any token information + print(f"Access token obtained, expires in: {access_token.expires_in}s") + + # Use the access token for API calls + _ = {"Authorization": f"Bearer {access_token.access_token}"} + # ... make authenticated requests with headers + + +async def token_refresh_example() -> None: + """Example showing how to refresh tokens when they expire. + + When your access token expires, you need to obtain a fresh ID token + from your enterprise IdP and use the refresh helper method. + """ + # Initial setup + id_token = await get_id_token_from_idp() + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example.com", + mcp_server_resource_id="https://mcp-server.example.com", + ) + + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example.com", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + ), + storage=SimpleTokenStorage(), + idp_token_endpoint="https://your-idp.com/oauth2/v1/token", + token_exchange_params=token_exchange_params, + ) + + _ = httpx.AsyncClient(auth=enterprise_auth, timeout=30.0) + + # Use the client for MCP operations... + # ... time passes and token expires ... + + # When token expires, get a fresh ID token from your IdP + new_id_token = await get_id_token_from_idp() + + # Refresh the authentication using the helper method + await enterprise_auth.refresh_with_new_id_token(new_id_token) + + # Next API call will automatically use the refreshed tokens + # No need to recreate the client or session + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/mcp/client/auth/extensions/__init__.py b/src/mcp/client/auth/extensions/__init__.py index e69de29bb..f9594864f 100644 --- a/src/mcp/client/auth/extensions/__init__.py +++ b/src/mcp/client/auth/extensions/__init__.py @@ -0,0 +1,33 @@ +"""MCP Client Auth Extensions.""" + +from mcp.client.auth.extensions.client_credentials import ( + ClientCredentialsOAuthProvider, + JWTParameters, + PrivateKeyJWTOAuthProvider, + RFC7523OAuthClientProvider, + SignedJWTParameters, + static_assertion_provider, +) +from mcp.client.auth.extensions.enterprise_managed_auth import ( + EnterpriseAuthOAuthClientProvider, + IDJAGClaims, + IDJAGTokenExchangeResponse, + TokenExchangeParameters, + decode_id_jag, + validate_token_exchange_params, +) + +__all__ = [ + "ClientCredentialsOAuthProvider", + "static_assertion_provider", + "SignedJWTParameters", + "PrivateKeyJWTOAuthProvider", + "JWTParameters", + "RFC7523OAuthClientProvider", + "EnterpriseAuthOAuthClientProvider", + "IDJAGClaims", + "IDJAGTokenExchangeResponse", + "TokenExchangeParameters", + "decode_id_jag", + "validate_token_exchange_params", +] diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py new file mode 100644 index 000000000..947f13a9e --- /dev/null +++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py @@ -0,0 +1,479 @@ +"""Enterprise Managed Authorization extension for MCP (SEP-990). + +Implements RFC 8693 Token Exchange and RFC 7523 JWT Bearer Grant for +enterprise SSO integration. +""" + +import logging +from json import JSONDecodeError + +import httpx +import jwt +from pydantic import BaseModel, Field, ValidationError +from typing_extensions import NotRequired, Required, TypedDict + +from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage +from mcp.shared.auth import OAuthClientMetadata + +logger = logging.getLogger(__name__) + + +class TokenExchangeRequestData(TypedDict): + """Type definition for RFC 8693 Token Exchange request data. + + Required fields are those mandated by RFC 8693. + Optional fields (NotRequired) may be included based on IdP requirements. + """ + + grant_type: Required[str] + requested_token_type: Required[str] + audience: Required[str] + resource: Required[str] + subject_token: Required[str] + subject_token_type: Required[str] + scope: NotRequired[str] + client_id: NotRequired[str] + client_secret: NotRequired[str] + + +class TokenExchangeParameters(BaseModel): + """Parameters for RFC 8693 Token Exchange request.""" + + requested_token_type: str = Field( + default="urn:ietf:params:oauth:token-type:id-jag", + description="Type of token being requested (ID-JAG)", + ) + + audience: str = Field( + ..., + description="Issuer URL of the MCP Server's authorization server", + ) + + resource: str = Field( + ..., + description="RFC 9728 Resource Identifier of the MCP Server", + ) + + scope: str | None = Field( + default=None, + description="Space-separated list of scopes being requested", + ) + + subject_token: str = Field( + ..., + description="ID Token or SAML assertion for the end user", + ) + + subject_token_type: str = Field( + ..., + description="Type of subject token (id_token or saml2)", + ) + + @classmethod + def from_id_token( + cls, + id_token: str, + mcp_server_auth_issuer: str, + mcp_server_resource_id: str, + scope: str | None = None, + ) -> "TokenExchangeParameters": + """Create parameters for OIDC ID Token exchange.""" + return cls( + subject_token=id_token, + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience=mcp_server_auth_issuer, + resource=mcp_server_resource_id, + scope=scope, + ) + + @classmethod + def from_saml_assertion( + cls, + saml_assertion: str, + mcp_server_auth_issuer: str, + mcp_server_resource_id: str, + scope: str | None = None, + ) -> "TokenExchangeParameters": + """Create parameters for SAML assertion exchange.""" + return cls( + subject_token=saml_assertion, + subject_token_type="urn:ietf:params:oauth:token-type:saml2", + audience=mcp_server_auth_issuer, + resource=mcp_server_resource_id, + scope=scope, + ) + + +class IDJAGTokenExchangeResponse(BaseModel): + """Response from RFC 8693 Token Exchange for ID-JAG.""" + + issued_token_type: str = Field( + ..., + description="Type of token issued (should be id-jag)", + ) + + access_token: str = Field( + ..., + description="The ID-JAG token (named access_token per RFC 8693)", + ) + + token_type: str = Field( + ..., + description="Token type (should be N_A for ID-JAG)", + ) + + scope: str | None = Field( + default=None, + description="Granted scopes", + ) + + expires_in: int | None = Field( + default=None, + description="Lifetime in seconds", + ) + + @property + def id_jag(self) -> str: + """Get the ID-JAG token.""" + return self.access_token + + +class IDJAGClaims(BaseModel): + """Claims structure for Identity Assertion JWT Authorization Grant. + + Note: ``typ`` is sourced from the JWT *header* (not the payload) by + ``decode_id_jag``. It is included here for convenience so callers + can inspect the full ID-JAG structure from a single object. + """ + + model_config = {"extra": "allow"} + + # JWT header + typ: str = Field( + ..., + description="JWT type - must be 'oauth-id-jag+jwt'", + ) + + # Required claims + jti: str = Field(..., description="Unique JWT ID") + iss: str = Field(..., description="IdP issuer URL") + sub: str = Field(..., description="Subject (user) identifier") + aud: str = Field(..., description="MCP Server's auth server issuer") + resource: str = Field(..., description="MCP Server resource identifier") + client_id: str = Field(..., description="MCP Client identifier") + exp: int = Field(..., description="Expiration timestamp") + iat: int = Field(..., description="Issued-at timestamp") + + # Optional claims + scope: str | None = Field(None, description="Space-separated scopes") + email: str | None = Field(None, description="User email") + + +class EnterpriseAuthOAuthClientProvider(OAuthClientProvider): + """OAuth client provider for Enterprise Managed Authorization (SEP-990). + + Implements: + - RFC 8693: Token Exchange (ID Token → ID-JAG) + - RFC 7523: JWT Bearer Grant (ID-JAG → Access Token) + + Concurrency & Thread Safety: + - SAFE: Concurrent requests within a single asyncio event loop. Token + operations are protected by the parent class's ``OAuthContext.lock`` + via ``async_auth_flow``. + - UNSAFE: Sharing a provider instance across multiple OS threads. Each + thread must instantiate its own provider and event loop. + - Note: Ensure any shared ``TokenStorage`` implementation is async-safe. + """ + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + idp_token_endpoint: str, + token_exchange_params: TokenExchangeParameters, + timeout: float = 300.0, + idp_client_id: str | None = None, + idp_client_secret: str | None = None, + override_audience_with_issuer: bool = True, + ) -> None: + """Initialize Enterprise Auth OAuth Client. + + Args: + server_url: MCP server URL + client_metadata: OAuth client metadata + storage: Token storage implementation + idp_token_endpoint: Enterprise IdP token endpoint URL + token_exchange_params: Token exchange parameters (not mutated) + timeout: Request timeout in seconds + idp_client_id: Optional client ID registered with the IdP for token exchange + idp_client_secret: Optional client secret registered with the IdP. + Must be accompanied by ``idp_client_id``; providing a secret + without an ID raises ``ValueError``. + override_audience_with_issuer: If True (default), replaces the IdP + audience with the discovered OAuth issuer URL. Set to False for + federated identity setups where the audience must differ. + + Raises: + ValueError: If ``idp_client_secret`` is provided without ``idp_client_id``. + OAuthFlowError: If ``token_exchange_params`` fail validation. + """ + # Validate pure parameters before creating any state (fail-fast) + if idp_client_secret is not None and idp_client_id is None: + raise ValueError( + "idp_client_secret was provided without idp_client_id. Provide both together, or omit the secret." + ) + validate_token_exchange_params(token_exchange_params) + + super().__init__( + server_url=server_url, + client_metadata=client_metadata, + storage=storage, + timeout=timeout, + ) + self.idp_token_endpoint = idp_token_endpoint + # Keep original params immutable; track mutable subject_token separately + self.token_exchange_params = token_exchange_params + self._subject_token = token_exchange_params.subject_token + self.idp_client_id = idp_client_id + self.idp_client_secret = idp_client_secret + self.override_audience_with_issuer = override_audience_with_issuer + + async def exchange_token_for_id_jag( + self, + client: httpx.AsyncClient, + ) -> str: + """Exchange ID Token for ID-JAG using RFC 8693 Token Exchange. + + Args: + client: HTTP client for making requests + + Returns: + The ID-JAG token string + + Raises: + OAuthTokenError: If token exchange fails + """ + logger.debug("Starting token exchange for ID-JAG") + + audience = self.token_exchange_params.audience + if self.override_audience_with_issuer: + # OAuthMetadata.issuer is a required AnyHttpUrl field (RFC 8414), + # so it is always non-None when oauth_metadata is present. + if self.context.oauth_metadata: + discovered_issuer = str(self.context.oauth_metadata.issuer) + if audience != discovered_issuer: + logger.warning( + f"Overriding audience '{audience}' with discovered issuer " + f"'{discovered_issuer}'. To prevent this, pass " + f"override_audience_with_issuer=False." + ) + audience = discovered_issuer + + token_data: TokenExchangeRequestData = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "requested_token_type": self.token_exchange_params.requested_token_type, + "audience": audience, + "resource": self.token_exchange_params.resource, + "subject_token": self._subject_token, + "subject_token_type": self.token_exchange_params.subject_token_type, + } + + if self.token_exchange_params.scope and self.token_exchange_params.scope.strip(): + token_data["scope"] = self.token_exchange_params.scope + + # Add IdP client authentication if provided. + # Sent as POST body parameters (not HTTP Basic) because this is the + # IdP's token-exchange endpoint — most enterprise IdPs (Okta, Azure AD, + # Ping) accept body credentials for token exchange. HTTP Basic is + # allowed by RFC 6749 §2.3.1 but not universally required here. + if self.idp_client_id is not None: + token_data["client_id"] = self.idp_client_id + if self.idp_client_secret is not None: + token_data["client_secret"] = self.idp_client_secret + + try: + response = await client.post( + self.idp_token_endpoint, + data=token_data, + timeout=self.context.timeout, + ) + + if response.status_code != 200: + error_data: dict[str, str] = {} + try: + if response.headers.get("content-type", "").startswith("application/json"): + error_data = response.json() + except JSONDecodeError: + pass + + error: str = error_data.get("error", "unknown_error") + error_description: str = error_data.get( + "error_description", f"Token exchange failed (HTTP {response.status_code})" + ) + raise OAuthTokenError(f"Token exchange failed: {error} - {error_description}") + + token_response = IDJAGTokenExchangeResponse.model_validate_json(response.content) + + if token_response.issued_token_type != "urn:ietf:params:oauth:token-type:id-jag": + raise OAuthTokenError(f"Unexpected token type: {token_response.issued_token_type}") + + if token_response.token_type != "N_A": + logger.warning(f"Expected token_type 'N_A', got '{token_response.token_type}'") + + logger.debug("Successfully obtained ID-JAG") + + return token_response.id_jag + + except httpx.HTTPError as e: + raise OAuthTokenError(f"HTTP error during token exchange: {e}") from e + except ValidationError as e: + raise OAuthTokenError("Invalid token exchange response from IdP") from e + + async def exchange_id_jag_for_access_token( + self, + id_jag: str, + ) -> httpx.Request: + """Build a JWT bearer grant request to exchange an ID-JAG for an access token (RFC 7523). + + This method only *builds* the ``httpx.Request``; it does not execute + it. HTTP execution and error parsing are deferred to the parent + class's ``async_auth_flow`` via ``_handle_token_response``. + + Follows the same pattern as ``ClientCredentialsOAuthProvider._exchange_token_client_credentials`` + and ``RFC7523OAuthClientProvider._exchange_token_jwt_bearer``: + use ``_get_token_endpoint()`` for the URL and ``prepare_token_auth()`` + for client authentication — no manual ``client_id`` injection or + context swapping needed. + + Args: + id_jag: The ID-JAG token obtained from ``exchange_token_for_id_jag`` + + Returns: + An ``httpx.Request`` for the JWT bearer grant + """ + logger.debug("Building JWT bearer grant request for ID-JAG") + + token_data: dict[str, str] = { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": id_jag, + } + + headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + + # Delegate client authentication (client_secret_basic, client_secret_post, + # or none) to the parent's context helper — same as every other grant type. + token_data, headers = self.context.prepare_token_auth(token_data, headers) + + # Include resource parameter per RFC 8707 — same guard as every sibling provider + if self.context.should_include_resource_param(self.context.protocol_version): + token_data["resource"] = self.context.get_resource_url() + + # Include scope if configured (may have been updated by parent's async_auth_flow + # from the server's WWW-Authenticate header before _perform_authorization is called) + if self.context.client_metadata.scope: + token_data["scope"] = self.context.client_metadata.scope + + token_url = self._get_token_endpoint() + return httpx.Request("POST", token_url, data=token_data, headers=headers) + + async def _perform_authorization(self) -> httpx.Request: + """Perform enterprise authorization flow. + + Called by the parent's ``async_auth_flow`` when a new access token is needed. + Unconditionally performs full token exchange as the parent already handles + token validity checks. + + Flow: + 1. Exchange IdP subject token for ID-JAG (RFC 8693, direct HTTP call) + 2. Return an ``httpx.Request`` for the JWT bearer grant (RFC 7523) + that the parent will execute and pass to ``_handle_token_response`` + + Returns: + httpx.Request for the JWT bearer grant to the MCP authorization server + """ + # Step 1: Exchange IDP subject token for ID-JAG (RFC 8693) + async with httpx.AsyncClient(timeout=self.context.timeout) as client: + id_jag = await self.exchange_token_for_id_jag(client) + + # Step 2: Build JWT bearer grant request (RFC 7523) + jwt_bearer_request = await self.exchange_id_jag_for_access_token(id_jag) + + logger.debug("Returning JWT bearer grant request to async_auth_flow") + return jwt_bearer_request + + async def refresh_with_new_id_token(self, new_id_token: str) -> None: + """Refresh MCP server access tokens using a fresh ID token from the IdP. + + Updates the subject token and clears cached state so that the next API + request triggers a full re-authentication. Acquires the context lock + to prevent racing with an in-progress ``async_auth_flow``. + + Note: OAuth metadata is not re-discovered. If the MCP server's OAuth + configuration has changed, create a new provider instance instead. + + Warning: This method is NOT safe to call from a different OS thread. + Call it only from the same thread and event loop that owns this + provider instance. + + Args: + new_id_token: Fresh ID token obtained from your enterprise IdP. + """ + async with self.context.lock: + logger.info("Refreshing tokens with new ID token from IdP") + # Update the mutable subject token (does NOT mutate the original params object) + self._subject_token = new_id_token + + # Clear tokens to force full re-exchange on next request + self.context.clear_tokens() + logger.debug("Token refresh prepared — will re-authenticate on next request") + + +def decode_id_jag(id_jag: str) -> IDJAGClaims: + """Decode an ID-JAG token without verification. + + Relies on the receiving server to validate the JWT signature. + + Args: + id_jag: The ID-JAG token string + + Returns: + Decoded ID-JAG claims + """ + claims = jwt.decode(id_jag, options={"verify_signature": False}) + header = jwt.get_unverified_header(id_jag) + claims["typ"] = header.get("typ", "") + + return IDJAGClaims.model_validate(claims) + + +def validate_token_exchange_params( + params: TokenExchangeParameters, +) -> None: + """Validate token exchange parameters beyond Pydantic field constraints. + + Pydantic ``Field(...)`` rejects *missing* values but permits empty strings. + This function adds: + - Empty-string checks for ``subject_token``, ``audience``, ``resource`` + - Allow-list check for ``subject_token_type`` (id_token or saml2) + + Args: + params: Token exchange parameters to validate + + Raises: + OAuthFlowError: If parameters are invalid + """ + if not params.subject_token: + raise OAuthFlowError("subject_token is required") + + if not params.audience: + raise OAuthFlowError("audience is required") + + if not params.resource: + raise OAuthFlowError("resource is required") + + if params.subject_token_type not in { + "urn:ietf:params:oauth:token-type:id_token", + "urn:ietf:params:oauth:token-type:saml2", + }: + raise OAuthFlowError(f"Invalid subject_token_type: {params.subject_token_type}") diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py new file mode 100644 index 000000000..e8b8e4b15 --- /dev/null +++ b/tests/client/auth/test_enterprise_managed_auth_client.py @@ -0,0 +1,1792 @@ +"""Tests for Enterprise Managed Authorization client-side implementation.""" + +import logging +import time +import urllib.parse +from typing import Any +from unittest.mock import AsyncMock, Mock, patch + +import httpx +import jwt +import pytest +from pydantic import AnyHttpUrl, AnyUrl + +from mcp.client.auth import OAuthFlowError, OAuthTokenError +from mcp.client.auth.extensions.enterprise_managed_auth import ( + EnterpriseAuthOAuthClientProvider, + IDJAGClaims, + IDJAGTokenExchangeResponse, + TokenExchangeParameters, + decode_id_jag, + validate_token_exchange_params, +) +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthToken, + ProtectedResourceMetadata, +) + + +@pytest.fixture +def sample_id_token() -> str: + """Generate a sample ID token for testing.""" + payload = { + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "mcp-client-app", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + "email": "user@example.com", + } + return jwt.encode(payload, "secret", algorithm="HS256") + + +@pytest.fixture +def sample_id_jag() -> str: + """Generate a sample ID-JAG token for testing.""" + # Create typed claims using IDJAGClaims model + claims = IDJAGClaims( + typ="oauth-id-jag+jwt", + jti="unique-jwt-id-12345", + iss="https://idp.example.com", + sub="user123", + aud="https://auth.mcp-server.example/", + resource="https://mcp-server.example/", + client_id="mcp-client-app", + exp=int(time.time()) + 300, + iat=int(time.time()), + scope="read write", + email=None, # Optional field + ) + + # Dump to dict for JWT encoding (exclude typ as it goes in header) + payload = claims.model_dump(exclude={"typ"}, exclude_none=True) + + return jwt.encode(payload, "secret", algorithm="HS256", headers={"typ": "oauth-id-jag+jwt"}) + + +@pytest.fixture +def mock_token_storage() -> Any: + """Create a mock token storage.""" + storage = Mock() + storage.get_tokens = AsyncMock(return_value=None) + storage.set_tokens = AsyncMock() + storage.get_client_info = AsyncMock(return_value=None) + storage.set_client_info = AsyncMock() + return storage + + +def test_token_exchange_params_from_id_token(): + """Test creating TokenExchangeParameters from ID token.""" + params = TokenExchangeParameters.from_id_token( + id_token="eyJhbGc...", + mcp_server_auth_issuer="https://auth.server.example/", + mcp_server_resource_id="https://server.example/", + scope="read write", + ) + + assert params.subject_token == "eyJhbGc..." + assert params.subject_token_type == "urn:ietf:params:oauth:token-type:id_token" + assert params.audience == "https://auth.server.example/" + assert params.resource == "https://server.example/" + assert params.scope == "read write" + assert params.requested_token_type == "urn:ietf:params:oauth:token-type:id-jag" + + +def test_token_exchange_params_from_saml_assertion(): + """Test creating TokenExchangeParameters from SAML assertion.""" + params = TokenExchangeParameters.from_saml_assertion( + saml_assertion="...", + mcp_server_auth_issuer="https://auth.server.example/", + mcp_server_resource_id="https://server.example/", + scope="read", + ) + + assert params.subject_token == "..." + assert params.subject_token_type == "urn:ietf:params:oauth:token-type:saml2" + assert params.audience == "https://auth.server.example/" + assert params.resource == "https://server.example/" + assert params.scope == "read" + + +def test_validate_token_exchange_params_valid(): + """Test validating valid token exchange parameters.""" + params = TokenExchangeParameters.from_id_token( + id_token="token", + mcp_server_auth_issuer="https://auth.example/", + mcp_server_resource_id="https://server.example/", + ) + + # Should not raise + validate_token_exchange_params(params) + + +def test_validate_token_exchange_params_invalid_token_type(): + """Test validation fails for invalid subject token type.""" + params = TokenExchangeParameters( + subject_token="token", + subject_token_type="invalid:type", + audience="https://auth.example/", + resource="https://server.example/", + ) + + with pytest.raises(OAuthFlowError, match="Invalid subject_token_type"): + validate_token_exchange_params(params) + + +def test_validate_token_exchange_params_missing_subject_token(): + """Test validation fails for missing subject token.""" + params = TokenExchangeParameters( + subject_token="", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience="https://auth.example/", + resource="https://server.example/", + ) + + with pytest.raises(OAuthFlowError, match="subject_token is required"): + validate_token_exchange_params(params) + + +def test_token_exchange_response_parsing(): + """Test parsing token exchange response.""" + response_json = """{ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": "eyJhbGc...", + "token_type": "N_A", + "scope": "read write", + "expires_in": 300 + }""" + + response = IDJAGTokenExchangeResponse.model_validate_json(response_json) + + assert response.issued_token_type == "urn:ietf:params:oauth:token-type:id-jag" + assert response.id_jag == "eyJhbGc..." + assert response.access_token == "eyJhbGc..." + assert response.token_type == "N_A" + assert response.scope == "read write" + assert response.expires_in == 300 + + +def test_token_exchange_response_id_jag_property(): + """Test id_jag property returns access_token.""" + response = IDJAGTokenExchangeResponse( + issued_token_type="urn:ietf:params:oauth:token-type:id-jag", + access_token="the-id-jag-token", + token_type="N_A", + ) + + assert response.id_jag == "the-id-jag-token" + + +def test_decode_id_jag(sample_id_jag: str): + """Test decoding ID-JAG token.""" + claims = decode_id_jag(sample_id_jag) + + assert claims.iss == "https://idp.example.com" + assert claims.sub == "user123" + assert claims.aud == "https://auth.mcp-server.example/" + assert claims.resource == "https://mcp-server.example/" + assert claims.client_id == "mcp-client-app" + assert claims.scope == "read write" + + +def test_decode_id_jag_invalid_jwt(): + """Test decoding malformed ID-JAG raises appropriate error.""" + with pytest.raises(jwt.DecodeError): + decode_id_jag("not.a.valid.jwt") + + +def test_decode_id_jag_incomplete_jwt(): + """Test decoding incomplete JWT raises error.""" + with pytest.raises(jwt.DecodeError): + decode_id_jag("only.two.parts") + + +def test_id_jag_claims_with_extra_fields(): + """Test IDJAGClaims allows extra fields.""" + claims_data = { + "typ": "oauth-id-jag+jwt", + "jti": "jti123", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.server.example/", + "resource": "https://server.example/", + "client_id": "client123", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + "scope": "read", + "email": "user@example.com", + "custom_claim": "custom_value", # Extra field + } + + claims = IDJAGClaims.model_validate(claims_data) + assert claims.email == "user@example.com" + # Extra field should be preserved + assert claims.model_extra is not None and claims.model_extra.get("custom_claim") == "custom_value" + + +# ============================================================================ +# Tests for EnterpriseAuthOAuthClientProvider +# ============================================================================ + + +@pytest.mark.anyio +async def test_exchange_token_for_id_jag_success(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): + """Test successful token exchange for ID-JAG.""" + # Create provider + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify + assert id_jag == sample_id_jag + + # Verify request was made correctly + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + assert call_args[0][0] == "https://idp.example.com/oauth2/token" + assert call_args[1]["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:token-exchange" + assert call_args[1]["data"]["requested_token_type"] == "urn:ietf:params:oauth:token-type:id-jag" + assert call_args[1]["data"]["audience"] == "https://auth.mcp-server.example/" + assert call_args[1]["data"]["resource"] == "https://mcp-server.example/" + + +@pytest.mark.anyio +async def test_exchange_token_for_id_jag_error(sample_id_token: str, mock_token_storage: Any): + """Test token exchange failure handling.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock error response + mock_response = httpx.Response( + status_code=400, + json={ + "error": "invalid_request", + "error_description": "Invalid subject token", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="Token exchange failed"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_for_id_jag_unexpected_token_type(sample_id_token: str, mock_token_storage: Any): + """Test token exchange with unexpected token type.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock response with wrong token type + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "access_token": "some-token", + "token_type": "Bearer", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="Unexpected token type"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_id_jag_for_access_token_success(sample_id_jag: str, mock_token_storage: Any): + """Test building JWT bearer grant request.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Build JWT bearer grant request + request = await provider.exchange_id_jag_for_access_token(sample_id_jag) + + # Verify the request was built correctly + assert isinstance(request, httpx.Request) + assert request.method == "POST" + assert str(request.url) == "https://auth.mcp-server.example/oauth2/token" + + # Parse the request body + body_params = urllib.parse.parse_qs(request.content.decode()) + assert body_params["grant_type"][0] == "urn:ietf:params:oauth:grant-type:jwt-bearer" + assert body_params["assertion"][0] == sample_id_jag + + +@pytest.mark.anyio +async def test_exchange_id_jag_for_access_token_no_metadata(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant falls back to constructed token URL without OAuth metadata.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # No OAuth metadata set — _get_token_endpoint() falls back to server URL base + + request = await provider.exchange_id_jag_for_access_token(sample_id_jag) + + # Should fall back to constructing token URL from server URL (same as parent class) + assert isinstance(request, httpx.Request) + assert request.method == "POST" + assert str(request.url) == "https://mcp-server.example/token" + + body_params = urllib.parse.parse_qs(request.content.decode()) + assert body_params["grant_type"][0] == "urn:ietf:params:oauth:grant-type:jwt-bearer" + assert body_params["assertion"][0] == sample_id_jag + + +@pytest.mark.anyio +async def test_perform_authorization_full_flow(mock_token_storage: Any, sample_id_jag: str): + """Test that _perform_authorization performs token exchange and builds JWT bearer request.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock the IDP token exchange response + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + }, + ) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform authorization + request = await provider._perform_authorization() + + # Verify it returns an httpx.Request for JWT bearer grant + assert isinstance(request, httpx.Request) + assert request.method == "POST" + assert str(request.url) == "https://auth.mcp-server.example/oauth2/token" + + # Verify the request contains JWT bearer grant + body_params = urllib.parse.parse_qs(request.content.decode()) + assert body_params["grant_type"][0] == "urn:ietf:params:oauth:grant-type:jwt-bearer" + assert body_params["assertion"][0] == sample_id_jag + + +@pytest.mark.anyio +async def test_perform_authorization_with_valid_tokens(mock_token_storage: Any, sample_id_jag: str): + """Test that _perform_authorization always performs full exchange. + + _perform_authorization is only called by the parent when tokens need to be + obtained or refreshed, so it unconditionally performs the full exchange flow. + """ + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock the IDP token exchange response + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + }, + ) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should always perform full flow (no is_token_valid shortcircuit) + request = await provider._perform_authorization() + + assert isinstance(request, httpx.Request) + assert request.method == "POST" + assert str(request.url) == "https://auth.mcp-server.example/oauth2/token" + + # Verify it uses the exchanged ID-JAG + body_params = urllib.parse.parse_qs(request.content.decode()) + assert body_params["assertion"][0] == sample_id_jag + + +@pytest.mark.anyio +async def test_exchange_token_with_client_authentication( + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any +): + """Test token exchange with client authentication.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + idp_client_id="test-idp-client-id", # IdP client ID, not MCP client ID + idp_client_secret="test-idp-client-secret", # IdP client secret + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify the ID-JAG was returned + assert id_jag == sample_id_jag + + # Verify client credentials were included + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-idp-client-id" + assert call_args[1]["data"]["client_secret"] == "test-idp-client-secret" + + +@pytest.mark.anyio +async def test_exchange_token_with_client_id_only(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): + """Test token exchange with client_id but no client_secret (covers branch 232->235).""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + idp_client_id="test-idp-client-id", # IdP client ID, not MCP client ID + idp_client_secret=None, # No secret + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify the ID-JAG was returned + assert id_jag == sample_id_jag + + # Verify client_id was included but NOT client_secret + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-idp-client-id" + assert "client_secret" not in call_args[1]["data"] + + +@pytest.mark.anyio +async def test_exchange_token_http_error(sample_id_token: str, mock_token_storage: Any): + """Test token exchange with HTTP error.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection failed")) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="HTTP error during token exchange"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_malformed_json_error_response(sample_id_token: str, mock_token_storage: Any): + """Test token exchange with malformed JSON error response that raises JSONDecodeError.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock error response with malformed JSON (will raise JSONDecodeError when parsing) + mock_response = httpx.Response( + status_code=400, + content=b'{"error": "invalid_request", "invalid json structure', # Malformed JSON + headers={"content-type": "application/json"}, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError with default error message including status code + with pytest.raises(OAuthTokenError, match=r"Token exchange failed.*HTTP 400"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_non_json_error_response(sample_id_token: str, mock_token_storage: Any): + """Test token exchange with non-JSON error response.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock error response with non-JSON content + mock_response = httpx.Response( + status_code=500, + content=b"Internal Server Error", + headers={"content-type": "text/plain"}, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError with default error + with pytest.raises(OAuthTokenError, match="Token exchange failed: unknown_error"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_invalid_response_schema(sample_id_token: str, mock_token_storage: Any): + """Test token exchange with HTTP 200 but response missing required fields. + + Regression test: ValidationError from model_validate_json must be caught + and wrapped in OAuthTokenError, not propagated raw. + """ + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # HTTP 200 but missing required fields (e.g., access_token) + mock_response = httpx.Response( + status_code=200, + json={"unexpected_field": "value"}, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + with pytest.raises(OAuthTokenError, match="Invalid token exchange response from IdP"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_warning_for_non_na_token_type( + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any +): + """Test token exchange logs warning for non-N_A token type.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock response with different token_type + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "Bearer", # Not N_A + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should succeed but log warning + with patch.object( + logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning" + ) as mock_warning: + id_jag = await provider.exchange_token_for_id_jag(mock_client) + assert id_jag == sample_id_jag + mock_warning.assert_called_once() + + +@pytest.mark.anyio +async def test_exchange_id_jag_with_client_authentication(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant request building with client authentication.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info with secret (auth method set by registration, as in real flow) + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret="test-client-secret", + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + token_endpoint_auth_method="client_secret_basic", + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Build JWT bearer grant request + request = await provider.exchange_id_jag_for_access_token(sample_id_jag) + + # Verify request was built correctly + assert isinstance(request, httpx.Request) + assert request.method == "POST" + + # Verify client credentials were handled via client_secret_basic + body_params = urllib.parse.parse_qs(request.content.decode()) + # With client_secret_basic, credentials are in Authorization header, not body + assert "client_secret" not in body_params + assert "Authorization" in request.headers + assert request.headers["Authorization"].startswith("Basic ") + + +@pytest.mark.anyio +async def test_exchange_id_jag_with_client_id_only(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant request building with client_id but no client_secret.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info WITHOUT secret (client_secret=None), auth method "none" as set by registration + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret=None, # No secret + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + token_endpoint_auth_method="none", + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Build JWT bearer grant request + request = await provider.exchange_id_jag_for_access_token(sample_id_jag) + + # Verify request was built correctly + assert isinstance(request, httpx.Request) + + # With auth method "none", no credentials are added + body_params = urllib.parse.parse_qs(request.content.decode()) + assert "client_secret" not in body_params + # With auth_method == "none", no Basic Authorization header is sent. + assert "Authorization" not in request.headers + + +@pytest.mark.anyio +async def test_exchange_token_with_client_info_but_no_client_id( + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any +): + """Test that providing idp_client_secret without idp_client_id raises ValueError.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + with pytest.raises(ValueError, match="idp_client_secret was provided without idp_client_id"): + EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + idp_client_id=None, # No client ID + idp_client_secret="test-idp-secret", # But has secret + ) + + +@pytest.mark.anyio +async def test_exchange_id_jag_with_client_info_but_no_client_id(sample_id_jag: str, mock_token_storage: Any): + """Test ID-JAG exchange request building when client_info exists but client_id is None.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Set client info with client_id=None + provider.context.client_info = OAuthClientInformationFull( + client_id=None, # This should skip the client_id assignment + client_secret="test-secret", + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Build JWT bearer grant request + request = await provider.exchange_id_jag_for_access_token(sample_id_jag) + + # Verify request was built correctly + assert isinstance(request, httpx.Request) + + # Verify client_id was not included (None), but client_secret should be handled + body_params = urllib.parse.parse_qs(request.content.decode()) + assert "client_id" not in body_params or body_params["client_id"][0] == "" + + +def test_validate_token_exchange_params_missing_audience(): + """Test validation fails for missing audience.""" + params = TokenExchangeParameters( + subject_token="token", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience="", + resource="https://server.example/", + ) + + with pytest.raises(OAuthFlowError, match="audience is required"): + validate_token_exchange_params(params) + + +def test_validate_token_exchange_params_missing_resource(): + """Test validation fails for missing resource.""" + params = TokenExchangeParameters( + subject_token="token", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience="https://auth.example/", + resource="", + ) + + with pytest.raises(OAuthFlowError, match="resource is required"): + validate_token_exchange_params(params) + + +@pytest.mark.anyio +async def test_exchange_id_jag_with_existing_auth_method(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant when token_endpoint_auth_method is already set (covers branch 323->326).""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info WITH auth method already set + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret="test-client-secret", + token_endpoint_auth_method="client_secret_post", # Already set + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Build JWT bearer grant request + request = await provider.exchange_id_jag_for_access_token(sample_id_jag) + + # Verify request was built correctly + assert isinstance(request, httpx.Request) + + # Verify it used client_secret_post (in body, not header) + body_params = urllib.parse.parse_qs(request.content.decode()) + assert body_params["client_id"][0] == "test-client-id" + assert body_params["client_secret"][0] == "test-client-secret" + # Should NOT have Authorization header for client_secret_post + assert "Authorization" not in request.headers or not request.headers["Authorization"].startswith("Basic ") + + +@pytest.mark.anyio +async def test_perform_authorization_with_valid_tokens_no_id_jag(mock_token_storage: Any): + """Test _perform_authorization when tokens are valid but no cached ID-JAG (covers branch 354->360).""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Set valid tokens + provider.context.current_tokens = OAuthToken( + token_type="Bearer", + access_token="valid-token", + expires_in=3600, + ) + provider.context.token_expiry_time = time.time() + 3600 + + # Mock the IDP token exchange response + sample_id_jag = "test-id-jag-token" + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + }, + ) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should fall through and perform full flow + request = await provider._perform_authorization() + + # Verify it returns a JWT bearer grant request + assert isinstance(request, httpx.Request) + assert request.method == "POST" + + # Verify it made the IDP token exchange call + mock_client.post.assert_called_once() + + +@pytest.mark.anyio +async def test_refresh_with_new_id_token(mock_token_storage: Any): + """Test refresh_with_new_id_token helper method.""" + old_id_token = "old-id-token" + new_id_token = "new-id-token" + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=old_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set some existing tokens + provider.context.current_tokens = OAuthToken( + token_type="Bearer", + access_token="old-access-token", + expires_in=3600, + ) + + # Verify initial state + assert provider._subject_token == old_id_token + assert provider.context.current_tokens.access_token == "old-access-token" + + # Call refresh with new ID token + await provider.refresh_with_new_id_token(new_id_token) + + # Verify state after refresh + assert provider._subject_token == new_id_token + # Original token_exchange_params.subject_token must NOT be mutated (#7) + assert provider.token_exchange_params.subject_token == old_id_token + assert provider.context.current_tokens is None # Tokens should be cleared + assert provider.context.token_expiry_time is None # Expiry should be cleared + + +@pytest.mark.anyio +async def test_perform_authorization_always_exchanges_even_with_tokens(mock_token_storage: Any, sample_id_jag: str): + """Test that _perform_authorization always performs full exchange even with existing tokens.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Set valid tokens + provider.context.current_tokens = OAuthToken( + token_type="Bearer", + access_token="valid-token", + expires_in=3600, + ) + provider.context.token_expiry_time = time.time() + 3600 + + # Mock the IDP token exchange response for new ID-JAG + new_id_jag = "new-id-jag-token" + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": new_id_jag, + "token_type": "N_A", + }, + ) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should get a new ID-JAG since the cached one is expired + request = await provider._perform_authorization() + + # Verify it made the IDP token exchange call (didn't reuse expired ID-JAG) + mock_client.post.assert_called_once() + + # Verify the request uses the NEW ID-JAG + body_params = urllib.parse.parse_qs(request.content.decode()) + assert body_params["assertion"][0] == new_id_jag + + +@pytest.mark.anyio +async def test_perform_authorization_always_exchanges(mock_token_storage: Any, sample_id_jag: str): + """Test that _perform_authorization always does full exchange even with cached state. + + The parent class only calls _perform_authorization when new tokens are + needed, so it must not short-circuit based on cached state. + """ + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Set valid tokens + provider.context.current_tokens = OAuthToken( + token_type="Bearer", + access_token="valid-token", + expires_in=3600, + ) + provider.context.token_expiry_time = time.time() + 3600 + + new_id_jag = "freshly-exchanged-id-jag" + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + mock_client.post = AsyncMock( + return_value=httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": new_id_jag, + "token_type": "N_A", + }, + ) + ) + + request = await provider._perform_authorization() + mock_client.post.assert_called_once() + + body_params = urllib.parse.parse_qs(request.content.decode()) + assert body_params["assertion"][0] == new_id_jag + + +@pytest.mark.anyio +async def test_audience_override_warning(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): + """Test that audience override logs a warning when values differ.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://configured-audience.example/", # Different from issuer + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set OAuth metadata with different issuer + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://actual-issuer.example/"), # Different from configured + authorization_endpoint=AnyHttpUrl("https://actual-issuer.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://actual-issuer.example/oauth2/token"), + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should log warning about audience override + with patch.object( + logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning" + ) as mock_warning: + await provider.exchange_token_for_id_jag(mock_client) + + # Verify warning was called with message about override + mock_warning.assert_called_once() + warning_message = mock_warning.call_args[0][0] + assert "Overriding audience" in warning_message + assert "https://configured-audience.example/" in warning_message + assert "https://actual-issuer.example/" in warning_message + + +@pytest.mark.anyio +async def test_audience_no_warning_when_matching(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): + """Test that no warning is logged when audience matches issuer.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", # Same as issuer + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set OAuth metadata with matching issuer + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), # Same as configured + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should NOT log warning when values match + with patch.object( + logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning" + ) as mock_warning: + await provider.exchange_token_for_id_jag(mock_client) + + # Verify warning was NOT called + mock_warning.assert_not_called() + + +@pytest.mark.anyio +async def test_empty_scope_not_included(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): + """Test that empty or whitespace-only scope is not included in token request.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope=" ", # Whitespace-only scope + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + await provider.exchange_token_for_id_jag(mock_client) + + # Verify scope was NOT included in request + call_args = mock_client.post.call_args + assert "scope" not in call_args[1]["data"] + + +@pytest.mark.anyio +async def test_audience_override_disabled(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): + """Test that override_audience_with_issuer=False preserves the configured audience.""" + configured_audience = "https://configured-audience.example/" + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer=configured_audience, + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + override_audience_with_issuer=False, # Disable override + ) + + # Set OAuth metadata with a DIFFERENT issuer + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://different-issuer.example/"), + authorization_endpoint=AnyHttpUrl("https://different-issuer.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://different-issuer.example/oauth2/token"), + ) + + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should NOT override audience even though discovered issuer is different + with patch.object( + logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning" + ) as mock_warning: + await provider.exchange_token_for_id_jag(mock_client) + # No warning should be logged because override is disabled + mock_warning.assert_not_called() + + # Verify the CONFIGURED audience was used (not the discovered issuer) + call_args = mock_client.post.call_args + assert call_args[1]["data"]["audience"] == configured_audience + + +@pytest.mark.anyio +async def test_exchange_token_without_oauth_metadata(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): + """Test token exchange when oauth_metadata is not set. + + This tests the scenario where OAuth metadata discovery hasn't happened yet. + The configured audience from token_exchange_params should be used directly. + + Note: Testing issuer=None is not possible because OAuthMetadata.issuer is a + required AnyHttpUrl field per RFC 8414, so the Pydantic model prevents None. + """ + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.configured.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # No OAuth metadata set (None) + assert provider.context.oauth_metadata is None + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + await provider.exchange_token_for_id_jag(mock_client) + + # Verify the configured audience was used (no override since metadata is None) + call_args = mock_client.post.call_args + assert call_args[1]["data"]["audience"] == "https://auth.configured.example/" + + +@pytest.mark.anyio +async def test_exchange_id_jag_does_not_mutate_context_client_info(sample_id_jag: str, mock_token_storage: Any): + """Test that exchange_id_jag_for_access_token does not mutate context.client_info. + + Regression test: the method previously mutated client_info.token_endpoint_auth_method + as a side effect, which could cause unexpected behavior on subsequent calls. + """ + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info with no auth method (None) and a secret + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret="test-secret", + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + token_endpoint_auth_method=None, # Explicitly None + ) + + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Call twice to verify idempotency + await provider.exchange_id_jag_for_access_token(sample_id_jag) + assert provider.context.client_info.token_endpoint_auth_method is None + + await provider.exchange_id_jag_for_access_token(sample_id_jag) + assert provider.context.client_info.token_endpoint_auth_method is None + + +@pytest.mark.anyio +async def test_exchange_id_jag_includes_resource_param(sample_id_jag: str, mock_token_storage: Any): + """Test that JWT bearer grant includes resource parameter per RFC 8707.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Set protected resource metadata so should_include_resource_param returns True + provider.context.protected_resource_metadata = ProtectedResourceMetadata( + resource=AnyHttpUrl("https://mcp-server.example/"), + authorization_servers=[AnyHttpUrl("https://auth.mcp-server.example/")], + ) + + request = await provider.exchange_id_jag_for_access_token(sample_id_jag) + body_params = urllib.parse.parse_qs(request.content.decode()) + + assert "resource" in body_params + assert body_params["resource"][0] == "https://mcp-server.example/" + + +@pytest.mark.anyio +async def test_exchange_id_jag_includes_resource_param_via_protocol_version( + sample_id_jag: str, mock_token_storage: Any +): + """Test that JWT bearer grant includes resource when protocol version qualifies.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # No PRM, but protocol version >= 2025-06-18 qualifies + provider.context.protocol_version = "2025-06-18" + + request = await provider.exchange_id_jag_for_access_token(sample_id_jag) + body_params = urllib.parse.parse_qs(request.content.decode()) + + assert "resource" in body_params + + +@pytest.mark.anyio +async def test_exchange_id_jag_omits_resource_when_not_applicable(sample_id_jag: str, mock_token_storage: Any): + """Test that JWT bearer grant omits resource when conditions are not met.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # No PRM and no qualifying protocol version + provider.context.protected_resource_metadata = None + provider.context.protocol_version = None + + request = await provider.exchange_id_jag_for_access_token(sample_id_jag) + body_params = urllib.parse.parse_qs(request.content.decode()) + + assert "resource" not in body_params + + +@pytest.mark.anyio +async def test_exchange_id_jag_includes_scope(sample_id_jag: str, mock_token_storage: Any): + """Test that JWT bearer grant includes scope from client metadata.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + scope="read write", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + request = await provider.exchange_id_jag_for_access_token(sample_id_jag) + body_params = urllib.parse.parse_qs(request.content.decode()) + + assert "scope" in body_params + assert body_params["scope"][0] == "read write" + + +@pytest.mark.anyio +async def test_exchange_id_jag_omits_scope_when_not_set(sample_id_jag: str, mock_token_storage: Any): + """Test that JWT bearer grant omits scope when client metadata has no scope.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + # No scope + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + request = await provider.exchange_id_jag_for_access_token(sample_id_jag) + body_params = urllib.parse.parse_qs(request.content.decode()) + + assert "scope" not in body_params + + +@pytest.mark.anyio +async def test_exchange_id_jag_without_client_info(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant request building when context has no client_info at all.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # No client_info set (None) + assert provider.context.client_info is None + + request = await provider.exchange_id_jag_for_access_token(sample_id_jag) + + # Should still produce a valid request without client credentials + assert isinstance(request, httpx.Request) + body_params = urllib.parse.parse_qs(request.content.decode()) + assert body_params["grant_type"][0] == "urn:ietf:params:oauth:grant-type:jwt-bearer" + assert body_params["assertion"][0] == sample_id_jag + assert "client_id" not in body_params