Skip to content

Commit ea7d585

Browse files
committed
rag support
Signed-off-by: Christopher Petito <chrisjpetito@gmail.com>
1 parent a7e636f commit ea7d585

35 files changed

Lines changed: 5503 additions & 41 deletions

cmd/root/new.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func (f *newFlags) runNewCommand(cmd *cobra.Command, args []string) error {
6868

6969
sess := session.New(opts...)
7070

71-
a := app.New("", rt, sess, prompt)
71+
a := app.New(ctx, "", rt, sess, prompt)
7272
m := tui.New(a)
7373

7474
progOpts := []tea.ProgramOption{

cmd/root/run.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ func handleRunMode(ctx context.Context, agentFilename string, rt runtime.Runtime
240240
return err
241241
}
242242

243-
a := app.New(agentFilename, rt, sess, firstMessage)
243+
a := app.New(ctx, agentFilename, rt, sess, firstMessage)
244244
m := tui.New(a)
245245

246246
progOpts := []tea.ProgramOption{

pkg/app/app.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,26 @@ type App struct {
2222
cancel context.CancelFunc
2323
}
2424

25-
func New(agentFilename string, rt runtime.Runtime, sess *session.Session, firstMessage *string) *App {
26-
return &App{
25+
func New(ctx context.Context, agentFilename string, rt runtime.Runtime, sess *session.Session, firstMessage *string) *App {
26+
app := &App{
2727
agentFilename: agentFilename,
2828
runtime: rt,
2929
session: sess,
3030
firstMessage: firstMessage,
3131
events: make(chan tea.Msg, 128),
3232
throttleDuration: 50 * time.Millisecond, // Throttle rapid events
3333
}
34+
35+
// If the runtime supports background RAG initialization, start it
36+
// and forward events to the TUI. Remote runtimes typically handle RAG server-side
37+
// and won't implement this optional interface.
38+
if ragRuntime, ok := rt.(runtime.RAGInitializer); ok {
39+
go ragRuntime.StartBackgroundRAGInit(ctx, func(event runtime.Event) {
40+
app.events <- event
41+
})
42+
}
43+
44+
return app
3445
}
3546

3647
func (a *App) FirstMessage() *string {

pkg/config/auto.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,22 @@ func PreferredMaxTokens(provider string) int {
5656
}
5757
return 64000
5858
}
59+
60+
// AutoEmbeddingModelConfigs returns the ordered list of embedding-capable models
61+
// to try when a RAG strategy uses `model: auto` for embeddings.
62+
//
63+
// The priority is:
64+
// 1. OpenAI -> text-embedding-3-small model
65+
// 2. DMR -> Google's embeddinggemma model (via Docker Model Runner)
66+
func AutoEmbeddingModelConfigs() []latest.ModelConfig {
67+
return []latest.ModelConfig{
68+
{
69+
Provider: "openai",
70+
Model: "text-embedding-3-small",
71+
},
72+
{
73+
Provider: "dmr",
74+
Model: "ai/embeddinggemma",
75+
},
76+
}
77+
}

pkg/config/overrides.go

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,29 +81,63 @@ func ensureModelsExist(cfg *v2.Config) error {
8181
cfg.Models = map[string]v2.ModelConfig{}
8282
}
8383

84+
// Ensure models referenced by agents exist
8485
for agentName := range cfg.Agents {
8586
agentConfig := cfg.Agents[agentName]
8687

8788
modelNames := strings.SplitSeq(agentConfig.Model, ",")
8889
for modelName := range modelNames {
89-
if modelName == "auto" {
90-
continue
90+
if err := ensureSingleModelExists(cfg, modelName, fmt.Sprintf("agent '%s'", agentName)); err != nil {
91+
return err
9192
}
92-
if _, exists := cfg.Models[modelName]; exists {
93+
}
94+
}
95+
96+
// Ensure models referenced by RAG strategies exist
97+
for ragName, ragCfg := range cfg.RAG {
98+
for _, stratCfg := range ragCfg.Strategies {
99+
rawModel, ok := stratCfg.Params["model"]
100+
if !ok {
93101
continue
94102
}
95103

96-
providerName, model, ok := strings.Cut(modelName, "/")
104+
modelName, ok := rawModel.(string)
97105
if !ok {
98-
return fmt.Errorf("agent '%s' references non-existent model '%s'", agentName, modelName)
106+
return fmt.Errorf("RAG strategy '%s' in RAG '%s' has non-string model value", stratCfg.Type, ragName)
99107
}
100108

101-
cfg.Models[modelName] = v2.ModelConfig{
102-
Provider: providerName,
103-
Model: model,
109+
if err := ensureSingleModelExists(cfg, modelName, fmt.Sprintf("RAG strategy '%s' in RAG '%s'", stratCfg.Type, ragName)); err != nil {
110+
return err
104111
}
105112
}
106113
}
107114

108115
return nil
109116
}
117+
118+
// ensureSingleModelExists normalizes shorthand model IDs like "openai/gpt-5-mini"
119+
// into full entries in cfg.Models so they can be reused by agents, RAG, and other
120+
// subsystems without duplicating parsing logic.
121+
func ensureSingleModelExists(cfg *v2.Config, modelName, context string) error {
122+
modelName = strings.TrimSpace(modelName)
123+
if modelName == "" || modelName == "auto" {
124+
// "auto" is handled dynamically at runtime and does not need a config entry.
125+
return nil
126+
}
127+
128+
if _, exists := cfg.Models[modelName]; exists {
129+
return nil
130+
}
131+
132+
providerName, model, ok := strings.Cut(modelName, "/")
133+
if !ok || providerName == "" || model == "" {
134+
return fmt.Errorf("%s references non-existent model '%s'", context, modelName)
135+
}
136+
137+
cfg.Models[modelName] = v2.ModelConfig{
138+
Provider: providerName,
139+
Model: model,
140+
}
141+
142+
return nil
143+
}

0 commit comments

Comments
 (0)