Skip to content

Commit 050398e

Browse files
authored
Merge pull request #2016 from dgageot/board/gorgonia-golearn-82dce730
Simplify rulebased router: remove redundant types and score aggregation
2 parents 201d8a7 + ad29365 commit 050398e

2 files changed

Lines changed: 83 additions & 76 deletions

File tree

pkg/model/provider/rulebased/client.go

Lines changed: 50 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
// Package rulebased provides a rule-based model router that selects
2-
// the appropriate model based on NLP analysis of the input using Bleve.
3-
//
4-
// Routes are defined with example texts, and Bleve's full-text search
5-
// determines the best matching route based on text similarity.
2+
// the appropriate model based on text similarity using Bleve full-text search.
63
//
74
// A model becomes a rule-based router when it has routing rules configured.
85
// The model's provider/model fields define the fallback model, and each
@@ -43,17 +40,11 @@ type ProviderFactory func(ctx context.Context, modelSpec string, models map[stri
4340
// Client implements the Provider interface for rule-based model routing.
4441
type Client struct {
4542
base.Config
46-
routes []route
43+
routes []Provider
4744
fallback Provider
4845
index bleve.Index
4946
}
5047

51-
// route represents a single routing rule.
52-
type route struct {
53-
model string
54-
provider Provider
55-
}
56-
5748
// NewClient creates a new rule-based routing client.
5849
// The cfg parameter should have Routing rules configured. The provider/model
5950
// fields of cfg define the fallback model that is used when no routing rule matches.
@@ -69,11 +60,21 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, models map[string]l
6960
return nil, fmt.Errorf("creating bleve index: %w", err)
7061
}
7162

72-
// Create fallback provider from the model's provider/model fields
63+
// On any subsequent error, close the index before returning.
64+
var cleanupErr error
65+
defer func() {
66+
if cleanupErr != nil {
67+
_ = index.Close()
68+
}
69+
}()
70+
71+
routeOpts := filterOutMaxTokens(opts)
72+
73+
// Create fallback provider from the model's provider/model fields.
7374
fallbackSpec := cfg.Provider + "/" + cfg.Model
74-
fallback, err := providerFactory(ctx, fallbackSpec, models, env, filterOutMaxTokens(opts)...)
75+
fallback, err := providerFactory(ctx, fallbackSpec, models, env, routeOpts...)
7576
if err != nil {
76-
_ = index.Close()
77+
cleanupErr = err
7778
return nil, fmt.Errorf("creating fallback provider %q: %w", fallbackSpec, err)
7879
}
7980

@@ -87,27 +88,28 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, models map[string]l
8788
fallback: fallback,
8889
}
8990

90-
// Process routing rules
91+
// Process routing rules. Each example is indexed with a doc ID
92+
// that encodes the route index (e.g. "r0_e1") so we can map
93+
// search hits back to the corresponding provider.
9194
for i, rule := range cfg.Routing {
9295
if rule.Model == "" {
93-
_ = index.Close()
94-
return nil, fmt.Errorf("routing rule %d: 'model' field is required", i)
96+
cleanupErr = fmt.Errorf("routing rule %d: 'model' field is required", i)
97+
return nil, cleanupErr
9598
}
9699

97-
provider, err := providerFactory(ctx, rule.Model, models, env, filterOutMaxTokens(opts)...)
100+
provider, err := providerFactory(ctx, rule.Model, models, env, routeOpts...)
98101
if err != nil {
99-
_ = index.Close()
102+
cleanupErr = err
100103
return nil, fmt.Errorf("creating provider for routing rule %q: %w", rule.Model, err)
101104
}
102105

103106
routeIndex := len(client.routes)
104-
client.routes = append(client.routes, route{model: rule.Model, provider: provider})
107+
client.routes = append(client.routes, provider)
105108

106-
// Index examples for this route
107109
for j, example := range rule.Examples {
108110
docID := fmt.Sprintf("r%d_e%d", routeIndex, j)
109-
if err := index.Index(docID, map[string]any{"text": example, "route": routeIndex}); err != nil {
110-
_ = index.Close()
111+
if err := index.Index(docID, map[string]any{"text": example}); err != nil {
112+
cleanupErr = err
111113
return nil, fmt.Errorf("indexing example: %w", err)
112114
}
113115
}
@@ -124,27 +126,23 @@ func createIndex() (bleve.Index, error) {
124126
textField := mapping.NewTextFieldMapping()
125127
textField.Analyzer = "en"
126128
docMapping.AddFieldMappingsAt("text", textField)
127-
docMapping.AddFieldMappingsAt("route", mapping.NewNumericFieldMapping())
128129

129130
indexMapping.DefaultMapping = docMapping
130131

131132
return bleve.NewMemOnly(indexMapping)
132133
}
133134

134135
// filterOutMaxTokens removes WithMaxTokens options from the slice.
135-
// This is necessary because child providers may have different token limits
136-
// than the parent router, and should determine their own limits.
136+
// Child providers may have different token limits than the parent router.
137137
func filterOutMaxTokens(opts []options.Opt) []options.Opt {
138138
var filtered []options.Opt
139139
for _, opt := range opts {
140140
if opt == nil {
141141
continue
142142
}
143-
// Test if this option sets maxTokens by applying it to an empty ModelOptions
144-
var test options.ModelOptions
145-
opt(&test)
146-
// If maxTokens was set, skip this option
147-
if test.MaxTokens() != 0 {
143+
var probe options.ModelOptions
144+
opt(&probe)
145+
if probe.MaxTokens() != 0 {
148146
continue
149147
}
150148
filtered = append(filtered, opt)
@@ -173,6 +171,7 @@ func (c *Client) CreateChatCompletionStream(
173171
}
174172

175173
// selectProvider finds the best matching provider for the messages.
174+
// Bleve returns hits sorted by score, so the top hit determines the route.
176175
func (c *Client) selectProvider(messages []chat.Message) Provider {
177176
userMessage := getLastUserMessage(messages)
178177
if userMessage == "" {
@@ -183,8 +182,7 @@ func (c *Client) selectProvider(messages []chat.Message) Provider {
183182
query.SetField("text")
184183

185184
searchRequest := bleve.NewSearchRequest(query)
186-
searchRequest.Size = 10
187-
searchRequest.Fields = []string{"route"}
185+
searchRequest.Size = 1
188186

189187
results, err := c.index.Search(searchRequest)
190188
if err != nil {
@@ -196,41 +194,36 @@ func (c *Client) selectProvider(messages []chat.Message) Provider {
196194
return c.defaultProvider()
197195
}
198196

199-
// Find best matching route by aggregating scores
200-
scores := make(map[int]float64)
201-
for _, hit := range results.Hits {
202-
var routeIdx int
203-
if _, err := fmt.Sscanf(hit.ID, "r%d_e", &routeIdx); err == nil {
204-
if hit.Score > scores[routeIdx] {
205-
scores[routeIdx] = hit.Score
206-
}
207-
}
197+
// Parse the route index from the top hit's doc ID (e.g. "r2_e0" → 2).
198+
hit := results.Hits[0]
199+
routeIdx, ok := parseRouteIndex(hit.ID)
200+
if !ok || routeIdx >= len(c.routes) {
201+
return c.defaultProvider()
208202
}
209203

210-
bestRoute, bestScore := -1, 0.0
211-
for idx, score := range scores {
212-
if score > bestScore {
213-
bestRoute, bestScore = idx, score
214-
}
215-
}
204+
selected := c.routes[routeIdx]
205+
slog.Debug("Route matched",
206+
"model", selected.ID(),
207+
"score", hit.Score,
208+
)
209+
return selected
210+
}
216211

217-
if bestRoute >= 0 && bestRoute < len(c.routes) {
218-
slog.Debug("Route matched",
219-
"model", c.routes[bestRoute].model,
220-
"score", bestScore,
221-
)
222-
return c.routes[bestRoute].provider
212+
// parseRouteIndex extracts the route index from a doc ID like "r2_e0".
213+
func parseRouteIndex(docID string) (int, bool) {
214+
var idx int
215+
if _, err := fmt.Sscanf(docID, "r%d_e", &idx); err != nil || idx < 0 {
216+
return 0, false
223217
}
224-
225-
return c.defaultProvider()
218+
return idx, true
226219
}
227220

228221
func (c *Client) defaultProvider() Provider {
229222
if c.fallback != nil {
230223
return c.fallback
231224
}
232225
if len(c.routes) > 0 {
233-
return c.routes[0].provider
226+
return c.routes[0]
234227
}
235228
return nil
236229
}

pkg/model/provider/rulebased/client_test.go

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,9 @@ func (m *mockProvider) BaseConfig() base.Config {
4040
// mockProviderFactory creates a mock provider factory for testing.
4141
// It resolves model references from the models map or parses inline specs.
4242
func mockProviderFactory(_ context.Context, modelSpec string, models map[string]latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) {
43-
// Check if it's a model reference
4443
if cfg, exists := models[modelSpec]; exists {
4544
return &mockProvider{id: cfg.Provider + "/" + cfg.Model}, nil
4645
}
47-
// Otherwise treat as inline spec
4846
return &mockProvider{id: modelSpec}, nil
4947
}
5048

@@ -62,7 +60,7 @@ func TestNewClient(t *testing.T) {
6260
name: "valid config with routing rules",
6361
modelCfg: latest.ModelConfig{
6462
Provider: "openai",
65-
Model: "gpt-4o", // fallback
63+
Model: "gpt-4o",
6664
Routing: []latest.RoutingRule{
6765
{
6866
Model: "anthropic/claude-3-haiku",
@@ -80,7 +78,7 @@ func TestNewClient(t *testing.T) {
8078
name: "routing with model references",
8179
modelCfg: latest.ModelConfig{
8280
Provider: "anthropic",
83-
Model: "claude-haiku-4-5", // fallback
81+
Model: "claude-haiku-4-5",
8482
Routing: []latest.RoutingRule{
8583
{
8684
Model: "fast",
@@ -183,7 +181,7 @@ func TestClient_SelectProvider(t *testing.T) {
183181

184182
cfg := &latest.ModelConfig{
185183
Provider: "openai",
186-
Model: "gpt-4o", // fallback
184+
Model: "gpt-4o",
187185
Routing: []latest.RoutingRule{
188186
{
189187
Model: "anthropic/claude-3-haiku",
@@ -262,11 +260,9 @@ func TestCreateIndex(t *testing.T) {
262260
require.NoError(t, err)
263261
defer index.Close()
264262

265-
// Index a document
266-
err = index.Index("test", map[string]any{"text": "hello world", "route": 0})
263+
err = index.Index("test", map[string]any{"text": "hello world"})
267264
require.NoError(t, err)
268265

269-
// Search for it
270266
query := bleve.NewMatchQuery("hello")
271267
query.SetField("text")
272268
results, err := index.Search(bleve.NewSearchRequest(query))
@@ -298,10 +294,9 @@ func TestClient_ID(t *testing.T) {
298294
func TestClient_DefaultProvider(t *testing.T) {
299295
t.Parallel()
300296

301-
// Test that fallback is always used for empty messages
302297
cfg := &latest.ModelConfig{
303298
Provider: "openai",
304-
Model: "gpt-4o", // fallback
299+
Model: "gpt-4o",
305300
Routing: []latest.RoutingRule{
306301
{
307302
Model: "anthropic/claude-3-haiku",
@@ -314,16 +309,13 @@ func TestClient_DefaultProvider(t *testing.T) {
314309
require.NoError(t, err)
315310
defer client.Close()
316311

317-
// Empty message should use fallback
318312
provider := client.selectProvider(nil)
319313
assert.Equal(t, "openai/gpt-4o", provider.ID())
320314
}
321315

322316
func TestClient_CreateChatCompletionStream_NilProvider(t *testing.T) {
323317
t.Parallel()
324318

325-
// Create a client with no routes and no fallback by directly manipulating the struct
326-
// This simulates an edge case where defaultProvider returns nil
327319
index, err := createIndex()
328320
require.NoError(t, err)
329321

@@ -335,7 +327,6 @@ func TestClient_CreateChatCompletionStream_NilProvider(t *testing.T) {
335327
}
336328
defer client.Close()
337329

338-
// Attempt to create stream should return error, not panic
339330
messages := []chat.Message{{Role: chat.MessageRoleUser, Content: "hello"}}
340331
_, err = client.CreateChatCompletionStream(t.Context(), messages, nil)
341332
require.Error(t, err)
@@ -348,8 +339,6 @@ func TestClient_ModelsMapStoredInBaseConfig(t *testing.T) {
348339
// This test verifies that the models map and env are stored in the base config.
349340
// This is required for CloneWithOptions to work correctly with routers
350341
// that use model references (e.g., "fast" instead of "anthropic/claude-haiku-4-5").
351-
// Without this, cloning a router would fail because model references can't be resolved
352-
// and the environment provider would be nil.
353342

354343
models := map[string]latest.ModelConfig{
355344
"fast": {Provider: "anthropic", Model: "claude-haiku-4-5"},
@@ -358,7 +347,7 @@ func TestClient_ModelsMapStoredInBaseConfig(t *testing.T) {
358347

359348
cfg := &latest.ModelConfig{
360349
Provider: "anthropic",
361-
Model: "claude-haiku-4-5", // fallback
350+
Model: "claude-haiku-4-5",
362351
Routing: []latest.RoutingRule{
363352
{
364353
Model: "fast",
@@ -371,17 +360,42 @@ func TestClient_ModelsMapStoredInBaseConfig(t *testing.T) {
371360
},
372361
}
373362

374-
// Create a mock env provider
375363
mockEnv := environment.NewNoEnvProvider()
376364

377365
client, err := NewClient(t.Context(), cfg, models, mockEnv, mockProviderFactory)
378366
require.NoError(t, err)
379367
defer client.Close()
380368

381-
// Verify the models map and env are stored in the base config
382369
baseConfig := client.BaseConfig()
383370
assert.NotNil(t, baseConfig.Models, "Models map should be stored in base config for cloning")
384371
assert.Equal(t, models, baseConfig.Models, "Models map should match what was passed to NewClient")
385372
assert.NotNil(t, baseConfig.Env, "Env should be stored in base config for cloning")
386373
assert.Equal(t, mockEnv, baseConfig.Env, "Env should match what was passed to NewClient")
387374
}
375+
376+
func TestParseRouteIndex(t *testing.T) {
377+
t.Parallel()
378+
379+
tests := []struct {
380+
docID string
381+
wantIdx int
382+
wantOK bool
383+
}{
384+
{"r0_e0", 0, true},
385+
{"r2_e5", 2, true},
386+
{"r10_e3", 10, true},
387+
{"invalid", 0, false},
388+
{"", 0, false},
389+
}
390+
391+
for _, tt := range tests {
392+
t.Run(tt.docID, func(t *testing.T) {
393+
t.Parallel()
394+
idx, ok := parseRouteIndex(tt.docID)
395+
assert.Equal(t, tt.wantOK, ok)
396+
if ok {
397+
assert.Equal(t, tt.wantIdx, idx)
398+
}
399+
})
400+
}
401+
}

0 commit comments

Comments
 (0)