From bc76e74b20ace193113c5f140fc511a159a569dc Mon Sep 17 00:00:00 2001 From: STiFLeR7 Date: Thu, 21 May 2026 11:35:10 +0530 Subject: [PATCH] feat: add protocol version override support for client session initialization --- src/mcp/client/client.py | 8 ++- src/mcp/client/session.py | 4 +- src/mcp/client/session_group.py | 6 ++- tests/client/test_client.py | 7 +++ tests/client/test_session.py | 84 ++++++++++++++++++++++++++++++ tests/client/test_session_group.py | 34 ++++++++++++ 6 files changed, 139 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 34d6a360f..570958be9 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -95,6 +95,9 @@ async def main(): elicitation_callback: ElicitationFnT | None = None """Callback for handling elicitation requests.""" + protocol_version: str | None = None + """The protocol version to request during initialization. Defaults to the latest version.""" + _session: ClientSession | None = field(init=False, default=None) _exit_stack: AsyncExitStack | None = field(init=False, default=None) _transport: Transport = field(init=False) @@ -129,7 +132,10 @@ async def __aenter__(self) -> Client: ) ) - await self._session.initialize() + if self.protocol_version is not None: + await self._session.initialize(protocol_version=self.protocol_version) + else: + await self._session.initialize() # Transfer ownership to self for __aexit__ to handle self._exit_stack = exit_stack.pop_all() diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0cea454a7..ee4a6b2cb 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -145,7 +145,7 @@ def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]: def _receive_notification_adapter(self) -> TypeAdapter[types.ServerNotification]: return types.server_notification_adapter - async def initialize(self) -> types.InitializeResult: + async def initialize(self, protocol_version: str = types.LATEST_PROTOCOL_VERSION) -> types.InitializeResult: sampling = ( (self._sampling_capabilities or types.SamplingCapability()) if self._sampling_callback is not _default_sampling_callback @@ -168,7 +168,7 @@ async def initialize(self) -> types.InitializeResult: result = await self.send_request( types.InitializeRequest( params=types.InitializeRequestParams( - protocol_version=types.LATEST_PROTOCOL_VERSION, + protocol_version=protocol_version, capabilities=types.ClientCapabilities( sampling=sampling, elicitation=elicitation, diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 961021264..543a2f980 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -80,6 +80,7 @@ class ClientSessionParameters: logging_callback: LoggingFnT | None = None message_handler: MessageHandlerFnT | None = None client_info: types.Implementation | None = None + protocol_version: str | None = None class ClientSessionGroup: @@ -313,7 +314,10 @@ async def _establish_session( ) ) - result = await session.initialize() + if session_params.protocol_version is not None: + result = await session.initialize(protocol_version=session_params.protocol_version) + else: + result = await session.initialize() # Session successfully initialized. # Store its stack and register the stack with the main group stack. diff --git a/tests/client/test_client.py b/tests/client/test_client.py index ac52a9024..15ab538e8 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -113,6 +113,13 @@ async def test_client_is_initialized(app: MCPServer): assert client.initialize_result.server_info.name == "test" +async def test_client_custom_protocol_version(app: MCPServer): + """Test that the client negotiates a custom protocol version when configured.""" + async with Client(app, protocol_version="2024-11-05") as client: + assert client.initialize_result.protocol_version == "2024-11-05" + assert client.initialize_result.server_info.name == "test" + + async def test_client_with_simple_server(simple_server: Server): """Test that from_server works with a basic Server instance.""" async with Client(simple_server) as client: diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f25c964f0..f6b1e2b65 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -110,6 +110,90 @@ async def message_handler( # pragma: no cover assert isinstance(initialized_notification, InitializedNotification) +@pytest.mark.anyio +async def test_client_session_initialize_custom_protocol_version(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + initialized_notification = None + result = None + + async def mock_server(): + nonlocal initialized_notification + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request, JSONRPCRequest) + request = client_request_adapter.validate_python( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request, InitializeRequest) + assert request.params.protocol_version == "2024-11-05" + + result = InitializeResult( + protocol_version="2024-11-05", + capabilities=ServerCapabilities( + logging=None, + resources=None, + tools=None, + experimental=None, + prompts=None, + ), + server_info=Implementation(name="mock-server", version="0.1.0"), + instructions="The server instructions.", + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + session_notification = await client_to_server_receive.receive() + jsonrpc_notification = session_notification.message + assert isinstance(jsonrpc_notification, JSONRPCNotification) + initialized_notification = client_notification_adapter.validate_python( + jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + + # Create a message handler to catch exceptions + async def message_handler( # pragma: no cover + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + result = await session.initialize(protocol_version="2024-11-05") + + # Assert the result + assert isinstance(result, InitializeResult) + assert result.protocol_version == "2024-11-05" + assert isinstance(result.capabilities, ServerCapabilities) + assert result.server_info == Implementation(name="mock-server", version="0.1.0") + assert result.instructions == "The server instructions." + + # Check that the client sent the initialized notification + assert initialized_notification + assert isinstance(initialized_notification, InitializedNotification) + + @pytest.mark.anyio async def test_client_session_custom_client_info(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f3..49e4cd0ce 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -385,3 +385,37 @@ async def test_client_session_group_establish_session_parameterized( # 3. Assert returned values assert returned_server_info is mock_initialize_result.server_info assert returned_session is mock_entered_session + + +@pytest.mark.anyio +async def test_client_session_group_establish_session_custom_protocol_version(): + with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class: + with mock.patch("mcp.client.session_group.mcp.stdio_client") as mock_stdio_client: + mock_client_cm_instance = mock.AsyncMock(name="stdioClientCM") + mock_read_stream = mock.AsyncMock(name="stdioRead") + mock_write_stream = mock.AsyncMock(name="stdioWrite") + + mock_client_cm_instance.__aenter__.return_value = (mock_read_stream, mock_write_stream) + mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None) + mock_stdio_client.return_value = mock_client_cm_instance + + mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM") + mock_ClientSession_class.return_value = mock_raw_session_cm + + mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance") + mock_raw_session_cm.__aenter__.return_value = mock_entered_session + mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None) + + mock_initialize_result = mock.AsyncMock(name="InitializeResult") + mock_initialize_result.server_info = types.Implementation(name="foo", version="1") + mock_entered_session.initialize.return_value = mock_initialize_result + + group = ClientSessionGroup() + server_params = StdioServerParameters(command="test_stdio_cmd") + session_params = ClientSessionParameters(protocol_version="2024-11-05") + + async with contextlib.AsyncExitStack() as stack: + group._exit_stack = stack + await group._establish_session(server_params, session_params) + + mock_entered_session.initialize.assert_awaited_once_with(protocol_version="2024-11-05")