diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 3a23133de..439471a84 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -132,7 +132,7 @@ def __init__( description: str | None = None, state: AgentState | dict | None = None, plugins: list[Plugin] | None = None, - hooks: list[HookProvider] | None = None, + hooks: list[HookProvider | HookCallback] | None = None, session_manager: SessionManager | None = None, structured_output_prompt: str | None = None, tool_executor: ToolExecutor | None = None, @@ -187,7 +187,8 @@ def __init__( Plugins are initialized with the agent instance after construction and can register hooks, modify agent attributes, or perform other setup tasks. Defaults to None. - hooks: hooks to be added to the agent hook registry + hooks: Hooks to be added to the agent hook registry. Accepts HookProvider instances + or plain callable hook callbacks (functions with typed event parameters). Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. @@ -341,7 +342,14 @@ def __init__( if hooks: for hook in hooks: - self.hooks.add_hook(hook) + if isinstance(hook, HookProvider): + self.hooks.add_hook(hook) + elif callable(hook): + self.hooks.add_callback(None, hook) + else: + raise ValueError( + f"Invalid hook: {hook!r}. Must be a HookProvider instance or a callable hook callback." + ) # Register built-in plugins self._plugin_registry.add_and_init(_ModelPlugin()) diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 1da245d70..3a40d69a8 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -1021,3 +1021,55 @@ def interrupt_tool(event: BeforeToolCallEvent): assert result.stop_reason == "end_turn" assert result.message["content"][0]["text"] == "Final response" assert agent._interrupt_state.activated is False + + +def test_hooks_param_accepts_callable(): + """Verify that a plain callable can be passed via hooks parameter.""" + events_received = [] + + def my_callback(event: AgentInitializedEvent) -> None: + events_received.append(event) + + agent = Agent(hooks=[my_callback], callback_handler=None) + + assert len(events_received) == 1 + assert isinstance(events_received[0], AgentInitializedEvent) + assert events_received[0].agent is agent + + +def test_hooks_param_accepts_mixed_list(): + """Verify that a mix of HookProviders and callables can be passed.""" + callback_events = [] + + def my_callback(event: AgentInitializedEvent) -> None: + callback_events.append(event) + + provider = MockHookProvider(event_types=[AgentInitializedEvent]) + + agent = Agent(hooks=[provider, my_callback], callback_handler=None) + + assert len(callback_events) == 1 + assert callback_events[0].agent is agent + length, _ = provider.get_events() + assert length == 1 + + +def test_hooks_param_invalid_hook_raises_error(): + """Verify that passing an invalid hook raises ValueError.""" + with pytest.raises(ValueError, match="Invalid hook"): + Agent(hooks=["not_a_hook"], callback_handler=None) # type: ignore + + +def test_hooks_param_callable_invoked_during_lifecycle(): + """Verify callable hooks fire during agent lifecycle.""" + before_events = [] + + def on_before(event: BeforeInvocationEvent) -> None: + before_events.append(event) + + mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Hello!"}]}]) + agent = Agent(model=mock_model, hooks=[on_before], callback_handler=None) + agent("test") + + assert len(before_events) == 1 + assert isinstance(before_events[0], BeforeInvocationEvent)