Skip to content

Commit 84449cf

Browse files
committed
fix(plugins): make hooks/tools mutable, fix type annotation, export hook
Address additional PR feedback: - Make hooks and tools properties return mutable lists for filtering/customization - Fix type annotation: _hook_event_types is list[type[TEvent]] not list[TEvent] - Export @hook from top-level strands package (from strands import hook) - Fix docstring typo: 'argument' -> 'attribute' - Add tests for filtering hooks and tools
1 parent 44ec90a commit 84449cf

4 files changed

Lines changed: 66 additions & 17 deletions

File tree

src/strands/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
from .agent.agent import Agent
55
from .agent.base import AgentBase
66
from .event_loop._retry import ModelRetryStrategy
7-
from .plugins import Plugin
7+
from .plugins import Plugin, hook
88
from .tools.decorator import tool
99
from .types.tools import ToolContext
1010

1111
__all__ = [
1212
"Agent",
1313
"AgentBase",
1414
"agent",
15+
"hook",
1516
"models",
1617
"ModelRetryStrategy",
1718
"Plugin",

src/strands/plugins/decorator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ def on_model_call(self, event: BeforeModelCallEvent):
2121

2222

2323
class _WrappedHookCallable(HookCallback, Generic[TEvent]):
24-
"""Wrapped version of HookCallback that includes a `_hook_event_types` argument."""
24+
"""Wrapped version of HookCallback that includes a `_hook_event_types` attribute."""
2525

26-
_hook_event_types: list[TEvent]
26+
_hook_event_types: list[type[TEvent]]
2727

2828

2929
# Handle @hook

src/strands/plugins/plugin.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,24 @@ def __init__(self) -> None:
8181
self._discover_decorated_methods()
8282

8383
@property
84-
def hooks(self) -> tuple[_WrappedHookCallable, ...]:
84+
def hooks(self) -> list[_WrappedHookCallable]:
8585
"""Discovered @hook decorated methods.
8686
87-
Returns a tuple of hook callbacks that will be auto-registered
88-
when the plugin is attached to an agent.
87+
Returns the list of hook callbacks that will be auto-registered
88+
when the plugin is attached to an agent. This list is mutable,
89+
allowing users to filter or modify hooks before registration.
8990
"""
90-
return tuple(self._hooks)
91+
return self._hooks
9192

9293
@property
93-
def tools(self) -> tuple[DecoratedFunctionTool, ...]:
94+
def tools(self) -> list[DecoratedFunctionTool]:
9495
"""Discovered @tool decorated methods.
9596
96-
Returns a tuple of tools that will be auto-registered
97-
when the plugin is attached to an agent.
97+
Returns the list of tools that will be auto-registered
98+
when the plugin is attached to an agent. This list is mutable,
99+
allowing users to filter or modify tools before registration.
98100
"""
99-
return tuple(self._tools)
101+
return self._tools
100102

101103
def _discover_decorated_methods(self) -> None:
102104
"""Scan class for @hook and @tool decorated methods."""

tests/strands/plugins/test_plugin_base_class.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def decorated_hook(self, event: BeforeModelCallEvent):
141141
assert len(plugin.hooks) == 1
142142
assert plugin.hooks[0].__name__ == "decorated_hook"
143143

144-
def test_hooks_property_returns_tuple(self):
145-
"""Test that hooks property returns an immutable tuple."""
144+
def test_hooks_property_returns_list(self):
145+
"""Test that hooks property returns a mutable list."""
146146

147147
class MyPlugin(Plugin):
148148
name = "my-plugin"
@@ -152,10 +152,10 @@ def my_hook(self, event: BeforeModelCallEvent):
152152
pass
153153

154154
plugin = MyPlugin()
155-
assert isinstance(plugin.hooks, tuple)
155+
assert isinstance(plugin.hooks, list)
156156

157-
def test_tools_property_returns_tuple(self):
158-
"""Test that tools property returns an immutable tuple."""
157+
def test_tools_property_returns_list(self):
158+
"""Test that tools property returns a mutable list."""
159159

160160
class MyPlugin(Plugin):
161161
name = "my-plugin"
@@ -166,7 +166,53 @@ def my_tool(self, param: str) -> str:
166166
return param
167167

168168
plugin = MyPlugin()
169-
assert isinstance(plugin.tools, tuple)
169+
assert isinstance(plugin.tools, list)
170+
171+
def test_hooks_can_be_filtered(self):
172+
"""Test that hooks list can be modified before registration."""
173+
174+
class MyPlugin(Plugin):
175+
name = "my-plugin"
176+
177+
@hook
178+
def hook1(self, event: BeforeModelCallEvent):
179+
pass
180+
181+
@hook
182+
def hook2(self, event: BeforeInvocationEvent):
183+
pass
184+
185+
plugin = MyPlugin()
186+
assert len(plugin.hooks) == 2
187+
188+
# Filter out hook1
189+
plugin.hooks[:] = [h for h in plugin.hooks if h.__name__ != "hook1"]
190+
assert len(plugin.hooks) == 1
191+
assert plugin.hooks[0].__name__ == "hook2"
192+
193+
def test_tools_can_be_filtered(self):
194+
"""Test that tools list can be modified before registration."""
195+
196+
class MyPlugin(Plugin):
197+
name = "my-plugin"
198+
199+
@tool
200+
def tool1(self, param: str) -> str:
201+
"""Tool 1."""
202+
return param
203+
204+
@tool
205+
def tool2(self, param: str) -> str:
206+
"""Tool 2."""
207+
return param
208+
209+
plugin = MyPlugin()
210+
assert len(plugin.tools) == 2
211+
212+
# Filter out tool1
213+
plugin.tools[:] = [t for t in plugin.tools if t.tool_name != "tool1"]
214+
assert len(plugin.tools) == 1
215+
assert plugin.tools[0].tool_name == "tool2"
170216

171217

172218
class TestPluginRegistryAutoRegistration:

0 commit comments

Comments
 (0)