Skip to content

Commit 692460f

Browse files
committed
Pre-compile topic filter regex, cache subscription patterns, docs clarify
1 parent 980d377 commit 692460f

4 files changed

Lines changed: 88 additions & 71 deletions

File tree

docs/events.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,9 @@ async def on_build_event(params: EventParams) -> None:
242242
print(f"Build: {params.payload}")
243243
```
244244

245-
The optional `topic_filter` applies an additional client-side filter using the same wildcard syntax as subscription patterns. Events that do not match the filter are silently dropped before reaching the handler.
245+
The optional `topic_filter` applies an additional client-side filter using the same wildcard syntax as subscription patterns. The filter is compiled once when the handler is registered and reused for every incoming event. Events that do not match the filter are silently dropped before reaching the handler.
246246

247-
The client also tracks subscribed patterns internally. Events for topics that do not match any active subscription are dropped, even if the server sends them.
247+
The client also tracks subscribed patterns internally. Once a client has at least one active subscription, events whose topic does not match any subscribed pattern are dropped before reaching the handler, even if the server sends them. A client that never calls `subscribe_events` has no subscription patterns registered and will pass every event received from the server through to the handler, subject only to the optional `topic_filter`. If you want strict subscription-only delivery, subscribe explicitly.
248248

249249
### Unsubscribing
250250

src/mcp/client/session.py

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from mcp.shared.context import RequestContext
1515
from mcp.shared.message import SessionMessage
1616
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
17+
from mcp.shared.topic_patterns import pattern_to_regex
1718
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1819

1920
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
@@ -148,7 +149,11 @@ def __init__(
148149
self._experimental_features: ExperimentalClientFeatures | None = None
149150
self._event_handler: EventHandlerFnT | None = None
150151
self._event_topic_filter: str | None = None
152+
self._event_topic_filter_regex: re.Pattern[str] | None = None
151153
self._subscribed_patterns: set[str] = set()
154+
# Cache compiled regexes for subscription patterns to avoid
155+
# recompiling on every incoming event.
156+
self._subscription_regex_cache: dict[str, re.Pattern[str]] = {}
152157

153158
# Experimental: Task handlers (use defaults if not provided)
154159
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
@@ -239,6 +244,8 @@ async def subscribe_events(self, topics: list[str]) -> types.EventSubscribeResul
239244
)
240245
for sub in result.subscribed:
241246
self._subscribed_patterns.add(sub.pattern)
247+
if sub.pattern not in self._subscription_regex_cache:
248+
self._subscription_regex_cache[sub.pattern] = pattern_to_regex(sub.pattern)
242249
return result
243250

244251
async def unsubscribe_events(self, topics: list[str]) -> types.EventUnsubscribeResult:
@@ -253,6 +260,7 @@ async def unsubscribe_events(self, topics: list[str]) -> types.EventUnsubscribeR
253260
)
254261
for pattern in result.unsubscribed:
255262
self._subscribed_patterns.discard(pattern)
263+
self._subscription_regex_cache.pop(pattern, None)
256264
return result
257265

258266
async def list_events(self) -> types.EventListResult:
@@ -268,9 +276,17 @@ def set_event_handler(
268276
*,
269277
topic_filter: str | None = None,
270278
) -> None:
271-
"""Register a callback for incoming event notifications."""
279+
"""Register a callback for incoming event notifications.
280+
281+
If *topic_filter* is provided, it is compiled once here and the
282+
cached regex is reused for every incoming event. The filter uses
283+
the same MQTT-style wildcard syntax as subscription patterns
284+
(``+`` for a single segment, ``#`` as a trailing multi-segment
285+
wildcard).
286+
"""
272287
self._event_handler = handler
273288
self._event_topic_filter = topic_filter
289+
self._event_topic_filter_regex = pattern_to_regex(topic_filter) if topic_filter is not None else None
274290

275291
def on_event(self, topic_filter: str | None = None):
276292
"""Decorator for registering an event handler."""
@@ -282,58 +298,43 @@ def decorator(fn: EventHandlerFnT) -> EventHandlerFnT:
282298
return decorator
283299

284300
def _topic_matches_subscriptions(self, topic: str) -> bool:
285-
"""Check if a topic matches any of our subscribed patterns."""
301+
"""Check if *topic* matches any of our subscribed patterns.
302+
303+
Compiled regexes are cached per subscription pattern so incoming
304+
events do not pay a recompile cost on every match attempt.
305+
"""
286306
for pattern in self._subscribed_patterns:
287-
parts = pattern.split("/")
288-
regex_parts: list[str] = []
289-
for i, part in enumerate(parts):
290-
if part == "#":
291-
if regex_parts:
292-
regex = "^" + "/".join(regex_parts) + "(/.*)?$"
293-
else:
294-
regex = "^.*$"
295-
if re.match(regex, topic):
296-
return True
297-
break
298-
elif part == "+":
299-
regex_parts.append("[^/]+")
300-
else:
301-
regex_parts.append(re.escape(part))
302-
else:
303-
regex = "^" + "/".join(regex_parts) + "$"
304-
if re.match(regex, topic):
305-
return True
307+
regex = self._subscription_regex_cache.get(pattern)
308+
if regex is None:
309+
regex = pattern_to_regex(pattern)
310+
self._subscription_regex_cache[pattern] = regex
311+
if regex.match(topic):
312+
return True
306313
return False
307314

308315
async def _handle_event(self, params: types.EventParams) -> None:
309-
"""Dispatch an incoming event to the registered handler."""
316+
"""Dispatch an incoming event to the registered handler.
317+
318+
Filtering order:
319+
320+
1. If no handler is registered, drop the event.
321+
2. If the client has any active subscriptions, the topic must
322+
match at least one of them. Events for unsubscribed topics
323+
are dropped. (A client with zero subscriptions accepts any
324+
topic the server chooses to deliver; this is the "pass
325+
through" fallback documented in ``docs/events.md``.)
326+
3. If an additional ``topic_filter`` was provided to
327+
``set_event_handler``, the topic must also match that
328+
filter.
329+
"""
310330
if self._event_handler is None:
311331
return
312332

313333
if self._subscribed_patterns and not self._topic_matches_subscriptions(params.topic):
314334
return
315335

316-
if self._event_topic_filter is not None:
317-
parts = self._event_topic_filter.split("/")
318-
regex_parts: list[str] = []
319-
matched = False
320-
for i, part in enumerate(parts):
321-
if part == "#":
322-
if regex_parts:
323-
regex = "^" + "/".join(regex_parts) + "(/.*)?$"
324-
else:
325-
regex = "^.*$"
326-
matched = bool(re.match(regex, params.topic))
327-
break
328-
elif part == "+":
329-
regex_parts.append("[^/]+")
330-
else:
331-
regex_parts.append(re.escape(part))
332-
else:
333-
regex = "^" + "/".join(regex_parts) + "$"
334-
matched = bool(re.match(regex, params.topic))
335-
if not matched:
336-
return
336+
if self._event_topic_filter_regex is not None and not self._event_topic_filter_regex.match(params.topic):
337+
return
337338

338339
await self._event_handler(params)
339340

src/mcp/server/events.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,35 +15,10 @@
1515
import re
1616
from datetime import datetime, timezone
1717

18+
from mcp.shared.topic_patterns import pattern_to_regex as _pattern_to_regex
1819
from mcp.types import RetainedEvent
1920

2021

21-
def _pattern_to_regex(pattern: str) -> re.Pattern[str]:
22-
"""Convert an MQTT-style topic pattern to a compiled regex.
23-
24-
``+`` becomes a single-segment match, ``#`` becomes a greedy
25-
multi-segment match (only valid as the final segment).
26-
"""
27-
parts = pattern.split("/")
28-
regex_parts: list[str] = []
29-
for i, part in enumerate(parts):
30-
if part == "#":
31-
if i != len(parts) - 1:
32-
raise ValueError("'#' wildcard is only valid as the last segment")
33-
# # matches zero or more trailing segments.
34-
# If preceding segments exist, the / before # is optional
35-
# so "myapp/#" matches both "myapp" and "myapp/anything".
36-
# If # is the sole segment, it matches everything.
37-
if regex_parts:
38-
return re.compile("^" + "/".join(regex_parts) + "(/.*)?$")
39-
else:
40-
return re.compile("^.*$")
41-
elif part == "+":
42-
regex_parts.append("[^/]+")
43-
else:
44-
regex_parts.append(re.escape(part))
45-
return re.compile("^" + "/".join(regex_parts) + "$")
46-
4722

4823
class SubscriptionRegistry:
4924
"""Thread-safe registry mapping session IDs to topic subscription patterns.

src/mcp/shared/topic_patterns.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Shared helpers for MQTT-style topic pattern matching.
2+
3+
Both the client (for subscription filtering) and the server (for the
4+
subscription registry and retained-event store) need to compile MQTT-style
5+
topic patterns into regular expressions. Keeping the implementation here
6+
avoids a client -> server import and guarantees identical semantics on both
7+
sides of the protocol.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
import re
13+
14+
__all__ = ["pattern_to_regex"]
15+
16+
17+
def pattern_to_regex(pattern: str) -> re.Pattern[str]:
18+
"""Convert an MQTT-style topic pattern to a compiled regex.
19+
20+
``+`` becomes a single-segment match, ``#`` becomes a greedy
21+
multi-segment match (only valid as the final segment).
22+
"""
23+
parts = pattern.split("/")
24+
regex_parts: list[str] = []
25+
for i, part in enumerate(parts):
26+
if part == "#":
27+
if i != len(parts) - 1:
28+
raise ValueError("'#' wildcard is only valid as the last segment")
29+
# # matches zero or more trailing segments.
30+
# If preceding segments exist, the / before # is optional
31+
# so "myapp/#" matches both "myapp" and "myapp/anything".
32+
# If # is the sole segment, it matches everything.
33+
if regex_parts:
34+
return re.compile("^" + "/".join(regex_parts) + "(/.*)?$")
35+
else:
36+
return re.compile("^.*$")
37+
elif part == "+":
38+
regex_parts.append("[^/]+")
39+
else:
40+
regex_parts.append(re.escape(part))
41+
return re.compile("^" + "/".join(regex_parts) + "$")

0 commit comments

Comments
 (0)