diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 34d6a360f..2f082ef09 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import AsyncIterator from contextlib import AsyncExitStack from dataclasses import KW_ONLY, dataclass, field from typing import Any @@ -26,10 +27,14 @@ ListToolsResult, LoggingLevel, PaginatedRequestParams, + Prompt, PromptReference, ReadResourceResult, RequestParamsMeta, + Resource, + ResourceTemplate, ResourceTemplateReference, + Tool, ) @@ -195,7 +200,11 @@ async def list_resources( cursor: str | None = None, meta: RequestParamsMeta | None = None, ) -> ListResourcesResult: - """List available resources from the server.""" + """List a single page of available resources from the server. + + Returns one page only. The result may include a `next_cursor` if more + pages are available. Use `list_all_resources` to drain every page. + """ return await self.session.list_resources(params=PaginatedRequestParams(cursor=cursor, _meta=meta)) async def list_resource_templates( @@ -204,7 +213,12 @@ async def list_resource_templates( cursor: str | None = None, meta: RequestParamsMeta | None = None, ) -> ListResourceTemplatesResult: - """List available resource templates from the server.""" + """List a single page of available resource templates from the server. + + Returns one page only. The result may include a `next_cursor` if more + pages are available. Use `list_all_resource_templates` to drain every + page. + """ return await self.session.list_resource_templates(params=PaginatedRequestParams(cursor=cursor, _meta=meta)) async def read_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> ReadResourceResult: @@ -262,7 +276,11 @@ async def list_prompts( cursor: str | None = None, meta: RequestParamsMeta | None = None, ) -> ListPromptsResult: - """List available prompts from the server.""" + """List a single page of available prompts from the server. + + Returns one page only. The result may include a `next_cursor` if more + pages are available. Use `list_all_prompts` to drain every page. + """ return await self.session.list_prompts(params=PaginatedRequestParams(cursor=cursor, _meta=meta)) async def get_prompt( @@ -299,9 +317,84 @@ async def complete( return await self.session.complete(ref=ref, argument=argument, context_arguments=context_arguments) async def list_tools(self, *, cursor: str | None = None, meta: RequestParamsMeta | None = None) -> ListToolsResult: - """List available tools from the server.""" + """List a single page of available tools from the server. + + Returns one page only. The result may include a `next_cursor` if more + pages are available. Use `list_all_tools` to drain every page. + """ return await self.session.list_tools(params=PaginatedRequestParams(cursor=cursor, _meta=meta)) + async def iter_all_tools(self, *, meta: RequestParamsMeta | None = None) -> AsyncIterator[Tool]: + """Yield every tool from the server, paging through `next_cursor`. + + Useful for streaming consumers that want to process tools without + materializing the full list in memory. + """ + cursor: str | None = None + while True: + result = await self.list_tools(cursor=cursor, meta=meta) + for tool in result.tools: + yield tool + if result.next_cursor is None: + return + cursor = result.next_cursor + + async def list_all_tools(self, *, meta: RequestParamsMeta | None = None) -> list[Tool]: + """List every tool from the server, draining `next_cursor` across pages. + + Unlike `list_tools`, which returns one page, this walks pagination + until the server reports no further pages and returns the combined + list. + """ + return [tool async for tool in self.iter_all_tools(meta=meta)] + + async def iter_all_prompts(self, *, meta: RequestParamsMeta | None = None) -> AsyncIterator[Prompt]: + """Yield every prompt from the server, paging through `next_cursor`.""" + cursor: str | None = None + while True: + result = await self.list_prompts(cursor=cursor, meta=meta) + for prompt in result.prompts: + yield prompt + if result.next_cursor is None: + return + cursor = result.next_cursor + + async def list_all_prompts(self, *, meta: RequestParamsMeta | None = None) -> list[Prompt]: + """List every prompt from the server, draining `next_cursor` across pages.""" + return [prompt async for prompt in self.iter_all_prompts(meta=meta)] + + async def iter_all_resources(self, *, meta: RequestParamsMeta | None = None) -> AsyncIterator[Resource]: + """Yield every resource from the server, paging through `next_cursor`.""" + cursor: str | None = None + while True: + result = await self.list_resources(cursor=cursor, meta=meta) + for resource in result.resources: + yield resource + if result.next_cursor is None: + return + cursor = result.next_cursor + + async def list_all_resources(self, *, meta: RequestParamsMeta | None = None) -> list[Resource]: + """List every resource from the server, draining `next_cursor` across pages.""" + return [resource async for resource in self.iter_all_resources(meta=meta)] + + async def iter_all_resource_templates( + self, *, meta: RequestParamsMeta | None = None + ) -> AsyncIterator[ResourceTemplate]: + """Yield every resource template from the server, paging through `next_cursor`.""" + cursor: str | None = None + while True: + result = await self.list_resource_templates(cursor=cursor, meta=meta) + for template in result.resource_templates: + yield template + if result.next_cursor is None: + return + cursor = result.next_cursor + + async def list_all_resource_templates(self, *, meta: RequestParamsMeta | None = None) -> list[ResourceTemplate]: + """List every resource template from the server, draining `next_cursor` across pages.""" + return [template async for template in self.iter_all_resource_templates(meta=meta)] + async def send_roots_list_changed(self) -> None: """Send a notification that the roots list has changed.""" # TODO(Marcelo): Currently, there is no way for the server to handle this. We should add support. diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 961021264..f29c78535 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -8,7 +8,7 @@ import contextlib import logging -from collections.abc import Callable +from collections.abc import Awaitable, Callable from dataclasses import dataclass from types import TracebackType from typing import Any, TypeAlias @@ -67,6 +67,28 @@ class StreamableHttpParameters(BaseModel): ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters +async def _drain_paginated( + fetch_page: Callable[..., Awaitable[Any]], + attribute: str, +) -> list[Any]: + """Drain a paginated `session.list_*` call across `next_cursor` pages. + + `fetch_page` is one of the ClientSession `list_*` methods that takes a + `params=PaginatedRequestParams(...)` keyword. `attribute` is the name of + the list attribute on the result (e.g. `"tools"`, `"prompts"`). + """ + items: list[Any] = [] + cursor: str | None = None + while True: + params = types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None + result = await fetch_page(params=params) + items.extend(getattr(result, attribute)) + next_cursor = getattr(result, "next_cursor", None) + if next_cursor is None: + return items + cursor = next_cursor + + # Use dataclass instead of Pydantic BaseModel # because Pydantic BaseModel cannot handle Protocol fields. @dataclass @@ -344,9 +366,11 @@ async def _aggregate_components(self, server_info: types.Implementation, session tools_temp: dict[str, types.Tool] = {} tool_to_session_temp: dict[str, mcp.ClientSession] = {} - # Query the server for its prompts and aggregate to list. + # Query the server for its prompts and aggregate to list. Drain + # pagination so we don't drop later pages on servers that split + # results across multiple `next_cursor` responses. try: - prompts = (await session.list_prompts()).prompts + prompts = await _drain_paginated(session.list_prompts, "prompts") for prompt in prompts: name = self._component_name(prompt.name, server_info) prompts_temp[name] = prompt @@ -356,7 +380,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session # Query the server for its resources and aggregate to list. try: - resources = (await session.list_resources()).resources + resources = await _drain_paginated(session.list_resources, "resources") for resource in resources: name = self._component_name(resource.name, server_info) resources_temp[name] = resource @@ -366,7 +390,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session # Query the server for its tools and aggregate to list. try: - tools = (await session.list_tools()).tools + tools = await _drain_paginated(session.list_tools, "tools") for tool in tools: name = self._component_name(tool.name, server_info) tools_temp[name] = tool diff --git a/tests/client/test_list_all_pagination.py b/tests/client/test_list_all_pagination.py new file mode 100644 index 000000000..8fbb3b087 --- /dev/null +++ b/tests/client/test_list_all_pagination.py @@ -0,0 +1,206 @@ +"""Tests for client `list_all_*` and `iter_all_*` pagination helpers. + +These helpers drain `next_cursor` across pages, so a server can split +its tools/prompts/resources/resource_templates across multiple list +calls and the client still sees the full collection. + +See: https://github.com/modelcontextprotocol/python-sdk/issues/2556 +""" + +from collections.abc import Awaitable, Callable +from typing import TypeVar + +import pytest + +from mcp import Client, types +from mcp.server import Server, ServerRequestContext + +from .conftest import StreamSpyCollection + +pytestmark = pytest.mark.anyio + +ItemT = TypeVar("ItemT") +ResultT = TypeVar("ResultT") + + +def _paginated_handler( + pages: list[list[str]], + make_item: Callable[[str], ItemT], + result_cls: Callable[..., ResultT], + items_field: str, +) -> Callable[[ServerRequestContext, types.PaginatedRequestParams | None], Awaitable[ResultT]]: + """Build a lowlevel-server handler that serves `pages` of items. + + Each page advances `next_cursor` from `"1"` ... `"N-1"` and the last + page returns no cursor. The handler is keyed by the cursor it receives + in the request, so cursor handling on both sides is exercised. + """ + # Map incoming cursor (None for first page) to the page index to return. + cursor_to_page: dict[str | None, int] = {None: 0} + for index in range(len(pages) - 1): + cursor_to_page[str(index + 1)] = index + 1 + + async def handler(_ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ResultT: + cursor = params.cursor if params else None + page_index = cursor_to_page[cursor] + page = pages[page_index] + next_cursor = str(page_index + 1) if page_index + 1 < len(pages) else None + return result_cls( + **{items_field: [make_item(name) for name in page]}, + next_cursor=next_cursor, + ) + + return handler + + +def _make_tool(name: str) -> types.Tool: + return types.Tool(name=name, input_schema={}) + + +def _make_prompt(name: str) -> types.Prompt: + return types.Prompt(name=name) + + +def _make_resource(name: str) -> types.Resource: + return types.Resource(name=name, uri=f"test://{name}") + + +def _make_resource_template(name: str) -> types.ResourceTemplate: + return types.ResourceTemplate(name=name, uri_template=f"test://{name}/{{id}}") + + +# ---- list_all_tools / iter_all_tools --------------------------------------- + + +async def test_list_all_tools_drains_all_pages( + stream_spy: Callable[[], StreamSpyCollection], +): + """list_all_tools follows `next_cursor` and returns the union of pages.""" + pages = [["a", "b"], ["c", "d"], ["e"]] + server = Server( + "paginated-tools", + on_list_tools=_paginated_handler(pages, _make_tool, types.ListToolsResult, "tools"), + ) + + async with Client(server) as client: + spies = stream_spy() + tools = await client.list_all_tools() + + assert [t.name for t in tools] == ["a", "b", "c", "d", "e"] + # One request per page. + requests = spies.get_client_requests(method="tools/list") + assert len(requests) == 3 + # First request has no cursor; subsequent ones carry the previous cursor. + assert requests[0].params is None or "cursor" not in requests[0].params + assert requests[1].params is not None and requests[1].params["cursor"] == "1" + assert requests[2].params is not None and requests[2].params["cursor"] == "2" + + +async def test_list_all_tools_single_page(): + """A server that returns one page (no cursor) should give back one list.""" + + async def handle_list_tools( + _ctx: ServerRequestContext, _params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[_make_tool("only")]) + + server = Server("single-page-tools", on_list_tools=handle_list_tools) + + async with Client(server) as client: + tools = await client.list_all_tools() + assert [t.name for t in tools] == ["only"] + + +async def test_list_all_tools_empty_server(): + """An empty server should yield an empty list, not raise.""" + + async def handle_list_tools( + _ctx: ServerRequestContext, _params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[]) + + server = Server("no-tools", on_list_tools=handle_list_tools) + + async with Client(server) as client: + tools = await client.list_all_tools() + assert tools == [] + + +async def test_iter_all_tools_streams_pages( + stream_spy: Callable[[], StreamSpyCollection], +): + """iter_all_tools yields one tool at a time and only pages when needed.""" + pages = [["a", "b"], ["c"]] + server = Server( + "stream-tools", + on_list_tools=_paginated_handler(pages, _make_tool, types.ListToolsResult, "tools"), + ) + + async with Client(server) as client: + spies = stream_spy() + seen = [tool.name async for tool in client.iter_all_tools()] + + assert seen == ["a", "b", "c"] + assert len(spies.get_client_requests(method="tools/list")) == 2 + + +# ---- list_all_prompts ------------------------------------------------------ + + +async def test_list_all_prompts_drains_all_pages( + stream_spy: Callable[[], StreamSpyCollection], +): + pages = [["p1", "p2"], ["p3"]] + server = Server( + "paginated-prompts", + on_list_prompts=_paginated_handler(pages, _make_prompt, types.ListPromptsResult, "prompts"), + ) + + async with Client(server) as client: + spies = stream_spy() + prompts = await client.list_all_prompts() + assert [p.name for p in prompts] == ["p1", "p2", "p3"] + assert len(spies.get_client_requests(method="prompts/list")) == 2 + + +# ---- list_all_resources ---------------------------------------------------- + + +async def test_list_all_resources_drains_all_pages( + stream_spy: Callable[[], StreamSpyCollection], +): + pages = [["r1", "r2"], ["r3"], ["r4"]] + server = Server( + "paginated-resources", + on_list_resources=_paginated_handler(pages, _make_resource, types.ListResourcesResult, "resources"), + ) + + async with Client(server) as client: + spies = stream_spy() + resources = await client.list_all_resources() + assert [r.name for r in resources] == ["r1", "r2", "r3", "r4"] + assert len(spies.get_client_requests(method="resources/list")) == 3 + + +# ---- list_all_resource_templates ------------------------------------------ + + +async def test_list_all_resource_templates_drains_all_pages( + stream_spy: Callable[[], StreamSpyCollection], +): + pages = [["t1"], ["t2", "t3"]] + server = Server( + "paginated-templates", + on_list_resource_templates=_paginated_handler( + pages, + _make_resource_template, + types.ListResourceTemplatesResult, + "resource_templates", + ), + ) + + async with Client(server) as client: + spies = stream_spy() + templates = await client.list_all_resource_templates() + assert [t.name for t in templates] == ["t1", "t2", "t3"] + assert len(spies.get_client_requests(method="resources/templates/list")) == 2 diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f3..c3c8d70e4 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -99,9 +99,9 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli mock_resource1.name = "resource_b" mock_prompt1 = mock.Mock(spec=types.Prompt) mock_prompt1.name = "prompt_c" - mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1]) - mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1]) - mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1]) + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1], next_cursor=None) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1], next_cursor=None) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1], next_cursor=None) # --- Test Execution --- group = ClientSessionGroup(exit_stack=mock_exit_stack) @@ -134,9 +134,9 @@ async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_s mock_session = mock.AsyncMock(spec=mcp.ClientSession) mock_tool = mock.Mock(spec=types.Tool) mock_tool.name = "base_tool" - mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool]) - mock_session.list_resources.return_value = mock.AsyncMock(resources=[]) - mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[]) + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool], next_cursor=None) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[], next_cursor=None) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[], next_cursor=None) # --- Test Setup --- def name_hook(name: str, server_info: types.Implementation) -> str: @@ -245,10 +245,10 @@ async def test_client_session_group_connect_to_server_duplicate_tool_raises_erro # Configure the new session to return a tool with the *same name* duplicate_tool = mock.Mock(spec=types.Tool) duplicate_tool.name = existing_tool_name - mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool]) + mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool], next_cursor=None) # Keep other lists empty for simplicity - mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[]) - mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[]) + mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[], next_cursor=None) + mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[], next_cursor=None) # --- Test Execution and Assertion --- with pytest.raises(MCPError) as excinfo: @@ -385,3 +385,45 @@ async def test_client_session_group_establish_session_parameterized( # 3. Assert returned values assert returned_server_info is mock_initialize_result.server_info assert returned_session is mock_entered_session + + +@pytest.mark.anyio +async def test_client_session_group_aggregates_paginated_tools( + mock_exit_stack: contextlib.AsyncExitStack, +): + """ClientSessionGroup must drain `next_cursor` so it sees every page. + + Regression for https://github.com/modelcontextprotocol/python-sdk/issues/2556: + aggregators across multiple MCP servers are the most likely place to hit + pagination, so the group should not stop at page one. + """ + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "PaginatedServer" + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + + tool_page1_a = mock.Mock(spec=types.Tool) + tool_page1_a.name = "tool_a" + tool_page1_b = mock.Mock(spec=types.Tool) + tool_page1_b.name = "tool_b" + tool_page2 = mock.Mock(spec=types.Tool) + tool_page2.name = "tool_c" + + list_tools_responses = [ + mock.AsyncMock(tools=[tool_page1_a, tool_page1_b], next_cursor="page-2"), + mock.AsyncMock(tools=[tool_page2], next_cursor=None), + ] + mock_session.list_tools.side_effect = list_tools_responses + mock_session.list_resources.return_value = mock.AsyncMock(resources=[], next_cursor=None) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[], next_cursor=None) + + group = ClientSessionGroup(exit_stack=mock_exit_stack) + with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)): + await group.connect_to_server(StdioServerParameters(command="test")) + + assert set(group.tools.keys()) == {"tool_a", "tool_b", "tool_c"} + # Two pages -> two `list_tools` calls. + assert mock_session.list_tools.await_count == 2 + # Second call should have supplied the cursor returned by the first. + second_call_kwargs = mock_session.list_tools.await_args_list[1].kwargs + assert second_call_kwargs["params"] is not None + assert second_call_kwargs["params"].cursor == "page-2"