@@ -187,10 +187,43 @@ func TestCodeModeTool_CallEcho(t *testing.T) {
187187 require .Empty (t , scriptResult .StdOut )
188188}
189189
190+ // TestCodeModeTool_StartRollsBackOnError verifies that when one toolset fails
191+ // to start, all successfully-started toolsets are stopped (rolled back).
192+ func TestCodeModeTool_StartRollsBackOnError (t * testing.T ) {
193+ failing := & testToolSet {startErr : assert .AnError }
194+ healthy := & testToolSet {}
195+
196+ tool := Wrap (healthy , failing ).(tools.Startable )
197+
198+ err := tool .Start (t .Context ())
199+ require .ErrorIs (t , err , assert .AnError )
200+ assert .Equal (t , 1 , failing .start , "failing toolset should have attempted start" )
201+ assert .Equal (t , 1 , healthy .start , "healthy toolset should have attempted start" )
202+ assert .Equal (t , 1 , healthy .stop , "healthy toolset should be rolled back after failure" )
203+ }
204+
205+ // TestCodeModeTool_StartStopWrappedToolSet verifies that Start/Stop find
206+ // Startable through a StartableToolSet wrapper via tools.As.
207+ func TestCodeModeTool_StartStopWrappedToolSet (t * testing.T ) {
208+ inner := & testToolSet {}
209+ wrapped := tools .NewStartable (inner )
210+
211+ tool := Wrap (wrapped ).(tools.Startable )
212+
213+ err := tool .Start (t .Context ())
214+ require .NoError (t , err )
215+ assert .Equal (t , 1 , inner .start )
216+
217+ err = tool .Stop (t .Context ())
218+ require .NoError (t , err )
219+ assert .Equal (t , 1 , inner .stop )
220+ }
221+
190222type testToolSet struct {
191- tools []tools.Tool
192- start int
193- stop int
223+ tools []tools.Tool
224+ start int
225+ stop int
226+ startErr error
194227}
195228
196229// Verify interface compliance
@@ -205,7 +238,7 @@ func (t *testToolSet) Tools(context.Context) ([]tools.Tool, error) {
205238
206239func (t * testToolSet ) Start (context.Context ) error {
207240 t .start ++
208- return nil
241+ return t . startErr
209242}
210243
211244func (t * testToolSet ) Stop (context.Context ) error {
0 commit comments