Skip to content

Commit 201d8a7

Browse files
authored
Merge pull request #2007 from dgageot/board/fix-code-mode-71471ec1
codemode: fix Start() fail-fast and use tools.As for wrapper unwrapping
2 parents ba9b462 + aec160d commit 201d8a7

2 files changed

Lines changed: 54 additions & 12 deletions

File tree

pkg/tools/codemode/codemode.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,27 +103,36 @@ func (c *codeModeTool) Tools(ctx context.Context) ([]tools.Tool, error) {
103103
}
104104

105105
func (c *codeModeTool) Start(ctx context.Context) error {
106+
var started []tools.Startable
107+
var errs []error
106108
for _, t := range c.toolsets {
107-
if startable, ok := t.(tools.Startable); ok {
108-
if err := startable.Start(ctx); err != nil {
109-
return err
109+
if s, ok := tools.As[tools.Startable](t); ok {
110+
if err := s.Start(ctx); err != nil {
111+
errs = append(errs, err)
112+
} else {
113+
started = append(started, s)
110114
}
111115
}
112116
}
113-
117+
if len(errs) > 0 {
118+
// Roll back successfully-started toolsets so we don't leave
119+
// the system in a partially-started state.
120+
for _, s := range started {
121+
errs = append(errs, s.Stop(ctx))
122+
}
123+
return errors.Join(errs...)
124+
}
114125
return nil
115126
}
116127

117128
func (c *codeModeTool) Stop(ctx context.Context) error {
118129
var errs []error
119-
120130
for _, t := range c.toolsets {
121-
if startable, ok := t.(tools.Startable); ok {
122-
if err := startable.Stop(ctx); err != nil {
131+
if s, ok := tools.As[tools.Startable](t); ok {
132+
if err := s.Stop(ctx); err != nil {
123133
errs = append(errs, err)
124134
}
125135
}
126136
}
127-
128137
return errors.Join(errs...)
129138
}

pkg/tools/codemode/codemode_test.go

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,43 @@ func TestCodeModeTool_CallEcho(t *testing.T) {
187187
require.Empty(t, scriptResult.StdOut)
188188
}
189189

190+
// TestCodeModeTool_StartRollsBackOnError verifies that when one toolset fails
191+
// to start, all successfully-started toolsets are stopped (rolled back).
192+
func TestCodeModeTool_StartRollsBackOnError(t *testing.T) {
193+
failing := &testToolSet{startErr: assert.AnError}
194+
healthy := &testToolSet{}
195+
196+
tool := Wrap(healthy, failing).(tools.Startable)
197+
198+
err := tool.Start(t.Context())
199+
require.ErrorIs(t, err, assert.AnError)
200+
assert.Equal(t, 1, failing.start, "failing toolset should have attempted start")
201+
assert.Equal(t, 1, healthy.start, "healthy toolset should have attempted start")
202+
assert.Equal(t, 1, healthy.stop, "healthy toolset should be rolled back after failure")
203+
}
204+
205+
// TestCodeModeTool_StartStopWrappedToolSet verifies that Start/Stop find
206+
// Startable through a StartableToolSet wrapper via tools.As.
207+
func TestCodeModeTool_StartStopWrappedToolSet(t *testing.T) {
208+
inner := &testToolSet{}
209+
wrapped := tools.NewStartable(inner)
210+
211+
tool := Wrap(wrapped).(tools.Startable)
212+
213+
err := tool.Start(t.Context())
214+
require.NoError(t, err)
215+
assert.Equal(t, 1, inner.start)
216+
217+
err = tool.Stop(t.Context())
218+
require.NoError(t, err)
219+
assert.Equal(t, 1, inner.stop)
220+
}
221+
190222
type testToolSet struct {
191-
tools []tools.Tool
192-
start int
193-
stop int
223+
tools []tools.Tool
224+
start int
225+
stop int
226+
startErr error
194227
}
195228

196229
// Verify interface compliance
@@ -205,7 +238,7 @@ func (t *testToolSet) Tools(context.Context) ([]tools.Tool, error) {
205238

206239
func (t *testToolSet) Start(context.Context) error {
207240
t.start++
208-
return nil
241+
return t.startErr
209242
}
210243

211244
func (t *testToolSet) Stop(context.Context) error {

0 commit comments

Comments
 (0)