11"""Unit tests for steering handler base class."""
22
3+ import inspect
34from unittest .mock import AsyncMock , Mock
45
56import pytest
89from strands .experimental .steering .core .context import SteeringContext , SteeringContextCallback , SteeringContextProvider
910from strands .experimental .steering .core .handler import SteeringHandler
1011from strands .hooks .events import AfterModelCallEvent , BeforeToolCallEvent
12+ from strands .hooks .registry import HookRegistry
1113from strands .plugins import Plugin
1214
1315
@@ -39,18 +41,14 @@ def test_steering_handler_is_plugin():
3941
4042def test_init_plugin ():
4143 """Test init_plugin registers hooks on agent."""
42- from strands .hooks import HookRegistry
43-
4444 handler = TestSteeringHandler ()
4545 agent = Mock ()
46- agent .hooks = HookRegistry ()
47- agent .tool_registry = Mock ()
4846
4947 handler .init_plugin (agent )
5048
51- # Verify hooks were auto- registered via @hook decorator
52- assert len ( agent .hooks . _registered_callbacks . get ( BeforeToolCallEvent , [])) >= 1
53- assert len ( agent .hooks . _registered_callbacks . get ( AfterModelCallEvent , [])) >= 1
49+ # Verify hooks were registered (tool and model steering hooks)
50+ assert agent .add_hook . call_count >= 2
51+ agent .add_hook . assert_any_call ( handler . provide_tool_steering_guidance , BeforeToolCallEvent )
5452
5553
5654def test_steering_context_initialization ():
@@ -174,7 +172,6 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs):
174172
175173def test_init_plugin_override ():
176174 """Test that init_plugin can be overridden."""
177- from strands .hooks import HookRegistry
178175
179176 class CustomHandler (SteeringHandler ):
180177 async def steer_before_tool (self , * , agent , tool_use , ** kwargs ):
@@ -186,14 +183,11 @@ def init_plugin(self, agent):
186183
187184 handler = CustomHandler ()
188185 agent = Mock ()
189- agent .hooks = HookRegistry ()
190- agent .tool_registry = Mock ()
191186
192187 handler .init_plugin (agent )
193188
194- # Should not register any hooks since parent init_plugin wasn't called
195- assert len (agent .hooks ._registered_callbacks .get (BeforeToolCallEvent , [])) == 0
196- assert len (agent .hooks ._registered_callbacks .get (AfterModelCallEvent , [])) == 0
189+ # Should not register any hooks
190+ assert agent .add_hook .call_count == 0
197191
198192
199193# Integration tests with context providers
@@ -227,77 +221,68 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs):
227221
228222def test_handler_registers_context_provider_hooks ():
229223 """Test that handler registers hooks from context callbacks."""
230- from strands .hooks import HookRegistry
231-
232224 mock_callback = MockContextCallback ()
233225 handler = TestSteeringHandlerWithProvider (context_callbacks = [mock_callback ])
234226 agent = Mock ()
235- agent .hooks = HookRegistry ()
236- agent .tool_registry = Mock ()
237- agent .add_hook = Mock ()
238227
239228 handler .init_plugin (agent )
240229
241- # Should register 1 context callback via add_hook ( steering hooks are auto-registered)
242- assert agent .add_hook .call_count >= 1
230+ # Should register hooks for context callback and steering guidance
231+ assert agent .add_hook .call_count >= 2
243232
244- # Check that BeforeToolCallEvent was registered (either via add_hook or auto-registration)
233+ # Check that BeforeToolCallEvent was registered
245234 call_args = [call [0 ] for call in agent .add_hook .call_args_list ]
246235 event_types = [args [1 ] for args in call_args ]
247236
248237 # Context callback should be registered
249- assert (
250- BeforeToolCallEvent in event_types or len (agent .hooks ._registered_callbacks .get (BeforeToolCallEvent , [])) >= 1
251- )
238+ assert BeforeToolCallEvent in event_types
252239
253-
254- def test_context_callbacks_receive_steering_context ():
240+ @ pytest . mark . asyncio
241+ async def test_context_callbacks_receive_steering_context ():
255242 """Test that context callbacks receive the handler's steering context."""
256243 mock_callback = MockContextCallback ()
257244 handler = TestSteeringHandlerWithProvider (context_callbacks = [mock_callback ])
258245 agent = Mock ()
259-
246+ agent .hooks = HookRegistry ()
247+ agent .tool_registry = Mock ()
248+ agent .add_hook = Mock (side_effect = lambda callback , event_type = None : agent .hooks .add_callback (event_type , callback ))
260249 handler .init_plugin (agent )
261250
262- # Get the registered callback for BeforeToolCallEvent
263- before_callback = None
264- for call in agent .add_hook .call_args_list :
265- if call [0 ][1 ] == BeforeToolCallEvent :
266- before_callback = call [0 ][0 ]
267- break
251+ # Get the registered callbacks for BeforeToolCallEvent
252+ callbacks = agent .hooks ._registered_callbacks .get (BeforeToolCallEvent , [])
253+ assert len (callbacks ) > 0
268254
269- assert before_callback is not None
270-
271- # Create a mock event and call the callback
255+ # The context callback is wrapped in a lambda, so we just call all callbacks
256+ # and check if the steering context was updated
272257 event = Mock (spec = BeforeToolCallEvent )
273258 event .tool_use = {"name" : "test_tool" , "input" : {}}
274259
275- # The callback should execute without error and update the steering context
276- before_callback (event )
260+ # Call all callbacks, handling both sync and async
261+ for cb in callbacks :
262+ try :
263+ result = await cb (event )
264+ if inspect .iscoroutine (result ):
265+ await result
266+ except Exception :
267+ pass # Some callbacks might be async or have other requirements
277268
278- # Verify the steering context was updated
269+ # Verify the steering context was updated by at least one callback
279270 assert handler .steering_context .data .get ("test_key" ) == "test_value"
280271
281272
282273def test_multiple_context_callbacks_registered ():
283274 """Test that multiple context callbacks are registered."""
284- from strands .hooks import HookRegistry
285-
286275 callback1 = MockContextCallback ()
287276 callback2 = MockContextCallback ()
288277
289278 handler = TestSteeringHandlerWithProvider (context_callbacks = [callback1 , callback2 ])
290279 agent = Mock ()
291- agent .hooks = HookRegistry ()
292- agent .tool_registry = Mock ()
293- agent .add_hook = Mock ()
294280
295281 handler .init_plugin (agent )
296282
297- # Should register 2 context callbacks via add_hook, plus auto-registered @hook methods
298- assert agent .add_hook .call_count == 2 # Only context callbacks use add_hook
299- assert len (agent .hooks ._registered_callbacks .get (BeforeToolCallEvent , [])) >= 1
300- assert len (agent .hooks ._registered_callbacks .get (AfterModelCallEvent , [])) >= 1
283+ # Should register one callback for each context provider plus tool and model steering guidance
284+ expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model)
285+ assert agent .add_hook .call_count >= expected_calls
301286
302287
303288def test_handler_initialization_with_callbacks ():
@@ -509,14 +494,10 @@ async def test_default_steer_after_model_returns_proceed():
509494
510495def test_init_plugin_registers_model_steering ():
511496 """Test that init_plugin registers model steering callback."""
512- from strands .hooks import HookRegistry
513-
514497 handler = TestSteeringHandler ()
515498 agent = Mock ()
516- agent .hooks = HookRegistry ()
517- agent .tool_registry = Mock ()
518499
519500 handler .init_plugin (agent )
520501
521- # Verify model steering hook was auto- registered via @hook decorator
522- assert len ( agent .hooks . _registered_callbacks . get ( AfterModelCallEvent , [])) >= 1
502+ # Verify model steering hook was registered
503+ agent .add_hook . assert_any_call ( handler . provide_model_steering_guidance , AfterModelCallEvent )
0 commit comments