From 095f12b0e5ee29c399831b3681fc159e390a8fa9 Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Thu, 26 Mar 2026 18:21:36 +0000 Subject: [PATCH 1/4] feat(hooks): accept callable hook callbacks in Agent constructor hooks parameter Previously, the hooks parameter in Agent.__init__ only accepted HookProvider instances. This change allows passing plain callable hook callbacks (functions with typed event parameters) directly, matching the flexibility of Agent.add_hook(). The hooks param now accepts a mixed list of HookProvider instances and/or callable hook callbacks: def on_start(event: BeforeInvocationEvent) -> None: print('Starting!') agent = Agent(hooks=[on_start, MyHookProvider()]) This provides a more convenient API for simple hook use cases where creating a full HookProvider class is unnecessary. Changes: - Updated hooks param type: list[HookProvider | HookCallback] | None - Added isinstance check to dispatch HookProviders vs callables - Added ValueError for invalid hook types - Added comprehensive tests (12 test cases) --- src/strands/agent/agent.py | 14 +- .../agent/test_agent_hooks_callable.py | 192 ++++++++++++++++++ 2 files changed, 203 insertions(+), 3 deletions(-) create mode 100644 tests/strands/agent/test_agent_hooks_callable.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 3a23133de..439471a84 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -132,7 +132,7 @@ def __init__( description: str | None = None, state: AgentState | dict | None = None, plugins: list[Plugin] | None = None, - hooks: list[HookProvider] | None = None, + hooks: list[HookProvider | HookCallback] | None = None, session_manager: SessionManager | None = None, structured_output_prompt: str | None = None, tool_executor: ToolExecutor | None = None, @@ -187,7 +187,8 @@ def __init__( Plugins are initialized with the agent instance after construction and can register hooks, modify agent attributes, or perform other setup tasks. Defaults to None. - hooks: hooks to be added to the agent hook registry + hooks: Hooks to be added to the agent hook registry. Accepts HookProvider instances + or plain callable hook callbacks (functions with typed event parameters). Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. @@ -341,7 +342,14 @@ def __init__( if hooks: for hook in hooks: - self.hooks.add_hook(hook) + if isinstance(hook, HookProvider): + self.hooks.add_hook(hook) + elif callable(hook): + self.hooks.add_callback(None, hook) + else: + raise ValueError( + f"Invalid hook: {hook!r}. Must be a HookProvider instance or a callable hook callback." + ) # Register built-in plugins self._plugin_registry.add_and_init(_ModelPlugin()) diff --git a/tests/strands/agent/test_agent_hooks_callable.py b/tests/strands/agent/test_agent_hooks_callable.py new file mode 100644 index 000000000..eb48aa02b --- /dev/null +++ b/tests/strands/agent/test_agent_hooks_callable.py @@ -0,0 +1,192 @@ +"""Tests for accepting callable hook callbacks in Agent constructor's hooks parameter.""" + +from unittest.mock import MagicMock + +import pytest + +from strands import Agent +from strands.hooks import ( + AfterInvocationEvent, + AgentInitializedEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + HookProvider, + HookRegistry, +) +from tests.fixtures.mocked_model_provider import MockedModelProvider + + +class TestHooksParamAcceptsCallables: + """Test that the Agent constructor's hooks parameter accepts both HookProviders and callables.""" + + def test_hooks_param_accepts_callable(self): + """Verify that a plain callable can be passed via hooks parameter.""" + events_received = [] + + def my_callback(event: AgentInitializedEvent) -> None: + events_received.append(event) + + agent = Agent(hooks=[my_callback], callback_handler=None) + + assert len(events_received) == 1 + assert isinstance(events_received[0], AgentInitializedEvent) + assert events_received[0].agent is agent + + def test_hooks_param_accepts_hook_provider(self): + """Verify that HookProvider still works as before (backward compatibility).""" + + class MyProvider(HookProvider): + def __init__(self): + self.events = [] + + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(AgentInitializedEvent, self.on_init) + + def on_init(self, event: AgentInitializedEvent) -> None: + self.events.append(event) + + provider = MyProvider() + agent = Agent(hooks=[provider], callback_handler=None) + + assert len(provider.events) == 1 + assert isinstance(provider.events[0], AgentInitializedEvent) + + def test_hooks_param_accepts_mixed_list(self): + """Verify that a mix of HookProviders and callables can be passed.""" + callback_events = [] + provider_events = [] + + def my_callback(event: AgentInitializedEvent) -> None: + callback_events.append(event) + + class MyProvider(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(AgentInitializedEvent, lambda e: provider_events.append(e)) + + agent = Agent(hooks=[MyProvider(), my_callback], callback_handler=None) + + assert len(callback_events) == 1 + assert len(provider_events) == 1 + assert callback_events[0].agent is agent + assert provider_events[0].agent is agent + + def test_hooks_param_callable_invoked_during_agent_lifecycle(self): + """Verify that callable hooks registered via hooks param fire during agent lifecycle.""" + before_events = [] + after_events = [] + + def on_before(event: BeforeInvocationEvent) -> None: + before_events.append(event) + + def on_after(event: AfterInvocationEvent) -> None: + after_events.append(event) + + mock_model = MockedModelProvider( + [{"role": "assistant", "content": [{"text": "Hello!"}]}] + ) + + agent = Agent( + model=mock_model, + hooks=[on_before, on_after], + callback_handler=None, + ) + agent("test prompt") + + assert len(before_events) == 1 + assert len(after_events) == 1 + assert isinstance(before_events[0], BeforeInvocationEvent) + assert isinstance(after_events[0], AfterInvocationEvent) + + def test_hooks_param_invalid_hook_raises_error(self): + """Verify that passing an invalid hook (not HookProvider or callable) raises ValueError.""" + with pytest.raises(ValueError, match="Invalid hook"): + Agent(hooks=["not_a_hook"], callback_handler=None) # type: ignore + + def test_hooks_param_none_is_valid(self): + """Verify that passing None for hooks is still valid.""" + agent = Agent(hooks=None, callback_handler=None) + assert agent is not None + + def test_hooks_param_empty_list_is_valid(self): + """Verify that passing an empty list for hooks is still valid.""" + agent = Agent(hooks=[], callback_handler=None) + assert agent is not None + + def test_hooks_param_callable_with_explicit_type_hint(self): + """Verify that callables with typed event parameters work via hooks param.""" + model_call_events = [] + + def on_model_call(event: BeforeModelCallEvent) -> None: + model_call_events.append(event) + + mock_model = MockedModelProvider( + [{"role": "assistant", "content": [{"text": "result"}]}] + ) + + agent = Agent( + model=mock_model, + hooks=[on_model_call], + callback_handler=None, + ) + agent("prompt") + + assert len(model_call_events) >= 1 + assert isinstance(model_call_events[0], BeforeModelCallEvent) + + def test_hooks_param_lambda_without_type_hint_raises_error(self): + """Verify that lambda functions without type hints raise ValueError.""" + with pytest.raises(ValueError, match="cannot infer event type"): + Agent( + hooks=[lambda event: None], # type: ignore + callback_handler=None, + ) + + def test_hooks_param_multiple_callables(self): + """Verify that multiple callables can be registered.""" + events_a = [] + events_b = [] + + def callback_a(event: AgentInitializedEvent) -> None: + events_a.append(event) + + def callback_b(event: AgentInitializedEvent) -> None: + events_b.append(event) + + agent = Agent(hooks=[callback_a, callback_b], callback_handler=None) + + assert len(events_a) == 1 + assert len(events_b) == 1 + + +class TestHooksParamAsyncCallables: + """Test that the Agent constructor's hooks parameter accepts async callables.""" + + def test_hooks_param_accepts_async_before_invocation_callback(self): + """Verify that async callable hooks can be registered for non-init events.""" + events_received = [] + + async def my_async_callback(event: BeforeInvocationEvent) -> None: + events_received.append(event) + + mock_model = MockedModelProvider( + [{"role": "assistant", "content": [{"text": "Hello!"}]}] + ) + + agent = Agent( + model=mock_model, + hooks=[my_async_callback], + callback_handler=None, + ) + agent("test") + + assert len(events_received) == 1 + assert isinstance(events_received[0], BeforeInvocationEvent) + + def test_hooks_param_rejects_async_agent_initialized_callback(self): + """Verify that async callbacks for AgentInitializedEvent raise ValueError.""" + + async def my_async_callback(event: AgentInitializedEvent) -> None: + pass + + with pytest.raises(ValueError, match="AgentInitializedEvent can only be registered with a synchronous callback"): + Agent(hooks=[my_async_callback], callback_handler=None) From 5bb7efc8caa0125647d54ded15968e4ddc4e974d Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Fri, 27 Mar 2026 14:31:40 +0000 Subject: [PATCH 2/4] fix: address hatch run prepare lint and formatting issues - Remove unused variable assignments (F841 lint errors) - Apply auto-formatting from hatch run format --- .../agent/test_agent_hooks_callable.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/tests/strands/agent/test_agent_hooks_callable.py b/tests/strands/agent/test_agent_hooks_callable.py index eb48aa02b..d42017505 100644 --- a/tests/strands/agent/test_agent_hooks_callable.py +++ b/tests/strands/agent/test_agent_hooks_callable.py @@ -1,7 +1,5 @@ """Tests for accepting callable hook callbacks in Agent constructor's hooks parameter.""" -from unittest.mock import MagicMock - import pytest from strands import Agent @@ -46,7 +44,7 @@ def on_init(self, event: AgentInitializedEvent) -> None: self.events.append(event) provider = MyProvider() - agent = Agent(hooks=[provider], callback_handler=None) + Agent(hooks=[provider], callback_handler=None) assert len(provider.events) == 1 assert isinstance(provider.events[0], AgentInitializedEvent) @@ -81,9 +79,7 @@ def on_before(event: BeforeInvocationEvent) -> None: def on_after(event: AfterInvocationEvent) -> None: after_events.append(event) - mock_model = MockedModelProvider( - [{"role": "assistant", "content": [{"text": "Hello!"}]}] - ) + mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Hello!"}]}]) agent = Agent( model=mock_model, @@ -119,9 +115,7 @@ def test_hooks_param_callable_with_explicit_type_hint(self): def on_model_call(event: BeforeModelCallEvent) -> None: model_call_events.append(event) - mock_model = MockedModelProvider( - [{"role": "assistant", "content": [{"text": "result"}]}] - ) + mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "result"}]}]) agent = Agent( model=mock_model, @@ -152,7 +146,7 @@ def callback_a(event: AgentInitializedEvent) -> None: def callback_b(event: AgentInitializedEvent) -> None: events_b.append(event) - agent = Agent(hooks=[callback_a, callback_b], callback_handler=None) + Agent(hooks=[callback_a, callback_b], callback_handler=None) assert len(events_a) == 1 assert len(events_b) == 1 @@ -168,9 +162,7 @@ def test_hooks_param_accepts_async_before_invocation_callback(self): async def my_async_callback(event: BeforeInvocationEvent) -> None: events_received.append(event) - mock_model = MockedModelProvider( - [{"role": "assistant", "content": [{"text": "Hello!"}]}] - ) + mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Hello!"}]}]) agent = Agent( model=mock_model, @@ -188,5 +180,7 @@ def test_hooks_param_rejects_async_agent_initialized_callback(self): async def my_async_callback(event: AgentInitializedEvent) -> None: pass - with pytest.raises(ValueError, match="AgentInitializedEvent can only be registered with a synchronous callback"): + with pytest.raises( + ValueError, match="AgentInitializedEvent can only be registered with a synchronous callback" + ): Agent(hooks=[my_async_callback], callback_handler=None) From abc819dcb9b9244d2d996133fa0b80e1f570c052 Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral <217235299+strands-agent@users.noreply.github.com> Date: Tue, 31 Mar 2026 19:56:34 +0000 Subject: [PATCH 3/4] refactor: address review feedback - trim tests, support Sequence Review feedback from @Unshure: 1. hooks param now accepts Sequence (list, tuple) instead of just list - Changed type hint to Sequence[HookProvider | HookCallback] - Added Sequence import from collections.abc 2. Moved tests to existing test_agent_hooks.py and trimmed down - Removed separate test_agent_hooks_callable.py file - Kept 5 essential tests: callable, mixed list, tuple, invalid, lifecycle - Removed duplicate HookProvider test (already covered by test_agent__init__hooks) - Removed redundant tests (None, empty list, async, lambda, multiple) 3. Removed duplicated test_hooks_param_accepts_hook_provider - This was already covered by existing test_agent__init__hooks --- src/strands/agent/agent.py | 7 +- tests/strands/agent/test_agent_hooks.py | 65 ++++++ .../agent/test_agent_hooks_callable.py | 186 ------------------ 3 files changed, 69 insertions(+), 189 deletions(-) delete mode 100644 tests/strands/agent/test_agent_hooks_callable.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 439471a84..0b49b36af 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -12,7 +12,7 @@ import logging import threading import warnings -from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping, Sequence from typing import ( TYPE_CHECKING, Any, @@ -132,7 +132,7 @@ def __init__( description: str | None = None, state: AgentState | dict | None = None, plugins: list[Plugin] | None = None, - hooks: list[HookProvider | HookCallback] | None = None, + hooks: Sequence[HookProvider | HookCallback] | None = None, session_manager: SessionManager | None = None, structured_output_prompt: str | None = None, tool_executor: ToolExecutor | None = None, @@ -187,7 +187,8 @@ def __init__( Plugins are initialized with the agent instance after construction and can register hooks, modify agent attributes, or perform other setup tasks. Defaults to None. - hooks: Hooks to be added to the agent hook registry. Accepts HookProvider instances + hooks: Hooks to be added to the agent hook registry. Accepts any sequence + (list, tuple) of HookProvider instances or plain callable hook callbacks (functions with typed event parameters). Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 1da245d70..2036db114 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -1021,3 +1021,68 @@ def interrupt_tool(event: BeforeToolCallEvent): assert result.stop_reason == "end_turn" assert result.message["content"][0]["text"] == "Final response" assert agent._interrupt_state.activated is False + + +def test_hooks_param_accepts_callable(): + """Verify that a plain callable can be passed via hooks parameter.""" + events_received = [] + + def my_callback(event: AgentInitializedEvent) -> None: + events_received.append(event) + + agent = Agent(hooks=[my_callback], callback_handler=None) + + assert len(events_received) == 1 + assert isinstance(events_received[0], AgentInitializedEvent) + assert events_received[0].agent is agent + + +def test_hooks_param_accepts_mixed_list(): + """Verify that a mix of HookProviders and callables can be passed.""" + callback_events = [] + + def my_callback(event: AgentInitializedEvent) -> None: + callback_events.append(event) + + provider = MockHookProvider(event_types=[AgentInitializedEvent]) + + agent = Agent(hooks=[provider, my_callback], callback_handler=None) + + assert len(callback_events) == 1 + assert callback_events[0].agent is agent + length, _ = provider.get_events() + assert length == 1 + + +def test_hooks_param_accepts_tuple(): + """Verify that a tuple of hooks can be passed (Sequence support).""" + events_received = [] + + def my_callback(event: AgentInitializedEvent) -> None: + events_received.append(event) + + agent = Agent(hooks=(my_callback,), callback_handler=None) + + assert len(events_received) == 1 + assert events_received[0].agent is agent + + +def test_hooks_param_invalid_hook_raises_error(): + """Verify that passing an invalid hook raises ValueError.""" + with pytest.raises(ValueError, match="Invalid hook"): + Agent(hooks=["not_a_hook"], callback_handler=None) # type: ignore + + +def test_hooks_param_callable_invoked_during_lifecycle(): + """Verify callable hooks fire during agent lifecycle.""" + before_events = [] + + def on_before(event: BeforeInvocationEvent) -> None: + before_events.append(event) + + mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Hello!"}]}]) + agent = Agent(model=mock_model, hooks=[on_before], callback_handler=None) + agent("test") + + assert len(before_events) == 1 + assert isinstance(before_events[0], BeforeInvocationEvent) diff --git a/tests/strands/agent/test_agent_hooks_callable.py b/tests/strands/agent/test_agent_hooks_callable.py deleted file mode 100644 index d42017505..000000000 --- a/tests/strands/agent/test_agent_hooks_callable.py +++ /dev/null @@ -1,186 +0,0 @@ -"""Tests for accepting callable hook callbacks in Agent constructor's hooks parameter.""" - -import pytest - -from strands import Agent -from strands.hooks import ( - AfterInvocationEvent, - AgentInitializedEvent, - BeforeInvocationEvent, - BeforeModelCallEvent, - HookProvider, - HookRegistry, -) -from tests.fixtures.mocked_model_provider import MockedModelProvider - - -class TestHooksParamAcceptsCallables: - """Test that the Agent constructor's hooks parameter accepts both HookProviders and callables.""" - - def test_hooks_param_accepts_callable(self): - """Verify that a plain callable can be passed via hooks parameter.""" - events_received = [] - - def my_callback(event: AgentInitializedEvent) -> None: - events_received.append(event) - - agent = Agent(hooks=[my_callback], callback_handler=None) - - assert len(events_received) == 1 - assert isinstance(events_received[0], AgentInitializedEvent) - assert events_received[0].agent is agent - - def test_hooks_param_accepts_hook_provider(self): - """Verify that HookProvider still works as before (backward compatibility).""" - - class MyProvider(HookProvider): - def __init__(self): - self.events = [] - - def register_hooks(self, registry: HookRegistry) -> None: - registry.add_callback(AgentInitializedEvent, self.on_init) - - def on_init(self, event: AgentInitializedEvent) -> None: - self.events.append(event) - - provider = MyProvider() - Agent(hooks=[provider], callback_handler=None) - - assert len(provider.events) == 1 - assert isinstance(provider.events[0], AgentInitializedEvent) - - def test_hooks_param_accepts_mixed_list(self): - """Verify that a mix of HookProviders and callables can be passed.""" - callback_events = [] - provider_events = [] - - def my_callback(event: AgentInitializedEvent) -> None: - callback_events.append(event) - - class MyProvider(HookProvider): - def register_hooks(self, registry: HookRegistry) -> None: - registry.add_callback(AgentInitializedEvent, lambda e: provider_events.append(e)) - - agent = Agent(hooks=[MyProvider(), my_callback], callback_handler=None) - - assert len(callback_events) == 1 - assert len(provider_events) == 1 - assert callback_events[0].agent is agent - assert provider_events[0].agent is agent - - def test_hooks_param_callable_invoked_during_agent_lifecycle(self): - """Verify that callable hooks registered via hooks param fire during agent lifecycle.""" - before_events = [] - after_events = [] - - def on_before(event: BeforeInvocationEvent) -> None: - before_events.append(event) - - def on_after(event: AfterInvocationEvent) -> None: - after_events.append(event) - - mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Hello!"}]}]) - - agent = Agent( - model=mock_model, - hooks=[on_before, on_after], - callback_handler=None, - ) - agent("test prompt") - - assert len(before_events) == 1 - assert len(after_events) == 1 - assert isinstance(before_events[0], BeforeInvocationEvent) - assert isinstance(after_events[0], AfterInvocationEvent) - - def test_hooks_param_invalid_hook_raises_error(self): - """Verify that passing an invalid hook (not HookProvider or callable) raises ValueError.""" - with pytest.raises(ValueError, match="Invalid hook"): - Agent(hooks=["not_a_hook"], callback_handler=None) # type: ignore - - def test_hooks_param_none_is_valid(self): - """Verify that passing None for hooks is still valid.""" - agent = Agent(hooks=None, callback_handler=None) - assert agent is not None - - def test_hooks_param_empty_list_is_valid(self): - """Verify that passing an empty list for hooks is still valid.""" - agent = Agent(hooks=[], callback_handler=None) - assert agent is not None - - def test_hooks_param_callable_with_explicit_type_hint(self): - """Verify that callables with typed event parameters work via hooks param.""" - model_call_events = [] - - def on_model_call(event: BeforeModelCallEvent) -> None: - model_call_events.append(event) - - mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "result"}]}]) - - agent = Agent( - model=mock_model, - hooks=[on_model_call], - callback_handler=None, - ) - agent("prompt") - - assert len(model_call_events) >= 1 - assert isinstance(model_call_events[0], BeforeModelCallEvent) - - def test_hooks_param_lambda_without_type_hint_raises_error(self): - """Verify that lambda functions without type hints raise ValueError.""" - with pytest.raises(ValueError, match="cannot infer event type"): - Agent( - hooks=[lambda event: None], # type: ignore - callback_handler=None, - ) - - def test_hooks_param_multiple_callables(self): - """Verify that multiple callables can be registered.""" - events_a = [] - events_b = [] - - def callback_a(event: AgentInitializedEvent) -> None: - events_a.append(event) - - def callback_b(event: AgentInitializedEvent) -> None: - events_b.append(event) - - Agent(hooks=[callback_a, callback_b], callback_handler=None) - - assert len(events_a) == 1 - assert len(events_b) == 1 - - -class TestHooksParamAsyncCallables: - """Test that the Agent constructor's hooks parameter accepts async callables.""" - - def test_hooks_param_accepts_async_before_invocation_callback(self): - """Verify that async callable hooks can be registered for non-init events.""" - events_received = [] - - async def my_async_callback(event: BeforeInvocationEvent) -> None: - events_received.append(event) - - mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Hello!"}]}]) - - agent = Agent( - model=mock_model, - hooks=[my_async_callback], - callback_handler=None, - ) - agent("test") - - assert len(events_received) == 1 - assert isinstance(events_received[0], BeforeInvocationEvent) - - def test_hooks_param_rejects_async_agent_initialized_callback(self): - """Verify that async callbacks for AgentInitializedEvent raise ValueError.""" - - async def my_async_callback(event: AgentInitializedEvent) -> None: - pass - - with pytest.raises( - ValueError, match="AgentInitializedEvent can only be registered with a synchronous callback" - ): - Agent(hooks=[my_async_callback], callback_handler=None) From 3127130c3e52e8ef94d87c8d685b4cadc811d8b8 Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Tue, 31 Mar 2026 22:00:56 +0000 Subject: [PATCH 4/4] revert: keep hooks param as list, not Sequence Per @mkmeral's feedback: keep hooks as list[...] for devx simplicity. If users want explicit tuple/Sequence control, they can use add_hook(). Changes: - Revert Sequence -> list in hooks type hint - Revert docstring to not mention tuple support - Remove Sequence import (no longer needed) - Remove test_hooks_param_accepts_tuple test --- src/strands/agent/agent.py | 7 +++---- tests/strands/agent/test_agent_hooks.py | 13 ------------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 0b49b36af..439471a84 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -12,7 +12,7 @@ import logging import threading import warnings -from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping, Sequence +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping from typing import ( TYPE_CHECKING, Any, @@ -132,7 +132,7 @@ def __init__( description: str | None = None, state: AgentState | dict | None = None, plugins: list[Plugin] | None = None, - hooks: Sequence[HookProvider | HookCallback] | None = None, + hooks: list[HookProvider | HookCallback] | None = None, session_manager: SessionManager | None = None, structured_output_prompt: str | None = None, tool_executor: ToolExecutor | None = None, @@ -187,8 +187,7 @@ def __init__( Plugins are initialized with the agent instance after construction and can register hooks, modify agent attributes, or perform other setup tasks. Defaults to None. - hooks: Hooks to be added to the agent hook registry. Accepts any sequence - (list, tuple) of HookProvider instances + hooks: Hooks to be added to the agent hook registry. Accepts HookProvider instances or plain callable hook callbacks (functions with typed event parameters). Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 2036db114..3a40d69a8 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -1054,19 +1054,6 @@ def my_callback(event: AgentInitializedEvent) -> None: assert length == 1 -def test_hooks_param_accepts_tuple(): - """Verify that a tuple of hooks can be passed (Sequence support).""" - events_received = [] - - def my_callback(event: AgentInitializedEvent) -> None: - events_received.append(event) - - agent = Agent(hooks=(my_callback,), callback_handler=None) - - assert len(events_received) == 1 - assert events_received[0].agent is agent - - def test_hooks_param_invalid_hook_raises_error(): """Verify that passing an invalid hook raises ValueError.""" with pytest.raises(ValueError, match="Invalid hook"):