Skip to content

Commit ba4d966

Browse files
authored
Merge pull request #2127 from dgageot/board/look-at-the-buitin-shell-tool-code-and-f-65e4e4bf
Fix two data races in shell tool
2 parents 7ec0a5b + 3568356 commit ba4d966

1 file changed

Lines changed: 34 additions & 28 deletions

File tree

pkg/tools/builtin/shell.go

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,11 @@ type backgroundJob struct {
7474
err error
7575
}
7676

77-
// limitedWriter wraps a buffer and stops writing after maxSize bytes
77+
// limitedWriter wraps a buffer and stops writing after maxSize bytes.
78+
// It uses an external mutex (mu) so that readers of the underlying buffer
79+
// can share the same lock.
7880
type limitedWriter struct {
79-
mu sync.Mutex
81+
mu *sync.RWMutex
8082
buf *bytes.Buffer
8183
written int64
8284
maxSize int64
@@ -86,20 +88,12 @@ func (lw *limitedWriter) Write(p []byte) (n int, err error) {
8688
lw.mu.Lock()
8789
defer lw.mu.Unlock()
8890

89-
if lw.written >= lw.maxSize {
90-
return len(p), nil // Discard but report success
91+
if remaining := lw.maxSize - lw.written; remaining > 0 {
92+
toWrite := min(int64(len(p)), remaining)
93+
lw.buf.Write(p[:toWrite]) // bytes.Buffer.Write never errors
94+
lw.written += toWrite
9195
}
92-
93-
remaining := lw.maxSize - lw.written
94-
toWrite := min(int64(len(p)), remaining)
95-
96-
n, err = lw.buf.Write(p[:toWrite])
97-
lw.written += int64(n)
98-
99-
if err == nil && int64(n) < int64(len(p)) {
100-
return len(p), nil // Report full write even if truncated
101-
}
102-
return n, err
96+
return len(p), nil // always report full write
10397
}
10498

10599
type RunShellArgs struct {
@@ -184,6 +178,15 @@ func (h *shellHandler) runNativeCommand(timeoutCtx, ctx context.Context, command
184178
select {
185179
case <-timeoutCtx.Done():
186180
_ = kill(cmd.Process, pg)
181+
// Wait for cmd.Wait() to complete so that the internal pipe-copy
182+
// goroutines finish writing to outBuf before we read it.
183+
// Use a grace period: if SIGTERM is ignored, escalate to SIGKILL.
184+
select {
185+
case <-done:
186+
case <-time.After(3 * time.Second):
187+
_ = cmd.Process.Kill()
188+
<-done
189+
}
187190
case cmdErr = <-done:
188191
}
189192

@@ -200,10 +203,20 @@ func (h *shellHandler) RunShellBackground(_ context.Context, params RunShellBack
200203
cmd.Dir = h.resolveWorkDir(params.Cwd)
201204
cmd.SysProcAttr = platformSpecificSysProcAttr()
202205

203-
outputBuf := &bytes.Buffer{}
204-
limitedWriter := &limitedWriter{buf: outputBuf, maxSize: 10 * 1024 * 1024}
205-
cmd.Stdout = limitedWriter
206-
cmd.Stderr = limitedWriter
206+
job := &backgroundJob{
207+
id: jobID,
208+
cmd: params.Cmd,
209+
cwd: params.Cwd,
210+
output: &bytes.Buffer{},
211+
startTime: time.Now(),
212+
}
213+
214+
// The limitedWriter shares the job's outputMu so that readers
215+
// (ViewBackgroundJob, ListBackgroundJobs) and the pipe-copy
216+
// goroutines spawned by exec.Cmd use the same lock.
217+
lw := &limitedWriter{mu: &job.outputMu, buf: job.output, maxSize: 10 * 1024 * 1024}
218+
cmd.Stdout = lw
219+
cmd.Stderr = lw
207220

208221
if err := cmd.Start(); err != nil {
209222
return tools.ResultError(fmt.Sprintf("Error starting background command: %s", err)), nil
@@ -215,15 +228,8 @@ func (h *shellHandler) RunShellBackground(_ context.Context, params RunShellBack
215228
return tools.ResultError(fmt.Sprintf("Error creating process group: %s", err)), nil
216229
}
217230

218-
job := &backgroundJob{
219-
id: jobID,
220-
cmd: params.Cmd,
221-
cwd: params.Cwd,
222-
process: cmd.Process,
223-
processGroup: pg,
224-
output: outputBuf,
225-
startTime: time.Now(),
226-
}
231+
job.process = cmd.Process
232+
job.processGroup = pg
227233
job.status.Store(statusRunning)
228234
h.jobs.Store(jobID, job)
229235

0 commit comments

Comments
 (0)