Skip to content

Commit 2c5e19a

Browse files
committed
Fix ruff formatting, add tests for 100% coverage
1 parent 692460f commit 2c5e19a

5 files changed

Lines changed: 141 additions & 28 deletions

File tree

src/mcp/server/events.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from mcp.types import RetainedEvent
2020

2121

22-
2322
class SubscriptionRegistry:
2423
"""Thread-safe registry mapping session IDs to topic subscription patterns.
2524
@@ -48,10 +47,7 @@ async def add(self, session_id: str, pattern: str) -> None:
4847
"""
4948
segments = pattern.split("/")
5049
if len(segments) > 8:
51-
raise ValueError(
52-
f"Topic pattern exceeds maximum depth of 8 segments "
53-
f"(got {len(segments)}): {pattern}"
54-
)
50+
raise ValueError(f"Topic pattern exceeds maximum depth of 8 segments (got {len(segments)}): {pattern}")
5551
async with self._lock:
5652
self._subscriptions.setdefault(session_id, set()).add(pattern)
5753
self._compile(pattern)

src/mcp/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,6 +1463,7 @@ class EventParams(NotificationParams):
14631463
correlationId: str | None = None
14641464
requestedEffects: list[EventEffect] | None = None
14651465
expiresAt: str | None = None
1466+
14661467
@property
14671468
def event_id(self) -> str:
14681469
return self.eventId

tests/test_event_roundtrip.py

Lines changed: 91 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,25 @@
22

33
from __future__ import annotations
44

5-
import asyncio
65
from typing import Any
76

87
import anyio
98
import pytest
109

1110
from mcp import types
1211
from mcp.client.session import ClientSession
13-
from mcp.server.lowlevel.server import Server, request_ctx
14-
from mcp.shared.context import RequestContext
1512
from mcp.server.events import RetainedValueStore, SubscriptionRegistry
1613
from mcp.server.lowlevel import NotificationOptions
14+
from mcp.server.lowlevel.server import Server, request_ctx
1715
from mcp.server.models import InitializationOptions
1816
from mcp.server.session import ServerSession
17+
from mcp.shared.context import RequestContext
1918
from mcp.shared.message import SessionMessage
2019
from mcp.shared.session import RequestResponder
2120
from mcp.types import (
22-
EventEmitNotification,
2321
EventListRequest,
2422
EventListResult,
2523
EventParams,
26-
EventsCapability,
2724
EventSubscribeParams,
2825
EventSubscribeRequest,
2926
EventSubscribeResult,
@@ -33,11 +30,9 @@
3330
EventUnsubscribeResult,
3431
RejectedTopic,
3532
RetainedEvent,
36-
ServerCapabilities,
3733
SubscribedTopic,
3834
)
3935

40-
4136
# Shared registry and store for the test server
4237
_registry = SubscriptionRegistry()
4338
_retained_store = RetainedValueStore()
@@ -141,7 +136,7 @@ async def _run_server(server_session: ServerSession, server: Server) -> None:
141136
@pytest.fixture(autouse=True)
142137
async def reset_registry():
143138
"""Reset the global registry and store between tests."""
144-
global _registry, _retained_store
139+
global _registry, _retained_store # noqa: PLW0603
145140
_registry = SubscriptionRegistry()
146141
_retained_store = RetainedValueStore()
147142
yield
@@ -189,11 +184,14 @@ async def event_handler(params: EventParams):
189184
sub_result = await client_session.subscribe_events(["test/+"])
190185
assert len(sub_result.subscribed) == 1
191186

192-
# Server emits
187+
# Server emits with an explicit timestamp, exercising the
188+
# branch where emit_event does NOT auto-generate one.
189+
explicit_ts = "2025-01-01T00:00:00+00:00"
193190
await server_session.emit_event(
194191
topic="test/hello",
195192
payload={"message": "world"},
196193
event_id="evt-1",
194+
timestamp=explicit_ts,
197195
)
198196

199197
# Give the notification time to propagate
@@ -203,6 +201,7 @@ async def event_handler(params: EventParams):
203201
assert received_events[0].topic == "test/hello"
204202
assert received_events[0].payload == {"message": "world"}
205203
assert received_events[0].event_id == "evt-1"
204+
assert received_events[0].timestamp == explicit_ts
206205

207206
tg.cancel_scope.cancel()
208207
except (anyio.ClosedResourceError, anyio.EndOfStream):
@@ -502,3 +501,86 @@ async def test_subscribe_rejects_undeclared_topic():
502501
tg.cancel_scope.cancel()
503502
except (anyio.ClosedResourceError, anyio.EndOfStream):
504503
pass
504+
505+
506+
@pytest.mark.anyio
507+
async def test_topic_matches_subscriptions_recompiles_on_cache_miss():
508+
"""_topic_matches_subscriptions should recompile when the cache entry is missing.
509+
510+
This exercises the fallback branch where a pattern is in ``_subscribed_patterns``
511+
but not in ``_subscription_regex_cache`` (e.g. after manual cache eviction).
512+
"""
513+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10)
514+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10)
515+
516+
try:
517+
async with (
518+
server_to_client_send,
519+
server_to_client_receive,
520+
client_to_server_send,
521+
client_to_server_receive,
522+
ClientSession(
523+
server_to_client_receive,
524+
client_to_server_send,
525+
) as client_session,
526+
):
527+
# Seed a pattern without populating the regex cache.
528+
client_session._subscribed_patterns.add("foo/+")
529+
assert "foo/+" not in client_session._subscription_regex_cache
530+
531+
assert client_session._topic_matches_subscriptions("foo/bar") is True
532+
# The cache should now be populated as a side effect.
533+
assert "foo/+" in client_session._subscription_regex_cache
534+
535+
# Non-matching topic exercises the return False path.
536+
assert client_session._topic_matches_subscriptions("other/thing") is False
537+
except (anyio.ClosedResourceError, anyio.EndOfStream):
538+
pass
539+
540+
541+
@pytest.mark.anyio
542+
async def test_subscribe_events_skips_recompile_for_cached_pattern():
543+
"""subscribe_events should not recompile a regex for an already-cached pattern.
544+
545+
Covers the branch where ``sub.pattern`` is already present in
546+
``_subscription_regex_cache`` so the compile step is skipped.
547+
"""
548+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10)
549+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10)
550+
551+
server = _create_test_server()
552+
# Reset shared state for isolation.
553+
_registry._subscriptions.clear()
554+
555+
try:
556+
async with (
557+
ServerSession(
558+
client_to_server_receive,
559+
server_to_client_send,
560+
InitializationOptions(
561+
server_name="test",
562+
server_version="0.1.0",
563+
capabilities=server.get_capabilities(NotificationOptions(), {}),
564+
),
565+
) as server_session,
566+
ClientSession(
567+
server_to_client_receive,
568+
client_to_server_send,
569+
message_handler=_message_handler,
570+
) as client_session,
571+
anyio.create_task_group() as tg,
572+
):
573+
tg.start_soon(_run_server, server_session, server)
574+
await client_session.initialize()
575+
576+
# First subscribe populates the cache.
577+
await client_session.subscribe_events(["test/+"])
578+
cached_regex = client_session._subscription_regex_cache["test/+"]
579+
580+
# Second subscribe to the same pattern should reuse the cached compile.
581+
await client_session.subscribe_events(["test/+"])
582+
assert client_session._subscription_regex_cache["test/+"] is cached_regex
583+
584+
tg.cancel_scope.cancel()
585+
except (anyio.ClosedResourceError, anyio.EndOfStream):
586+
pass

tests/test_event_types.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@
77

88
import anyio
99
import pytest
10-
1110
from pydantic import ValidationError
1211

1312
from mcp import types
1413
from mcp.client.session import ClientSession
15-
from mcp.server.lowlevel.server import Server
16-
from mcp.shared.context import RequestContext
1714
from mcp.server.events import RetainedValueStore, SubscriptionRegistry
1815
from mcp.server.lowlevel import NotificationOptions
16+
from mcp.server.lowlevel.server import Server
1917
from mcp.server.models import InitializationOptions
2018
from mcp.server.session import ServerSession
19+
from mcp.shared.context import RequestContext
2120
from mcp.shared.message import SessionMessage
2221
from mcp.shared.session import RequestResponder
2322
from mcp.types import (
23+
ClientRequest,
2424
EventEffect,
2525
EventEmitNotification,
2626
EventListRequest,
@@ -37,10 +37,9 @@
3737
RejectedTopic,
3838
RetainedEvent,
3939
ServerCapabilities,
40-
SubscribedTopic,
41-
ClientRequest,
4240
ServerNotification,
4341
ServerResult,
42+
SubscribedTopic,
4443
)
4544

4645

@@ -181,9 +180,7 @@ def test_roundtrip_via_root_model(self):
181180

182181
class TestEventSubscribeRequest:
183182
def test_roundtrip_via_root_model(self):
184-
req = EventSubscribeRequest(
185-
params=EventSubscribeParams(topics=["a/+", "b/#"])
186-
)
183+
req = EventSubscribeRequest(params=EventSubscribeParams(topics=["a/+", "b/#"]))
187184
data = req.model_dump(by_alias=True, mode="json")
188185
wrapped = ClientRequest.model_validate(data)
189186
parsed = wrapped.root
@@ -193,9 +190,7 @@ def test_roundtrip_via_root_model(self):
193190

194191
class TestEventUnsubscribeRequest:
195192
def test_roundtrip_via_root_model(self):
196-
req = EventUnsubscribeRequest(
197-
params=EventUnsubscribeParams(topics=["a/+"])
198-
)
193+
req = EventUnsubscribeRequest(params=EventUnsubscribeParams(topics=["a/+"]))
199194
data = req.model_dump(by_alias=True, mode="json")
200195
wrapped = ClientRequest.model_validate(data)
201196
parsed = wrapped.root
@@ -316,15 +311,16 @@ async def _on_unsubscribe_events(
316311

317312
def _create_test_server() -> Server:
318313
server = Server("test-events-server")
314+
319315
# Register event handlers via request_handlers dict (keyed by type)
320316
async def subscribe_handler(req: EventSubscribeRequest):
321317
ctx = server.request_context
322-
result = await _on_subscribe_events(ctx, req.root.params if hasattr(req, 'root') else req.params)
318+
result = await _on_subscribe_events(ctx, req.root.params if hasattr(req, "root") else req.params)
323319
return types.ServerResult(result)
324320

325321
async def unsubscribe_handler(req: EventUnsubscribeRequest):
326322
ctx = server.request_context
327-
result = await _on_unsubscribe_events(ctx, req.root.params if hasattr(req, 'root') else req.params)
323+
result = await _on_unsubscribe_events(ctx, req.root.params if hasattr(req, "root") else req.params)
328324
return types.ServerResult(result)
329325

330326
server.request_handlers[EventSubscribeRequest] = subscribe_handler
@@ -350,6 +346,7 @@ async def _run_server(server_session: ServerSession, server: Server) -> None:
350346
handler = server.request_handlers.get(type(req.root))
351347
if handler:
352348
from mcp.server.lowlevel.server import request_ctx
349+
353350
token = request_ctx.set(
354351
RequestContext(
355352
request_id=message.request_id,
@@ -368,7 +365,7 @@ async def _run_server(server_session: ServerSession, server: Server) -> None:
368365
@pytest.fixture(autouse=True)
369366
def _reset_event_types_registry():
370367
"""Reset the global registry and store between tests."""
371-
global _registry, _retained_store
368+
global _registry, _retained_store # noqa: PLW0603
372369
_registry = SubscriptionRegistry()
373370
_retained_store = RetainedValueStore()
374371
yield

tests/test_subscription_registry.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,43 @@ async def test_not_expired_returned(self, store: RetainedValueStore):
183183
await store.set("a/b", event, expires_at=future)
184184
assert await store.get("a/b") == event
185185

186+
async def test_get_matching_reuses_cached_regex(self, store: RetainedValueStore):
187+
"""Second call with same pattern should reuse cached compiled regex."""
188+
e1 = RetainedEvent(topic="a/x", eventId="e1", payload="v1")
189+
await store.set("a/x", e1)
190+
# First call compiles and caches
191+
first = await store.get_matching("a/+")
192+
assert len(first) == 1
193+
# Second call hits the cache branch (skips compile)
194+
second = await store.get_matching("a/+")
195+
assert len(second) == 1
196+
assert second[0].topic == "a/x"
197+
198+
async def test_invalid_expires_at_treated_as_not_expired(self, store: RetainedValueStore):
199+
"""Malformed ``expires_at`` should be treated as not expired rather than raising."""
200+
event = RetainedEvent(topic="a/b", eventId="e1", payload="val")
201+
await store.set("a/b", event, expires_at="not-a-valid-iso-timestamp")
202+
# Parsing fails (ValueError), so _is_expired returns False and the value is returned.
203+
assert await store.get("a/b") == event
204+
205+
async def test_naive_expires_at_assumed_utc(self, store: RetainedValueStore):
206+
"""A naive (tz-less) ISO timestamp should be interpreted as UTC.
207+
208+
Exercises the ``if expiry.tzinfo is None`` branch in ``_is_expired``.
209+
"""
210+
# Naive timestamp in the future (no timezone suffix).
211+
future_naive = (datetime.now(timezone.utc) + timedelta(hours=1)).replace(tzinfo=None).isoformat()
212+
event = RetainedEvent(topic="a/b", eventId="e1", payload="val")
213+
await store.set("a/b", event, expires_at=future_naive)
214+
# Interpreted as UTC -> not expired -> returned.
215+
assert await store.get("a/b") == event
216+
217+
# Naive timestamp in the past -> expired -> None.
218+
past_naive = (datetime.now(timezone.utc) - timedelta(hours=1)).replace(tzinfo=None).isoformat()
219+
event2 = RetainedEvent(topic="c/d", eventId="e2", payload="val2")
220+
await store.set("c/d", event2, expires_at=past_naive)
221+
assert await store.get("c/d") is None
222+
186223
async def test_expired_cleaned_on_get_matching(self, store: RetainedValueStore):
187224
past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat()
188225
future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat()

0 commit comments

Comments
 (0)