Skip to content

Commit e0d3f03

Browse files
authored
Merge pull request #66 from dgenio/feat/mcp-driver
feat: add built-in MCPDriver with stdio and Streamable HTTP transports
2 parents e2ff8a7 + 26eeab5 commit e0d3f03

8 files changed

Lines changed: 772 additions & 20 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [Unreleased]
99

1010
### Added
11+
- Built-in `MCPDriver` with stdio and Streamable HTTP transports, tool auto-discovery, normalized MCP result handling, and optional dependency guardrails.
1112
- Declared weaver-spec v0.1.0 compatibility in README: invariants I-01 (firewall), I-02 (authorization + audit), and I-06 (scoped tokens) are satisfied.
1213
- Added placeholder `conformance_stub` CI job that will activate once the weaver-spec conformance suite ships (dgenio/weaver-spec#4).
13-
1414
## [0.4.0] - 2026-03-14
1515

1616
### Added

docs/integrations.md

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,94 @@
22

33
## MCP (Model Context Protocol)
44

5-
To integrate with an MCP server, implement a custom driver that wraps the MCP client:
5+
The built-in `MCPDriver` supports both local stdio servers and remote Streamable HTTP servers.
6+
7+
Install the optional dependency first:
8+
9+
```bash
10+
pip install "weaver-kernel[mcp]"
11+
```
12+
13+
### Stdio transport
614

715
```python
8-
from agent_kernel.drivers.base import Driver, ExecutionContext
9-
from agent_kernel.models import RawResult
16+
import asyncio
1017

11-
class MCPDriver:
12-
def __init__(self, mcp_client, driver_id: str = "mcp"):
13-
self._client = mcp_client
14-
self._driver_id = driver_id
18+
from agent_kernel import CapabilityRegistry, Kernel, StaticRouter
19+
from agent_kernel.drivers.mcp import MCPDriver
20+
21+
22+
async def main() -> None:
23+
registry = CapabilityRegistry()
24+
router = StaticRouter(fallback=[])
25+
kernel = Kernel(registry=registry, router=router)
26+
27+
# Connect to a local MCP server process.
28+
driver = MCPDriver.from_stdio(
29+
command="python",
30+
args=["-m", "my_mcp_server"],
31+
server_name="local-tools",
32+
)
33+
kernel.register_driver(driver)
34+
35+
# Discover tools and register them as capabilities.
36+
capabilities = await driver.discover(namespace="local")
37+
registry.register_many(capabilities)
38+
39+
# Route each discovered capability to this MCP driver.
40+
for capability in capabilities:
41+
router.add_route(capability.capability_id, [driver.driver_id])
1542

16-
@property
17-
def driver_id(self) -> str:
18-
return self._driver_id
1943

20-
async def execute(self, ctx: ExecutionContext) -> RawResult:
21-
operation = ctx.args.get("operation", ctx.capability_id)
22-
result = await self._client.call_tool(operation, ctx.args)
23-
return RawResult(capability_id=ctx.capability_id, data=result)
44+
asyncio.run(main())
2445
```
2546

26-
Then register it:
47+
### Streamable HTTP transport
2748

2849
```python
29-
kernel.register_driver(MCPDriver(mcp_client))
30-
router.add_route("mcp.my_tool", ["mcp"])
50+
import asyncio
51+
52+
from agent_kernel import CapabilityRegistry, Kernel, StaticRouter
53+
from agent_kernel.drivers.mcp import MCPDriver
54+
55+
56+
async def main() -> None:
57+
registry = CapabilityRegistry()
58+
router = StaticRouter(fallback=[])
59+
kernel = Kernel(registry=registry, router=router)
60+
61+
# Connect to a remote Streamable HTTP MCP server.
62+
# Note: max_retries > 0 creates at-least-once delivery semantics for
63+
# tools/call — if a connection drops after the server processes the
64+
# request but before the response arrives, the call will be repeated.
65+
# Ensure target tools are idempotent, or set max_retries=0 for
66+
# WRITE/DESTRUCTIVE capabilities.
67+
driver = MCPDriver.from_http(
68+
url="https://example.com/mcp",
69+
server_name="remote-tools",
70+
max_retries=1,
71+
)
72+
kernel.register_driver(driver)
73+
74+
# Discover tools and register them as capabilities.
75+
capabilities = await driver.discover(namespace="remote")
76+
registry.register_many(capabilities)
77+
78+
# Route each discovered capability to this MCP driver.
79+
for capability in capabilities:
80+
router.add_route(capability.capability_id, [driver.driver_id])
81+
82+
83+
asyncio.run(main())
3184
```
3285

86+
### Notes
87+
88+
- `discover()` converts `tools/list` results into `Capability` objects.
89+
- `execute()` calls `tools/call` and normalizes MCP content blocks for the firewall.
90+
- MCP `isError` responses raise `DriverError` with the server-provided detail.
91+
- If `mcp` is not installed, factory methods raise a helpful `ImportError`.
92+
3393
## HTTPDriver
3494

3595
The built-in `HTTPDriver` supports GET, POST, PUT, DELETE:

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ dev = [
3838
"ruff>=0.4",
3939
"mypy>=1.10",
4040
"httpx>=0.27",
41+
"mcp>=1.6",
4142
]
42-
mcp = ["mcp>=1.0"]
43+
mcp = ["mcp>=1.6"]
4344
otel = ["opentelemetry-api>=1.20"]
4445

4546
[tool.hatch.build.targets.wheel]

src/agent_kernel/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
from .drivers.base import Driver, ExecutionContext
3939
from .drivers.http import HTTPDriver
40+
from .drivers.mcp import MCPDriver
4041
from .drivers.memory import InMemoryDriver, make_billing_driver
4142
from .enums import SafetyClass, SensitivityTag
4243
from .errors import (
@@ -129,6 +130,7 @@
129130
"ExecutionContext",
130131
"InMemoryDriver",
131132
"HTTPDriver",
133+
"MCPDriver",
132134
"make_billing_driver",
133135
# firewall
134136
"Firewall",

src/agent_kernel/drivers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .base import Driver, ExecutionContext
44
from .http import HTTPDriver
5+
from .mcp import MCPDriver
56
from .memory import InMemoryDriver
67

7-
__all__ = ["Driver", "ExecutionContext", "HTTPDriver", "InMemoryDriver"]
8+
__all__ = ["Driver", "ExecutionContext", "HTTPDriver", "MCPDriver", "InMemoryDriver"]

src/agent_kernel/drivers/mcp.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
"""MCP driver: execute capabilities against Model Context Protocol servers."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Awaitable, Callable
6+
from typing import Any
7+
8+
from ..enums import SafetyClass
9+
from ..errors import DriverError
10+
from ..models import Capability, ImplementationRef, RawResult
11+
from .base import ExecutionContext
12+
from .mcp_support import (
13+
SessionFactory,
14+
ToolSpec,
15+
build_http_session_factory,
16+
build_stdio_session_factory,
17+
call_tool,
18+
extract_tool_specs,
19+
normalize_call_result,
20+
)
21+
22+
# Lazy import of McpError — only available when the mcp optional dep is installed.
23+
# If mcp is absent, factory methods raise ImportError before any session is created,
24+
# so _McpError will never be None on a live driver instance.
25+
try:
26+
from mcp.shared.exceptions import McpError as _McpError
27+
except ImportError: # pragma: no cover
28+
_McpError = None # type: ignore[assignment,misc]
29+
30+
31+
def _infer_safety_class(spec: ToolSpec) -> SafetyClass:
32+
"""Infer a SafetyClass from MCP ToolAnnotations hints.
33+
34+
Uses a conservative default of READ when annotations are absent.
35+
The caller's safety_class_map takes precedence over the inferred value.
36+
"""
37+
if spec.destructive_hint:
38+
return SafetyClass.DESTRUCTIVE
39+
if spec.read_only_hint:
40+
return SafetyClass.READ
41+
return SafetyClass.READ
42+
43+
44+
class MCPDriver:
45+
"""A driver that invokes capabilities via MCP tools/call."""
46+
47+
def __init__(
48+
self,
49+
*,
50+
driver_id: str,
51+
session_factory: SessionFactory,
52+
server_name: str,
53+
transport: str,
54+
max_http_retries: int = 1,
55+
) -> None:
56+
self._driver_id = driver_id
57+
self._session_factory = session_factory
58+
self._server_name = server_name
59+
self._transport = transport
60+
self._max_http_retries = max(max_http_retries, 0)
61+
62+
@property
63+
def driver_id(self) -> str:
64+
"""Unique identifier for this driver instance."""
65+
return self._driver_id
66+
67+
@classmethod
68+
def from_stdio(
69+
cls,
70+
command: str,
71+
args: list[str] | None = None,
72+
*,
73+
server_name: str = "stdio",
74+
) -> MCPDriver:
75+
"""Create an MCP driver using stdio transport.
76+
77+
Raises:
78+
ImportError: If the optional ``mcp`` dependency is not installed.
79+
"""
80+
session_factory = build_stdio_session_factory(command=command, args=args or [])
81+
return cls(
82+
driver_id=f"mcp:{server_name}",
83+
session_factory=session_factory,
84+
server_name=server_name,
85+
transport="stdio",
86+
max_http_retries=0,
87+
)
88+
89+
@classmethod
90+
def from_http(
91+
cls,
92+
url: str,
93+
*,
94+
server_name: str = "http",
95+
max_retries: int = 1,
96+
) -> MCPDriver:
97+
"""Create an MCP driver using Streamable HTTP transport.
98+
99+
Raises:
100+
ImportError: If the optional ``mcp`` dependency is not installed.
101+
"""
102+
session_factory = build_http_session_factory(url=url)
103+
return cls(
104+
driver_id=f"mcp:{server_name}",
105+
session_factory=session_factory,
106+
server_name=server_name,
107+
transport="http",
108+
max_http_retries=max_retries,
109+
)
110+
111+
async def discover(
112+
self,
113+
*,
114+
namespace: str | None = None,
115+
safety_class_map: dict[str, SafetyClass] | None = None,
116+
) -> list[Capability]:
117+
"""Discover MCP tools across all pages and convert them to capabilities."""
118+
tools = await self._run_with_retry(
119+
operation_name="tools/list",
120+
action=self._fetch_all_tools,
121+
)
122+
123+
capabilities: list[Capability] = []
124+
for spec in extract_tool_specs(tools):
125+
capability_id = f"{namespace}.{spec.name}" if namespace else spec.name
126+
inferred = _infer_safety_class(spec)
127+
safety_class = (
128+
safety_class_map.get(spec.name, inferred)
129+
if safety_class_map is not None
130+
else inferred
131+
)
132+
capabilities.append(
133+
Capability(
134+
capability_id=capability_id,
135+
name=spec.name,
136+
description=spec.description,
137+
safety_class=safety_class,
138+
tags=["mcp", self._server_name],
139+
impl=ImplementationRef(
140+
driver_id=self._driver_id,
141+
operation=spec.name,
142+
),
143+
)
144+
)
145+
return capabilities
146+
147+
async def _fetch_all_tools(self, session: Any) -> list[Any]:
148+
"""Paginate tools/list to exhaustion and return a flat list of Tool objects."""
149+
all_tools: list[Any] = []
150+
cursor: str | None = None
151+
while True:
152+
result = await session.list_tools(cursor=cursor)
153+
all_tools.extend(getattr(result, "tools", []) or [])
154+
cursor = getattr(result, "nextCursor", None)
155+
if not cursor:
156+
break
157+
return all_tools
158+
159+
async def execute(self, ctx: ExecutionContext) -> RawResult:
160+
"""Execute an MCP tool call for the given capability context."""
161+
operation = str(ctx.args.get("operation", ctx.capability_id))
162+
params = {k: v for k, v in ctx.args.items() if k != "operation"}
163+
164+
# Apply policy constraints as default arguments, without overriding explicit args.
165+
# read_timeout_seconds is an SDK control parameter — applied to the session call
166+
# directly rather than forwarded to the tool as an argument.
167+
read_timeout_seconds_raw = ctx.constraints.get("read_timeout_seconds")
168+
for key, value in ctx.constraints.items():
169+
if key != "read_timeout_seconds":
170+
params.setdefault(key, value)
171+
172+
read_timeout_seconds: float | None = (
173+
float(read_timeout_seconds_raw) if read_timeout_seconds_raw is not None else None
174+
)
175+
176+
result = await self._run_with_retry(
177+
operation_name=f"tools/call:{operation}",
178+
action=lambda session: call_tool(
179+
session,
180+
operation=operation,
181+
params=params,
182+
read_timeout_seconds=read_timeout_seconds,
183+
),
184+
)
185+
186+
data = normalize_call_result(
187+
result,
188+
operation=operation,
189+
driver_id=self._driver_id,
190+
)
191+
return RawResult(
192+
capability_id=ctx.capability_id,
193+
data=data,
194+
metadata={
195+
"driver_id": self._driver_id,
196+
"transport": self._transport,
197+
"operation": operation,
198+
},
199+
)
200+
201+
async def _run_with_retry(
202+
self,
203+
*,
204+
operation_name: str,
205+
action: Callable[[Any], Awaitable[Any]],
206+
) -> Any:
207+
attempts = 1 + self._max_http_retries if self._transport == "http" else 1
208+
last_exc: Exception | None = None
209+
210+
for _attempt in range(attempts):
211+
try:
212+
async with self._session_factory() as session:
213+
return await action(session)
214+
except DriverError:
215+
raise
216+
except Exception as exc:
217+
# McpError is a protocol-level rejection (tool not found, auth
218+
# failure, invalid params) — the server processed and rejected the
219+
# request. It is not retryable; surface it immediately as DriverError.
220+
if _McpError is not None and isinstance(exc, _McpError):
221+
raise DriverError(
222+
f"MCPDriver '{self._driver_id}' received a protocol error "
223+
f"during {operation_name}: {exc}"
224+
) from exc
225+
# All other exceptions are session/transport failures (connection
226+
# refused, EOF, timeout) and are retryable for HTTP transport.
227+
# Note: HTTP retries create at-least-once delivery semantics for
228+
# tools/call. Callers using WRITE/DESTRUCTIVE capabilities over HTTP
229+
# should ensure the target tool is idempotent, or set max_retries=0.
230+
last_exc = exc
231+
232+
reason = str(last_exc) if last_exc is not None else "unknown transport failure"
233+
raise DriverError(
234+
f"MCPDriver '{self._driver_id}' failed during {operation_name} over "
235+
f"{self._transport}: {reason}"
236+
) from last_exc

0 commit comments

Comments
 (0)