diff --git a/.github/actions/conformance/client.py b/.github/actions/conformance/client.py
index 58f684f01..4c0f1d841 100644
--- a/.github/actions/conformance/client.py
+++ b/.github/actions/conformance/client.py
@@ -5,7 +5,7 @@
Contract:
- MCP_CONFORMANCE_SCENARIO env var -> scenario name
- - MCP_CONFORMANCE_CONTEXT env var -> optional JSON (for client-credentials scenarios)
+ - MCP_CONFORMANCE_CONTEXT env var -> optional JSON (for auth scenarios)
- Server URL as last CLI argument (sys.argv[1])
- Must exit 0 within 30 seconds
@@ -16,7 +16,16 @@
elicitation-sep1034-client-defaults - Elicitation with default accept callback
auth/client-credentials-jwt - Client credentials with private_key_jwt
auth/client-credentials-basic - Client credentials with client_secret_basic
+ auth/cross-app-access-complete-flow - Enterprise managed OAuth (SEP-990) - v0.1.14+
auth/* - Authorization code flow (default for auth scenarios)
+
+Enterprise Auth (SEP-990):
+ The conformance package v0.1.14+ (https://github.com/modelcontextprotocol/conformance/pull/110)
+ provides the scenario 'auth/cross-app-access-complete-flow' which tests the complete
+ enterprise managed OAuth flow: IDP ID token → ID-JAG → access token.
+
+ The client receives test context (idp_id_token, idp_token_endpoint, etc.) via
+ MCP_CONFORMANCE_CONTEXT environment variable and performs the token exchange flows automatically.
"""
import asyncio
@@ -314,9 +323,100 @@ async def run_auth_code_client(server_url: str) -> None:
await _run_auth_session(server_url, oauth_auth)
+@register("auth/cross-app-access-complete-flow")
+async def run_cross_app_access_complete_flow(server_url: str) -> None:
+ """Enterprise managed auth: Complete SEP-990 flow (OIDC ID token → ID-JAG → access token).
+
+ This scenario is provided by @modelcontextprotocol/conformance@0.1.14+ (PR #110).
+ It tests the complete enterprise managed OAuth flow using token exchange (RFC 8693)
+ and JWT bearer grant (RFC 7523).
+ """
+ from mcp.client.auth.extensions.enterprise_managed_auth import (
+ EnterpriseAuthOAuthClientProvider,
+ TokenExchangeParameters,
+ )
+
+ context = get_conformance_context()
+ # The conformance package provides these fields
+ idp_id_token = context.get("idp_id_token")
+ idp_token_endpoint = context.get("idp_token_endpoint")
+ idp_issuer = context.get("idp_issuer")
+
+ # For cross-app access, we need to determine the MCP server's resource ID and auth issuer
+ # The conformance package sets up the auth server, and the MCP server URL is passed to us
+
+ if not idp_id_token:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_id_token'")
+ if not idp_token_endpoint:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_token_endpoint'")
+ if not idp_issuer:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_issuer'")
+
+ # Extract base URL by stripping trailing /mcp path (Python 3.9+).
+ # The conformance harness always serves MCP at /mcp, so stripping
+ # the suffix gives us the auth-server base URL for fallback defaults.
+ base_url = server_url.removesuffix("/mcp")
+ auth_issuer = context.get("auth_issuer", base_url)
+ resource_id = context.get("resource_id", server_url)
+
+ logger.debug("Cross-app access flow:")
+ logger.debug(f" IDP Issuer: {idp_issuer}")
+ logger.debug(f" IDP Token Endpoint: {idp_token_endpoint}")
+ logger.debug(f" Auth Issuer: {auth_issuer}")
+ logger.debug(f" Resource ID: {resource_id}")
+
+ # Create token exchange parameters from IDP ID token
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=idp_id_token,
+ mcp_server_auth_issuer=auth_issuer,
+ mcp_server_resource_id=resource_id,
+ scope=context.get("scope"),
+ )
+
+ # Get pre-configured client credentials from context (if provided)
+ client_id = context.get("client_id")
+ client_secret = context.get("client_secret")
+
+ storage = InMemoryTokenStorage()
+
+ # Create enterprise auth provider
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url=server_url,
+ client_metadata=OAuthClientMetadata(
+ client_name="conformance-cross-app-client",
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ ),
+ storage=storage,
+ idp_token_endpoint=idp_token_endpoint,
+ token_exchange_params=token_exchange_params,
+ )
+
+ # If client credentials are provided in context, use them instead of dynamic registration
+ if client_id and client_secret:
+ from mcp.shared.auth import OAuthClientInformationFull
+
+ logger.debug(f"Using pre-configured client credentials: {client_id}")
+ client_info = OAuthClientInformationFull(
+ client_id=client_id,
+ client_secret=client_secret,
+ token_endpoint_auth_method="client_secret_basic",
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ )
+ enterprise_auth.context.client_info = client_info
+ await storage.set_client_info(client_info)
+
+ await _run_auth_session(server_url, enterprise_auth)
+
+
async def _run_auth_session(server_url: str, oauth_auth: OAuthClientProvider) -> None:
"""Common session logic for all OAuth flows."""
- client = httpx.AsyncClient(auth=oauth_auth, timeout=30.0)
+ # Allow timeout to be configured via environment variable for different test scenarios
+ timeout = float(os.environ.get("MCP_CONFORMANCE_TIMEOUT", "30.0"))
+ client = httpx.AsyncClient(auth=oauth_auth, timeout=timeout)
async with streamable_http_client(url=server_url, http_client=client) as (read_stream, write_stream):
async with ClientSession(
read_stream, write_stream, elicitation_callback=default_elicitation_callback
diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml
index d876da00b..7af927ddf 100644
--- a/.github/workflows/conformance.yml
+++ b/.github/workflows/conformance.yml
@@ -42,4 +42,4 @@ jobs:
with:
node-version: 24
- run: uv sync --frozen --all-extras --package mcp
- - run: npx @modelcontextprotocol/conformance@0.1.13 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all
+ - run: npx @modelcontextprotocol/conformance@0.1.14 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all
diff --git a/README.v2.md b/README.v2.md
index 55d867586..b0b0b7286 100644
--- a/README.v2.md
+++ b/README.v2.md
@@ -70,6 +70,7 @@
- [Writing MCP Clients](#writing-mcp-clients)
- [Client Display Utilities](#client-display-utilities)
- [OAuth Authentication for Clients](#oauth-authentication-for-clients)
+ - [Enterprise Managed Authorization](#enterprise-managed-authorization)
- [Parsing Tool Results](#parsing-tool-results)
- [MCP Primitives](#mcp-primitives)
- [Server Capabilities](#server-capabilities)
@@ -2395,6 +2396,328 @@ _Full example: [examples/snippets/clients/oauth_client.py](https://github.com/mo
For a complete working example, see [`examples/clients/simple-auth-client/`](examples/clients/simple-auth-client/).
+#### Enterprise Managed Authorization
+
+The SDK includes support for Enterprise Managed Authorization (SEP-990), which enables MCP clients to connect to protected servers using enterprise Single Sign-On (SSO) systems. This implementation supports:
+
+- **RFC 8693**: OAuth 2.0 Token Exchange (ID Token -> ID-JAG)
+- **RFC 7523**: JSON Web Token (JWT) Profile for OAuth 2.0 Authorization Grants (ID-JAG -> Access Token)
+- Integration with enterprise identity providers (Okta, Azure AD, etc.)
+
+**Key Components:**
+
+The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provider to implement the enterprise authorization flow:
+
+**Token Exchange Flow:**
+
+1. **Obtain ID Token** from your enterprise IdP (e.g., Okta, Azure AD)
+2. **Exchange ID Token for ID-JAG** using RFC 8693 Token Exchange
+3. **Exchange ID-JAG for Access Token** using RFC 7523 JWT Bearer Grant
+4. **Use Access Token** to call protected MCP server tools
+
+**Using the Access Token with MCP Server:**
+
+1. Once you have obtained the access token, you can use it to authenticate requests to the MCP server
+2. The access token is automatically included in all subsequent requests to the MCP server, allowing you to access protected tools and resources based on your enterprise identity and permissions.
+
+**Handling Token Expiration and Refresh:**
+
+Access tokens have a limited lifetime and will expire. When tokens expire:
+
+- **Check Token Expiration**: Use the `expires_in` field to determine when the token expires
+- **Refresh Flow**: When expired, repeat the token exchange flow with a fresh ID token from your IdP
+- **Automatic Refresh**: Implement automatic token refresh before expiration (recommended for production)
+- **Error Handling**: Catch authentication errors and retry with refreshed tokens
+
+**Important Notes:**
+
+- **ID Token Expiration**: If the ID token from your IdP expires, you must re-authenticate with the IdP to obtain a new ID token before performing token exchange
+- **Token Storage**: Store tokens securely and implement the `TokenStorage` interface to persist tokens between application restarts
+- **Scope Changes**: If you need different scopes, you must obtain a new ID token from the IdP with the required scopes
+- **Security**: Never log or expose access tokens or ID tokens in production environments
+
+**Example Usage:**
+
+
+```python
+import asyncio
+
+import httpx
+from pydantic import AnyUrl
+
+from mcp import ClientSession
+from mcp.client.auth import TokenStorage
+from mcp.client.auth.extensions import (
+ EnterpriseAuthOAuthClientProvider,
+ TokenExchangeParameters,
+)
+from mcp.client.streamable_http import streamable_http_client
+from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
+
+
+# Placeholder function for IdP authentication
+async def get_id_token_from_idp() -> str:
+ """Placeholder function to get ID token from your IdP.
+
+ In production, implement actual IdP authentication flow.
+ """
+ raise NotImplementedError("Implement your IdP authentication flow here")
+
+
+# Define token storage implementation
+class SimpleTokenStorage(TokenStorage):
+ def __init__(self) -> None:
+ self._tokens: OAuthToken | None = None
+ self._client_info: OAuthClientInformationFull | None = None
+
+ async def get_tokens(self) -> OAuthToken | None:
+ return self._tokens
+
+ async def set_tokens(self, tokens: OAuthToken) -> None:
+ self._tokens = tokens
+
+ async def get_client_info(self) -> OAuthClientInformationFull | None:
+ return self._client_info
+
+ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
+ self._client_info = client_info
+
+
+async def discover_mcp_server_metadata(server_url: str) -> tuple[str, str]:
+ """Discover MCP server's OAuth metadata and resource identifier.
+
+ Returns:
+ Tuple of (auth_issuer, resource_id)
+ """
+ from mcp.client.auth.utils import (
+ build_oauth_authorization_server_metadata_discovery_urls,
+ build_protected_resource_metadata_discovery_urls,
+ handle_auth_metadata_response,
+ handle_protected_resource_response,
+ )
+
+ async with httpx.AsyncClient() as client:
+ # Step 1: Discover Protected Resource Metadata (PRM)
+ prm_urls = build_protected_resource_metadata_discovery_urls(None, server_url)
+
+ prm = None
+ for url in prm_urls:
+ response = await client.get(url)
+ prm = await handle_protected_resource_response(response)
+ if prm:
+ break
+
+ if not prm:
+ raise ValueError("Could not discover Protected Resource Metadata")
+
+ # Extract resource identifier and authorization server URL
+ resource_id = str(prm.resource)
+ auth_server_url = str(prm.authorization_servers[0]) if prm.authorization_servers else None
+
+ # Step 2: Discover OAuth Authorization Server Metadata
+ oauth_urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url)
+
+ oauth_metadata = None
+ for url in oauth_urls:
+ response = await client.get(url)
+ ok, asm = await handle_auth_metadata_response(response)
+ if ok and asm:
+ oauth_metadata = asm
+ break
+
+ if not oauth_metadata or not oauth_metadata.issuer:
+ raise ValueError("Could not discover OAuth metadata or issuer")
+
+ auth_issuer = str(oauth_metadata.issuer)
+
+ return auth_issuer, resource_id
+
+
+async def main() -> None:
+ """Example demonstrating enterprise managed authorization with MCP."""
+ server_url = "https://mcp-server.example.com"
+
+ # Step 1: Get ID token from your IdP (e.g., Okta, Azure AD)
+ id_token = await get_id_token_from_idp()
+
+ # Step 2: Discover MCP server's OAuth metadata and resource identifier
+ # This replaces hardcoding these values
+ mcp_server_auth_issuer, mcp_server_resource_id = await discover_mcp_server_metadata(server_url)
+ print(f"Discovered auth issuer: {mcp_server_auth_issuer}")
+ print(f"Discovered resource ID: {mcp_server_resource_id}")
+
+ # Step 3: Configure token exchange parameters using discovered values
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer=mcp_server_auth_issuer,
+ mcp_server_resource_id=mcp_server_resource_id,
+ scope="mcp:tools mcp:resources", # Optional scopes
+ )
+
+ # Step 4: Create enterprise auth provider
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url=server_url,
+ client_metadata=OAuthClientMetadata(
+ client_name="Enterprise MCP Client",
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token", # Your IdP's token endpoint
+ token_exchange_params=token_exchange_params,
+ # Optional: IdP client credentials if your IdP requires client authentication for token exchange
+ # idp_client_id="your-idp-client-id",
+ # idp_client_secret="your-idp-client-secret",
+ )
+
+ # Step 5: Create authenticated HTTP client
+ # The auth provider automatically handles the two-step token exchange:
+ # 1. ID Token -> ID-JAG (via IDP)
+ # 2. ID-JAG -> Access Token (via MCP server)
+ client = httpx.AsyncClient(auth=enterprise_auth, timeout=30.0)
+
+ # Step 6: Connect to MCP server with authenticated client
+ async with streamable_http_client(url=server_url, http_client=client) as (read, write):
+ async with ClientSession(read, write) as session:
+ await session.initialize()
+
+ # List available tools
+ tools_result = await session.list_tools()
+ print(f"Available tools: {[t.name for t in tools_result.tools]}")
+
+ # Call a tool - auth tokens are automatically managed
+ if tools_result.tools:
+ tool_name = tools_result.tools[0].name
+ result = await session.call_tool(tool_name, {})
+ print(f"Tool result: {result.content}")
+
+ # List available resources
+ resources = await session.list_resources()
+ for resource in resources.resources:
+ print(f"Resource: {resource.uri}")
+
+
+async def advanced_manual_flow() -> None:
+ """Advanced example showing manual token exchange.
+
+ Use cases for manual token exchange:
+ - Testing and debugging: Inspect ID-JAG claims before exchanging for access token
+ - Token caching: Store and reuse ID-JAG across multiple MCP server connections
+ - Custom error handling: Implement specific retry logic for each token exchange step
+ - Monitoring: Log token exchange metrics and performance
+ - Token introspection: Validate ID-JAG structure before sending to MCP server
+ """
+ id_token = await get_id_token_from_idp()
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example.com",
+ mcp_server_resource_id="https://mcp-server.example.com",
+ )
+
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example.com",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Manual token exchange (useful for debugging, caching, custom error handling, etc.)
+ async with httpx.AsyncClient() as client:
+ # Step 1: Exchange ID token for ID-JAG
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+ # WARNING: Only log tokens in development/testing environments
+ # In production, NEVER log tokens or token fragments as they are sensitive credentials
+ print(f"Obtained ID-JAG: {id_jag[:50]}...")
+
+ # Step 2: Build JWT bearer grant request
+ jwt_bearer_request = await enterprise_auth.exchange_id_jag_for_access_token(id_jag)
+ print(f"Built JWT bearer grant request to: {jwt_bearer_request.url}")
+
+ # Step 3: Execute the request to get access token
+ response = await client.send(jwt_bearer_request)
+ response.raise_for_status()
+ token_data = response.json()
+
+ access_token = OAuthToken(
+ access_token=token_data["access_token"],
+ token_type=token_data["token_type"],
+ expires_in=token_data.get("expires_in"),
+ )
+ # WARNING: In production, do not log token expiry or any token information
+ print(f"Access token obtained, expires in: {access_token.expires_in}s")
+
+ # Use the access token for API calls
+ _ = {"Authorization": f"Bearer {access_token.access_token}"}
+ # ... make authenticated requests with headers
+
+
+async def token_refresh_example() -> None:
+ """Example showing how to refresh tokens when they expire.
+
+ When your access token expires, you need to obtain a fresh ID token
+ from your enterprise IdP and use the refresh helper method.
+ """
+ # Initial setup
+ id_token = await get_id_token_from_idp()
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example.com",
+ mcp_server_resource_id="https://mcp-server.example.com",
+ )
+
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example.com",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ _ = httpx.AsyncClient(auth=enterprise_auth, timeout=30.0)
+
+ # Use the client for MCP operations...
+ # ... time passes and token expires ...
+
+ # When token expires, get a fresh ID token from your IdP
+ new_id_token = await get_id_token_from_idp()
+
+ # Refresh the authentication using the helper method
+ await enterprise_auth.refresh_with_new_id_token(new_id_token)
+
+ # Next API call will automatically use the refreshed tokens
+ # No need to recreate the client or session
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
+```
+
+_Full example: [examples/snippets/clients/enterprise_managed_auth_client.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/enterprise_managed_auth_client.py)_
+
+
+**Working with SAML Assertions:**
+
+If your enterprise uses SAML instead of OIDC, you can exchange SAML assertions:
+
+```python
+token_exchange_params = TokenExchangeParameters.from_saml_assertion(
+ saml_assertion=saml_assertion_string,
+ mcp_server_auth_issuer="https://your-idp.com",
+ mcp_server_resource_id="https://mcp-server.example.com",
+ scope="mcp:tools",
+)
+```
+
+For more details on the enterprise authorization flow, see the [MCP Enterprise Authorization specification](https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization).
+
### Parsing Tool Results
When calling tools through MCP, the `CallToolResult` object contains the tool's response in a structured format. Understanding how to parse this result is essential for properly handling tool outputs.
diff --git a/examples/snippets/clients/enterprise_managed_auth_client.py b/examples/snippets/clients/enterprise_managed_auth_client.py
new file mode 100644
index 000000000..3888fdc21
--- /dev/null
+++ b/examples/snippets/clients/enterprise_managed_auth_client.py
@@ -0,0 +1,258 @@
+import asyncio
+
+import httpx
+from pydantic import AnyUrl
+
+from mcp import ClientSession
+from mcp.client.auth import TokenStorage
+from mcp.client.auth.extensions import (
+ EnterpriseAuthOAuthClientProvider,
+ TokenExchangeParameters,
+)
+from mcp.client.streamable_http import streamable_http_client
+from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
+
+
+# Placeholder function for IdP authentication
+async def get_id_token_from_idp() -> str:
+ """Placeholder function to get ID token from your IdP.
+
+ In production, implement actual IdP authentication flow.
+ """
+ raise NotImplementedError("Implement your IdP authentication flow here")
+
+
+# Define token storage implementation
+class SimpleTokenStorage(TokenStorage):
+ def __init__(self) -> None:
+ self._tokens: OAuthToken | None = None
+ self._client_info: OAuthClientInformationFull | None = None
+
+ async def get_tokens(self) -> OAuthToken | None:
+ return self._tokens
+
+ async def set_tokens(self, tokens: OAuthToken) -> None:
+ self._tokens = tokens
+
+ async def get_client_info(self) -> OAuthClientInformationFull | None:
+ return self._client_info
+
+ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
+ self._client_info = client_info
+
+
+async def discover_mcp_server_metadata(server_url: str) -> tuple[str, str]:
+ """Discover MCP server's OAuth metadata and resource identifier.
+
+ Returns:
+ Tuple of (auth_issuer, resource_id)
+ """
+ from mcp.client.auth.utils import (
+ build_oauth_authorization_server_metadata_discovery_urls,
+ build_protected_resource_metadata_discovery_urls,
+ handle_auth_metadata_response,
+ handle_protected_resource_response,
+ )
+
+ async with httpx.AsyncClient() as client:
+ # Step 1: Discover Protected Resource Metadata (PRM)
+ prm_urls = build_protected_resource_metadata_discovery_urls(None, server_url)
+
+ prm = None
+ for url in prm_urls:
+ response = await client.get(url)
+ prm = await handle_protected_resource_response(response)
+ if prm:
+ break
+
+ if not prm:
+ raise ValueError("Could not discover Protected Resource Metadata")
+
+ # Extract resource identifier and authorization server URL
+ resource_id = str(prm.resource)
+ auth_server_url = str(prm.authorization_servers[0]) if prm.authorization_servers else None
+
+ # Step 2: Discover OAuth Authorization Server Metadata
+ oauth_urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url)
+
+ oauth_metadata = None
+ for url in oauth_urls:
+ response = await client.get(url)
+ ok, asm = await handle_auth_metadata_response(response)
+ if ok and asm:
+ oauth_metadata = asm
+ break
+
+ if not oauth_metadata or not oauth_metadata.issuer:
+ raise ValueError("Could not discover OAuth metadata or issuer")
+
+ auth_issuer = str(oauth_metadata.issuer)
+
+ return auth_issuer, resource_id
+
+
+async def main() -> None:
+ """Example demonstrating enterprise managed authorization with MCP."""
+ server_url = "https://mcp-server.example.com"
+
+ # Step 1: Get ID token from your IdP (e.g., Okta, Azure AD)
+ id_token = await get_id_token_from_idp()
+
+ # Step 2: Discover MCP server's OAuth metadata and resource identifier
+ # This replaces hardcoding these values
+ mcp_server_auth_issuer, mcp_server_resource_id = await discover_mcp_server_metadata(server_url)
+ print(f"Discovered auth issuer: {mcp_server_auth_issuer}")
+ print(f"Discovered resource ID: {mcp_server_resource_id}")
+
+ # Step 3: Configure token exchange parameters using discovered values
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer=mcp_server_auth_issuer,
+ mcp_server_resource_id=mcp_server_resource_id,
+ scope="mcp:tools mcp:resources", # Optional scopes
+ )
+
+ # Step 4: Create enterprise auth provider
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url=server_url,
+ client_metadata=OAuthClientMetadata(
+ client_name="Enterprise MCP Client",
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token", # Your IdP's token endpoint
+ token_exchange_params=token_exchange_params,
+ # Optional: IdP client credentials if your IdP requires client authentication for token exchange
+ # idp_client_id="your-idp-client-id",
+ # idp_client_secret="your-idp-client-secret",
+ )
+
+ # Step 5: Create authenticated HTTP client
+ # The auth provider automatically handles the two-step token exchange:
+ # 1. ID Token -> ID-JAG (via IDP)
+ # 2. ID-JAG -> Access Token (via MCP server)
+ client = httpx.AsyncClient(auth=enterprise_auth, timeout=30.0)
+
+ # Step 6: Connect to MCP server with authenticated client
+ async with streamable_http_client(url=server_url, http_client=client) as (read, write):
+ async with ClientSession(read, write) as session:
+ await session.initialize()
+
+ # List available tools
+ tools_result = await session.list_tools()
+ print(f"Available tools: {[t.name for t in tools_result.tools]}")
+
+ # Call a tool - auth tokens are automatically managed
+ if tools_result.tools:
+ tool_name = tools_result.tools[0].name
+ result = await session.call_tool(tool_name, {})
+ print(f"Tool result: {result.content}")
+
+ # List available resources
+ resources = await session.list_resources()
+ for resource in resources.resources:
+ print(f"Resource: {resource.uri}")
+
+
+async def advanced_manual_flow() -> None:
+ """Advanced example showing manual token exchange.
+
+ Use cases for manual token exchange:
+ - Testing and debugging: Inspect ID-JAG claims before exchanging for access token
+ - Token caching: Store and reuse ID-JAG across multiple MCP server connections
+ - Custom error handling: Implement specific retry logic for each token exchange step
+ - Monitoring: Log token exchange metrics and performance
+ - Token introspection: Validate ID-JAG structure before sending to MCP server
+ """
+ id_token = await get_id_token_from_idp()
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example.com",
+ mcp_server_resource_id="https://mcp-server.example.com",
+ )
+
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example.com",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Manual token exchange (useful for debugging, caching, custom error handling, etc.)
+ async with httpx.AsyncClient() as client:
+ # Step 1: Exchange ID token for ID-JAG
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+ # WARNING: Only log tokens in development/testing environments
+ # In production, NEVER log tokens or token fragments as they are sensitive credentials
+ print(f"Obtained ID-JAG: {id_jag[:50]}...")
+
+ # Step 2: Build JWT bearer grant request
+ jwt_bearer_request = await enterprise_auth.exchange_id_jag_for_access_token(id_jag)
+ print(f"Built JWT bearer grant request to: {jwt_bearer_request.url}")
+
+ # Step 3: Execute the request to get access token
+ response = await client.send(jwt_bearer_request)
+ response.raise_for_status()
+ token_data = response.json()
+
+ access_token = OAuthToken(
+ access_token=token_data["access_token"],
+ token_type=token_data["token_type"],
+ expires_in=token_data.get("expires_in"),
+ )
+ # WARNING: In production, do not log token expiry or any token information
+ print(f"Access token obtained, expires in: {access_token.expires_in}s")
+
+ # Use the access token for API calls
+ _ = {"Authorization": f"Bearer {access_token.access_token}"}
+ # ... make authenticated requests with headers
+
+
+async def token_refresh_example() -> None:
+ """Example showing how to refresh tokens when they expire.
+
+ When your access token expires, you need to obtain a fresh ID token
+ from your enterprise IdP and use the refresh helper method.
+ """
+ # Initial setup
+ id_token = await get_id_token_from_idp()
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example.com",
+ mcp_server_resource_id="https://mcp-server.example.com",
+ )
+
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example.com",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ _ = httpx.AsyncClient(auth=enterprise_auth, timeout=30.0)
+
+ # Use the client for MCP operations...
+ # ... time passes and token expires ...
+
+ # When token expires, get a fresh ID token from your IdP
+ new_id_token = await get_id_token_from_idp()
+
+ # Refresh the authentication using the helper method
+ await enterprise_auth.refresh_with_new_id_token(new_id_token)
+
+ # Next API call will automatically use the refreshed tokens
+ # No need to recreate the client or session
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/src/mcp/client/auth/extensions/__init__.py b/src/mcp/client/auth/extensions/__init__.py
index e69de29bb..f9594864f 100644
--- a/src/mcp/client/auth/extensions/__init__.py
+++ b/src/mcp/client/auth/extensions/__init__.py
@@ -0,0 +1,33 @@
+"""MCP Client Auth Extensions."""
+
+from mcp.client.auth.extensions.client_credentials import (
+ ClientCredentialsOAuthProvider,
+ JWTParameters,
+ PrivateKeyJWTOAuthProvider,
+ RFC7523OAuthClientProvider,
+ SignedJWTParameters,
+ static_assertion_provider,
+)
+from mcp.client.auth.extensions.enterprise_managed_auth import (
+ EnterpriseAuthOAuthClientProvider,
+ IDJAGClaims,
+ IDJAGTokenExchangeResponse,
+ TokenExchangeParameters,
+ decode_id_jag,
+ validate_token_exchange_params,
+)
+
+__all__ = [
+ "ClientCredentialsOAuthProvider",
+ "static_assertion_provider",
+ "SignedJWTParameters",
+ "PrivateKeyJWTOAuthProvider",
+ "JWTParameters",
+ "RFC7523OAuthClientProvider",
+ "EnterpriseAuthOAuthClientProvider",
+ "IDJAGClaims",
+ "IDJAGTokenExchangeResponse",
+ "TokenExchangeParameters",
+ "decode_id_jag",
+ "validate_token_exchange_params",
+]
diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py
new file mode 100644
index 000000000..947f13a9e
--- /dev/null
+++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py
@@ -0,0 +1,479 @@
+"""Enterprise Managed Authorization extension for MCP (SEP-990).
+
+Implements RFC 8693 Token Exchange and RFC 7523 JWT Bearer Grant for
+enterprise SSO integration.
+"""
+
+import logging
+from json import JSONDecodeError
+
+import httpx
+import jwt
+from pydantic import BaseModel, Field, ValidationError
+from typing_extensions import NotRequired, Required, TypedDict
+
+from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
+from mcp.shared.auth import OAuthClientMetadata
+
+logger = logging.getLogger(__name__)
+
+
+class TokenExchangeRequestData(TypedDict):
+ """Type definition for RFC 8693 Token Exchange request data.
+
+ Required fields are those mandated by RFC 8693.
+ Optional fields (NotRequired) may be included based on IdP requirements.
+ """
+
+ grant_type: Required[str]
+ requested_token_type: Required[str]
+ audience: Required[str]
+ resource: Required[str]
+ subject_token: Required[str]
+ subject_token_type: Required[str]
+ scope: NotRequired[str]
+ client_id: NotRequired[str]
+ client_secret: NotRequired[str]
+
+
+class TokenExchangeParameters(BaseModel):
+ """Parameters for RFC 8693 Token Exchange request."""
+
+ requested_token_type: str = Field(
+ default="urn:ietf:params:oauth:token-type:id-jag",
+ description="Type of token being requested (ID-JAG)",
+ )
+
+ audience: str = Field(
+ ...,
+ description="Issuer URL of the MCP Server's authorization server",
+ )
+
+ resource: str = Field(
+ ...,
+ description="RFC 9728 Resource Identifier of the MCP Server",
+ )
+
+ scope: str | None = Field(
+ default=None,
+ description="Space-separated list of scopes being requested",
+ )
+
+ subject_token: str = Field(
+ ...,
+ description="ID Token or SAML assertion for the end user",
+ )
+
+ subject_token_type: str = Field(
+ ...,
+ description="Type of subject token (id_token or saml2)",
+ )
+
+ @classmethod
+ def from_id_token(
+ cls,
+ id_token: str,
+ mcp_server_auth_issuer: str,
+ mcp_server_resource_id: str,
+ scope: str | None = None,
+ ) -> "TokenExchangeParameters":
+ """Create parameters for OIDC ID Token exchange."""
+ return cls(
+ subject_token=id_token,
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience=mcp_server_auth_issuer,
+ resource=mcp_server_resource_id,
+ scope=scope,
+ )
+
+ @classmethod
+ def from_saml_assertion(
+ cls,
+ saml_assertion: str,
+ mcp_server_auth_issuer: str,
+ mcp_server_resource_id: str,
+ scope: str | None = None,
+ ) -> "TokenExchangeParameters":
+ """Create parameters for SAML assertion exchange."""
+ return cls(
+ subject_token=saml_assertion,
+ subject_token_type="urn:ietf:params:oauth:token-type:saml2",
+ audience=mcp_server_auth_issuer,
+ resource=mcp_server_resource_id,
+ scope=scope,
+ )
+
+
+class IDJAGTokenExchangeResponse(BaseModel):
+ """Response from RFC 8693 Token Exchange for ID-JAG."""
+
+ issued_token_type: str = Field(
+ ...,
+ description="Type of token issued (should be id-jag)",
+ )
+
+ access_token: str = Field(
+ ...,
+ description="The ID-JAG token (named access_token per RFC 8693)",
+ )
+
+ token_type: str = Field(
+ ...,
+ description="Token type (should be N_A for ID-JAG)",
+ )
+
+ scope: str | None = Field(
+ default=None,
+ description="Granted scopes",
+ )
+
+ expires_in: int | None = Field(
+ default=None,
+ description="Lifetime in seconds",
+ )
+
+ @property
+ def id_jag(self) -> str:
+ """Get the ID-JAG token."""
+ return self.access_token
+
+
+class IDJAGClaims(BaseModel):
+ """Claims structure for Identity Assertion JWT Authorization Grant.
+
+ Note: ``typ`` is sourced from the JWT *header* (not the payload) by
+ ``decode_id_jag``. It is included here for convenience so callers
+ can inspect the full ID-JAG structure from a single object.
+ """
+
+ model_config = {"extra": "allow"}
+
+ # JWT header
+ typ: str = Field(
+ ...,
+ description="JWT type - must be 'oauth-id-jag+jwt'",
+ )
+
+ # Required claims
+ jti: str = Field(..., description="Unique JWT ID")
+ iss: str = Field(..., description="IdP issuer URL")
+ sub: str = Field(..., description="Subject (user) identifier")
+ aud: str = Field(..., description="MCP Server's auth server issuer")
+ resource: str = Field(..., description="MCP Server resource identifier")
+ client_id: str = Field(..., description="MCP Client identifier")
+ exp: int = Field(..., description="Expiration timestamp")
+ iat: int = Field(..., description="Issued-at timestamp")
+
+ # Optional claims
+ scope: str | None = Field(None, description="Space-separated scopes")
+ email: str | None = Field(None, description="User email")
+
+
+class EnterpriseAuthOAuthClientProvider(OAuthClientProvider):
+ """OAuth client provider for Enterprise Managed Authorization (SEP-990).
+
+ Implements:
+ - RFC 8693: Token Exchange (ID Token → ID-JAG)
+ - RFC 7523: JWT Bearer Grant (ID-JAG → Access Token)
+
+ Concurrency & Thread Safety:
+ - SAFE: Concurrent requests within a single asyncio event loop. Token
+ operations are protected by the parent class's ``OAuthContext.lock``
+ via ``async_auth_flow``.
+ - UNSAFE: Sharing a provider instance across multiple OS threads. Each
+ thread must instantiate its own provider and event loop.
+ - Note: Ensure any shared ``TokenStorage`` implementation is async-safe.
+ """
+
+ def __init__(
+ self,
+ server_url: str,
+ client_metadata: OAuthClientMetadata,
+ storage: TokenStorage,
+ idp_token_endpoint: str,
+ token_exchange_params: TokenExchangeParameters,
+ timeout: float = 300.0,
+ idp_client_id: str | None = None,
+ idp_client_secret: str | None = None,
+ override_audience_with_issuer: bool = True,
+ ) -> None:
+ """Initialize Enterprise Auth OAuth Client.
+
+ Args:
+ server_url: MCP server URL
+ client_metadata: OAuth client metadata
+ storage: Token storage implementation
+ idp_token_endpoint: Enterprise IdP token endpoint URL
+ token_exchange_params: Token exchange parameters (not mutated)
+ timeout: Request timeout in seconds
+ idp_client_id: Optional client ID registered with the IdP for token exchange
+ idp_client_secret: Optional client secret registered with the IdP.
+ Must be accompanied by ``idp_client_id``; providing a secret
+ without an ID raises ``ValueError``.
+ override_audience_with_issuer: If True (default), replaces the IdP
+ audience with the discovered OAuth issuer URL. Set to False for
+ federated identity setups where the audience must differ.
+
+ Raises:
+ ValueError: If ``idp_client_secret`` is provided without ``idp_client_id``.
+ OAuthFlowError: If ``token_exchange_params`` fail validation.
+ """
+ # Validate pure parameters before creating any state (fail-fast)
+ if idp_client_secret is not None and idp_client_id is None:
+ raise ValueError(
+ "idp_client_secret was provided without idp_client_id. Provide both together, or omit the secret."
+ )
+ validate_token_exchange_params(token_exchange_params)
+
+ super().__init__(
+ server_url=server_url,
+ client_metadata=client_metadata,
+ storage=storage,
+ timeout=timeout,
+ )
+ self.idp_token_endpoint = idp_token_endpoint
+ # Keep original params immutable; track mutable subject_token separately
+ self.token_exchange_params = token_exchange_params
+ self._subject_token = token_exchange_params.subject_token
+ self.idp_client_id = idp_client_id
+ self.idp_client_secret = idp_client_secret
+ self.override_audience_with_issuer = override_audience_with_issuer
+
+ async def exchange_token_for_id_jag(
+ self,
+ client: httpx.AsyncClient,
+ ) -> str:
+ """Exchange ID Token for ID-JAG using RFC 8693 Token Exchange.
+
+ Args:
+ client: HTTP client for making requests
+
+ Returns:
+ The ID-JAG token string
+
+ Raises:
+ OAuthTokenError: If token exchange fails
+ """
+ logger.debug("Starting token exchange for ID-JAG")
+
+ audience = self.token_exchange_params.audience
+ if self.override_audience_with_issuer:
+ # OAuthMetadata.issuer is a required AnyHttpUrl field (RFC 8414),
+ # so it is always non-None when oauth_metadata is present.
+ if self.context.oauth_metadata:
+ discovered_issuer = str(self.context.oauth_metadata.issuer)
+ if audience != discovered_issuer:
+ logger.warning(
+ f"Overriding audience '{audience}' with discovered issuer "
+ f"'{discovered_issuer}'. To prevent this, pass "
+ f"override_audience_with_issuer=False."
+ )
+ audience = discovered_issuer
+
+ token_data: TokenExchangeRequestData = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "requested_token_type": self.token_exchange_params.requested_token_type,
+ "audience": audience,
+ "resource": self.token_exchange_params.resource,
+ "subject_token": self._subject_token,
+ "subject_token_type": self.token_exchange_params.subject_token_type,
+ }
+
+ if self.token_exchange_params.scope and self.token_exchange_params.scope.strip():
+ token_data["scope"] = self.token_exchange_params.scope
+
+ # Add IdP client authentication if provided.
+ # Sent as POST body parameters (not HTTP Basic) because this is the
+ # IdP's token-exchange endpoint — most enterprise IdPs (Okta, Azure AD,
+ # Ping) accept body credentials for token exchange. HTTP Basic is
+ # allowed by RFC 6749 §2.3.1 but not universally required here.
+ if self.idp_client_id is not None:
+ token_data["client_id"] = self.idp_client_id
+ if self.idp_client_secret is not None:
+ token_data["client_secret"] = self.idp_client_secret
+
+ try:
+ response = await client.post(
+ self.idp_token_endpoint,
+ data=token_data,
+ timeout=self.context.timeout,
+ )
+
+ if response.status_code != 200:
+ error_data: dict[str, str] = {}
+ try:
+ if response.headers.get("content-type", "").startswith("application/json"):
+ error_data = response.json()
+ except JSONDecodeError:
+ pass
+
+ error: str = error_data.get("error", "unknown_error")
+ error_description: str = error_data.get(
+ "error_description", f"Token exchange failed (HTTP {response.status_code})"
+ )
+ raise OAuthTokenError(f"Token exchange failed: {error} - {error_description}")
+
+ token_response = IDJAGTokenExchangeResponse.model_validate_json(response.content)
+
+ if token_response.issued_token_type != "urn:ietf:params:oauth:token-type:id-jag":
+ raise OAuthTokenError(f"Unexpected token type: {token_response.issued_token_type}")
+
+ if token_response.token_type != "N_A":
+ logger.warning(f"Expected token_type 'N_A', got '{token_response.token_type}'")
+
+ logger.debug("Successfully obtained ID-JAG")
+
+ return token_response.id_jag
+
+ except httpx.HTTPError as e:
+ raise OAuthTokenError(f"HTTP error during token exchange: {e}") from e
+ except ValidationError as e:
+ raise OAuthTokenError("Invalid token exchange response from IdP") from e
+
+ async def exchange_id_jag_for_access_token(
+ self,
+ id_jag: str,
+ ) -> httpx.Request:
+ """Build a JWT bearer grant request to exchange an ID-JAG for an access token (RFC 7523).
+
+ This method only *builds* the ``httpx.Request``; it does not execute
+ it. HTTP execution and error parsing are deferred to the parent
+ class's ``async_auth_flow`` via ``_handle_token_response``.
+
+ Follows the same pattern as ``ClientCredentialsOAuthProvider._exchange_token_client_credentials``
+ and ``RFC7523OAuthClientProvider._exchange_token_jwt_bearer``:
+ use ``_get_token_endpoint()`` for the URL and ``prepare_token_auth()``
+ for client authentication — no manual ``client_id`` injection or
+ context swapping needed.
+
+ Args:
+ id_jag: The ID-JAG token obtained from ``exchange_token_for_id_jag``
+
+ Returns:
+ An ``httpx.Request`` for the JWT bearer grant
+ """
+ logger.debug("Building JWT bearer grant request for ID-JAG")
+
+ token_data: dict[str, str] = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
+ "assertion": id_jag,
+ }
+
+ headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
+
+ # Delegate client authentication (client_secret_basic, client_secret_post,
+ # or none) to the parent's context helper — same as every other grant type.
+ token_data, headers = self.context.prepare_token_auth(token_data, headers)
+
+ # Include resource parameter per RFC 8707 — same guard as every sibling provider
+ if self.context.should_include_resource_param(self.context.protocol_version):
+ token_data["resource"] = self.context.get_resource_url()
+
+ # Include scope if configured (may have been updated by parent's async_auth_flow
+ # from the server's WWW-Authenticate header before _perform_authorization is called)
+ if self.context.client_metadata.scope:
+ token_data["scope"] = self.context.client_metadata.scope
+
+ token_url = self._get_token_endpoint()
+ return httpx.Request("POST", token_url, data=token_data, headers=headers)
+
+ async def _perform_authorization(self) -> httpx.Request:
+ """Perform enterprise authorization flow.
+
+ Called by the parent's ``async_auth_flow`` when a new access token is needed.
+ Unconditionally performs full token exchange as the parent already handles
+ token validity checks.
+
+ Flow:
+ 1. Exchange IdP subject token for ID-JAG (RFC 8693, direct HTTP call)
+ 2. Return an ``httpx.Request`` for the JWT bearer grant (RFC 7523)
+ that the parent will execute and pass to ``_handle_token_response``
+
+ Returns:
+ httpx.Request for the JWT bearer grant to the MCP authorization server
+ """
+ # Step 1: Exchange IDP subject token for ID-JAG (RFC 8693)
+ async with httpx.AsyncClient(timeout=self.context.timeout) as client:
+ id_jag = await self.exchange_token_for_id_jag(client)
+
+ # Step 2: Build JWT bearer grant request (RFC 7523)
+ jwt_bearer_request = await self.exchange_id_jag_for_access_token(id_jag)
+
+ logger.debug("Returning JWT bearer grant request to async_auth_flow")
+ return jwt_bearer_request
+
+ async def refresh_with_new_id_token(self, new_id_token: str) -> None:
+ """Refresh MCP server access tokens using a fresh ID token from the IdP.
+
+ Updates the subject token and clears cached state so that the next API
+ request triggers a full re-authentication. Acquires the context lock
+ to prevent racing with an in-progress ``async_auth_flow``.
+
+ Note: OAuth metadata is not re-discovered. If the MCP server's OAuth
+ configuration has changed, create a new provider instance instead.
+
+ Warning: This method is NOT safe to call from a different OS thread.
+ Call it only from the same thread and event loop that owns this
+ provider instance.
+
+ Args:
+ new_id_token: Fresh ID token obtained from your enterprise IdP.
+ """
+ async with self.context.lock:
+ logger.info("Refreshing tokens with new ID token from IdP")
+ # Update the mutable subject token (does NOT mutate the original params object)
+ self._subject_token = new_id_token
+
+ # Clear tokens to force full re-exchange on next request
+ self.context.clear_tokens()
+ logger.debug("Token refresh prepared — will re-authenticate on next request")
+
+
+def decode_id_jag(id_jag: str) -> IDJAGClaims:
+ """Decode an ID-JAG token without verification.
+
+ Relies on the receiving server to validate the JWT signature.
+
+ Args:
+ id_jag: The ID-JAG token string
+
+ Returns:
+ Decoded ID-JAG claims
+ """
+ claims = jwt.decode(id_jag, options={"verify_signature": False})
+ header = jwt.get_unverified_header(id_jag)
+ claims["typ"] = header.get("typ", "")
+
+ return IDJAGClaims.model_validate(claims)
+
+
+def validate_token_exchange_params(
+ params: TokenExchangeParameters,
+) -> None:
+ """Validate token exchange parameters beyond Pydantic field constraints.
+
+ Pydantic ``Field(...)`` rejects *missing* values but permits empty strings.
+ This function adds:
+ - Empty-string checks for ``subject_token``, ``audience``, ``resource``
+ - Allow-list check for ``subject_token_type`` (id_token or saml2)
+
+ Args:
+ params: Token exchange parameters to validate
+
+ Raises:
+ OAuthFlowError: If parameters are invalid
+ """
+ if not params.subject_token:
+ raise OAuthFlowError("subject_token is required")
+
+ if not params.audience:
+ raise OAuthFlowError("audience is required")
+
+ if not params.resource:
+ raise OAuthFlowError("resource is required")
+
+ if params.subject_token_type not in {
+ "urn:ietf:params:oauth:token-type:id_token",
+ "urn:ietf:params:oauth:token-type:saml2",
+ }:
+ raise OAuthFlowError(f"Invalid subject_token_type: {params.subject_token_type}")
diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py
new file mode 100644
index 000000000..e8b8e4b15
--- /dev/null
+++ b/tests/client/auth/test_enterprise_managed_auth_client.py
@@ -0,0 +1,1792 @@
+"""Tests for Enterprise Managed Authorization client-side implementation."""
+
+import logging
+import time
+import urllib.parse
+from typing import Any
+from unittest.mock import AsyncMock, Mock, patch
+
+import httpx
+import jwt
+import pytest
+from pydantic import AnyHttpUrl, AnyUrl
+
+from mcp.client.auth import OAuthFlowError, OAuthTokenError
+from mcp.client.auth.extensions.enterprise_managed_auth import (
+ EnterpriseAuthOAuthClientProvider,
+ IDJAGClaims,
+ IDJAGTokenExchangeResponse,
+ TokenExchangeParameters,
+ decode_id_jag,
+ validate_token_exchange_params,
+)
+from mcp.shared.auth import (
+ OAuthClientInformationFull,
+ OAuthClientMetadata,
+ OAuthMetadata,
+ OAuthToken,
+ ProtectedResourceMetadata,
+)
+
+
+@pytest.fixture
+def sample_id_token() -> str:
+ """Generate a sample ID token for testing."""
+ payload = {
+ "iss": "https://idp.example.com",
+ "sub": "user123",
+ "aud": "mcp-client-app",
+ "exp": int(time.time()) + 3600,
+ "iat": int(time.time()),
+ "email": "user@example.com",
+ }
+ return jwt.encode(payload, "secret", algorithm="HS256")
+
+
+@pytest.fixture
+def sample_id_jag() -> str:
+ """Generate a sample ID-JAG token for testing."""
+ # Create typed claims using IDJAGClaims model
+ claims = IDJAGClaims(
+ typ="oauth-id-jag+jwt",
+ jti="unique-jwt-id-12345",
+ iss="https://idp.example.com",
+ sub="user123",
+ aud="https://auth.mcp-server.example/",
+ resource="https://mcp-server.example/",
+ client_id="mcp-client-app",
+ exp=int(time.time()) + 300,
+ iat=int(time.time()),
+ scope="read write",
+ email=None, # Optional field
+ )
+
+ # Dump to dict for JWT encoding (exclude typ as it goes in header)
+ payload = claims.model_dump(exclude={"typ"}, exclude_none=True)
+
+ return jwt.encode(payload, "secret", algorithm="HS256", headers={"typ": "oauth-id-jag+jwt"})
+
+
+@pytest.fixture
+def mock_token_storage() -> Any:
+ """Create a mock token storage."""
+ storage = Mock()
+ storage.get_tokens = AsyncMock(return_value=None)
+ storage.set_tokens = AsyncMock()
+ storage.get_client_info = AsyncMock(return_value=None)
+ storage.set_client_info = AsyncMock()
+ return storage
+
+
+def test_token_exchange_params_from_id_token():
+ """Test creating TokenExchangeParameters from ID token."""
+ params = TokenExchangeParameters.from_id_token(
+ id_token="eyJhbGc...",
+ mcp_server_auth_issuer="https://auth.server.example/",
+ mcp_server_resource_id="https://server.example/",
+ scope="read write",
+ )
+
+ assert params.subject_token == "eyJhbGc..."
+ assert params.subject_token_type == "urn:ietf:params:oauth:token-type:id_token"
+ assert params.audience == "https://auth.server.example/"
+ assert params.resource == "https://server.example/"
+ assert params.scope == "read write"
+ assert params.requested_token_type == "urn:ietf:params:oauth:token-type:id-jag"
+
+
+def test_token_exchange_params_from_saml_assertion():
+ """Test creating TokenExchangeParameters from SAML assertion."""
+ params = TokenExchangeParameters.from_saml_assertion(
+ saml_assertion="...",
+ mcp_server_auth_issuer="https://auth.server.example/",
+ mcp_server_resource_id="https://server.example/",
+ scope="read",
+ )
+
+ assert params.subject_token == "..."
+ assert params.subject_token_type == "urn:ietf:params:oauth:token-type:saml2"
+ assert params.audience == "https://auth.server.example/"
+ assert params.resource == "https://server.example/"
+ assert params.scope == "read"
+
+
+def test_validate_token_exchange_params_valid():
+ """Test validating valid token exchange parameters."""
+ params = TokenExchangeParameters.from_id_token(
+ id_token="token",
+ mcp_server_auth_issuer="https://auth.example/",
+ mcp_server_resource_id="https://server.example/",
+ )
+
+ # Should not raise
+ validate_token_exchange_params(params)
+
+
+def test_validate_token_exchange_params_invalid_token_type():
+ """Test validation fails for invalid subject token type."""
+ params = TokenExchangeParameters(
+ subject_token="token",
+ subject_token_type="invalid:type",
+ audience="https://auth.example/",
+ resource="https://server.example/",
+ )
+
+ with pytest.raises(OAuthFlowError, match="Invalid subject_token_type"):
+ validate_token_exchange_params(params)
+
+
+def test_validate_token_exchange_params_missing_subject_token():
+ """Test validation fails for missing subject token."""
+ params = TokenExchangeParameters(
+ subject_token="",
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience="https://auth.example/",
+ resource="https://server.example/",
+ )
+
+ with pytest.raises(OAuthFlowError, match="subject_token is required"):
+ validate_token_exchange_params(params)
+
+
+def test_token_exchange_response_parsing():
+ """Test parsing token exchange response."""
+ response_json = """{
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": "eyJhbGc...",
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300
+ }"""
+
+ response = IDJAGTokenExchangeResponse.model_validate_json(response_json)
+
+ assert response.issued_token_type == "urn:ietf:params:oauth:token-type:id-jag"
+ assert response.id_jag == "eyJhbGc..."
+ assert response.access_token == "eyJhbGc..."
+ assert response.token_type == "N_A"
+ assert response.scope == "read write"
+ assert response.expires_in == 300
+
+
+def test_token_exchange_response_id_jag_property():
+ """Test id_jag property returns access_token."""
+ response = IDJAGTokenExchangeResponse(
+ issued_token_type="urn:ietf:params:oauth:token-type:id-jag",
+ access_token="the-id-jag-token",
+ token_type="N_A",
+ )
+
+ assert response.id_jag == "the-id-jag-token"
+
+
+def test_decode_id_jag(sample_id_jag: str):
+ """Test decoding ID-JAG token."""
+ claims = decode_id_jag(sample_id_jag)
+
+ assert claims.iss == "https://idp.example.com"
+ assert claims.sub == "user123"
+ assert claims.aud == "https://auth.mcp-server.example/"
+ assert claims.resource == "https://mcp-server.example/"
+ assert claims.client_id == "mcp-client-app"
+ assert claims.scope == "read write"
+
+
+def test_decode_id_jag_invalid_jwt():
+ """Test decoding malformed ID-JAG raises appropriate error."""
+ with pytest.raises(jwt.DecodeError):
+ decode_id_jag("not.a.valid.jwt")
+
+
+def test_decode_id_jag_incomplete_jwt():
+ """Test decoding incomplete JWT raises error."""
+ with pytest.raises(jwt.DecodeError):
+ decode_id_jag("only.two.parts")
+
+
+def test_id_jag_claims_with_extra_fields():
+ """Test IDJAGClaims allows extra fields."""
+ claims_data = {
+ "typ": "oauth-id-jag+jwt",
+ "jti": "jti123",
+ "iss": "https://idp.example.com",
+ "sub": "user123",
+ "aud": "https://auth.server.example/",
+ "resource": "https://server.example/",
+ "client_id": "client123",
+ "exp": int(time.time()) + 300,
+ "iat": int(time.time()),
+ "scope": "read",
+ "email": "user@example.com",
+ "custom_claim": "custom_value", # Extra field
+ }
+
+ claims = IDJAGClaims.model_validate(claims_data)
+ assert claims.email == "user@example.com"
+ # Extra field should be preserved
+ assert claims.model_extra is not None and claims.model_extra.get("custom_claim") == "custom_value"
+
+
+# ============================================================================
+# Tests for EnterpriseAuthOAuthClientProvider
+# ============================================================================
+
+
+@pytest.mark.anyio
+async def test_exchange_token_for_id_jag_success(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test successful token exchange for ID-JAG."""
+ # Create provider
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify
+ assert id_jag == sample_id_jag
+
+ # Verify request was made correctly
+ mock_client.post.assert_called_once()
+ call_args = mock_client.post.call_args
+ assert call_args[0][0] == "https://idp.example.com/oauth2/token"
+ assert call_args[1]["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:token-exchange"
+ assert call_args[1]["data"]["requested_token_type"] == "urn:ietf:params:oauth:token-type:id-jag"
+ assert call_args[1]["data"]["audience"] == "https://auth.mcp-server.example/"
+ assert call_args[1]["data"]["resource"] == "https://mcp-server.example/"
+
+
+@pytest.mark.anyio
+async def test_exchange_token_for_id_jag_error(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange failure handling."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock error response
+ mock_response = httpx.Response(
+ status_code=400,
+ json={
+ "error": "invalid_request",
+ "error_description": "Invalid subject token",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="Token exchange failed"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_for_id_jag_unexpected_token_type(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with unexpected token type."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock response with wrong token type
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "access_token": "some-token",
+ "token_type": "Bearer",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="Unexpected token type"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_for_access_token_success(sample_id_jag: str, mock_token_storage: Any):
+ """Test building JWT bearer grant request."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Build JWT bearer grant request
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+
+ # Verify the request was built correctly
+ assert isinstance(request, httpx.Request)
+ assert request.method == "POST"
+ assert str(request.url) == "https://auth.mcp-server.example/oauth2/token"
+
+ # Parse the request body
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["grant_type"][0] == "urn:ietf:params:oauth:grant-type:jwt-bearer"
+ assert body_params["assertion"][0] == sample_id_jag
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_for_access_token_no_metadata(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant falls back to constructed token URL without OAuth metadata."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # No OAuth metadata set — _get_token_endpoint() falls back to server URL base
+
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+
+ # Should fall back to constructing token URL from server URL (same as parent class)
+ assert isinstance(request, httpx.Request)
+ assert request.method == "POST"
+ assert str(request.url) == "https://mcp-server.example/token"
+
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["grant_type"][0] == "urn:ietf:params:oauth:grant-type:jwt-bearer"
+ assert body_params["assertion"][0] == sample_id_jag
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_full_flow(mock_token_storage: Any, sample_id_jag: str):
+ """Test that _perform_authorization performs token exchange and builds JWT bearer request."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock the IDP token exchange response
+ with patch("httpx.AsyncClient") as mock_client_class:
+ mock_client = AsyncMock()
+ mock_client_class.return_value.__aenter__.return_value = mock_client
+
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform authorization
+ request = await provider._perform_authorization()
+
+ # Verify it returns an httpx.Request for JWT bearer grant
+ assert isinstance(request, httpx.Request)
+ assert request.method == "POST"
+ assert str(request.url) == "https://auth.mcp-server.example/oauth2/token"
+
+ # Verify the request contains JWT bearer grant
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["grant_type"][0] == "urn:ietf:params:oauth:grant-type:jwt-bearer"
+ assert body_params["assertion"][0] == sample_id_jag
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_with_valid_tokens(mock_token_storage: Any, sample_id_jag: str):
+ """Test that _perform_authorization always performs full exchange.
+
+ _perform_authorization is only called by the parent when tokens need to be
+ obtained or refreshed, so it unconditionally performs the full exchange flow.
+ """
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock the IDP token exchange response
+ with patch("httpx.AsyncClient") as mock_client_class:
+ mock_client = AsyncMock()
+ mock_client_class.return_value.__aenter__.return_value = mock_client
+
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should always perform full flow (no is_token_valid shortcircuit)
+ request = await provider._perform_authorization()
+
+ assert isinstance(request, httpx.Request)
+ assert request.method == "POST"
+ assert str(request.url) == "https://auth.mcp-server.example/oauth2/token"
+
+ # Verify it uses the exchanged ID-JAG
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["assertion"][0] == sample_id_jag
+
+
+@pytest.mark.anyio
+async def test_exchange_token_with_client_authentication(
+ sample_id_token: str, sample_id_jag: str, mock_token_storage: Any
+):
+ """Test token exchange with client authentication."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ idp_client_id="test-idp-client-id", # IdP client ID, not MCP client ID
+ idp_client_secret="test-idp-client-secret", # IdP client secret
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify the ID-JAG was returned
+ assert id_jag == sample_id_jag
+
+ # Verify client credentials were included
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["client_id"] == "test-idp-client-id"
+ assert call_args[1]["data"]["client_secret"] == "test-idp-client-secret"
+
+
+@pytest.mark.anyio
+async def test_exchange_token_with_client_id_only(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test token exchange with client_id but no client_secret (covers branch 232->235)."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ idp_client_id="test-idp-client-id", # IdP client ID, not MCP client ID
+ idp_client_secret=None, # No secret
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify the ID-JAG was returned
+ assert id_jag == sample_id_jag
+
+ # Verify client_id was included but NOT client_secret
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["client_id"] == "test-idp-client-id"
+ assert "client_secret" not in call_args[1]["data"]
+
+
+@pytest.mark.anyio
+async def test_exchange_token_http_error(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with HTTP error."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection failed"))
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="HTTP error during token exchange"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_malformed_json_error_response(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with malformed JSON error response that raises JSONDecodeError."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock error response with malformed JSON (will raise JSONDecodeError when parsing)
+ mock_response = httpx.Response(
+ status_code=400,
+ content=b'{"error": "invalid_request", "invalid json structure', # Malformed JSON
+ headers={"content-type": "application/json"},
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError with default error message including status code
+ with pytest.raises(OAuthTokenError, match=r"Token exchange failed.*HTTP 400"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_non_json_error_response(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with non-JSON error response."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock error response with non-JSON content
+ mock_response = httpx.Response(
+ status_code=500,
+ content=b"Internal Server Error",
+ headers={"content-type": "text/plain"},
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError with default error
+ with pytest.raises(OAuthTokenError, match="Token exchange failed: unknown_error"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_invalid_response_schema(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with HTTP 200 but response missing required fields.
+
+ Regression test: ValidationError from model_validate_json must be caught
+ and wrapped in OAuthTokenError, not propagated raw.
+ """
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # HTTP 200 but missing required fields (e.g., access_token)
+ mock_response = httpx.Response(
+ status_code=200,
+ json={"unexpected_field": "value"},
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ with pytest.raises(OAuthTokenError, match="Invalid token exchange response from IdP"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_warning_for_non_na_token_type(
+ sample_id_token: str, sample_id_jag: str, mock_token_storage: Any
+):
+ """Test token exchange logs warning for non-N_A token type."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock response with different token_type
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "Bearer", # Not N_A
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should succeed but log warning
+ with patch.object(
+ logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning"
+ ) as mock_warning:
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+ assert id_jag == sample_id_jag
+ mock_warning.assert_called_once()
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_client_authentication(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant request building with client authentication."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info with secret (auth method set by registration, as in real flow)
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret="test-client-secret",
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ token_endpoint_auth_method="client_secret_basic",
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Build JWT bearer grant request
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+
+ # Verify request was built correctly
+ assert isinstance(request, httpx.Request)
+ assert request.method == "POST"
+
+ # Verify client credentials were handled via client_secret_basic
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ # With client_secret_basic, credentials are in Authorization header, not body
+ assert "client_secret" not in body_params
+ assert "Authorization" in request.headers
+ assert request.headers["Authorization"].startswith("Basic ")
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_client_id_only(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant request building with client_id but no client_secret."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info WITHOUT secret (client_secret=None), auth method "none" as set by registration
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret=None, # No secret
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ token_endpoint_auth_method="none",
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Build JWT bearer grant request
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+
+ # Verify request was built correctly
+ assert isinstance(request, httpx.Request)
+
+ # With auth method "none", no credentials are added
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert "client_secret" not in body_params
+ # With auth_method == "none", no Basic Authorization header is sent.
+ assert "Authorization" not in request.headers
+
+
+@pytest.mark.anyio
+async def test_exchange_token_with_client_info_but_no_client_id(
+ sample_id_token: str, sample_id_jag: str, mock_token_storage: Any
+):
+ """Test that providing idp_client_secret without idp_client_id raises ValueError."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ with pytest.raises(ValueError, match="idp_client_secret was provided without idp_client_id"):
+ EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ idp_client_id=None, # No client ID
+ idp_client_secret="test-idp-secret", # But has secret
+ )
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_client_info_but_no_client_id(sample_id_jag: str, mock_token_storage: Any):
+ """Test ID-JAG exchange request building when client_info exists but client_id is None."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Set client info with client_id=None
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id=None, # This should skip the client_id assignment
+ client_secret="test-secret",
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Build JWT bearer grant request
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+
+ # Verify request was built correctly
+ assert isinstance(request, httpx.Request)
+
+ # Verify client_id was not included (None), but client_secret should be handled
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert "client_id" not in body_params or body_params["client_id"][0] == ""
+
+
+def test_validate_token_exchange_params_missing_audience():
+ """Test validation fails for missing audience."""
+ params = TokenExchangeParameters(
+ subject_token="token",
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience="",
+ resource="https://server.example/",
+ )
+
+ with pytest.raises(OAuthFlowError, match="audience is required"):
+ validate_token_exchange_params(params)
+
+
+def test_validate_token_exchange_params_missing_resource():
+ """Test validation fails for missing resource."""
+ params = TokenExchangeParameters(
+ subject_token="token",
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience="https://auth.example/",
+ resource="",
+ )
+
+ with pytest.raises(OAuthFlowError, match="resource is required"):
+ validate_token_exchange_params(params)
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_existing_auth_method(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant when token_endpoint_auth_method is already set (covers branch 323->326)."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info WITH auth method already set
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret="test-client-secret",
+ token_endpoint_auth_method="client_secret_post", # Already set
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Build JWT bearer grant request
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+
+ # Verify request was built correctly
+ assert isinstance(request, httpx.Request)
+
+ # Verify it used client_secret_post (in body, not header)
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["client_id"][0] == "test-client-id"
+ assert body_params["client_secret"][0] == "test-client-secret"
+ # Should NOT have Authorization header for client_secret_post
+ assert "Authorization" not in request.headers or not request.headers["Authorization"].startswith("Basic ")
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_with_valid_tokens_no_id_jag(mock_token_storage: Any):
+ """Test _perform_authorization when tokens are valid but no cached ID-JAG (covers branch 354->360)."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Set valid tokens
+ provider.context.current_tokens = OAuthToken(
+ token_type="Bearer",
+ access_token="valid-token",
+ expires_in=3600,
+ )
+ provider.context.token_expiry_time = time.time() + 3600
+
+ # Mock the IDP token exchange response
+ sample_id_jag = "test-id-jag-token"
+ with patch("httpx.AsyncClient") as mock_client_class:
+ mock_client = AsyncMock()
+ mock_client_class.return_value.__aenter__.return_value = mock_client
+
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should fall through and perform full flow
+ request = await provider._perform_authorization()
+
+ # Verify it returns a JWT bearer grant request
+ assert isinstance(request, httpx.Request)
+ assert request.method == "POST"
+
+ # Verify it made the IDP token exchange call
+ mock_client.post.assert_called_once()
+
+
+@pytest.mark.anyio
+async def test_refresh_with_new_id_token(mock_token_storage: Any):
+ """Test refresh_with_new_id_token helper method."""
+ old_id_token = "old-id-token"
+ new_id_token = "new-id-token"
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=old_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set some existing tokens
+ provider.context.current_tokens = OAuthToken(
+ token_type="Bearer",
+ access_token="old-access-token",
+ expires_in=3600,
+ )
+
+ # Verify initial state
+ assert provider._subject_token == old_id_token
+ assert provider.context.current_tokens.access_token == "old-access-token"
+
+ # Call refresh with new ID token
+ await provider.refresh_with_new_id_token(new_id_token)
+
+ # Verify state after refresh
+ assert provider._subject_token == new_id_token
+ # Original token_exchange_params.subject_token must NOT be mutated (#7)
+ assert provider.token_exchange_params.subject_token == old_id_token
+ assert provider.context.current_tokens is None # Tokens should be cleared
+ assert provider.context.token_expiry_time is None # Expiry should be cleared
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_always_exchanges_even_with_tokens(mock_token_storage: Any, sample_id_jag: str):
+ """Test that _perform_authorization always performs full exchange even with existing tokens."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Set valid tokens
+ provider.context.current_tokens = OAuthToken(
+ token_type="Bearer",
+ access_token="valid-token",
+ expires_in=3600,
+ )
+ provider.context.token_expiry_time = time.time() + 3600
+
+ # Mock the IDP token exchange response for new ID-JAG
+ new_id_jag = "new-id-jag-token"
+ with patch("httpx.AsyncClient") as mock_client_class:
+ mock_client = AsyncMock()
+ mock_client_class.return_value.__aenter__.return_value = mock_client
+
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": new_id_jag,
+ "token_type": "N_A",
+ },
+ )
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should get a new ID-JAG since the cached one is expired
+ request = await provider._perform_authorization()
+
+ # Verify it made the IDP token exchange call (didn't reuse expired ID-JAG)
+ mock_client.post.assert_called_once()
+
+ # Verify the request uses the NEW ID-JAG
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["assertion"][0] == new_id_jag
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_always_exchanges(mock_token_storage: Any, sample_id_jag: str):
+ """Test that _perform_authorization always does full exchange even with cached state.
+
+ The parent class only calls _perform_authorization when new tokens are
+ needed, so it must not short-circuit based on cached state.
+ """
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Set valid tokens
+ provider.context.current_tokens = OAuthToken(
+ token_type="Bearer",
+ access_token="valid-token",
+ expires_in=3600,
+ )
+ provider.context.token_expiry_time = time.time() + 3600
+
+ new_id_jag = "freshly-exchanged-id-jag"
+
+ with patch("httpx.AsyncClient") as mock_client_class:
+ mock_client = AsyncMock()
+ mock_client_class.return_value.__aenter__.return_value = mock_client
+ mock_client.post = AsyncMock(
+ return_value=httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": new_id_jag,
+ "token_type": "N_A",
+ },
+ )
+ )
+
+ request = await provider._perform_authorization()
+ mock_client.post.assert_called_once()
+
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["assertion"][0] == new_id_jag
+
+
+@pytest.mark.anyio
+async def test_audience_override_warning(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test that audience override logs a warning when values differ."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://configured-audience.example/", # Different from issuer
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set OAuth metadata with different issuer
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://actual-issuer.example/"), # Different from configured
+ authorization_endpoint=AnyHttpUrl("https://actual-issuer.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://actual-issuer.example/oauth2/token"),
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should log warning about audience override
+ with patch.object(
+ logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning"
+ ) as mock_warning:
+ await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify warning was called with message about override
+ mock_warning.assert_called_once()
+ warning_message = mock_warning.call_args[0][0]
+ assert "Overriding audience" in warning_message
+ assert "https://configured-audience.example/" in warning_message
+ assert "https://actual-issuer.example/" in warning_message
+
+
+@pytest.mark.anyio
+async def test_audience_no_warning_when_matching(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test that no warning is logged when audience matches issuer."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/", # Same as issuer
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set OAuth metadata with matching issuer
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"), # Same as configured
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should NOT log warning when values match
+ with patch.object(
+ logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning"
+ ) as mock_warning:
+ await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify warning was NOT called
+ mock_warning.assert_not_called()
+
+
+@pytest.mark.anyio
+async def test_empty_scope_not_included(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test that empty or whitespace-only scope is not included in token request."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope=" ", # Whitespace-only scope
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify scope was NOT included in request
+ call_args = mock_client.post.call_args
+ assert "scope" not in call_args[1]["data"]
+
+
+@pytest.mark.anyio
+async def test_audience_override_disabled(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test that override_audience_with_issuer=False preserves the configured audience."""
+ configured_audience = "https://configured-audience.example/"
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer=configured_audience,
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ override_audience_with_issuer=False, # Disable override
+ )
+
+ # Set OAuth metadata with a DIFFERENT issuer
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://different-issuer.example/"),
+ authorization_endpoint=AnyHttpUrl("https://different-issuer.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://different-issuer.example/oauth2/token"),
+ )
+
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should NOT override audience even though discovered issuer is different
+ with patch.object(
+ logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning"
+ ) as mock_warning:
+ await provider.exchange_token_for_id_jag(mock_client)
+ # No warning should be logged because override is disabled
+ mock_warning.assert_not_called()
+
+ # Verify the CONFIGURED audience was used (not the discovered issuer)
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["audience"] == configured_audience
+
+
+@pytest.mark.anyio
+async def test_exchange_token_without_oauth_metadata(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test token exchange when oauth_metadata is not set.
+
+ This tests the scenario where OAuth metadata discovery hasn't happened yet.
+ The configured audience from token_exchange_params should be used directly.
+
+ Note: Testing issuer=None is not possible because OAuthMetadata.issuer is a
+ required AnyHttpUrl field per RFC 8414, so the Pydantic model prevents None.
+ """
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.configured.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # No OAuth metadata set (None)
+ assert provider.context.oauth_metadata is None
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify the configured audience was used (no override since metadata is None)
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["audience"] == "https://auth.configured.example/"
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_does_not_mutate_context_client_info(sample_id_jag: str, mock_token_storage: Any):
+ """Test that exchange_id_jag_for_access_token does not mutate context.client_info.
+
+ Regression test: the method previously mutated client_info.token_endpoint_auth_method
+ as a side effect, which could cause unexpected behavior on subsequent calls.
+ """
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info with no auth method (None) and a secret
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret="test-secret",
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ token_endpoint_auth_method=None, # Explicitly None
+ )
+
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Call twice to verify idempotency
+ await provider.exchange_id_jag_for_access_token(sample_id_jag)
+ assert provider.context.client_info.token_endpoint_auth_method is None
+
+ await provider.exchange_id_jag_for_access_token(sample_id_jag)
+ assert provider.context.client_info.token_endpoint_auth_method is None
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_includes_resource_param(sample_id_jag: str, mock_token_storage: Any):
+ """Test that JWT bearer grant includes resource parameter per RFC 8707."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Set protected resource metadata so should_include_resource_param returns True
+ provider.context.protected_resource_metadata = ProtectedResourceMetadata(
+ resource=AnyHttpUrl("https://mcp-server.example/"),
+ authorization_servers=[AnyHttpUrl("https://auth.mcp-server.example/")],
+ )
+
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+ body_params = urllib.parse.parse_qs(request.content.decode())
+
+ assert "resource" in body_params
+ assert body_params["resource"][0] == "https://mcp-server.example/"
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_includes_resource_param_via_protocol_version(
+ sample_id_jag: str, mock_token_storage: Any
+):
+ """Test that JWT bearer grant includes resource when protocol version qualifies."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # No PRM, but protocol version >= 2025-06-18 qualifies
+ provider.context.protocol_version = "2025-06-18"
+
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+ body_params = urllib.parse.parse_qs(request.content.decode())
+
+ assert "resource" in body_params
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_omits_resource_when_not_applicable(sample_id_jag: str, mock_token_storage: Any):
+ """Test that JWT bearer grant omits resource when conditions are not met."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # No PRM and no qualifying protocol version
+ provider.context.protected_resource_metadata = None
+ provider.context.protocol_version = None
+
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+ body_params = urllib.parse.parse_qs(request.content.decode())
+
+ assert "resource" not in body_params
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_includes_scope(sample_id_jag: str, mock_token_storage: Any):
+ """Test that JWT bearer grant includes scope from client metadata."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ scope="read write",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+ body_params = urllib.parse.parse_qs(request.content.decode())
+
+ assert "scope" in body_params
+ assert body_params["scope"][0] == "read write"
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_omits_scope_when_not_set(sample_id_jag: str, mock_token_storage: Any):
+ """Test that JWT bearer grant omits scope when client metadata has no scope."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ # No scope
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+ body_params = urllib.parse.parse_qs(request.content.decode())
+
+ assert "scope" not in body_params
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_without_client_info(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant request building when context has no client_info at all."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # No client_info set (None)
+ assert provider.context.client_info is None
+
+ request = await provider.exchange_id_jag_for_access_token(sample_id_jag)
+
+ # Should still produce a valid request without client credentials
+ assert isinstance(request, httpx.Request)
+ body_params = urllib.parse.parse_qs(request.content.decode())
+ assert body_params["grant_type"][0] == "urn:ietf:params:oauth:grant-type:jwt-bearer"
+ assert body_params["assertion"][0] == sample_id_jag
+ assert "client_id" not in body_params