Skip to content

Commit 75d32dc

Browse files
authored
Merge pull request #1747 from rumpl/context
Thread context.Context through modelsdev store API
2 parents 8f179d0 + 753dffe commit 75d32dc

17 files changed

Lines changed: 71 additions & 67 deletions

File tree

pkg/config/examples_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func TestParseExamples(t *testing.T) {
7070
continue
7171
}
7272

73-
model, err := modelsStore.GetModel(model.Provider + "/" + model.Model)
73+
model, err := modelsStore.GetModel(t.Context(), model.Provider+"/"+model.Model)
7474
require.NoError(t, err)
7575
require.NotNil(t, model)
7676
}

pkg/config/model_alias.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package config
22

33
import (
4+
"context"
45
"log/slog"
56
"strings"
67

@@ -16,7 +17,7 @@ import (
1617
// either set directly on the model or inherited from a custom provider definition.
1718
// This is necessary because external providers (like Azure Foundry) may use the alias
1819
// names directly as deployment names rather than the pinned version names.
19-
func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) {
20+
func ResolveModelAliases(ctx context.Context, cfg *latest.Config, store *modelsdev.Store) {
2021
// Resolve model aliases in the models section
2122
for name, modelCfg := range cfg.Models {
2223
// Skip alias resolution for models with custom base_url (direct or via provider)
@@ -27,15 +28,15 @@ func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) {
2728
continue
2829
}
2930

30-
if resolved := store.ResolveModelAlias(modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model {
31+
if resolved := store.ResolveModelAlias(ctx, modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model {
3132
modelCfg.Model = resolved
3233
cfg.Models[name] = modelCfg
3334
}
3435

3536
// Resolve model aliases in routing rules
3637
for i, rule := range modelCfg.Routing {
3738
if provider, model, ok := strings.Cut(rule.Model, "/"); ok {
38-
if resolved := store.ResolveModelAlias(provider, model); resolved != model {
39+
if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model {
3940
modelCfg.Routing[i].Model = provider + "/" + resolved
4041
}
4142
}
@@ -52,7 +53,7 @@ func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) {
5253
var resolvedModels []string
5354
for modelRef := range strings.SplitSeq(agent.Model, ",") {
5455
if provider, model, ok := strings.Cut(modelRef, "/"); ok {
55-
if resolved := store.ResolveModelAlias(provider, model); resolved != model {
56+
if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model {
5657
resolvedModels = append(resolvedModels, provider+"/"+resolved)
5758
continue
5859
}

pkg/config/model_alias_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ func TestResolveModelAliases(t *testing.T) {
237237

238238
for _, tt := range tests {
239239
t.Run(tt.name, func(t *testing.T) {
240-
ResolveModelAliases(tt.cfg, store)
240+
ResolveModelAliases(t.Context(), tt.cfg, store)
241241
assert.Equal(t, tt.expected, tt.cfg)
242242
})
243243
}

pkg/model/provider/bedrock/client.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
112112

113113
// Detect prompt caching capability at init time for efficiency.
114114
// Uses models.dev cache pricing as proxy for capability detection.
115-
cachingSupported := detectCachingSupport(cfg.Model)
115+
cachingSupported := detectCachingSupport(ctx, cfg.Model)
116116

117117
slog.Debug("Bedrock client created successfully",
118118
"model", cfg.Model,
@@ -133,15 +133,15 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
133133
// detectCachingSupport checks if a model supports prompt caching using models.dev data.
134134
// Models with non-zero CacheRead/CacheWrite costs support prompt caching.
135135
// Returns false on lookup failure (safe default for unsupported models).
136-
func detectCachingSupport(model string) bool {
136+
func detectCachingSupport(ctx context.Context, model string) bool {
137137
store, err := modelsdev.NewStore()
138138
if err != nil {
139139
slog.Debug("Bedrock models store unavailable, prompt caching disabled", "error", err)
140140
return false
141141
}
142142

143143
modelID := "amazon-bedrock/" + model
144-
m, err := store.GetModel(modelID)
144+
m, err := store.GetModel(ctx, modelID)
145145
if err != nil {
146146
slog.Debug("Bedrock prompt caching disabled: model not found in models.dev",
147147
"model_id", modelID, "error", err)

pkg/model/provider/bedrock/client_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,23 +1249,23 @@ func TestDetectCachingSupport_SupportedModel(t *testing.T) {
12491249
t.Parallel()
12501250

12511251
// Uses real models.dev lookup to verify Claude models support caching
1252-
supported := detectCachingSupport("anthropic.claude-3-5-sonnet-20241022-v2:0")
1252+
supported := detectCachingSupport(t.Context(), "anthropic.claude-3-5-sonnet-20241022-v2:0")
12531253
assert.True(t, supported)
12541254
}
12551255

12561256
func TestDetectCachingSupport_UnsupportedModel(t *testing.T) {
12571257
t.Parallel()
12581258

12591259
// Llama doesn't have cache pricing in models.dev
1260-
supported := detectCachingSupport("meta.llama3-8b-instruct-v1:0")
1260+
supported := detectCachingSupport(t.Context(), "meta.llama3-8b-instruct-v1:0")
12611261
assert.False(t, supported)
12621262
}
12631263

12641264
func TestDetectCachingSupport_UnknownModel(t *testing.T) {
12651265
t.Parallel()
12661266

12671267
// Unknown model should gracefully return false, not panic
1268-
supported := detectCachingSupport("nonexistent.model.that.does.not.exist:v1")
1268+
supported := detectCachingSupport(t.Context(), "nonexistent.model.that.does.not.exist:v1")
12691269
assert.False(t, supported)
12701270
}
12711271

pkg/modelsdev/store.go

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,16 @@ const (
2121
)
2222

2323
// Store manages access to the models.dev data.
24-
// The database is loaded lazily on first access and cached for the
25-
// lifetime of the Store. All methods are safe for concurrent use.
24+
// All methods are safe for concurrent use.
2625
type Store struct {
27-
db func() (*Database, error)
26+
cacheFile string
27+
mu sync.Mutex
28+
db *Database
2829
}
2930

30-
// defaultStore is a cached singleton store instance for repeated access.
31-
var defaultStore = sync.OnceValues(newStoreInternal)
32-
33-
// NewStore returns the cached default store instance.
34-
// The underlying database is fetched lazily on first access
35-
// from a local cache file or the models.dev API.
31+
// NewStore creates a new models.dev store.
32+
// The database is loaded on first access via GetDatabase.
3633
func NewStore() (*Store, error) {
37-
return defaultStore()
38-
}
39-
40-
// newStoreInternal creates a new models.dev store that loads data
41-
// from the filesystem cache or the network on first access.
42-
func newStoreInternal() (*Store, error) {
4334
homeDir, err := os.UserHomeDir()
4435
if err != nil {
4536
return nil, fmt.Errorf("failed to get user home directory: %w", err)
@@ -50,12 +41,8 @@ func newStoreInternal() (*Store, error) {
5041
return nil, fmt.Errorf("failed to create cache directory: %w", err)
5142
}
5243

53-
cacheFile := filepath.Join(cacheDir, CacheFileName)
54-
5544
return &Store{
56-
db: sync.OnceValues(func() (*Database, error) {
57-
return loadDatabase(cacheFile)
58-
}),
45+
cacheFile: filepath.Join(cacheDir, CacheFileName),
5946
}, nil
6047
}
6148

@@ -64,19 +51,30 @@ func newStoreInternal() (*Store, error) {
6451
// from the network or touches the filesystem, making it suitable for
6552
// tests and any scenario where the provider data is already known.
6653
func NewDatabaseStore(db *Database) *Store {
67-
return &Store{
68-
db: func() (*Database, error) { return db, nil },
69-
}
54+
return &Store{db: db}
7055
}
7156

7257
// GetDatabase returns the models.dev database, fetching from cache or API as needed.
73-
func (s *Store) GetDatabase() (*Database, error) {
74-
return s.db()
58+
func (s *Store) GetDatabase(ctx context.Context) (*Database, error) {
59+
s.mu.Lock()
60+
defer s.mu.Unlock()
61+
62+
if s.db != nil {
63+
return s.db, nil
64+
}
65+
66+
db, err := loadDatabase(ctx, s.cacheFile)
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
s.db = db
72+
return db, nil
7573
}
7674

7775
// GetProvider returns a specific provider by ID.
78-
func (s *Store) GetProvider(providerID string) (*Provider, error) {
79-
db, err := s.db()
76+
func (s *Store) GetProvider(ctx context.Context, providerID string) (*Provider, error) {
77+
db, err := s.GetDatabase(ctx)
8078
if err != nil {
8179
return nil, err
8280
}
@@ -90,15 +88,15 @@ func (s *Store) GetProvider(providerID string) (*Provider, error) {
9088
}
9189

9290
// GetModel returns a specific model by provider ID and model ID.
93-
func (s *Store) GetModel(id string) (*Model, error) {
91+
func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) {
9492
parts := strings.SplitN(id, "/", 2)
9593
if len(parts) != 2 {
9694
return nil, fmt.Errorf("invalid model ID: %q", id)
9795
}
9896
providerID := parts[0]
9997
modelID := parts[1]
10098

101-
provider, err := s.GetProvider(providerID)
99+
provider, err := s.GetProvider(ctx, providerID)
102100
if err != nil {
103101
return nil, err
104102
}
@@ -130,15 +128,15 @@ func (s *Store) GetModel(id string) (*Model, error) {
130128

131129
// loadDatabase loads the database from the local cache file or
132130
// falls back to fetching from the models.dev API.
133-
func loadDatabase(cacheFile string) (*Database, error) {
131+
func loadDatabase(ctx context.Context, cacheFile string) (*Database, error) {
134132
// Try to load from cache first
135133
cached, err := loadFromCache(cacheFile)
136134
if err == nil && time.Since(cached.LastRefresh) < refreshInterval {
137135
return &cached.Database, nil
138136
}
139137

140138
// Cache is invalid or doesn't exist, fetch from API
141-
database, fetchErr := fetchFromAPI()
139+
database, fetchErr := fetchFromAPI(ctx)
142140
if fetchErr != nil {
143141
// If API fetch fails, but we have cached data, use it
144142
if cached != nil {
@@ -156,8 +154,8 @@ func loadDatabase(cacheFile string) (*Database, error) {
156154
return database, nil
157155
}
158156

159-
func fetchFromAPI() (*Database, error) {
160-
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, ModelsDevAPIURL, http.NoBody)
157+
func fetchFromAPI(ctx context.Context) (*Database, error) {
158+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ModelsDevAPIURL, http.NoBody)
161159
if err != nil {
162160
return nil, fmt.Errorf("failed to create request: %w", err)
163161
}
@@ -225,7 +223,7 @@ var datePattern = regexp.MustCompile(`-\d{4}-?\d{2}-?\d{2}$`)
225223
// For example, ("anthropic", "claude-sonnet-4-5") might resolve to "claude-sonnet-4-5-20250929".
226224
// If the model is not an alias (already pinned or unknown), the original model name is returned.
227225
// This method uses the models.dev database to find the corresponding pinned version.
228-
func (s *Store) ResolveModelAlias(providerID, modelName string) string {
226+
func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName string) string {
229227
if providerID == "" || modelName == "" {
230228
return modelName
231229
}
@@ -236,7 +234,7 @@ func (s *Store) ResolveModelAlias(providerID, modelName string) string {
236234
}
237235

238236
// Get the provider from the database
239-
provider, err := s.GetProvider(providerID)
237+
provider, err := s.GetProvider(ctx, providerID)
240238
if err != nil {
241239
return modelName
242240
}
@@ -285,7 +283,7 @@ func isBedrockRegionPrefix(prefix string) bool {
285283
// - If modelID is empty or not in "provider/model" format, returns true (fail-open)
286284
// - If models.dev lookup fails for any reason, returns true (fail-open)
287285
// - If lookup succeeds, returns the model's Reasoning field value
288-
func ModelSupportsReasoning(modelID string) bool {
286+
func ModelSupportsReasoning(ctx context.Context, modelID string) bool {
289287
// Fail-open for empty model ID
290288
if modelID == "" {
291289
return true
@@ -303,7 +301,7 @@ func ModelSupportsReasoning(modelID string) bool {
303301
return true
304302
}
305303

306-
model, err := store.GetModel(modelID)
304+
model, err := store.GetModel(ctx, modelID)
307305
if err != nil {
308306
slog.Debug("Failed to lookup model in models.dev, assuming reasoning supported to allow user choice", "model_id", modelID, "error", err)
309307
return true

pkg/modelsdev/store_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func TestResolveModelAlias(t *testing.T) {
5757

5858
for _, tt := range tests {
5959
t.Run(tt.name, func(t *testing.T) {
60-
result := store.ResolveModelAlias(tt.provider, tt.model)
60+
result := store.ResolveModelAlias(t.Context(), tt.provider, tt.model)
6161
assert.Equal(t, tt.expected, result)
6262
})
6363
}

pkg/rag/strategy/semantic_embeddings.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ func calculateSemanticUsageCost(modelsStore modelStore, modelID string, usage *c
506506
return 0
507507
}
508508

509-
model, err := modelsStore.GetModel(modelID)
509+
model, err := modelsStore.GetModel(context.Background(), modelID)
510510
if err != nil {
511511
slog.Debug("Failed to get semantic model pricing from models.dev, cost will be 0",
512512
"model_id", modelID,

pkg/rag/strategy/vector_store.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ type VectorStore struct {
8787
}
8888

8989
type modelStore interface {
90-
GetModel(modelID string) (*modelsdev.Model, error)
90+
GetModel(ctx context.Context, modelID string) (*modelsdev.Model, error)
9191
}
9292

9393
// EmbeddingInputBuilder builds the string that will be sent to the embedding model
@@ -174,7 +174,7 @@ func (s *VectorStore) calculateCost(tokens int64) float64 {
174174
return 0
175175
}
176176

177-
model, err := s.modelsStore.GetModel(s.modelID)
177+
model, err := s.modelsStore.GetModel(context.Background(), s.modelID)
178178
if err != nil {
179179
slog.Debug("Failed to get model pricing from models.dev, cost will be 0",
180180
"model_id", s.modelID,

pkg/runtime/model_switcher.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ func (r *LocalRuntime) AvailableModels(ctx context.Context) []ModelChoice {
286286
// buildCatalogChoices builds ModelChoice entries from the models.dev catalog,
287287
// filtered by supported providers and available credentials.
288288
func (r *LocalRuntime) buildCatalogChoices(ctx context.Context) []ModelChoice {
289-
db, err := r.modelsStore.GetDatabase()
289+
db, err := r.modelsStore.GetDatabase(ctx)
290290
if err != nil {
291291
slog.Debug("Failed to get models.dev database for catalog", "error", err)
292292
return nil
@@ -446,7 +446,7 @@ func (r *LocalRuntime) createProviderFromConfig(ctx context.Context, cfg *latest
446446
if cfg.MaxTokens != nil {
447447
opts = append(opts, options.WithMaxTokens(*cfg.MaxTokens))
448448
} else if r.modelsStore != nil {
449-
m, err := r.modelsStore.GetModel(cfg.Provider + "/" + cfg.Model)
449+
m, err := r.modelsStore.GetModel(ctx, cfg.Provider+"/"+cfg.Model)
450450
if err == nil && m != nil {
451451
opts = append(opts, options.WithMaxTokens(m.Limit.Output))
452452
}

0 commit comments

Comments
 (0)