@@ -3,6 +3,7 @@ use chrono::Utc;
33use clap:: { Args , Subcommand } ;
44use serde:: Deserialize ;
55use std:: collections:: HashMap ;
6+ use std:: time:: Duration ;
67use strum:: IntoEnumIterator ;
78
89use crate :: {
@@ -66,7 +67,7 @@ impl ModelsArgs {
6667 }
6768 }
6869 ModelsAction :: Update => {
69- let supported: Vec < & str > = Provider :: iter ( ) . map ( |p| p. models_dev_id ( ) ) . collect ( ) ;
70+ let supported: Vec < String > = Provider :: iter ( ) . map ( |p| p. to_string ( ) ) . collect ( ) ;
7071
7172 crate :: output:: progress ( "Fetching models from models.dev..." , output_level) ;
7273
@@ -236,19 +237,24 @@ struct ModelsDevProvider {
236237
237238#[ derive( Deserialize ) ]
238239struct ModelsDevModel {
239- #[ serde( default ) ]
240240 id : Option < String > ,
241241}
242242
243- async fn fetch_models_from_models_dev ( supported_providers : & [ & str ] ) -> Result < Vec < String > > {
244- let response = reqwest:: get ( "https://models.dev/api.json" )
243+ async fn fetch_models_from_models_dev ( supported_providers : & [ String ] ) -> Result < Vec < String > > {
244+ let client = reqwest:: Client :: builder ( )
245+ . timeout ( Duration :: from_secs ( 10 ) )
246+ . build ( ) ?;
247+
248+ let response = client
249+ . get ( "https://models.dev/api.json" )
250+ . send ( )
245251 . await ?
246252 . error_for_status ( ) ?;
247253 let providers: HashMap < String , ModelsDevProvider > = response. json ( ) . await ?;
248254
249255 let mut all_models = Vec :: new ( ) ;
250256 for provider_id in supported_providers {
251- if let Some ( provider) = providers. get ( * provider_id) {
257+ if let Some ( provider) = providers. get ( provider_id) {
252258 for ( model_id, model) in & provider. models {
253259 let id = model. id . as_deref ( ) . unwrap_or ( model_id) ;
254260 all_models. push ( format ! ( "{provider_id}:{id}" ) ) ;
0 commit comments