@@ -21,25 +21,16 @@ const (
2121)
2222
2323// Store manages access to the models.dev data.
24- // The database is loaded lazily on first access and cached for the
25- // lifetime of the Store. All methods are safe for concurrent use.
24+ // All methods are safe for concurrent use.
2625type Store struct {
27- db func () (* Database , error )
26+ cacheFile string
27+ mu sync.Mutex
28+ db * Database
2829}
2930
30- // defaultStore is a cached singleton store instance for repeated access.
31- var defaultStore = sync .OnceValues (newStoreInternal )
32-
33- // NewStore returns the cached default store instance.
34- // The underlying database is fetched lazily on first access
35- // from a local cache file or the models.dev API.
31+ // NewStore creates a new models.dev store.
32+ // The database is loaded on first access via GetDatabase.
3633func NewStore () (* Store , error ) {
37- return defaultStore ()
38- }
39-
40- // newStoreInternal creates a new models.dev store that loads data
41- // from the filesystem cache or the network on first access.
42- func newStoreInternal () (* Store , error ) {
4334 homeDir , err := os .UserHomeDir ()
4435 if err != nil {
4536 return nil , fmt .Errorf ("failed to get user home directory: %w" , err )
@@ -50,12 +41,8 @@ func newStoreInternal() (*Store, error) {
5041 return nil , fmt .Errorf ("failed to create cache directory: %w" , err )
5142 }
5243
53- cacheFile := filepath .Join (cacheDir , CacheFileName )
54-
5544 return & Store {
56- db : sync .OnceValues (func () (* Database , error ) {
57- return loadDatabase (cacheFile )
58- }),
45+ cacheFile : filepath .Join (cacheDir , CacheFileName ),
5946 }, nil
6047}
6148
@@ -64,19 +51,30 @@ func newStoreInternal() (*Store, error) {
6451// from the network or touches the filesystem, making it suitable for
6552// tests and any scenario where the provider data is already known.
6653func NewDatabaseStore (db * Database ) * Store {
67- return & Store {
68- db : func () (* Database , error ) { return db , nil },
69- }
54+ return & Store {db : db }
7055}
7156
7257// GetDatabase returns the models.dev database, fetching from cache or API as needed.
73- func (s * Store ) GetDatabase () (* Database , error ) {
74- return s .db ()
58+ func (s * Store ) GetDatabase (ctx context.Context ) (* Database , error ) {
59+ s .mu .Lock ()
60+ defer s .mu .Unlock ()
61+
62+ if s .db != nil {
63+ return s .db , nil
64+ }
65+
66+ db , err := loadDatabase (ctx , s .cacheFile )
67+ if err != nil {
68+ return nil , err
69+ }
70+
71+ s .db = db
72+ return db , nil
7573}
7674
7775// GetProvider returns a specific provider by ID.
78- func (s * Store ) GetProvider (providerID string ) (* Provider , error ) {
79- db , err := s .db ( )
76+ func (s * Store ) GetProvider (ctx context. Context , providerID string ) (* Provider , error ) {
77+ db , err := s .GetDatabase ( ctx )
8078 if err != nil {
8179 return nil , err
8280 }
@@ -90,15 +88,15 @@ func (s *Store) GetProvider(providerID string) (*Provider, error) {
9088}
9189
9290// GetModel returns a specific model by provider ID and model ID.
93- func (s * Store ) GetModel (id string ) (* Model , error ) {
91+ func (s * Store ) GetModel (ctx context. Context , id string ) (* Model , error ) {
9492 parts := strings .SplitN (id , "/" , 2 )
9593 if len (parts ) != 2 {
9694 return nil , fmt .Errorf ("invalid model ID: %q" , id )
9795 }
9896 providerID := parts [0 ]
9997 modelID := parts [1 ]
10098
101- provider , err := s .GetProvider (providerID )
99+ provider , err := s .GetProvider (ctx , providerID )
102100 if err != nil {
103101 return nil , err
104102 }
@@ -130,15 +128,15 @@ func (s *Store) GetModel(id string) (*Model, error) {
130128
131129// loadDatabase loads the database from the local cache file or
132130// falls back to fetching from the models.dev API.
133- func loadDatabase (cacheFile string ) (* Database , error ) {
131+ func loadDatabase (ctx context. Context , cacheFile string ) (* Database , error ) {
134132 // Try to load from cache first
135133 cached , err := loadFromCache (cacheFile )
136134 if err == nil && time .Since (cached .LastRefresh ) < refreshInterval {
137135 return & cached .Database , nil
138136 }
139137
140138 // Cache is invalid or doesn't exist, fetch from API
141- database , fetchErr := fetchFromAPI ()
139+ database , fetchErr := fetchFromAPI (ctx )
142140 if fetchErr != nil {
143141 // If API fetch fails, but we have cached data, use it
144142 if cached != nil {
@@ -156,8 +154,8 @@ func loadDatabase(cacheFile string) (*Database, error) {
156154 return database , nil
157155}
158156
159- func fetchFromAPI () (* Database , error ) {
160- req , err := http .NewRequestWithContext (context . Background () , http .MethodGet , ModelsDevAPIURL , http .NoBody )
157+ func fetchFromAPI (ctx context. Context ) (* Database , error ) {
158+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , ModelsDevAPIURL , http .NoBody )
161159 if err != nil {
162160 return nil , fmt .Errorf ("failed to create request: %w" , err )
163161 }
@@ -225,7 +223,7 @@ var datePattern = regexp.MustCompile(`-\d{4}-?\d{2}-?\d{2}$`)
225223// For example, ("anthropic", "claude-sonnet-4-5") might resolve to "claude-sonnet-4-5-20250929".
226224// If the model is not an alias (already pinned or unknown), the original model name is returned.
227225// This method uses the models.dev database to find the corresponding pinned version.
228- func (s * Store ) ResolveModelAlias (providerID , modelName string ) string {
226+ func (s * Store ) ResolveModelAlias (ctx context. Context , providerID , modelName string ) string {
229227 if providerID == "" || modelName == "" {
230228 return modelName
231229 }
@@ -236,7 +234,7 @@ func (s *Store) ResolveModelAlias(providerID, modelName string) string {
236234 }
237235
238236 // Get the provider from the database
239- provider , err := s .GetProvider (providerID )
237+ provider , err := s .GetProvider (ctx , providerID )
240238 if err != nil {
241239 return modelName
242240 }
@@ -285,7 +283,7 @@ func isBedrockRegionPrefix(prefix string) bool {
285283// - If modelID is empty or not in "provider/model" format, returns true (fail-open)
286284// - If models.dev lookup fails for any reason, returns true (fail-open)
287285// - If lookup succeeds, returns the model's Reasoning field value
288- func ModelSupportsReasoning (modelID string ) bool {
286+ func ModelSupportsReasoning (ctx context. Context , modelID string ) bool {
289287 // Fail-open for empty model ID
290288 if modelID == "" {
291289 return true
@@ -303,7 +301,7 @@ func ModelSupportsReasoning(modelID string) bool {
303301 return true
304302 }
305303
306- model , err := store .GetModel (modelID )
304+ model , err := store .GetModel (ctx , modelID )
307305 if err != nil {
308306 slog .Debug ("Failed to lookup model in models.dev, assuming reasoning supported to allow user choice" , "model_id" , modelID , "error" , err )
309307 return true
0 commit comments