11// Package rulebased provides a rule-based model router that selects
2- // the appropriate model based on NLP analysis of the input using Bleve.
3- //
4- // Routes are defined with example texts, and Bleve's full-text search
5- // determines the best matching route based on text similarity.
2+ // the appropriate model based on text similarity using Bleve full-text search.
63//
74// A model becomes a rule-based router when it has routing rules configured.
85// The model's provider/model fields define the fallback model, and each
@@ -43,17 +40,11 @@ type ProviderFactory func(ctx context.Context, modelSpec string, models map[stri
4340// Client implements the Provider interface for rule-based model routing.
4441type Client struct {
4542 base.Config
46- routes []route
43+ routes []Provider
4744 fallback Provider
4845 index bleve.Index
4946}
5047
51- // route represents a single routing rule.
52- type route struct {
53- model string
54- provider Provider
55- }
56-
5748// NewClient creates a new rule-based routing client.
5849// The cfg parameter should have Routing rules configured. The provider/model
5950// fields of cfg define the fallback model that is used when no routing rule matches.
@@ -69,11 +60,21 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, models map[string]l
6960 return nil , fmt .Errorf ("creating bleve index: %w" , err )
7061 }
7162
72- // Create fallback provider from the model's provider/model fields
63+ // On any subsequent error, close the index before returning.
64+ var cleanupErr error
65+ defer func () {
66+ if cleanupErr != nil {
67+ _ = index .Close ()
68+ }
69+ }()
70+
71+ routeOpts := filterOutMaxTokens (opts )
72+
73+ // Create fallback provider from the model's provider/model fields.
7374 fallbackSpec := cfg .Provider + "/" + cfg .Model
74- fallback , err := providerFactory (ctx , fallbackSpec , models , env , filterOutMaxTokens ( opts ) ... )
75+ fallback , err := providerFactory (ctx , fallbackSpec , models , env , routeOpts ... )
7576 if err != nil {
76- _ = index . Close ()
77+ cleanupErr = err
7778 return nil , fmt .Errorf ("creating fallback provider %q: %w" , fallbackSpec , err )
7879 }
7980
@@ -87,27 +88,28 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, models map[string]l
8788 fallback : fallback ,
8889 }
8990
90- // Process routing rules
91+ // Process routing rules. Each example is indexed with a doc ID
92+ // that encodes the route index (e.g. "r0_e1") so we can map
93+ // search hits back to the corresponding provider.
9194 for i , rule := range cfg .Routing {
9295 if rule .Model == "" {
93- _ = index . Close ( )
94- return nil , fmt . Errorf ( "routing rule %d: 'model' field is required" , i )
96+ cleanupErr = fmt . Errorf ( "routing rule %d: 'model' field is required" , i )
97+ return nil , cleanupErr
9598 }
9699
97- provider , err := providerFactory (ctx , rule .Model , models , env , filterOutMaxTokens ( opts ) ... )
100+ provider , err := providerFactory (ctx , rule .Model , models , env , routeOpts ... )
98101 if err != nil {
99- _ = index . Close ()
102+ cleanupErr = err
100103 return nil , fmt .Errorf ("creating provider for routing rule %q: %w" , rule .Model , err )
101104 }
102105
103106 routeIndex := len (client .routes )
104- client .routes = append (client .routes , route { model : rule . Model , provider : provider } )
107+ client .routes = append (client .routes , provider )
105108
106- // Index examples for this route
107109 for j , example := range rule .Examples {
108110 docID := fmt .Sprintf ("r%d_e%d" , routeIndex , j )
109- if err := index .Index (docID , map [string ]any {"text" : example , "route" : routeIndex }); err != nil {
110- _ = index . Close ()
111+ if err := index .Index (docID , map [string ]any {"text" : example }); err != nil {
112+ cleanupErr = err
111113 return nil , fmt .Errorf ("indexing example: %w" , err )
112114 }
113115 }
@@ -124,27 +126,23 @@ func createIndex() (bleve.Index, error) {
124126 textField := mapping .NewTextFieldMapping ()
125127 textField .Analyzer = "en"
126128 docMapping .AddFieldMappingsAt ("text" , textField )
127- docMapping .AddFieldMappingsAt ("route" , mapping .NewNumericFieldMapping ())
128129
129130 indexMapping .DefaultMapping = docMapping
130131
131132 return bleve .NewMemOnly (indexMapping )
132133}
133134
134135// filterOutMaxTokens removes WithMaxTokens options from the slice.
135- // This is necessary because child providers may have different token limits
136- // than the parent router, and should determine their own limits.
136+ // Child providers may have different token limits than the parent router.
137137func filterOutMaxTokens (opts []options.Opt ) []options.Opt {
138138 var filtered []options.Opt
139139 for _ , opt := range opts {
140140 if opt == nil {
141141 continue
142142 }
143- // Test if this option sets maxTokens by applying it to an empty ModelOptions
144- var test options.ModelOptions
145- opt (& test )
146- // If maxTokens was set, skip this option
147- if test .MaxTokens () != 0 {
143+ var probe options.ModelOptions
144+ opt (& probe )
145+ if probe .MaxTokens () != 0 {
148146 continue
149147 }
150148 filtered = append (filtered , opt )
@@ -173,6 +171,7 @@ func (c *Client) CreateChatCompletionStream(
173171}
174172
175173// selectProvider finds the best matching provider for the messages.
174+ // Bleve returns hits sorted by score, so the top hit determines the route.
176175func (c * Client ) selectProvider (messages []chat.Message ) Provider {
177176 userMessage := getLastUserMessage (messages )
178177 if userMessage == "" {
@@ -183,8 +182,7 @@ func (c *Client) selectProvider(messages []chat.Message) Provider {
183182 query .SetField ("text" )
184183
185184 searchRequest := bleve .NewSearchRequest (query )
186- searchRequest .Size = 10
187- searchRequest .Fields = []string {"route" }
185+ searchRequest .Size = 1
188186
189187 results , err := c .index .Search (searchRequest )
190188 if err != nil {
@@ -196,41 +194,36 @@ func (c *Client) selectProvider(messages []chat.Message) Provider {
196194 return c .defaultProvider ()
197195 }
198196
199- // Find best matching route by aggregating scores
200- scores := make (map [int ]float64 )
201- for _ , hit := range results .Hits {
202- var routeIdx int
203- if _ , err := fmt .Sscanf (hit .ID , "r%d_e" , & routeIdx ); err == nil {
204- if hit .Score > scores [routeIdx ] {
205- scores [routeIdx ] = hit .Score
206- }
207- }
197+ // Parse the route index from the top hit's doc ID (e.g. "r2_e0" → 2).
198+ hit := results .Hits [0 ]
199+ routeIdx , ok := parseRouteIndex (hit .ID )
200+ if ! ok || routeIdx >= len (c .routes ) {
201+ return c .defaultProvider ()
208202 }
209203
210- bestRoute , bestScore := - 1 , 0.0
211- for idx , score := range scores {
212- if score > bestScore {
213- bestRoute , bestScore = idx , score
214- }
215- }
204+ selected := c .routes [routeIdx ]
205+ slog .Debug ("Route matched" ,
206+ "model" , selected .ID (),
207+ "score" , hit .Score ,
208+ )
209+ return selected
210+ }
216211
217- if bestRoute >= 0 && bestRoute < len (c .routes ) {
218- slog .Debug ("Route matched" ,
219- "model" , c .routes [bestRoute ].model ,
220- "score" , bestScore ,
221- )
222- return c .routes [bestRoute ].provider
212+ // parseRouteIndex extracts the route index from a doc ID like "r2_e0".
213+ func parseRouteIndex (docID string ) (int , bool ) {
214+ var idx int
215+ if _ , err := fmt .Sscanf (docID , "r%d_e" , & idx ); err != nil || idx < 0 {
216+ return 0 , false
223217 }
224-
225- return c .defaultProvider ()
218+ return idx , true
226219}
227220
228221func (c * Client ) defaultProvider () Provider {
229222 if c .fallback != nil {
230223 return c .fallback
231224 }
232225 if len (c .routes ) > 0 {
233- return c .routes [0 ]. provider
226+ return c .routes [0 ]
234227 }
235228 return nil
236229}
0 commit comments