|
14 | 14 | ) |
15 | 15 | from mcp.server.auth.json_response import PydanticJSONResponse |
16 | 16 | from mcp.server.auth.provider import ( |
| 17 | + AuthorizationErrorCode, |
17 | 18 | AuthorizationParams, |
| 19 | + AuthorizeError, |
18 | 20 | OAuthServerProvider, |
19 | 21 | construct_redirect_uri, |
20 | 22 | ) |
@@ -49,20 +51,9 @@ class AuthorizationRequest(BaseModel): |
49 | 51 | ) |
50 | 52 |
|
51 | 53 |
|
52 | | -AuthorizationErrorCode = Literal[ |
53 | | - "invalid_request", |
54 | | - "unauthorized_client", |
55 | | - "access_denied", |
56 | | - "unsupported_response_type", |
57 | | - "invalid_scope", |
58 | | - "server_error", |
59 | | - "temporarily_unavailable", |
60 | | -] |
61 | | - |
62 | | - |
63 | 54 | class AuthorizationErrorResponse(BaseModel): |
64 | 55 | error: AuthorizationErrorCode |
65 | | - error_description: str |
| 56 | + error_description: str | None |
66 | 57 | error_uri: AnyUrl | None = None |
67 | 58 | # must be set if provided in the request |
68 | 59 | state: str | None = None |
@@ -98,16 +89,14 @@ async def handle(self, request: Request) -> Response: |
98 | 89 |
|
99 | 90 | async def error_response( |
100 | 91 | error: AuthorizationErrorCode, |
101 | | - error_description: str, |
| 92 | + error_description: str | None, |
102 | 93 | attempt_load_client: bool = True, |
103 | 94 | ): |
104 | 95 | nonlocal client, redirect_uri, state |
105 | 96 | if client is None and attempt_load_client: |
106 | 97 | # make last-ditch attempt to load the client |
107 | 98 | client_id = best_effort_extract_string("client_id", params) |
108 | | - client = client_id and await self.provider.clients_store.get_client( |
109 | | - client_id |
110 | | - ) |
| 99 | + client = client_id and await self.provider.get_client(client_id) |
111 | 100 | if redirect_uri is None and client: |
112 | 101 | # make last-ditch effort to load the redirect uri |
113 | 102 | if params is not None and "redirect_uri" not in params: |
@@ -171,7 +160,7 @@ async def error_response( |
171 | 160 | ) |
172 | 161 |
|
173 | 162 | # Get client information |
174 | | - client = await self.provider.clients_store.get_client( |
| 163 | + client = await self.provider.get_client( |
175 | 164 | auth_request.client_id, |
176 | 165 | ) |
177 | 166 | if not client: |
@@ -210,15 +199,22 @@ async def error_response( |
210 | 199 | redirect_uri=redirect_uri, |
211 | 200 | ) |
212 | 201 |
|
213 | | - # Let the provider pick the next URI to redirect to |
214 | | - return RedirectResponse( |
215 | | - url=await self.provider.authorize( |
216 | | - client, |
217 | | - auth_params, |
218 | | - ), |
219 | | - status_code=302, |
220 | | - headers={"Cache-Control": "no-store"}, |
221 | | - ) |
| 202 | + try: |
| 203 | + # Let the provider pick the next URI to redirect to |
| 204 | + return RedirectResponse( |
| 205 | + url=await self.provider.authorize( |
| 206 | + client, |
| 207 | + auth_params, |
| 208 | + ), |
| 209 | + status_code=302, |
| 210 | + headers={"Cache-Control": "no-store"}, |
| 211 | + ) |
| 212 | + except AuthorizeError as e: |
| 213 | + # Handle authorization errors as defined in RFC 6749 Section 4.1.2.1 |
| 214 | + return await error_response( |
| 215 | + error=e.error, |
| 216 | + error_description=e.error_description, |
| 217 | + ) |
222 | 218 |
|
223 | 219 | except Exception as validation_error: |
224 | 220 | # Catch-all for unexpected errors |
|
0 commit comments