Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
description: str | None = None,
state: AgentState | dict | None = None,
plugins: list[Plugin] | None = None,
hooks: list[HookProvider] | None = None,
hooks: list[HookProvider | HookCallback] | None = None,
session_manager: SessionManager | None = None,
structured_output_prompt: str | None = None,
tool_executor: ToolExecutor | None = None,
Expand Down Expand Up @@ -187,7 +187,8 @@ def __init__(
Plugins are initialized with the agent instance after construction and can register hooks,
modify agent attributes, or perform other setup tasks.
Defaults to None.
hooks: hooks to be added to the agent hook registry
hooks: Hooks to be added to the agent hook registry. Accepts HookProvider instances
or plain callable hook callbacks (functions with typed event parameters).
Defaults to None.
session_manager: Manager for handling agent sessions including conversation history and state.
If provided, enables session-based persistence and state management.
Expand Down Expand Up @@ -341,7 +342,14 @@ def __init__(

if hooks:
for hook in hooks:
self.hooks.add_hook(hook)
if isinstance(hook, HookProvider):
self.hooks.add_hook(hook)
elif callable(hook):
self.hooks.add_callback(None, hook)
else:
raise ValueError(
f"Invalid hook: {hook!r}. Must be a HookProvider instance or a callable hook callback."
)

# Register built-in plugins
self._plugin_registry.add_and_init(_ModelPlugin())
Expand Down
52 changes: 52 additions & 0 deletions tests/strands/agent/test_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,3 +1021,55 @@ def interrupt_tool(event: BeforeToolCallEvent):
assert result.stop_reason == "end_turn"
assert result.message["content"][0]["text"] == "Final response"
assert agent._interrupt_state.activated is False


def test_hooks_param_accepts_callable():
"""Verify that a plain callable can be passed via hooks parameter."""
events_received = []

def my_callback(event: AgentInitializedEvent) -> None:
events_received.append(event)

agent = Agent(hooks=[my_callback], callback_handler=None)

assert len(events_received) == 1
assert isinstance(events_received[0], AgentInitializedEvent)
assert events_received[0].agent is agent


def test_hooks_param_accepts_mixed_list():
"""Verify that a mix of HookProviders and callables can be passed."""
callback_events = []

def my_callback(event: AgentInitializedEvent) -> None:
callback_events.append(event)

provider = MockHookProvider(event_types=[AgentInitializedEvent])

agent = Agent(hooks=[provider, my_callback], callback_handler=None)

assert len(callback_events) == 1
assert callback_events[0].agent is agent
length, _ = provider.get_events()
assert length == 1


def test_hooks_param_invalid_hook_raises_error():
"""Verify that passing an invalid hook raises ValueError."""
with pytest.raises(ValueError, match="Invalid hook"):
Agent(hooks=["not_a_hook"], callback_handler=None) # type: ignore


def test_hooks_param_callable_invoked_during_lifecycle():
"""Verify callable hooks fire during agent lifecycle."""
before_events = []

def on_before(event: BeforeInvocationEvent) -> None:
before_events.append(event)

mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Hello!"}]}])
agent = Agent(model=mock_model, hooks=[on_before], callback_handler=None)
agent("test")

assert len(before_events) == 1
assert isinstance(before_events[0], BeforeInvocationEvent)
Loading