Skip to content

Commit fda4ec4

Browse files
authored
Merge pull request #2114 from gtardif/fix_rag_init_context_cancel
Fix rag init context cancel
2 parents bd55840 + c8a9ae8 commit fda4ec4

7 files changed

Lines changed: 277 additions & 43 deletions

File tree

pkg/fsx/collect.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package fsx
22

33
import (
4+
"context"
45
"fmt"
56
"os"
67
"path/filepath"
@@ -13,11 +14,18 @@ import (
1314
// Supports glob patterns (via doublestar), directories, and individual files.
1415
// Skips paths that don't exist instead of returning an error.
1516
// Optional shouldIgnore filter can exclude files/directories (return true to skip).
16-
func CollectFiles(paths []string, shouldIgnore func(path string) bool) ([]string, error) {
17+
// Respects context cancellation.
18+
func CollectFiles(ctx context.Context, paths []string, shouldIgnore func(path string) bool) ([]string, error) {
1719
var files []string
1820
seen := make(map[string]bool)
1921

2022
for _, pattern := range paths {
23+
// Check for context cancellation
24+
select {
25+
case <-ctx.Done():
26+
return nil, ctx.Err()
27+
default:
28+
}
2129
expanded, err := expandPattern(pattern)
2230
if err != nil {
2331
return nil, err
@@ -27,6 +35,12 @@ func CollectFiles(paths []string, shouldIgnore func(path string) bool) ([]string
2735
}
2836

2937
for _, entry := range expanded {
38+
// Check for context cancellation
39+
select {
40+
case <-ctx.Done():
41+
return nil, ctx.Err()
42+
default:
43+
}
3044
normalized := normalizePath(entry)
3145

3246
// Check if this path should be ignored
@@ -44,14 +58,21 @@ func CollectFiles(paths []string, shouldIgnore func(path string) bool) ([]string
4458

4559
if info.IsDir() {
4660
// Use DirectoryTree to collect files from directory
47-
tree, err := DirectoryTree(normalized, func(string) error { return nil }, shouldIgnore, 0)
61+
tree, err := DirectoryTree(ctx, normalized, func(string) error { return nil }, shouldIgnore, 0)
4862
if err != nil {
4963
return nil, fmt.Errorf("failed to read directory %s: %w", normalized, err)
5064
}
5165
// Traverse tree and collect absolute file paths
5266
var dirFiles []string
5367
CollectFilesFromTree(tree, filepath.Dir(normalized), &dirFiles)
5468
for _, f := range dirFiles {
69+
// Check for context cancellation
70+
select {
71+
case <-ctx.Done():
72+
return nil, ctx.Err()
73+
default:
74+
}
75+
5576
absPath := normalizePath(f)
5677
if !seen[absPath] {
5778
files = append(files, absPath)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package fsx
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
"testing"
9+
"time"
10+
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
func TestCollectFiles_ContextCancellation(t *testing.T) {
16+
t.Parallel()
17+
18+
tmpDir := t.TempDir()
19+
20+
// Create a large directory structure to ensure context cancellation has time to kick in
21+
for i := range 100 {
22+
subDir := filepath.Join(tmpDir, "dir", "subdir", "deepdir", fmt.Sprintf("dir%d", i))
23+
require.NoError(t, os.MkdirAll(subDir, 0o755))
24+
for j := range 10 {
25+
filePath := filepath.Join(subDir, fmt.Sprintf("file%d.txt", j))
26+
require.NoError(t, os.WriteFile(filePath, []byte("test content"), 0o644))
27+
}
28+
}
29+
30+
t.Run("respects context cancellation", func(t *testing.T) {
31+
ctx, cancel := context.WithCancel(t.Context())
32+
33+
// Cancel context immediately
34+
cancel()
35+
36+
_, err := CollectFiles(ctx, []string{tmpDir}, nil)
37+
assert.ErrorIs(t, err, context.Canceled)
38+
})
39+
40+
t.Run("respects context timeout", func(t *testing.T) {
41+
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Nanosecond)
42+
defer cancel()
43+
44+
// Give time for timeout to trigger
45+
time.Sleep(10 * time.Millisecond)
46+
47+
_, err := CollectFiles(ctx, []string{tmpDir}, nil)
48+
assert.ErrorIs(t, err, context.DeadlineExceeded)
49+
})
50+
}
51+
52+
func TestDirectoryTree_ContextCancellation(t *testing.T) {
53+
t.Parallel()
54+
55+
tmpDir := t.TempDir()
56+
57+
// Create a large directory structure
58+
for i := range 100 {
59+
subDir := filepath.Join(tmpDir, "dir", "subdir", fmt.Sprintf("dir%d", i))
60+
require.NoError(t, os.MkdirAll(subDir, 0o755))
61+
for j := range 10 {
62+
filePath := filepath.Join(subDir, fmt.Sprintf("file%d.txt", j))
63+
require.NoError(t, os.WriteFile(filePath, []byte("test content"), 0o644))
64+
}
65+
}
66+
67+
t.Run("respects context cancellation", func(t *testing.T) {
68+
ctx, cancel := context.WithCancel(t.Context())
69+
70+
// Cancel context immediately
71+
cancel()
72+
73+
_, err := DirectoryTree(ctx, tmpDir, func(string) error { return nil }, nil, 0)
74+
assert.ErrorIs(t, err, context.Canceled)
75+
})
76+
77+
t.Run("respects context timeout", func(t *testing.T) {
78+
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Nanosecond)
79+
defer cancel()
80+
81+
// Give time for timeout to trigger
82+
time.Sleep(10 * time.Millisecond)
83+
84+
_, err := DirectoryTree(ctx, tmpDir, func(string) error { return nil }, nil, 0)
85+
assert.ErrorIs(t, err, context.DeadlineExceeded)
86+
})
87+
}

pkg/fsx/collect_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func TestCollectFiles_WithShouldIgnoreFilter(t *testing.T) {
5252
t.Run("no filter collects all files", func(t *testing.T) {
5353
t.Parallel()
5454

55-
got, err := CollectFiles([]string{tmpDir}, nil)
55+
got, err := CollectFiles(t.Context(), []string{tmpDir}, nil)
5656
require.NoError(t, err)
5757
assert.Len(t, got, 5)
5858
})
@@ -66,7 +66,7 @@ func TestCollectFiles_WithShouldIgnoreFilter(t *testing.T) {
6666
return base == "vendor" || base == "node_modules"
6767
}
6868

69-
got, err := CollectFiles([]string{tmpDir}, shouldIgnore)
69+
got, err := CollectFiles(t.Context(), []string{tmpDir}, shouldIgnore)
7070
require.NoError(t, err)
7171

7272
// Should only have src/*.go and build/output.bin
@@ -87,7 +87,7 @@ func TestCollectFiles_WithShouldIgnoreFilter(t *testing.T) {
8787
return strings.HasSuffix(path, ".bin")
8888
}
8989

90-
got, err := CollectFiles([]string{tmpDir}, shouldIgnore)
90+
got, err := CollectFiles(t.Context(), []string{tmpDir}, shouldIgnore)
9191
require.NoError(t, err)
9292

9393
assert.Len(t, got, 4)
@@ -104,7 +104,7 @@ func TestCollectFiles_WithShouldIgnoreFilter(t *testing.T) {
104104
return strings.Contains(path, "vendor")
105105
}
106106

107-
got, err := CollectFiles([]string{tmpDir}, shouldIgnore)
107+
got, err := CollectFiles(t.Context(), []string{tmpDir}, shouldIgnore)
108108
require.NoError(t, err)
109109

110110
for _, f := range got {
@@ -148,7 +148,7 @@ func TestCollectFiles_GitDirectoryExclusion(t *testing.T) {
148148
t.Run("without filter includes .git", func(t *testing.T) {
149149
t.Parallel()
150150

151-
got, err := CollectFiles([]string{tmpDir}, nil)
151+
got, err := CollectFiles(t.Context(), []string{tmpDir}, nil)
152152
require.NoError(t, err)
153153

154154
// Should include .git files
@@ -177,7 +177,7 @@ func TestCollectFiles_GitDirectoryExclusion(t *testing.T) {
177177
strings.HasPrefix(normalized, ".git/")
178178
}
179179

180-
got, err := CollectFiles([]string{tmpDir}, shouldIgnore)
180+
got, err := CollectFiles(t.Context(), []string{tmpDir}, shouldIgnore)
181181
require.NoError(t, err)
182182

183183
// Should only have src/main.go
@@ -229,7 +229,7 @@ func TestCollectFiles_GlobsWithFilter(t *testing.T) {
229229
return strings.HasSuffix(path, "_test.go")
230230
}
231231

232-
got, err := CollectFiles([]string{filepath.Join(tmpDir, "pkg", "**", "*.go")}, shouldIgnore)
232+
got, err := CollectFiles(t.Context(), []string{filepath.Join(tmpDir, "pkg", "**", "*.go")}, shouldIgnore)
233233
require.NoError(t, err)
234234

235235
// Should only have non-test .go files
@@ -329,7 +329,7 @@ func TestCollectFiles_Deduplication(t *testing.T) {
329329
tmpDir, // Will also include test.go
330330
}
331331

332-
got, err := CollectFiles(patterns, nil)
332+
got, err := CollectFiles(t.Context(), patterns, nil)
333333
require.NoError(t, err)
334334

335335
// Should only have one entry
@@ -352,7 +352,7 @@ func TestCollectFiles_NonExistentPaths(t *testing.T) {
352352
filepath.Join(tmpDir, "also", "missing", "file.go"),
353353
}
354354

355-
got, err := CollectFiles(patterns, nil)
355+
got, err := CollectFiles(t.Context(), patterns, nil)
356356
require.NoError(t, err)
357357

358358
// Should only have the real file
@@ -371,7 +371,7 @@ func TestCollectFiles_SortedOutput(t *testing.T) {
371371
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, f), []byte("package test"), 0o644))
372372
}
373373

374-
got, err := CollectFiles([]string{tmpDir}, nil)
374+
got, err := CollectFiles(t.Context(), []string{tmpDir}, nil)
375375
require.NoError(t, err)
376376

377377
// Verify we got all files

pkg/fsx/fs.go

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,18 @@ type TreeNode struct {
1515
Children []*TreeNode `json:"children,omitempty"`
1616
}
1717

18-
func DirectoryTree(path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int) (*TreeNode, error) {
18+
func DirectoryTree(ctx context.Context, path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int) (*TreeNode, error) {
1919
itemCount := 0
20-
return directoryTree(path, isPathAllowed, shouldIgnore, maxItems, &itemCount)
20+
return directoryTree(ctx, path, isPathAllowed, shouldIgnore, maxItems, &itemCount)
2121
}
2222

23-
func directoryTree(path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int, itemCount *int) (*TreeNode, error) {
23+
func directoryTree(ctx context.Context, path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int, itemCount *int) (*TreeNode, error) {
24+
// Check for context cancellation
25+
select {
26+
case <-ctx.Done():
27+
return nil, ctx.Err()
28+
default:
29+
}
2430
if maxItems > 0 && *itemCount >= maxItems {
2531
return nil, nil
2632
}
@@ -47,6 +53,13 @@ func directoryTree(path string, isPathAllowed func(string) error, shouldIgnore f
4753
}
4854

4955
for _, entry := range entries {
56+
// Check for context cancellation
57+
select {
58+
case <-ctx.Done():
59+
return node, ctx.Err()
60+
default:
61+
}
62+
5063
childPath := filepath.Join(path, entry.Name())
5164
if err := isPathAllowed(childPath); err != nil {
5265
continue // Skip disallowed paths
@@ -57,7 +70,7 @@ func directoryTree(path string, isPathAllowed func(string) error, shouldIgnore f
5770
continue
5871
}
5972

60-
childNode, err := directoryTree(childPath, isPathAllowed, shouldIgnore, maxItems, itemCount)
73+
childNode, err := directoryTree(ctx, childPath, isPathAllowed, shouldIgnore, maxItems, itemCount)
6174
if err != nil || childNode == nil {
6275
continue
6376
}
@@ -68,17 +81,6 @@ func directoryTree(path string, isPathAllowed func(string) error, shouldIgnore f
6881
return node, nil
6982
}
7083

71-
func ListDirectory(path string, shouldIgnore func(string) bool) ([]string, error) {
72-
tree, err := DirectoryTree(path, func(string) error { return nil }, shouldIgnore, 0)
73-
if err != nil {
74-
return nil, err
75-
}
76-
77-
var files []string
78-
CollectFilesFromTree(tree, "", &files)
79-
return files, nil
80-
}
81-
8284
// CollectFilesFromTree recursively collects file paths from a DirectoryTree.
8385
// Pass basePath="" for relative paths, or a parent directory for absolute paths.
8486
func CollectFilesFromTree(node *TreeNode, basePath string, files *[]string) {

0 commit comments

Comments
 (0)