@@ -110,6 +110,90 @@ async def message_handler( # pragma: no cover
110110 assert isinstance (initialized_notification , InitializedNotification )
111111
112112
113+ @pytest .mark .anyio
114+ async def test_client_session_initialize_custom_protocol_version ():
115+ client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
116+ server_to_client_send , server_to_client_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
117+
118+ initialized_notification = None
119+ result = None
120+
121+ async def mock_server ():
122+ nonlocal initialized_notification
123+
124+ session_message = await client_to_server_receive .receive ()
125+ jsonrpc_request = session_message .message
126+ assert isinstance (jsonrpc_request , JSONRPCRequest )
127+ request = client_request_adapter .validate_python (
128+ jsonrpc_request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
129+ )
130+ assert isinstance (request , InitializeRequest )
131+ assert request .params .protocol_version == "2024-11-05"
132+
133+ result = InitializeResult (
134+ protocol_version = "2024-11-05" ,
135+ capabilities = ServerCapabilities (
136+ logging = None ,
137+ resources = None ,
138+ tools = None ,
139+ experimental = None ,
140+ prompts = None ,
141+ ),
142+ server_info = Implementation (name = "mock-server" , version = "0.1.0" ),
143+ instructions = "The server instructions." ,
144+ )
145+
146+ async with server_to_client_send :
147+ await server_to_client_send .send (
148+ SessionMessage (
149+ JSONRPCResponse (
150+ jsonrpc = "2.0" ,
151+ id = jsonrpc_request .id ,
152+ result = result .model_dump (by_alias = True , mode = "json" , exclude_none = True ),
153+ )
154+ )
155+ )
156+ session_notification = await client_to_server_receive .receive ()
157+ jsonrpc_notification = session_notification .message
158+ assert isinstance (jsonrpc_notification , JSONRPCNotification )
159+ initialized_notification = client_notification_adapter .validate_python (
160+ jsonrpc_notification .model_dump (by_alias = True , mode = "json" , exclude_none = True )
161+ )
162+
163+ # Create a message handler to catch exceptions
164+ async def message_handler ( # pragma: no cover
165+ message : RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ,
166+ ) -> None :
167+ if isinstance (message , Exception ):
168+ raise message
169+
170+ async with (
171+ ClientSession (
172+ server_to_client_receive ,
173+ client_to_server_send ,
174+ message_handler = message_handler ,
175+ ) as session ,
176+ anyio .create_task_group () as tg ,
177+ client_to_server_send ,
178+ client_to_server_receive ,
179+ server_to_client_send ,
180+ server_to_client_receive ,
181+ ):
182+ tg .start_soon (mock_server )
183+ result = await session .initialize (protocol_version = "2024-11-05" )
184+
185+ # Assert the result
186+ assert isinstance (result , InitializeResult )
187+ assert result .protocol_version == "2024-11-05"
188+ assert isinstance (result .capabilities , ServerCapabilities )
189+ assert result .server_info == Implementation (name = "mock-server" , version = "0.1.0" )
190+ assert result .instructions == "The server instructions."
191+
192+ # Check that the client sent the initialized notification
193+ assert initialized_notification
194+ assert isinstance (initialized_notification , InitializedNotification )
195+
196+
113197@pytest .mark .anyio
114198async def test_client_session_custom_client_info ():
115199 client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
0 commit comments