Skip to content

Commit 753dffe

Browse files
committed
Thread context.Context through modelsdev store API
Replace the lazy-loaded closure-based Store with a simple struct that holds a cache file path and loads the database on first GetDatabase call. Pass caller-provided context to the HTTP request in fetchFromAPI instead of using context.Background(). NewStore() remains context-free (no I/O). All Store methods (GetDatabase, GetProvider, GetModel, ResolveModelAlias) and ModelSupportsReasoning accept context for the network call. Updated all callsites, using context.Background() only in TUI callbacks where no caller context is available. Assisted-By: cagent
1 parent 8f179d0 commit 753dffe

17 files changed

Lines changed: 71 additions & 67 deletions

File tree

pkg/config/examples_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func TestParseExamples(t *testing.T) {
7070
continue
7171
}
7272

73-
model, err := modelsStore.GetModel(model.Provider + "/" + model.Model)
73+
model, err := modelsStore.GetModel(t.Context(), model.Provider+"/"+model.Model)
7474
require.NoError(t, err)
7575
require.NotNil(t, model)
7676
}

pkg/config/model_alias.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package config
22

33
import (
4+
"context"
45
"log/slog"
56
"strings"
67

@@ -16,7 +17,7 @@ import (
1617
// either set directly on the model or inherited from a custom provider definition.
1718
// This is necessary because external providers (like Azure Foundry) may use the alias
1819
// names directly as deployment names rather than the pinned version names.
19-
func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) {
20+
func ResolveModelAliases(ctx context.Context, cfg *latest.Config, store *modelsdev.Store) {
2021
// Resolve model aliases in the models section
2122
for name, modelCfg := range cfg.Models {
2223
// Skip alias resolution for models with custom base_url (direct or via provider)
@@ -27,15 +28,15 @@ func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) {
2728
continue
2829
}
2930

30-
if resolved := store.ResolveModelAlias(modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model {
31+
if resolved := store.ResolveModelAlias(ctx, modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model {
3132
modelCfg.Model = resolved
3233
cfg.Models[name] = modelCfg
3334
}
3435

3536
// Resolve model aliases in routing rules
3637
for i, rule := range modelCfg.Routing {
3738
if provider, model, ok := strings.Cut(rule.Model, "/"); ok {
38-
if resolved := store.ResolveModelAlias(provider, model); resolved != model {
39+
if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model {
3940
modelCfg.Routing[i].Model = provider + "/" + resolved
4041
}
4142
}
@@ -52,7 +53,7 @@ func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) {
5253
var resolvedModels []string
5354
for modelRef := range strings.SplitSeq(agent.Model, ",") {
5455
if provider, model, ok := strings.Cut(modelRef, "/"); ok {
55-
if resolved := store.ResolveModelAlias(provider, model); resolved != model {
56+
if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model {
5657
resolvedModels = append(resolvedModels, provider+"/"+resolved)
5758
continue
5859
}

pkg/config/model_alias_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ func TestResolveModelAliases(t *testing.T) {
237237

238238
for _, tt := range tests {
239239
t.Run(tt.name, func(t *testing.T) {
240-
ResolveModelAliases(tt.cfg, store)
240+
ResolveModelAliases(t.Context(), tt.cfg, store)
241241
assert.Equal(t, tt.expected, tt.cfg)
242242
})
243243
}

pkg/model/provider/bedrock/client.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
112112

113113
// Detect prompt caching capability at init time for efficiency.
114114
// Uses models.dev cache pricing as proxy for capability detection.
115-
cachingSupported := detectCachingSupport(cfg.Model)
115+
cachingSupported := detectCachingSupport(ctx, cfg.Model)
116116

117117
slog.Debug("Bedrock client created successfully",
118118
"model", cfg.Model,
@@ -133,15 +133,15 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
133133
// detectCachingSupport checks if a model supports prompt caching using models.dev data.
134134
// Models with non-zero CacheRead/CacheWrite costs support prompt caching.
135135
// Returns false on lookup failure (safe default for unsupported models).
136-
func detectCachingSupport(model string) bool {
136+
func detectCachingSupport(ctx context.Context, model string) bool {
137137
store, err := modelsdev.NewStore()
138138
if err != nil {
139139
slog.Debug("Bedrock models store unavailable, prompt caching disabled", "error", err)
140140
return false
141141
}
142142

143143
modelID := "amazon-bedrock/" + model
144-
m, err := store.GetModel(modelID)
144+
m, err := store.GetModel(ctx, modelID)
145145
if err != nil {
146146
slog.Debug("Bedrock prompt caching disabled: model not found in models.dev",
147147
"model_id", modelID, "error", err)

pkg/model/provider/bedrock/client_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,23 +1249,23 @@ func TestDetectCachingSupport_SupportedModel(t *testing.T) {
12491249
t.Parallel()
12501250

12511251
// Uses real models.dev lookup to verify Claude models support caching
1252-
supported := detectCachingSupport("anthropic.claude-3-5-sonnet-20241022-v2:0")
1252+
supported := detectCachingSupport(t.Context(), "anthropic.claude-3-5-sonnet-20241022-v2:0")
12531253
assert.True(t, supported)
12541254
}
12551255

12561256
func TestDetectCachingSupport_UnsupportedModel(t *testing.T) {
12571257
t.Parallel()
12581258

12591259
// Llama doesn't have cache pricing in models.dev
1260-
supported := detectCachingSupport("meta.llama3-8b-instruct-v1:0")
1260+
supported := detectCachingSupport(t.Context(), "meta.llama3-8b-instruct-v1:0")
12611261
assert.False(t, supported)
12621262
}
12631263

12641264
func TestDetectCachingSupport_UnknownModel(t *testing.T) {
12651265
t.Parallel()
12661266

12671267
// Unknown model should gracefully return false, not panic
1268-
supported := detectCachingSupport("nonexistent.model.that.does.not.exist:v1")
1268+
supported := detectCachingSupport(t.Context(), "nonexistent.model.that.does.not.exist:v1")
12691269
assert.False(t, supported)
12701270
}
12711271

pkg/modelsdev/store.go

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,16 @@ const (
2121
)
2222

2323
// Store manages access to the models.dev data.
24-
// The database is loaded lazily on first access and cached for the
25-
// lifetime of the Store. All methods are safe for concurrent use.
24+
// All methods are safe for concurrent use.
2625
type Store struct {
27-
db func() (*Database, error)
26+
cacheFile string
27+
mu sync.Mutex
28+
db *Database
2829
}
2930

30-
// defaultStore is a cached singleton store instance for repeated access.
31-
var defaultStore = sync.OnceValues(newStoreInternal)
32-
33-
// NewStore returns the cached default store instance.
34-
// The underlying database is fetched lazily on first access
35-
// from a local cache file or the models.dev API.
31+
// NewStore creates a new models.dev store.
32+
// The database is loaded on first access via GetDatabase.
3633
func NewStore() (*Store, error) {
37-
return defaultStore()
38-
}
39-
40-
// newStoreInternal creates a new models.dev store that loads data
41-
// from the filesystem cache or the network on first access.
42-
func newStoreInternal() (*Store, error) {
4334
homeDir, err := os.UserHomeDir()
4435
if err != nil {
4536
return nil, fmt.Errorf("failed to get user home directory: %w", err)
@@ -50,12 +41,8 @@ func newStoreInternal() (*Store, error) {
5041
return nil, fmt.Errorf("failed to create cache directory: %w", err)
5142
}
5243

53-
cacheFile := filepath.Join(cacheDir, CacheFileName)
54-
5544
return &Store{
56-
db: sync.OnceValues(func() (*Database, error) {
57-
return loadDatabase(cacheFile)
58-
}),
45+
cacheFile: filepath.Join(cacheDir, CacheFileName),
5946
}, nil
6047
}
6148

@@ -64,19 +51,30 @@ func newStoreInternal() (*Store, error) {
6451
// from the network or touches the filesystem, making it suitable for
6552
// tests and any scenario where the provider data is already known.
6653
func NewDatabaseStore(db *Database) *Store {
67-
return &Store{
68-
db: func() (*Database, error) { return db, nil },
69-
}
54+
return &Store{db: db}
7055
}
7156

7257
// GetDatabase returns the models.dev database, fetching from cache or API as needed.
73-
func (s *Store) GetDatabase() (*Database, error) {
74-
return s.db()
58+
func (s *Store) GetDatabase(ctx context.Context) (*Database, error) {
59+
s.mu.Lock()
60+
defer s.mu.Unlock()
61+
62+
if s.db != nil {
63+
return s.db, nil
64+
}
65+
66+
db, err := loadDatabase(ctx, s.cacheFile)
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
s.db = db
72+
return db, nil
7573
}
7674

7775
// GetProvider returns a specific provider by ID.
78-
func (s *Store) GetProvider(providerID string) (*Provider, error) {
79-
db, err := s.db()
76+
func (s *Store) GetProvider(ctx context.Context, providerID string) (*Provider, error) {
77+
db, err := s.GetDatabase(ctx)
8078
if err != nil {
8179
return nil, err
8280
}
@@ -90,15 +88,15 @@ func (s *Store) GetProvider(providerID string) (*Provider, error) {
9088
}
9189

9290
// GetModel returns a specific model by provider ID and model ID.
93-
func (s *Store) GetModel(id string) (*Model, error) {
91+
func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) {
9492
parts := strings.SplitN(id, "/", 2)
9593
if len(parts) != 2 {
9694
return nil, fmt.Errorf("invalid model ID: %q", id)
9795
}
9896
providerID := parts[0]
9997
modelID := parts[1]
10098

101-
provider, err := s.GetProvider(providerID)
99+
provider, err := s.GetProvider(ctx, providerID)
102100
if err != nil {
103101
return nil, err
104102
}
@@ -130,15 +128,15 @@ func (s *Store) GetModel(id string) (*Model, error) {
130128

131129
// loadDatabase loads the database from the local cache file or
132130
// falls back to fetching from the models.dev API.
133-
func loadDatabase(cacheFile string) (*Database, error) {
131+
func loadDatabase(ctx context.Context, cacheFile string) (*Database, error) {
134132
// Try to load from cache first
135133
cached, err := loadFromCache(cacheFile)
136134
if err == nil && time.Since(cached.LastRefresh) < refreshInterval {
137135
return &cached.Database, nil
138136
}
139137

140138
// Cache is invalid or doesn't exist, fetch from API
141-
database, fetchErr := fetchFromAPI()
139+
database, fetchErr := fetchFromAPI(ctx)
142140
if fetchErr != nil {
143141
// If API fetch fails, but we have cached data, use it
144142
if cached != nil {
@@ -156,8 +154,8 @@ func loadDatabase(cacheFile string) (*Database, error) {
156154
return database, nil
157155
}
158156

159-
func fetchFromAPI() (*Database, error) {
160-
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, ModelsDevAPIURL, http.NoBody)
157+
func fetchFromAPI(ctx context.Context) (*Database, error) {
158+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ModelsDevAPIURL, http.NoBody)
161159
if err != nil {
162160
return nil, fmt.Errorf("failed to create request: %w", err)
163161
}
@@ -225,7 +223,7 @@ var datePattern = regexp.MustCompile(`-\d{4}-?\d{2}-?\d{2}$`)
225223
// For example, ("anthropic", "claude-sonnet-4-5") might resolve to "claude-sonnet-4-5-20250929".
226224
// If the model is not an alias (already pinned or unknown), the original model name is returned.
227225
// This method uses the models.dev database to find the corresponding pinned version.
228-
func (s *Store) ResolveModelAlias(providerID, modelName string) string {
226+
func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName string) string {
229227
if providerID == "" || modelName == "" {
230228
return modelName
231229
}
@@ -236,7 +234,7 @@ func (s *Store) ResolveModelAlias(providerID, modelName string) string {
236234
}
237235

238236
// Get the provider from the database
239-
provider, err := s.GetProvider(providerID)
237+
provider, err := s.GetProvider(ctx, providerID)
240238
if err != nil {
241239
return modelName
242240
}
@@ -285,7 +283,7 @@ func isBedrockRegionPrefix(prefix string) bool {
285283
// - If modelID is empty or not in "provider/model" format, returns true (fail-open)
286284
// - If models.dev lookup fails for any reason, returns true (fail-open)
287285
// - If lookup succeeds, returns the model's Reasoning field value
288-
func ModelSupportsReasoning(modelID string) bool {
286+
func ModelSupportsReasoning(ctx context.Context, modelID string) bool {
289287
// Fail-open for empty model ID
290288
if modelID == "" {
291289
return true
@@ -303,7 +301,7 @@ func ModelSupportsReasoning(modelID string) bool {
303301
return true
304302
}
305303

306-
model, err := store.GetModel(modelID)
304+
model, err := store.GetModel(ctx, modelID)
307305
if err != nil {
308306
slog.Debug("Failed to lookup model in models.dev, assuming reasoning supported to allow user choice", "model_id", modelID, "error", err)
309307
return true

pkg/modelsdev/store_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func TestResolveModelAlias(t *testing.T) {
5757

5858
for _, tt := range tests {
5959
t.Run(tt.name, func(t *testing.T) {
60-
result := store.ResolveModelAlias(tt.provider, tt.model)
60+
result := store.ResolveModelAlias(t.Context(), tt.provider, tt.model)
6161
assert.Equal(t, tt.expected, result)
6262
})
6363
}

pkg/rag/strategy/semantic_embeddings.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ func calculateSemanticUsageCost(modelsStore modelStore, modelID string, usage *c
506506
return 0
507507
}
508508

509-
model, err := modelsStore.GetModel(modelID)
509+
model, err := modelsStore.GetModel(context.Background(), modelID)
510510
if err != nil {
511511
slog.Debug("Failed to get semantic model pricing from models.dev, cost will be 0",
512512
"model_id", modelID,

pkg/rag/strategy/vector_store.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ type VectorStore struct {
8787
}
8888

8989
type modelStore interface {
90-
GetModel(modelID string) (*modelsdev.Model, error)
90+
GetModel(ctx context.Context, modelID string) (*modelsdev.Model, error)
9191
}
9292

9393
// EmbeddingInputBuilder builds the string that will be sent to the embedding model
@@ -174,7 +174,7 @@ func (s *VectorStore) calculateCost(tokens int64) float64 {
174174
return 0
175175
}
176176

177-
model, err := s.modelsStore.GetModel(s.modelID)
177+
model, err := s.modelsStore.GetModel(context.Background(), s.modelID)
178178
if err != nil {
179179
slog.Debug("Failed to get model pricing from models.dev, cost will be 0",
180180
"model_id", s.modelID,

pkg/runtime/model_switcher.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ func (r *LocalRuntime) AvailableModels(ctx context.Context) []ModelChoice {
286286
// buildCatalogChoices builds ModelChoice entries from the models.dev catalog,
287287
// filtered by supported providers and available credentials.
288288
func (r *LocalRuntime) buildCatalogChoices(ctx context.Context) []ModelChoice {
289-
db, err := r.modelsStore.GetDatabase()
289+
db, err := r.modelsStore.GetDatabase(ctx)
290290
if err != nil {
291291
slog.Debug("Failed to get models.dev database for catalog", "error", err)
292292
return nil
@@ -446,7 +446,7 @@ func (r *LocalRuntime) createProviderFromConfig(ctx context.Context, cfg *latest
446446
if cfg.MaxTokens != nil {
447447
opts = append(opts, options.WithMaxTokens(*cfg.MaxTokens))
448448
} else if r.modelsStore != nil {
449-
m, err := r.modelsStore.GetModel(cfg.Provider + "/" + cfg.Model)
449+
m, err := r.modelsStore.GetModel(ctx, cfg.Provider+"/"+cfg.Model)
450450
if err == nil && m != nil {
451451
opts = append(opts, options.WithMaxTokens(m.Limit.Output))
452452
}

0 commit comments

Comments
 (0)