Skip to content

Commit 168d76a

Browse files
committed
refactor: address review feedback - trim tests, support Sequence
Review feedback from @Unshure: 1. hooks param now accepts Sequence (list, tuple) instead of just list - Changed type hint to Sequence[HookProvider | HookCallback] - Added Sequence import from collections.abc 2. Moved tests to existing test_agent_hooks.py and trimmed down - Removed separate test_agent_hooks_callable.py file - Kept 5 essential tests: callable, mixed list, tuple, invalid, lifecycle - Removed duplicate HookProvider test (already covered by test_agent__init__hooks) - Removed redundant tests (None, empty list, async, lambda, multiple) 3. Removed duplicated test_hooks_param_accepts_hook_provider - This was already covered by existing test_agent__init__hooks
1 parent e3f9481 commit 168d76a

3 files changed

Lines changed: 69 additions & 189 deletions

File tree

src/strands/agent/agent.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import logging
1313
import threading
1414
import warnings
15-
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping
15+
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping, Sequence
1616
from typing import (
1717
TYPE_CHECKING,
1818
Any,
@@ -129,7 +129,7 @@ def __init__(
129129
description: str | None = None,
130130
state: AgentState | dict | None = None,
131131
plugins: list[Plugin] | None = None,
132-
hooks: list[HookProvider | HookCallback] | None = None,
132+
hooks: Sequence[HookProvider | HookCallback] | None = None,
133133
session_manager: SessionManager | None = None,
134134
structured_output_prompt: str | None = None,
135135
tool_executor: ToolExecutor | None = None,
@@ -183,7 +183,8 @@ def __init__(
183183
Plugins are initialized with the agent instance after construction and can register hooks,
184184
modify agent attributes, or perform other setup tasks.
185185
Defaults to None.
186-
hooks: Hooks to be added to the agent hook registry. Accepts HookProvider instances
186+
hooks: Hooks to be added to the agent hook registry. Accepts any sequence
187+
(list, tuple) of HookProvider instances
187188
or plain callable hook callbacks (functions with typed event parameters).
188189
Defaults to None.
189190
session_manager: Manager for handling agent sessions including conversation history and state.

tests/strands/agent/test_agent_hooks.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,3 +1021,68 @@ def interrupt_tool(event: BeforeToolCallEvent):
10211021
assert result.stop_reason == "end_turn"
10221022
assert result.message["content"][0]["text"] == "Final response"
10231023
assert agent._interrupt_state.activated is False
1024+
1025+
1026+
def test_hooks_param_accepts_callable():
1027+
"""Verify that a plain callable can be passed via hooks parameter."""
1028+
events_received = []
1029+
1030+
def my_callback(event: AgentInitializedEvent) -> None:
1031+
events_received.append(event)
1032+
1033+
agent = Agent(hooks=[my_callback], callback_handler=None)
1034+
1035+
assert len(events_received) == 1
1036+
assert isinstance(events_received[0], AgentInitializedEvent)
1037+
assert events_received[0].agent is agent
1038+
1039+
1040+
def test_hooks_param_accepts_mixed_list():
1041+
"""Verify that a mix of HookProviders and callables can be passed."""
1042+
callback_events = []
1043+
1044+
def my_callback(event: AgentInitializedEvent) -> None:
1045+
callback_events.append(event)
1046+
1047+
provider = MockHookProvider(event_types=[AgentInitializedEvent])
1048+
1049+
agent = Agent(hooks=[provider, my_callback], callback_handler=None)
1050+
1051+
assert len(callback_events) == 1
1052+
assert callback_events[0].agent is agent
1053+
length, _ = provider.get_events()
1054+
assert length == 1
1055+
1056+
1057+
def test_hooks_param_accepts_tuple():
1058+
"""Verify that a tuple of hooks can be passed (Sequence support)."""
1059+
events_received = []
1060+
1061+
def my_callback(event: AgentInitializedEvent) -> None:
1062+
events_received.append(event)
1063+
1064+
agent = Agent(hooks=(my_callback,), callback_handler=None)
1065+
1066+
assert len(events_received) == 1
1067+
assert events_received[0].agent is agent
1068+
1069+
1070+
def test_hooks_param_invalid_hook_raises_error():
1071+
"""Verify that passing an invalid hook raises ValueError."""
1072+
with pytest.raises(ValueError, match="Invalid hook"):
1073+
Agent(hooks=["not_a_hook"], callback_handler=None) # type: ignore
1074+
1075+
1076+
def test_hooks_param_callable_invoked_during_lifecycle():
1077+
"""Verify callable hooks fire during agent lifecycle."""
1078+
before_events = []
1079+
1080+
def on_before(event: BeforeInvocationEvent) -> None:
1081+
before_events.append(event)
1082+
1083+
mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Hello!"}]}])
1084+
agent = Agent(model=mock_model, hooks=[on_before], callback_handler=None)
1085+
agent("test")
1086+
1087+
assert len(before_events) == 1
1088+
assert isinstance(before_events[0], BeforeInvocationEvent)

tests/strands/agent/test_agent_hooks_callable.py

Lines changed: 0 additions & 186 deletions
This file was deleted.

0 commit comments

Comments
 (0)