Skip to content

Commit 2e8a268

Browse files
committed
Update typing
1 parent 8fb7244 commit 2e8a268

4 files changed

Lines changed: 84 additions & 147 deletions

File tree

src/strands/plugins/decorator.py

Lines changed: 26 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,69 @@
11
"""Hook decorator for Plugin methods.
22
3-
This module provides the @hook decorator that marks methods as hook callbacks
4-
for automatic registration when the plugin is attached to an agent.
5-
6-
The @hook decorator performs several functions:
7-
8-
1. Marks methods as hook callbacks for automatic discovery by Plugin base class
9-
2. Infers event types from the callback's type hints (consistent with HookRegistry.add_callback)
10-
3. Supports both @hook and @hook() syntax
11-
4. Supports union types for multiple event types (e.g., BeforeModelCallEvent | AfterModelCallEvent)
12-
5. Stores hook metadata on the decorated method for later discovery
3+
Marks methods as hook callbacks for automatic registration when the plugin
4+
is attached to an agent. Infers event types from type hints and supports
5+
union types for multiple events.
136
147
Example:
158
```python
16-
from strands.plugins import Plugin, hook
17-
from strands.hooks import BeforeModelCallEvent, AfterModelCallEvent
18-
199
class MyPlugin(Plugin):
20-
name = "my-plugin"
21-
2210
@hook
2311
def on_model_call(self, event: BeforeModelCallEvent):
2412
print(event)
25-
26-
@hook
27-
def on_any_model_event(self, event: BeforeModelCallEvent | AfterModelCallEvent):
28-
print(event)
2913
```
3014
"""
3115

3216
from collections.abc import Callable
33-
from typing import TYPE_CHECKING, TypeVar, overload
17+
from typing import Generic, cast, overload
3418

35-
if TYPE_CHECKING:
36-
from ..hooks.registry import BaseHookEvent
19+
from ..hooks._type_inference import infer_event_types
20+
from ..hooks.registry import HookCallback, TEvent
3721

38-
# Type for wrapped function
39-
T = TypeVar("T", bound=Callable[..., object])
22+
23+
class _WrappedHookCallable(HookCallback, Generic[TEvent]):
24+
"""Wrapped version of HookCallback that includes a `_hook_event_types` argument."""
25+
26+
_hook_event_types: list[TEvent]
4027

4128

4229
# Handle @hook
4330
@overload
44-
def hook(__func: T) -> T: ...
31+
def hook(__func: HookCallback) -> _WrappedHookCallable: ...
4532

4633

4734
# Handle @hook()
4835
@overload
49-
def hook() -> Callable[[T], T]: ...
36+
def hook() -> Callable[[HookCallback], _WrappedHookCallable]: ...
5037

5138

5239
def hook(
53-
func: T | None = None,
54-
) -> T | Callable[[T], T]:
55-
"""Decorator that marks a method as a hook callback for automatic registration.
56-
57-
This decorator enables declarative hook registration in Plugin classes. When a
58-
Plugin is attached to an agent, methods marked with @hook are automatically
59-
discovered and registered with the agent's hook registry.
40+
func: HookCallback | None = None,
41+
) -> _WrappedHookCallable | Callable[[HookCallback], _WrappedHookCallable]:
42+
"""Mark a method as a hook callback for automatic registration.
6043
61-
The event type is inferred from the callback's type hint on the first parameter
62-
(after 'self' for instance methods). Union types are supported for registering
63-
a single callback for multiple event types.
64-
65-
The decorator can be used in two ways:
66-
- As a simple decorator: `@hook`
67-
- With parentheses: `@hook()`
44+
Infers event type from the callback's type hint. Supports union types
45+
for multiple events. Can be used as @hook or @hook().
6846
6947
Args:
70-
func: The function to decorate. When used as a simple decorator, this is
71-
the function being decorated. When used with parentheses, this will be None.
48+
func: The function to decorate.
7249
7350
Returns:
74-
The decorated function with hook metadata attached.
51+
The decorated function with hook metadata.
7552
7653
Raises:
77-
ValueError: If the event type cannot be inferred from type hints, or if
78-
the type hint is not a valid HookEvent subclass.
79-
80-
Example:
81-
```python
82-
class MyPlugin(Plugin):
83-
name = "my-plugin"
84-
85-
@hook
86-
def on_model_call(self, event: BeforeModelCallEvent):
87-
print(f"Model called: {event}")
88-
89-
@hook
90-
def on_any_event(self, event: BeforeModelCallEvent | AfterModelCallEvent):
91-
print(f"Event: {type(event).__name__}")
92-
```
54+
ValueError: If event type cannot be inferred from type hints.
9355
"""
9456

95-
def decorator(f: T) -> T:
96-
# Import here to avoid circular dependency at runtime
97-
from ..hooks._type_inference import infer_event_types
98-
57+
def decorator(f: HookCallback[TEvent]) -> _WrappedHookCallable[TEvent]:
9958
# Infer event types from type hints
100-
event_types: list[type[BaseHookEvent]] = infer_event_types(f) # type: ignore[arg-type]
59+
event_types: list[type[TEvent]] = infer_event_types(f)
10160

10261
# Store hook metadata on the function
103-
f._hook_event_types = event_types # type: ignore[attr-defined]
62+
f_wrapped = cast(_WrappedHookCallable, f)
63+
f_wrapped._hook_event_types = event_types
10464

105-
return f
65+
return f_wrapped
10666

107-
# Handle both @hook and @hook() syntax
10867
if func is None:
10968
return decorator
110-
11169
return decorator(func)

src/strands/plugins/plugin.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from collections.abc import Awaitable
1010
from typing import TYPE_CHECKING
1111

12-
from strands.tools.decorator import DecoratedFunctionTool
12+
from ..tools.decorator import DecoratedFunctionTool
13+
from .decorator import _WrappedHookCallable
1314

1415
if TYPE_CHECKING:
1516
from ..agent import Agent
@@ -74,17 +75,13 @@ def __init__(self) -> None:
7475
Scans the class for methods decorated with @hook and @tool and stores
7576
references for later registration when init_plugin is called.
7677
"""
77-
self._hooks: list[object] = []
78+
self._hooks: list[_WrappedHookCallable] = []
7879
self._tools: list[DecoratedFunctionTool] = []
7980
self._discover_decorated_methods()
8081

8182
def _discover_decorated_methods(self) -> None:
8283
"""Scan class for @hook and @tool decorated methods."""
8384
for name in dir(self):
84-
# Skip private and dunder methods
85-
if name.startswith("_"):
86-
continue
87-
8885
try:
8986
attr = getattr(self, name)
9087
except Exception:
@@ -118,7 +115,7 @@ def init_plugin(self, agent: "Agent") -> None | Awaitable[None]:
118115
for hook_callback in self._hooks:
119116
event_types = getattr(hook_callback, "_hook_event_types", [])
120117
for event_type in event_types:
121-
agent.hooks.add_callback(event_type, hook_callback)
118+
agent.add_hook(hook_callback, event_type)
122119
logger.debug(
123120
"plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook",
124121
self.name,

tests/strands/experimental/steering/core/test_handler.py

Lines changed: 35 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Unit tests for steering handler base class."""
22

3+
import inspect
34
from unittest.mock import AsyncMock, Mock
45

56
import pytest
@@ -8,6 +9,7 @@
89
from strands.experimental.steering.core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider
910
from strands.experimental.steering.core.handler import SteeringHandler
1011
from strands.hooks.events import AfterModelCallEvent, BeforeToolCallEvent
12+
from strands.hooks.registry import HookRegistry
1113
from strands.plugins import Plugin
1214

1315

@@ -39,18 +41,14 @@ def test_steering_handler_is_plugin():
3941

4042
def test_init_plugin():
4143
"""Test init_plugin registers hooks on agent."""
42-
from strands.hooks import HookRegistry
43-
4444
handler = TestSteeringHandler()
4545
agent = Mock()
46-
agent.hooks = HookRegistry()
47-
agent.tool_registry = Mock()
4846

4947
handler.init_plugin(agent)
5048

51-
# Verify hooks were auto-registered via @hook decorator
52-
assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1
53-
assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1
49+
# Verify hooks were registered (tool and model steering hooks)
50+
assert agent.add_hook.call_count >= 2
51+
agent.add_hook.assert_any_call(handler.provide_tool_steering_guidance, BeforeToolCallEvent)
5452

5553

5654
def test_steering_context_initialization():
@@ -174,7 +172,6 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs):
174172

175173
def test_init_plugin_override():
176174
"""Test that init_plugin can be overridden."""
177-
from strands.hooks import HookRegistry
178175

179176
class CustomHandler(SteeringHandler):
180177
async def steer_before_tool(self, *, agent, tool_use, **kwargs):
@@ -186,14 +183,11 @@ def init_plugin(self, agent):
186183

187184
handler = CustomHandler()
188185
agent = Mock()
189-
agent.hooks = HookRegistry()
190-
agent.tool_registry = Mock()
191186

192187
handler.init_plugin(agent)
193188

194-
# Should not register any hooks since parent init_plugin wasn't called
195-
assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) == 0
196-
assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) == 0
189+
# Should not register any hooks
190+
assert agent.add_hook.call_count == 0
197191

198192

199193
# Integration tests with context providers
@@ -227,77 +221,68 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs):
227221

228222
def test_handler_registers_context_provider_hooks():
229223
"""Test that handler registers hooks from context callbacks."""
230-
from strands.hooks import HookRegistry
231-
232224
mock_callback = MockContextCallback()
233225
handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback])
234226
agent = Mock()
235-
agent.hooks = HookRegistry()
236-
agent.tool_registry = Mock()
237-
agent.add_hook = Mock()
238227

239228
handler.init_plugin(agent)
240229

241-
# Should register 1 context callback via add_hook (steering hooks are auto-registered)
242-
assert agent.add_hook.call_count >= 1
230+
# Should register hooks for context callback and steering guidance
231+
assert agent.add_hook.call_count >= 2
243232

244-
# Check that BeforeToolCallEvent was registered (either via add_hook or auto-registration)
233+
# Check that BeforeToolCallEvent was registered
245234
call_args = [call[0] for call in agent.add_hook.call_args_list]
246235
event_types = [args[1] for args in call_args]
247236

248237
# Context callback should be registered
249-
assert (
250-
BeforeToolCallEvent in event_types or len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1
251-
)
238+
assert BeforeToolCallEvent in event_types
252239

253-
254-
def test_context_callbacks_receive_steering_context():
240+
@pytest.mark.asyncio
241+
async def test_context_callbacks_receive_steering_context():
255242
"""Test that context callbacks receive the handler's steering context."""
256243
mock_callback = MockContextCallback()
257244
handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback])
258245
agent = Mock()
259-
246+
agent.hooks = HookRegistry()
247+
agent.tool_registry = Mock()
248+
agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback))
260249
handler.init_plugin(agent)
261250

262-
# Get the registered callback for BeforeToolCallEvent
263-
before_callback = None
264-
for call in agent.add_hook.call_args_list:
265-
if call[0][1] == BeforeToolCallEvent:
266-
before_callback = call[0][0]
267-
break
251+
# Get the registered callbacks for BeforeToolCallEvent
252+
callbacks = agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])
253+
assert len(callbacks) > 0
268254

269-
assert before_callback is not None
270-
271-
# Create a mock event and call the callback
255+
# The context callback is wrapped in a lambda, so we just call all callbacks
256+
# and check if the steering context was updated
272257
event = Mock(spec=BeforeToolCallEvent)
273258
event.tool_use = {"name": "test_tool", "input": {}}
274259

275-
# The callback should execute without error and update the steering context
276-
before_callback(event)
260+
# Call all callbacks, handling both sync and async
261+
for cb in callbacks:
262+
try:
263+
result = await cb(event)
264+
if inspect.iscoroutine(result):
265+
await result
266+
except Exception:
267+
pass # Some callbacks might be async or have other requirements
277268

278-
# Verify the steering context was updated
269+
# Verify the steering context was updated by at least one callback
279270
assert handler.steering_context.data.get("test_key") == "test_value"
280271

281272

282273
def test_multiple_context_callbacks_registered():
283274
"""Test that multiple context callbacks are registered."""
284-
from strands.hooks import HookRegistry
285-
286275
callback1 = MockContextCallback()
287276
callback2 = MockContextCallback()
288277

289278
handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2])
290279
agent = Mock()
291-
agent.hooks = HookRegistry()
292-
agent.tool_registry = Mock()
293-
agent.add_hook = Mock()
294280

295281
handler.init_plugin(agent)
296282

297-
# Should register 2 context callbacks via add_hook, plus auto-registered @hook methods
298-
assert agent.add_hook.call_count == 2 # Only context callbacks use add_hook
299-
assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1
300-
assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1
283+
# Should register one callback for each context provider plus tool and model steering guidance
284+
expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model)
285+
assert agent.add_hook.call_count >= expected_calls
301286

302287

303288
def test_handler_initialization_with_callbacks():
@@ -509,14 +494,10 @@ async def test_default_steer_after_model_returns_proceed():
509494

510495
def test_init_plugin_registers_model_steering():
511496
"""Test that init_plugin registers model steering callback."""
512-
from strands.hooks import HookRegistry
513-
514497
handler = TestSteeringHandler()
515498
agent = Mock()
516-
agent.hooks = HookRegistry()
517-
agent.tool_registry = Mock()
518499

519500
handler.init_plugin(agent)
520501

521-
# Verify model steering hook was auto-registered via @hook decorator
522-
assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1
502+
# Verify model steering hook was registered
503+
agent.add_hook.assert_any_call(handler.provide_model_steering_guidance, AfterModelCallEvent)

0 commit comments

Comments
 (0)