@@ -89,6 +89,18 @@ def make_model(self):
8989 log_probs = self .log_probs ,
9090 )
9191
92+ @dataclass
93+ class LiteLLMModelArgs (BaseModelArgs ):
94+
95+ def make_model (self ):
96+ return LiteLLMChatModel (
97+ model_name = self .model_name ,
98+ temperature = self .temperature ,
99+ max_tokens = self .max_new_tokens ,
100+ log_probs = self .log_probs ,
101+ reasoning_effort = self .reasoning_effort ,
102+ )
103+
92104
93105@dataclass
94106class OpenAIModelArgs (BaseModelArgs ):
@@ -393,7 +405,6 @@ def __init__(
393405 log_probs = log_probs ,
394406 )
395407
396-
397408class AzureChatModel (ChatModel ):
398409 def __init__ (
399410 self ,
@@ -627,3 +638,113 @@ def make_model(self):
627638 temperature = self .temperature ,
628639 max_tokens = self .max_new_tokens ,
629640 )
641+
642+ class LiteLLMChatModel (AbstractChatModel ):
643+ def __init__ (
644+ self ,
645+ model_name ,
646+ api_key = None ,
647+ temperature = 0.5 ,
648+ max_tokens = 100 ,
649+ max_retry = 4 ,
650+ min_retry_wait_time = 60 ,
651+ api_key_env_var = None ,
652+ client_class = OpenAI ,
653+ client_args = None ,
654+ pricing_func = None ,
655+ log_probs = False ,
656+ reasoning_effort = None ,
657+ ):
658+ assert max_retry > 0 , "max_retry should be greater than 0"
659+
660+ self .model_name = model_name
661+ self .temperature = temperature
662+ self .max_tokens = max_tokens
663+ self .max_retry = max_retry
664+ self .min_retry_wait_time = min_retry_wait_time
665+ self .log_probs = log_probs
666+ self .reasoning_effort = reasoning_effort
667+
668+ # Get pricing information
669+ if pricing_func :
670+ pricings = pricing_func ()
671+ try :
672+ self .input_cost = float (pricings [model_name ]["prompt" ])
673+ self .output_cost = float (pricings [model_name ]["completion" ])
674+ except KeyError :
675+ logging .warning (
676+ f"Model { model_name } not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community."
677+ )
678+ self .input_cost = 0.0
679+ self .output_cost = 0.0
680+ else :
681+ self .input_cost = 0.0
682+ self .output_cost = 0.0
683+
684+
685+ def __call__ (self , messages : list [dict ], n_samples : int = 1 , temperature : float = None ) -> dict :
686+ from litellm import completion as litellm_completion
687+ # Initialize retry tracking attributes
688+ self .retries = 0
689+ self .success = False
690+ self .error_types = []
691+
692+ completion = None
693+ e = None
694+ for itr in range (self .max_retry ):
695+ self .retries += 1
696+ temperature = temperature if temperature is not None else self .temperature
697+ try :
698+ completion = litellm_completion (
699+ model = self .model_name ,
700+ messages = messages ,
701+ # n=n_samples,
702+ # temperature=temperature,
703+ # max_completion_tokens=self.max_tokens,
704+ reasoning_effort = self .reasoning_effort ,
705+ )
706+
707+ if completion .usage is None :
708+ raise OpenRouterError (
709+ "The completion object does not contain usage information. This is likely a bug in the OpenRouter API."
710+ )
711+
712+ self .success = True
713+ break
714+ except openai .OpenAIError as e :
715+ error_type = handle_error (e , itr , self .min_retry_wait_time , self .max_retry )
716+ self .error_types .append (error_type )
717+
718+ if not completion :
719+ raise RetryError (
720+ f"Failed to get a response from the API after { self .max_retry } retries\n "
721+ f"Last error: { error_type } "
722+ )
723+
724+ input_tokens = completion .usage .prompt_tokens
725+ output_tokens = completion .usage .completion_tokens
726+ cost = input_tokens * self .input_cost + output_tokens * self .output_cost
727+
728+ if hasattr (tracking .TRACKER , "instance" ) and isinstance (
729+ tracking .TRACKER .instance , tracking .LLMTracker
730+ ):
731+ tracking .TRACKER .instance (input_tokens , output_tokens , cost )
732+
733+ if n_samples == 1 :
734+ res_text = completion .choices [0 ].message .content
735+ if res_text is not None :
736+ res_text = res_text .removesuffix ("<|end|>" ).strip ()
737+ else :
738+ res_text = ""
739+ res = AIMessage (res_text )
740+ if self .log_probs :
741+ res ["log_probs" ] = completion .choices [0 ].log_probs
742+ return res
743+ else :
744+ return [AIMessage (c .message .content .removesuffix ("<|end|>" ).strip ()) for c in completion .choices ]
745+
746+ def get_stats (self ):
747+ return {
748+ "n_retry_llm" : self .retries ,
749+ # "busted_retry_llm": int(not self.success), # not logged if it occurs anyways
750+ }
0 commit comments