Skip to content

Commit 45833bb

Browse files
committed
fix: graceful SSE drain on session manager shutdown
Terminate all active transports before cancelling the task group during shutdown. This closes in-memory anyio streams cleanly, allowing EventSourceResponse to send a final `more_body=False` chunk — a clean HTTP close instead of a connection reset that triggers "upstream prematurely closed connection" errors at reverse proxies. Changes: - Track in-flight stateless transports in _stateless_transports set - In run() finally block, call terminate() on all transports (both stateful and stateless) before tg.cancel_scope.cancel() - Add E2E tests for graceful shutdown in both stateless and stateful modes using httpx.ASGITransport Upstream PR: modelcontextprotocol#2239
1 parent 02ed42b commit 45833bb

2 files changed

Lines changed: 222 additions & 9 deletions

File tree

src/mcp/server/streamable_http_manager.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def __init__(
9090
self._session_creation_lock = anyio.Lock()
9191
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
9292

93+
# Track in-flight stateless transports for graceful shutdown
94+
self._stateless_transports: set[StreamableHTTPServerTransport] = set()
95+
9396
# The task group will be set during lifespan
9497
self._task_group = None
9598
# Thread-safe tracking of run() calls
@@ -130,11 +133,28 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
130133
yield # Let the application run
131134
finally:
132135
logger.info("StreamableHTTP session manager shutting down")
136+
137+
# Terminate all active transports before cancelling the task
138+
# group. This closes their in-memory streams, which lets
139+
# EventSourceResponse send a final ``more_body=False`` chunk
140+
# — a clean HTTP close instead of a connection reset.
141+
for transport in list(self._server_instances.values()):
142+
try:
143+
await transport.terminate()
144+
except Exception:
145+
logger.debug("Error terminating transport during shutdown", exc_info=True)
146+
for transport in list(self._stateless_transports):
147+
try:
148+
await transport.terminate()
149+
except Exception:
150+
logger.debug("Error terminating stateless transport during shutdown", exc_info=True)
151+
133152
# Cancel task group to stop all spawned tasks
134153
tg.cancel_scope.cancel()
135154
self._task_group = None
136155
# Clear any remaining server instances
137156
self._server_instances.clear()
157+
self._stateless_transports.clear()
138158

139159
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
140160
"""Process ASGI request with proper session handling and transport setup.
@@ -166,6 +186,9 @@ async def _handle_stateless_request(self, scope: Scope, receive: Receive, send:
166186
security_settings=self.security_settings,
167187
)
168188

189+
# Track for graceful shutdown
190+
self._stateless_transports.add(http_transport)
191+
169192
# Start server in a new task
170193
async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED):
171194
async with http_transport.connect() as streams:
@@ -185,13 +208,16 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
185208
# This ensures the server task is cancelled when the request
186209
# finishes, preventing zombie tasks from accumulating.
187210
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1764
188-
async with anyio.create_task_group() as request_tg:
189-
await request_tg.start(run_stateless_server)
190-
# Handle the HTTP request directly in the caller's context
191-
# (not as a child task) so execution flows back naturally.
192-
await http_transport.handle_request(scope, receive, send)
193-
# Cancel the request-scoped task group to stop the server task.
194-
request_tg.cancel_scope.cancel()
211+
try:
212+
async with anyio.create_task_group() as request_tg:
213+
await request_tg.start(run_stateless_server)
214+
# Handle the HTTP request directly in the caller's context
215+
# (not as a child task) so execution flows back naturally.
216+
await http_transport.handle_request(scope, receive, send)
217+
# Cancel the request-scoped task group to stop the server task.
218+
request_tg.cancel_scope.cancel()
219+
finally:
220+
self._stateless_transports.discard(http_transport)
195221

196222
# Terminate after the task group exits — the server task is already
197223
# cancelled at this point, so this is just cleanup (sets _terminated

tests/server/test_streamable_http_manager.py

Lines changed: 189 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
import pytest
1111
from starlette.types import Message
1212

13-
from mcp import Client
13+
from mcp import Client, types
1414
from mcp.client.streamable_http import streamable_http_client
1515
from mcp.server import Server, ServerRequestContext, streamable_http_manager
1616
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport
17-
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
17+
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
1818
from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams
1919

2020

@@ -490,3 +490,190 @@ def test_session_idle_timeout_rejects_non_positive():
490490
def test_session_idle_timeout_rejects_stateless():
491491
with pytest.raises(RuntimeError, match="not supported in stateless"):
492492
StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True)
493+
494+
495+
MCP_HEADERS = {
496+
"Accept": "application/json, text/event-stream",
497+
"Content-Type": "application/json",
498+
}
499+
500+
_INITIALIZE_REQUEST = {
501+
"jsonrpc": "2.0",
502+
"id": 1,
503+
"method": "initialize",
504+
"params": {
505+
"protocolVersion": "2025-03-26",
506+
"capabilities": {},
507+
"clientInfo": {"name": "test", "version": "0.1"},
508+
},
509+
}
510+
511+
_INITIALIZED_NOTIFICATION = {
512+
"jsonrpc": "2.0",
513+
"method": "notifications/initialized",
514+
}
515+
516+
_TOOL_CALL_REQUEST = {
517+
"jsonrpc": "2.0",
518+
"id": 2,
519+
"method": "tools/call",
520+
"params": {"name": "slow_tool", "arguments": {"message": "hello"}},
521+
}
522+
523+
524+
def _make_slow_tool_server() -> tuple[Server, anyio.Event]:
525+
"""Create an MCP server with a tool that blocks forever, returning
526+
the server and an event that fires when the tool starts executing."""
527+
tool_started = anyio.Event()
528+
529+
async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult:
530+
tool_started.set()
531+
await anyio.sleep_forever()
532+
return types.CallToolResult( # pragma: no cover
533+
content=[types.TextContent(type="text", text="never reached")]
534+
)
535+
536+
async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult:
537+
return ListToolsResult(
538+
tools=[
539+
types.Tool(
540+
name="slow_tool",
541+
description="A tool that blocks forever",
542+
inputSchema={"type": "object", "properties": {"message": {"type": "string"}}},
543+
)
544+
]
545+
)
546+
547+
app = Server("test-graceful-shutdown", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools)
548+
return app, tool_started
549+
550+
551+
@pytest.mark.anyio
552+
async def test_graceful_shutdown_terminates_active_stateless_transports():
553+
"""Verify that shutting down the session manager terminates in-flight
554+
stateless transports so SSE streams close cleanly (``more_body=False``)
555+
instead of being abruptly cancelled.
556+
557+
This prevents "upstream prematurely closed connection" errors at reverse
558+
proxies like nginx.
559+
"""
560+
app, tool_started = _make_slow_tool_server()
561+
manager = StreamableHTTPSessionManager(app=app, stateless=True)
562+
563+
mcp_app = StreamableHTTPASGIApp(manager)
564+
565+
manager_ready = anyio.Event()
566+
stream_outcome: str | None = None
567+
568+
with anyio.fail_after(10):
569+
async with anyio.create_task_group() as tg:
570+
571+
async def run_lifespan_and_shutdown():
572+
async with manager.run():
573+
manager_ready.set()
574+
with anyio.fail_after(5):
575+
await tool_started.wait()
576+
577+
async def make_requests():
578+
nonlocal stream_outcome
579+
with anyio.fail_after(5):
580+
await manager_ready.wait()
581+
async with (
582+
httpx.ASGITransport(mcp_app) as transport,
583+
httpx.AsyncClient(transport=transport, base_url="http://testserver") as client,
584+
):
585+
# Initialize
586+
resp = await client.post("/mcp/", json=_INITIALIZE_REQUEST, headers=MCP_HEADERS)
587+
resp.raise_for_status()
588+
589+
# Send initialized notification
590+
resp = await client.post("/mcp/", json=_INITIALIZED_NOTIFICATION, headers=MCP_HEADERS)
591+
assert resp.status_code == 202
592+
593+
# Send slow tool call — this returns an SSE stream
594+
try:
595+
async with client.stream(
596+
"POST",
597+
"/mcp/",
598+
json=_TOOL_CALL_REQUEST,
599+
headers=MCP_HEADERS,
600+
timeout=httpx.Timeout(10, connect=5),
601+
) as stream:
602+
stream.raise_for_status()
603+
async for _chunk in stream.aiter_bytes():
604+
pass
605+
stream_outcome = "clean"
606+
except httpx.RemoteProtocolError:
607+
stream_outcome = "reset"
608+
609+
tg.start_soon(run_lifespan_and_shutdown)
610+
tg.start_soon(make_requests)
611+
612+
assert stream_outcome == "clean", f"Expected clean HTTP close, got {stream_outcome}"
613+
614+
615+
@pytest.mark.anyio
616+
async def test_graceful_shutdown_terminates_active_stateful_transports():
617+
"""Verify that shutting down the session manager terminates in-flight
618+
stateful transports so SSE streams close cleanly."""
619+
app, tool_started = _make_slow_tool_server()
620+
manager = StreamableHTTPSessionManager(app=app, stateless=False)
621+
622+
mcp_app = StreamableHTTPASGIApp(manager)
623+
624+
manager_ready = anyio.Event()
625+
stream_outcome: str | None = None
626+
627+
with anyio.fail_after(10):
628+
async with anyio.create_task_group() as tg:
629+
630+
async def run_lifespan_and_shutdown():
631+
async with manager.run():
632+
manager_ready.set()
633+
with anyio.fail_after(5):
634+
await tool_started.wait()
635+
636+
async def make_requests():
637+
nonlocal stream_outcome
638+
with anyio.fail_after(5):
639+
await manager_ready.wait()
640+
async with (
641+
httpx.ASGITransport(mcp_app) as transport,
642+
httpx.AsyncClient(transport=transport, base_url="http://testserver") as client,
643+
):
644+
# Initialize (creates a session)
645+
resp = await client.post("/mcp/", json=_INITIALIZE_REQUEST, headers=MCP_HEADERS)
646+
resp.raise_for_status()
647+
session_id = resp.headers.get(MCP_SESSION_ID_HEADER)
648+
assert session_id is not None
649+
650+
session_headers = {
651+
**MCP_HEADERS,
652+
MCP_SESSION_ID_HEADER: session_id,
653+
"mcp-protocol-version": "2025-03-26",
654+
}
655+
656+
# Send initialized notification
657+
resp = await client.post("/mcp/", json=_INITIALIZED_NOTIFICATION, headers=session_headers)
658+
assert resp.status_code == 202
659+
660+
# Send slow tool call
661+
try:
662+
async with client.stream(
663+
"POST",
664+
"/mcp/",
665+
json=_TOOL_CALL_REQUEST,
666+
headers=session_headers,
667+
timeout=httpx.Timeout(10, connect=5),
668+
) as stream:
669+
stream.raise_for_status()
670+
async for _chunk in stream.aiter_bytes():
671+
pass
672+
stream_outcome = "clean"
673+
except httpx.RemoteProtocolError:
674+
stream_outcome = "reset"
675+
676+
tg.start_soon(run_lifespan_and_shutdown)
677+
tg.start_soon(make_requests)
678+
679+
assert stream_outcome == "clean", f"Expected clean HTTP close, got {stream_outcome}"

0 commit comments

Comments
 (0)