Skip to content

Commit c9cc22a

Browse files
committed
feat: add protocol version override support for client session initialization
1 parent e8e6484 commit c9cc22a

6 files changed

Lines changed: 139 additions & 4 deletions

File tree

src/mcp/client/client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ async def main():
9595
elicitation_callback: ElicitationFnT | None = None
9696
"""Callback for handling elicitation requests."""
9797

98+
protocol_version: str | None = None
99+
"""The protocol version to request during initialization. Defaults to the latest version."""
100+
98101
_session: ClientSession | None = field(init=False, default=None)
99102
_exit_stack: AsyncExitStack | None = field(init=False, default=None)
100103
_transport: Transport = field(init=False)
@@ -129,7 +132,10 @@ async def __aenter__(self) -> Client:
129132
)
130133
)
131134

132-
await self._session.initialize()
135+
if self.protocol_version is not None:
136+
await self._session.initialize(protocol_version=self.protocol_version)
137+
else:
138+
await self._session.initialize()
133139

134140
# Transfer ownership to self for __aexit__ to handle
135141
self._exit_stack = exit_stack.pop_all()

src/mcp/client/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]:
145145
def _receive_notification_adapter(self) -> TypeAdapter[types.ServerNotification]:
146146
return types.server_notification_adapter
147147

148-
async def initialize(self) -> types.InitializeResult:
148+
async def initialize(self, protocol_version: str = types.LATEST_PROTOCOL_VERSION) -> types.InitializeResult:
149149
sampling = (
150150
(self._sampling_capabilities or types.SamplingCapability())
151151
if self._sampling_callback is not _default_sampling_callback
@@ -168,7 +168,7 @@ async def initialize(self) -> types.InitializeResult:
168168
result = await self.send_request(
169169
types.InitializeRequest(
170170
params=types.InitializeRequestParams(
171-
protocol_version=types.LATEST_PROTOCOL_VERSION,
171+
protocol_version=protocol_version,
172172
capabilities=types.ClientCapabilities(
173173
sampling=sampling,
174174
elicitation=elicitation,

src/mcp/client/session_group.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class ClientSessionParameters:
8080
logging_callback: LoggingFnT | None = None
8181
message_handler: MessageHandlerFnT | None = None
8282
client_info: types.Implementation | None = None
83+
protocol_version: str | None = None
8384

8485

8586
class ClientSessionGroup:
@@ -313,7 +314,10 @@ async def _establish_session(
313314
)
314315
)
315316

316-
result = await session.initialize()
317+
if session_params.protocol_version is not None:
318+
result = await session.initialize(protocol_version=session_params.protocol_version)
319+
else:
320+
result = await session.initialize()
317321

318322
# Session successfully initialized.
319323
# Store its stack and register the stack with the main group stack.

tests/client/test_client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ async def test_client_is_initialized(app: MCPServer):
113113
assert client.initialize_result.server_info.name == "test"
114114

115115

116+
async def test_client_custom_protocol_version(app: MCPServer):
117+
"""Test that the client negotiates a custom protocol version when configured."""
118+
async with Client(app, protocol_version="2024-11-05") as client:
119+
assert client.initialize_result.protocol_version == "2024-11-05"
120+
assert client.initialize_result.server_info.name == "test"
121+
122+
116123
async def test_client_with_simple_server(simple_server: Server):
117124
"""Test that from_server works with a basic Server instance."""
118125
async with Client(simple_server) as client:

tests/client/test_session.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,90 @@ async def message_handler( # pragma: no cover
110110
assert isinstance(initialized_notification, InitializedNotification)
111111

112112

113+
@pytest.mark.anyio
114+
async def test_client_session_initialize_custom_protocol_version():
115+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
116+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
117+
118+
initialized_notification = None
119+
result = None
120+
121+
async def mock_server():
122+
nonlocal initialized_notification
123+
124+
session_message = await client_to_server_receive.receive()
125+
jsonrpc_request = session_message.message
126+
assert isinstance(jsonrpc_request, JSONRPCRequest)
127+
request = client_request_adapter.validate_python(
128+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
129+
)
130+
assert isinstance(request, InitializeRequest)
131+
assert request.params.protocol_version == "2024-11-05"
132+
133+
result = InitializeResult(
134+
protocol_version="2024-11-05",
135+
capabilities=ServerCapabilities(
136+
logging=None,
137+
resources=None,
138+
tools=None,
139+
experimental=None,
140+
prompts=None,
141+
),
142+
server_info=Implementation(name="mock-server", version="0.1.0"),
143+
instructions="The server instructions.",
144+
)
145+
146+
async with server_to_client_send:
147+
await server_to_client_send.send(
148+
SessionMessage(
149+
JSONRPCResponse(
150+
jsonrpc="2.0",
151+
id=jsonrpc_request.id,
152+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
153+
)
154+
)
155+
)
156+
session_notification = await client_to_server_receive.receive()
157+
jsonrpc_notification = session_notification.message
158+
assert isinstance(jsonrpc_notification, JSONRPCNotification)
159+
initialized_notification = client_notification_adapter.validate_python(
160+
jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True)
161+
)
162+
163+
# Create a message handler to catch exceptions
164+
async def message_handler( # pragma: no cover
165+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
166+
) -> None:
167+
if isinstance(message, Exception):
168+
raise message
169+
170+
async with (
171+
ClientSession(
172+
server_to_client_receive,
173+
client_to_server_send,
174+
message_handler=message_handler,
175+
) as session,
176+
anyio.create_task_group() as tg,
177+
client_to_server_send,
178+
client_to_server_receive,
179+
server_to_client_send,
180+
server_to_client_receive,
181+
):
182+
tg.start_soon(mock_server)
183+
result = await session.initialize(protocol_version="2024-11-05")
184+
185+
# Assert the result
186+
assert isinstance(result, InitializeResult)
187+
assert result.protocol_version == "2024-11-05"
188+
assert isinstance(result.capabilities, ServerCapabilities)
189+
assert result.server_info == Implementation(name="mock-server", version="0.1.0")
190+
assert result.instructions == "The server instructions."
191+
192+
# Check that the client sent the initialized notification
193+
assert initialized_notification
194+
assert isinstance(initialized_notification, InitializedNotification)
195+
196+
113197
@pytest.mark.anyio
114198
async def test_client_session_custom_client_info():
115199
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)

tests/client/test_session_group.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,3 +385,37 @@ async def test_client_session_group_establish_session_parameterized(
385385
# 3. Assert returned values
386386
assert returned_server_info is mock_initialize_result.server_info
387387
assert returned_session is mock_entered_session
388+
389+
390+
@pytest.mark.anyio
391+
async def test_client_session_group_establish_session_custom_protocol_version():
392+
with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class:
393+
with mock.patch("mcp.client.session_group.mcp.stdio_client") as mock_stdio_client:
394+
mock_client_cm_instance = mock.AsyncMock(name="stdioClientCM")
395+
mock_read_stream = mock.AsyncMock(name="stdioRead")
396+
mock_write_stream = mock.AsyncMock(name="stdioWrite")
397+
398+
mock_client_cm_instance.__aenter__.return_value = (mock_read_stream, mock_write_stream)
399+
mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None)
400+
mock_stdio_client.return_value = mock_client_cm_instance
401+
402+
mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM")
403+
mock_ClientSession_class.return_value = mock_raw_session_cm
404+
405+
mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance")
406+
mock_raw_session_cm.__aenter__.return_value = mock_entered_session
407+
mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None)
408+
409+
mock_initialize_result = mock.AsyncMock(name="InitializeResult")
410+
mock_initialize_result.server_info = types.Implementation(name="foo", version="1")
411+
mock_entered_session.initialize.return_value = mock_initialize_result
412+
413+
group = ClientSessionGroup()
414+
server_params = StdioServerParameters(command="test_stdio_cmd")
415+
session_params = ClientSessionParameters(protocol_version="2024-11-05")
416+
417+
async with contextlib.AsyncExitStack() as stack:
418+
group._exit_stack = stack
419+
await group._establish_session(server_params, session_params)
420+
421+
mock_entered_session.initialize.assert_awaited_once_with(protocol_version="2024-11-05")

0 commit comments

Comments
 (0)