Skip to content

Commit fa676c0

Browse files
authored
Merge pull request #967 from krissetto/code-aware-chunking
Code aware chunking in RAG strategies
2 parents d5de54f + bc2b832 commit fa676c0

17 files changed

Lines changed: 1172 additions & 167 deletions

File tree

Dockerfile

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,60 @@ ARG ALPINE_VERSION="3.22"
66
# xx is a helper for cross-compilation
77
FROM --platform=$BUILDPLATFORM tonistiigi/xx:1.7.0 AS xx
88

9+
# osxcross contains the MacOSX cross toolchain for xx
10+
FROM crazymax/osxcross:15.5-debian AS osxcross
11+
912
FROM --platform=$BUILDPLATFORM golang:${GO_VERSION}-alpine${ALPINE_VERSION} AS builder-base
1013
COPY --from=xx / /
1114
WORKDIR /src
1215
RUN --mount=type=cache,target=/go/pkg/mod \
1316
--mount=type=bind,source=go.mod,target=go.mod \
1417
--mount=type=bind,source=go.sum,target=go.sum \
1518
go mod download
19+
ENV CGO_ENABLED=1
20+
1621

17-
FROM builder-base AS builder
18-
COPY . ./
1922
ARG GIT_TAG
2023
ARG GIT_COMMIT
2124
ARG TARGETPLATFORM
2225
ARG TARGETOS
2326
ARG TARGETARCH
24-
RUN --mount=type=cache,target=/root/.cache,id=docker-ai-$TARGETPLATFORM \
27+
28+
FROM builder-base AS builder-darwin
29+
RUN apk add clang
30+
COPY . ./
31+
RUN --mount=type=bind,from=osxcross,src=/osxsdk,target=/xx-sdk \
32+
--mount=type=cache,target=/root/.cache,id=docker-ai-$TARGETPLATFORM \
2533
--mount=type=cache,target=/go/pkg/mod <<EOT
2634
set -ex
2735
xx-go build -trimpath -ldflags "-s -w -X 'github.com/docker/cagent/pkg/version.Version=$GIT_TAG' -X 'github.com/docker/cagent/pkg/version.Commit=$GIT_COMMIT'" -o /binaries/cagent-$TARGETOS-$TARGETARCH .
28-
xx-verify --static /binaries/cagent-$TARGETOS-$TARGETARCH
29-
if [ "$TARGETOS" = "windows" ]; then
30-
mv /binaries/cagent-$TARGETOS-$TARGETARCH /binaries/cagent-$TARGETOS-$TARGETARCH.exe
31-
fi
36+
xx-verify --static /binaries/cagent-darwin-$TARGETARCH
3237
EOT
3338

39+
FROM builder-base AS builder-linux
40+
RUN apk add clang
41+
RUN xx-apk add musl-dev gcc
42+
COPY . ./
43+
RUN --mount=type=cache,target=/root/.cache,id=docker-ai-$TARGETPLATFORM \
44+
--mount=type=cache,target=/go/pkg/mod <<EOT
45+
set -ex
46+
xx-go build -trimpath -ldflags "-s -w -linkmode=external -extldflags '-static' -X 'github.com/docker/cagent/pkg/version.Version=$GIT_TAG' -X 'github.com/docker/cagent/pkg/version.Commit=$GIT_COMMIT'" -o /binaries/cagent-$TARGETOS-$TARGETARCH .
47+
xx-verify --static /binaries/cagent-linux-$TARGETARCH
48+
EOT
49+
50+
FROM builder-base AS builder-windows
51+
RUN apk add zig build-base
52+
COPY . ./
53+
RUN --mount=type=cache,target=/root/.cache,id=docker-ai-$TARGETPLATFORM \
54+
--mount=type=cache,target=/go/pkg/mod <<EOT
55+
set -ex
56+
CC="zig cc -target x86_64-windows-gnu" CXX="zig c++ -target x86_64-windows-gnu" xx-go build -trimpath -ldflags "-s -w -X 'github.com/docker/cagent/pkg/version.Version=$GIT_TAG' -X 'github.com/docker/cagent/pkg/version.Commit=$GIT_COMMIT'" -o /binaries/cagent-$TARGETOS-$TARGETARCH .
57+
mv /binaries/cagent-$TARGETOS-$TARGETARCH /binaries/cagent-$TARGETOS-$TARGETARCH.exe
58+
xx-verify --static /binaries/cagent-windows-$TARGETARCH.exe
59+
EOT
60+
61+
FROM builder-$TARGETOS AS builder
62+
3463
FROM scratch AS local
3564
ARG TARGETOS TARGETARCH
3665
COPY --from=builder /binaries/cagent-$TARGETOS-$TARGETARCH cagent

cagent-schema.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@
728728
"chunked-embeddings"
729729
]
730730
},
731-
"model": {
731+
"embedding_model": {
732732
"type": "string",
733733
"description": "Embedding model reference for chunked-embeddings strategies (looked up in models map, or 'auto' for automatic selection)",
734734
"examples": [
@@ -804,6 +804,10 @@
804804
"respect_word_boundaries": {
805805
"type": "boolean",
806806
"description": "When true, chunks will split on the nearest whitespace boundary instead of at the exact character limit, preventing words from being truncated."
807+
},
808+
"code_aware": {
809+
"type": "boolean",
810+
"description": "Enable code-aware chunking for source files. When true, the chunking strategy will prefer AST-based or language-aware processors when available (tree-sitter based), and fall back to plain text chunking for unsupported languages."
807811
}
808812
},
809813
"additionalProperties": false

docs/USAGE.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,30 @@ models:
972972
- `chunking.size`: Chunk size in characters (default: `1000`)
973973
- `chunking.overlap`: Overlap between chunks (default: `75`)
974974

975+
**Code-Aware Chunking:**
976+
977+
When indexing source code, you can enable code-aware chunking to produce semantically aligned chunks based on the code's AST (Abstract Syntax Tree). This keeps functions and methods intact rather than splitting them arbitrarily:
978+
979+
```yaml
980+
rag:
981+
codebase:
982+
docs: [./src]
983+
strategies:
984+
- type: bm25
985+
database: ./code.db
986+
chunking:
987+
size: 2000
988+
code_aware: true # Enable AST-based chunking
989+
```
990+
991+
- `chunking.code_aware`: When `true`, uses tree-sitter for AST-based chunking (default: `false`), and `size` becomes indicative
992+
993+
**Notes:**
994+
- Currently supports **Go** source files (`.go`). More languages will be added incrementally.
995+
- Falls back to plain text chunking for unsupported file types.
996+
- Produces chunks that align with code structure (functions, methods, type declarations).
997+
- Particularly useful for code search and retrieval tasks.
998+
975999
**Results:**
9761000
- `limit`: Final number of results (default: `15`)
9771001
- `deduplicate`: Remove duplicates (default: `true`)

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ require (
3333
github.com/mattn/go-runewidth v0.0.19
3434
github.com/modelcontextprotocol/go-sdk v1.1.0
3535
github.com/openai/openai-go/v3 v3.8.1
36+
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82
3637
github.com/spf13/cobra v1.10.1
3738
github.com/stretchr/testify v1.11.1
3839
github.com/temoto/robotstxt v1.1.2

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
240240
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
241241
github.com/skeema/knownhosts v1.3.1 h1:X2osQ+RAjK76shCbvhHHHVl3ZlgDm8apHEHFqRjnBY8=
242242
github.com/skeema/knownhosts v1.3.1/go.mod h1:r7KTdC8l4uxWRyK2TpQZ/1o5HaSzh06ePQNxPwTcfiY=
243+
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 h1:6C8qej6f1bStuePVkLSFxoU22XBS165D3klxlzRg8F4=
244+
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82/go.mod h1:xe4pgH49k4SsmkQq5OT8abwhWmnzkhpgnXeekbx2efw=
243245
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM=
244246
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
245247
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=

pkg/config/latest/types.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,13 @@ func unmarshalChunkingConfig(src any, dst *RAGChunkingConfig) {
518518
dst.RespectWordBoundaries = val
519519
}
520520
}
521+
522+
// Handle code_aware - YAML should give us a bool
523+
if ca, ok := m["code_aware"]; ok {
524+
if val, ok := ca.(bool); ok {
525+
dst.CodeAware = val
526+
}
527+
}
521528
}
522529

523530
// coerceToInt converts various numeric types to int
@@ -579,6 +586,11 @@ type RAGChunkingConfig struct {
579586
Size int `json:"size,omitempty"`
580587
Overlap int `json:"overlap,omitempty"`
581588
RespectWordBoundaries bool `json:"respect_word_boundaries,omitempty"`
589+
// CodeAware enables code-aware chunking for source files. When true, the
590+
// chunking strategy uses tree-sitter for AST-based chunking, producing
591+
// semantically aligned chunks (e.g., whole functions). Falls back to
592+
// plain text chunking for unsupported languages.
593+
CodeAware bool `json:"code_aware,omitempty"`
582594
}
583595

584596
// UnmarshalYAML implements custom unmarshaling to apply sensible defaults for chunking

pkg/rag/chunk/chunk.go

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,21 @@ type Chunk struct {
1919
Metadata map[string]string
2020
}
2121

22-
// Processor handles document processing
23-
type Processor struct{}
24-
25-
// New creates a new document processor
26-
func New() *Processor {
27-
return &Processor{}
22+
// DocumentProcessor takes file content and returns chunks.
23+
// Config (size, overlap, etc.) is set at construction time.
24+
type DocumentProcessor interface {
25+
Process(path string, content []byte) ([]Chunk, error)
2826
}
2927

30-
// ProcessFile reads a file and splits it into chunks
31-
func (p *Processor) ProcessFile(path string, chunkSize, overlap int, respectWordBoundaries bool) ([]Chunk, error) {
32-
content, err := os.ReadFile(path)
33-
if err != nil {
34-
return nil, fmt.Errorf("failed to read file: %w", err)
35-
}
36-
37-
return p.ChunkText(string(content), chunkSize, overlap, respectWordBoundaries), nil
28+
// TextDocumentProcessor is the default text-based chunker
29+
type TextDocumentProcessor struct {
30+
size int
31+
overlap int
32+
respectWordBoundaries bool
3833
}
3934

40-
// ChunkText splits text into overlapping chunks
41-
func (p *Processor) ChunkText(text string, size, overlap int, respectWordBoundaries bool) []Chunk {
35+
// NewTextDocumentProcessor creates a text-based document processor
36+
func NewTextDocumentProcessor(size, overlap int, respectWordBoundaries bool) *TextDocumentProcessor {
4237
if size <= 0 {
4338
size = 1000
4439
}
@@ -48,7 +43,20 @@ func (p *Processor) ChunkText(text string, size, overlap int, respectWordBoundar
4843
if overlap >= size {
4944
overlap = size / 2
5045
}
46+
return &TextDocumentProcessor{
47+
size: size,
48+
overlap: overlap,
49+
respectWordBoundaries: respectWordBoundaries,
50+
}
51+
}
5152

53+
// Process implements DocumentProcessor for text-based chunking
54+
func (t *TextDocumentProcessor) Process(_ string, content []byte) ([]Chunk, error) {
55+
return t.chunkText(string(content)), nil
56+
}
57+
58+
// chunkText splits text into overlapping chunks
59+
func (t *TextDocumentProcessor) chunkText(text string) []Chunk {
5260
var chunks []Chunk
5361
runes := []rune(text)
5462
totalLen := len(runes)
@@ -62,7 +70,7 @@ func (p *Processor) ChunkText(text string, size, overlap int, respectWordBoundar
6270

6371
for start < totalLen {
6472
// Calculate end position (start + size, but not beyond document end)
65-
end := start + size
73+
end := start + t.size
6674
if end > totalLen {
6775
end = totalLen
6876
}
@@ -72,14 +80,14 @@ func (p *Processor) ChunkText(text string, size, overlap int, respectWordBoundar
7280
// For the final chunk (end == totalLen) we always take the remainder
7381
// of the document as-is to avoid generating progressively smaller
7482
// tail chunks.
75-
if respectWordBoundaries && end > start && end < totalLen {
83+
if t.respectWordBoundaries && end > start && end < totalLen {
7684
// Limit search to the current chunk window.
7785
target := end
7886

7987
// Backtrack from target to find whitespace; if none is found
8088
// in a reasonable range, keep the original end so that we
8189
// still make progress even for very long "words".
82-
searchEnd := p.findNearestWhitespace(runes[start:target+1], target-start) + start
90+
searchEnd := t.findNearestWhitespace(runes[start:target+1], target-start) + start
8391
if searchEnd > start && searchEnd < end {
8492
end = searchEnd
8593
}
@@ -100,7 +108,7 @@ func (p *Processor) ChunkText(text string, size, overlap int, respectWordBoundar
100108
}
101109

102110
// Next chunk starts at the end of the previous chunk minus overlap
103-
nextStart := end - overlap
111+
nextStart := end - t.overlap
104112

105113
// CRITICAL: Ensure we always make forward progress
106114
// If nextStart would move us backward or keep us in place, advance by at least 1
@@ -111,14 +119,14 @@ func (p *Processor) ChunkText(text string, size, overlap int, respectWordBoundar
111119
// When respecting word boundaries, make sure the next chunk
112120
// does not start in the middle of a word. Move the start
113121
// forward to the next whitespace, then to the next non-whitespace.
114-
if respectWordBoundaries {
122+
if t.respectWordBoundaries {
115123
// Move forward until we hit whitespace or end-of-text
116-
for nextStart < totalLen && !p.isWhitespace(runes[nextStart]) {
124+
for nextStart < totalLen && !t.isWhitespace(runes[nextStart]) {
117125
nextStart++
118126
}
119127
// Skip the whitespace itself so we start at the first character
120128
// of the next word (if any).
121-
for nextStart < totalLen && p.isWhitespace(runes[nextStart]) {
129+
for nextStart < totalLen && t.isWhitespace(runes[nextStart]) {
122130
nextStart++
123131
}
124132
}
@@ -131,7 +139,7 @@ func (p *Processor) ChunkText(text string, size, overlap int, respectWordBoundar
131139

132140
// findNearestWhitespace finds the nearest whitespace boundary to the target position
133141
// It searches backward first (within a reasonable distance), then forward if needed
134-
func (p *Processor) findNearestWhitespace(runes []rune, target int) int {
142+
func (t *TextDocumentProcessor) findNearestWhitespace(runes []rune, target int) int {
135143
// Don't search beyond 20% of the total length in either direction
136144
maxSearchDistance := len(runes) / 5
137145
if maxSearchDistance < 50 {
@@ -144,9 +152,9 @@ func (p *Processor) findNearestWhitespace(runes []rune, target int) int {
144152
// Search backward first (prefer to keep chunks slightly smaller)
145153
for i := 0; i < maxSearchDistance && target-i > 0; i++ {
146154
pos := target - i
147-
if p.isWhitespace(runes[pos]) {
155+
if t.isWhitespace(runes[pos]) {
148156
// Skip consecutive whitespace
149-
for pos > 0 && p.isWhitespace(runes[pos-1]) {
157+
for pos > 0 && t.isWhitespace(runes[pos-1]) {
150158
pos--
151159
}
152160
return pos
@@ -156,7 +164,7 @@ func (p *Processor) findNearestWhitespace(runes []rune, target int) int {
156164
// Search forward if no whitespace found backward
157165
for i := 1; i < maxSearchDistance && target+i < len(runes); i++ {
158166
pos := target + i
159-
if p.isWhitespace(runes[pos]) {
167+
if t.isWhitespace(runes[pos]) {
160168
return pos
161169
}
162170
}
@@ -166,12 +174,23 @@ func (p *Processor) findNearestWhitespace(runes []rune, target int) int {
166174
}
167175

168176
// isWhitespace checks if a rune is whitespace
169-
func (p *Processor) isWhitespace(r rune) bool {
177+
func (t *TextDocumentProcessor) isWhitespace(r rune) bool {
170178
return r == ' ' || r == '\t' || r == '\n' || r == '\r'
171179
}
172180

181+
// --- File utility functions (standalone, not tied to any processor) ---
182+
183+
// ProcessFile reads a file and processes it using the given document processor
184+
func ProcessFile(dp DocumentProcessor, path string) ([]Chunk, error) {
185+
content, err := os.ReadFile(path)
186+
if err != nil {
187+
return nil, fmt.Errorf("failed to read file: %w", err)
188+
}
189+
return dp.Process(path, content)
190+
}
191+
173192
// FileHash calculates SHA256 hash of a file
174-
func (p *Processor) FileHash(path string) (string, error) {
193+
func FileHash(path string) (string, error) {
175194
f, err := os.Open(path)
176195
if err != nil {
177196
return "", fmt.Errorf("failed to open file: %w", err)
@@ -188,12 +207,12 @@ func (p *Processor) FileHash(path string) (string, error) {
188207

189208
// CollectFiles recursively collects all files from given paths
190209
// Skips paths that don't exist instead of returning an error
191-
func (p *Processor) CollectFiles(paths []string) ([]string, error) {
210+
func CollectFiles(paths []string) ([]string, error) {
192211
var files []string
193212
seen := make(map[string]bool)
194213

195214
for _, pattern := range paths {
196-
expanded, err := p.expandPattern(pattern)
215+
expanded, err := expandPattern(pattern)
197216
if err != nil {
198217
return nil, err
199218
}
@@ -245,7 +264,7 @@ func (p *Processor) CollectFiles(paths []string) ([]string, error) {
245264

246265
// Matches reports whether the given path matches any configured document path or glob pattern.
247266
// To be used in file watchers to determine if a new/changed file matches the glob patterns or not.
248-
func (p *Processor) Matches(path string, patterns []string) (bool, error) {
267+
func Matches(path string, patterns []string) (bool, error) {
249268
if len(patterns) == 0 {
250269
return false, nil
251270
}
@@ -293,7 +312,7 @@ func (p *Processor) Matches(path string, patterns []string) (bool, error) {
293312
return false, nil
294313
}
295314

296-
func (p *Processor) expandPattern(pattern string) ([]string, error) {
315+
func expandPattern(pattern string) ([]string, error) {
297316
if !hasGlob(pattern) {
298317
return []string{normalizePath(pattern)}, nil
299318
}

0 commit comments

Comments
 (0)