Skip to content

Commit ac9f3eb

Browse files
committed
Fix # root wildcard, top-level imports, regex cache
1 parent 59c3265 commit ac9f3eb

4 files changed

Lines changed: 35 additions & 18 deletions

File tree

src/mcp/client/session.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import re
23
from datetime import timedelta
34
from typing import Any, Protocol, overload
45

@@ -282,24 +283,25 @@ def decorator(fn: EventHandlerFnT) -> EventHandlerFnT:
282283

283284
def _topic_matches_subscriptions(self, topic: str) -> bool:
284285
"""Check if a topic matches any of our subscribed patterns."""
285-
import re as _re
286-
287286
for pattern in self._subscribed_patterns:
288287
parts = pattern.split("/")
289288
regex_parts: list[str] = []
290289
for i, part in enumerate(parts):
291290
if part == "#":
292-
regex = "^" + "/".join(regex_parts) + "(/.*)?$"
293-
if _re.match(regex, topic):
291+
if regex_parts:
292+
regex = "^" + "/".join(regex_parts) + "(/.*)?$"
293+
else:
294+
regex = "^.*$"
295+
if re.match(regex, topic):
294296
return True
295297
break
296298
elif part == "+":
297299
regex_parts.append("[^/]+")
298300
else:
299-
regex_parts.append(_re.escape(part))
301+
regex_parts.append(re.escape(part))
300302
else:
301303
regex = "^" + "/".join(regex_parts) + "$"
302-
if _re.match(regex, topic):
304+
if re.match(regex, topic):
303305
return True
304306
return False
305307

@@ -312,23 +314,24 @@ async def _handle_event(self, params: types.EventParams) -> None:
312314
return
313315

314316
if self._event_topic_filter is not None:
315-
import re as _re
316-
317317
parts = self._event_topic_filter.split("/")
318318
regex_parts: list[str] = []
319319
matched = False
320320
for i, part in enumerate(parts):
321321
if part == "#":
322-
regex = "^" + "/".join(regex_parts) + "(/.*)?$"
323-
matched = bool(_re.match(regex, params.topic))
322+
if regex_parts:
323+
regex = "^" + "/".join(regex_parts) + "(/.*)?$"
324+
else:
325+
regex = "^.*$"
326+
matched = bool(re.match(regex, params.topic))
324327
break
325328
elif part == "+":
326329
regex_parts.append("[^/]+")
327330
else:
328-
regex_parts.append(_re.escape(part))
331+
regex_parts.append(re.escape(part))
329332
else:
330333
regex = "^" + "/".join(regex_parts) + "$"
331-
matched = bool(_re.match(regex, params.topic))
334+
matched = bool(re.match(regex, params.topic))
332335
if not matched:
333336
return
334337

src/mcp/server/events.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,14 @@ def _pattern_to_regex(pattern: str) -> re.Pattern[str]:
3030
if part == "#":
3131
if i != len(parts) - 1:
3232
raise ValueError("'#' wildcard is only valid as the last segment")
33-
# Use (/.*)?$ so that # matches zero or more trailing segments.
34-
# e.g. "a/#" -> "^a(/.*)?$" matches "a", "a/b", "a/b/c"
35-
return re.compile("^" + "/".join(regex_parts) + "(/.*)?$")
33+
# # matches zero or more trailing segments.
34+
# If preceding segments exist, the / before # is optional
35+
# so "myapp/#" matches both "myapp" and "myapp/anything".
36+
# If # is the sole segment, it matches everything.
37+
if regex_parts:
38+
return re.compile("^" + "/".join(regex_parts) + "(/.*)?$")
39+
else:
40+
return re.compile("^.*$")
3641
elif part == "+":
3742
regex_parts.append("[^/]+")
3843
else:
@@ -126,6 +131,7 @@ def __init__(self) -> None:
126131
self._lock = asyncio.Lock()
127132
self._store: dict[str, RetainedEvent] = {}
128133
self._expires: dict[str, str] = {} # topic -> ISO 8601 expires_at
134+
self._regex_cache: dict[str, re.Pattern[str]] = {}
129135

130136
async def set(self, topic: str, event: RetainedEvent, expires_at: str | None = None) -> None:
131137
"""Store or replace the retained value for *topic*."""
@@ -151,7 +157,9 @@ async def get(self, topic: str) -> RetainedEvent | None:
151157
async def get_matching(self, pattern: str) -> list[RetainedEvent]:
152158
"""Return all non-expired retained events whose topic matches *pattern*."""
153159
async with self._lock:
154-
regex = _pattern_to_regex(pattern)
160+
if pattern not in self._regex_cache:
161+
self._regex_cache[pattern] = _pattern_to_regex(pattern)
162+
regex = self._regex_cache[pattern]
155163
result: list[RetainedEvent] = []
156164
expired_topics: list[str] = []
157165
for topic, event in self._store.items():

src/mcp/server/session.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4444
import anyio.lowlevel
4545
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4646
from pydantic import AnyUrl
47+
from ulid import ULID
4748

4849
import mcp.types as types
4950
from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures
@@ -218,8 +219,6 @@ async def emit_event(
218219
) -> None:
219220
"""Push an event to the client on the given topic."""
220221
if event_id is None:
221-
from ulid import ULID
222-
223222
event_id = str(ULID())
224223
if timestamp is None:
225224
from datetime import datetime, timezone

tests/test_subscription_registry.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ async def test_rejects_pattern_exceeding_max_depth(self, registry: SubscriptionR
109109
with pytest.raises(ValueError, match="exceeds maximum depth of 8 segments"):
110110
await registry.add("s1", "a/b/c/d/e/f/g/h/i")
111111

112+
async def test_hash_root_wildcard_matches_everything(self, registry: SubscriptionRegistry):
113+
"""Pattern '#' (sole segment) should match any topic."""
114+
await registry.add("s1", "#")
115+
assert await registry.match("any/topic/at/all") == {"s1"}
116+
assert await registry.match("single") == {"s1"}
117+
assert await registry.match("a/b") == {"s1"}
118+
112119
async def test_hash_matches_zero_trailing_no_slash(self, registry: SubscriptionRegistry):
113120
"""# should match the prefix with no trailing slash (zero segments after prefix).
114121

0 commit comments

Comments
 (0)