Skip to content

Commit d3dfa85

Browse files
add litellm chat model for genericagent
1 parent 519abed commit d3dfa85

2 files changed

Lines changed: 185 additions & 1 deletion

File tree

src/agentlab/llm/chat_api.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,18 @@ def make_model(self):
8989
log_probs=self.log_probs,
9090
)
9191

92+
@dataclass
93+
class LiteLLMModelArgs(BaseModelArgs):
94+
95+
def make_model(self):
96+
return LiteLLMChatModel(
97+
model_name=self.model_name,
98+
temperature=self.temperature,
99+
max_tokens=self.max_new_tokens,
100+
log_probs=self.log_probs,
101+
reasoning_effort=self.reasoning_effort,
102+
)
103+
92104

93105
@dataclass
94106
class OpenAIModelArgs(BaseModelArgs):
@@ -393,7 +405,6 @@ def __init__(
393405
log_probs=log_probs,
394406
)
395407

396-
397408
class AzureChatModel(ChatModel):
398409
def __init__(
399410
self,
@@ -627,3 +638,113 @@ def make_model(self):
627638
temperature=self.temperature,
628639
max_tokens=self.max_new_tokens,
629640
)
641+
642+
class LiteLLMChatModel(AbstractChatModel):
643+
def __init__(
644+
self,
645+
model_name,
646+
api_key=None,
647+
temperature=0.5,
648+
max_tokens=100,
649+
max_retry=4,
650+
min_retry_wait_time=60,
651+
api_key_env_var=None,
652+
client_class=OpenAI,
653+
client_args=None,
654+
pricing_func=None,
655+
log_probs=False,
656+
reasoning_effort=None,
657+
):
658+
assert max_retry > 0, "max_retry should be greater than 0"
659+
660+
self.model_name = model_name
661+
self.temperature = temperature
662+
self.max_tokens = max_tokens
663+
self.max_retry = max_retry
664+
self.min_retry_wait_time = min_retry_wait_time
665+
self.log_probs = log_probs
666+
self.reasoning_effort = reasoning_effort
667+
668+
# Get pricing information
669+
if pricing_func:
670+
pricings = pricing_func()
671+
try:
672+
self.input_cost = float(pricings[model_name]["prompt"])
673+
self.output_cost = float(pricings[model_name]["completion"])
674+
except KeyError:
675+
logging.warning(
676+
f"Model {model_name} not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community."
677+
)
678+
self.input_cost = 0.0
679+
self.output_cost = 0.0
680+
else:
681+
self.input_cost = 0.0
682+
self.output_cost = 0.0
683+
684+
685+
def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
686+
from litellm import completion as litellm_completion
687+
# Initialize retry tracking attributes
688+
self.retries = 0
689+
self.success = False
690+
self.error_types = []
691+
692+
completion = None
693+
e = None
694+
for itr in range(self.max_retry):
695+
self.retries += 1
696+
temperature = temperature if temperature is not None else self.temperature
697+
try:
698+
completion = litellm_completion(
699+
model=self.model_name,
700+
messages=messages,
701+
# n=n_samples,
702+
# temperature=temperature,
703+
# max_completion_tokens=self.max_tokens,
704+
reasoning_effort=self.reasoning_effort,
705+
)
706+
707+
if completion.usage is None:
708+
raise OpenRouterError(
709+
"The completion object does not contain usage information. This is likely a bug in the OpenRouter API."
710+
)
711+
712+
self.success = True
713+
break
714+
except openai.OpenAIError as e:
715+
error_type = handle_error(e, itr, self.min_retry_wait_time, self.max_retry)
716+
self.error_types.append(error_type)
717+
718+
if not completion:
719+
raise RetryError(
720+
f"Failed to get a response from the API after {self.max_retry} retries\n"
721+
f"Last error: {error_type}"
722+
)
723+
724+
input_tokens = completion.usage.prompt_tokens
725+
output_tokens = completion.usage.completion_tokens
726+
cost = input_tokens * self.input_cost + output_tokens * self.output_cost
727+
728+
if hasattr(tracking.TRACKER, "instance") and isinstance(
729+
tracking.TRACKER.instance, tracking.LLMTracker
730+
):
731+
tracking.TRACKER.instance(input_tokens, output_tokens, cost)
732+
733+
if n_samples == 1:
734+
res_text = completion.choices[0].message.content
735+
if res_text is not None:
736+
res_text = res_text.removesuffix("<|end|>").strip()
737+
else:
738+
res_text = ""
739+
res = AIMessage(res_text)
740+
if self.log_probs:
741+
res["log_probs"] = completion.choices[0].log_probs
742+
return res
743+
else:
744+
return [AIMessage(c.message.content.removesuffix("<|end|>").strip()) for c in completion.choices]
745+
746+
def get_stats(self):
747+
return {
748+
"n_retry_llm": self.retries,
749+
# "busted_retry_llm": int(not self.success), # not logged if it occurs anyways
750+
}

src/agentlab/llm/llm_configs.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
OpenAIModelArgs,
88
OpenRouterModelArgs,
99
SelfHostedModelArgs,
10+
LiteLLMModelArgs,
1011
)
1112

1213
default_oss_llms_args = {
@@ -200,6 +201,68 @@
200201
temperature=1, # temperature param not supported by gpt-5
201202
vision_support=True,
202203
),
204+
"azure/gpt-5-high-2025-08-07": AzureModelArgs(
205+
model_name="gpt-5",
206+
max_total_tokens=400_000,
207+
max_input_tokens=256_000,
208+
max_new_tokens=128_000,
209+
temperature=1, # temperature param not supported by gpt-5
210+
vision_support=True,
211+
reasoning_effort="high",
212+
),
213+
"azure/gpt-5-mini-high-2025-08-07": AzureModelArgs(
214+
model_name="gpt-5-mini",
215+
max_total_tokens=400_000,
216+
max_input_tokens=256_000,
217+
max_new_tokens=128_000,
218+
temperature=1, # temperature param not supported by gpt-5
219+
vision_support=True,
220+
reasoning_effort="high",
221+
),
222+
"azure/gpt-5-nano-high-2025-08-07": AzureModelArgs(
223+
model_name="gpt-5-nano",
224+
max_total_tokens=400_000,
225+
max_input_tokens=256_000,
226+
max_new_tokens=128_000,
227+
temperature=1, # temperature param not supported by gpt-5
228+
vision_support=True,
229+
reasoning_effort="high",
230+
),
231+
"azure/gpt-oss-120b": AzureModelArgs(
232+
model_name="gpt-oss-120b",
233+
max_total_tokens=200_000,
234+
max_input_tokens=200_000,
235+
max_new_tokens=100_000,
236+
temperature=1,
237+
vision_support=False,
238+
reasoning_effort="low",
239+
),
240+
"azure/o3-high-2025-04-16": AzureModelArgs(
241+
model_name="o3",
242+
max_total_tokens=200_000,
243+
max_input_tokens=200_000,
244+
max_new_tokens=100_000,
245+
temperature=1,
246+
vision_support=False,
247+
reasoning_effort="high",
248+
),
249+
"azure/o3-mini-2025-01-31": AzureModelArgs(
250+
model_name="o3-mini",
251+
max_total_tokens=200_000,
252+
max_input_tokens=200_000,
253+
max_new_tokens=100_000,
254+
temperature=1,
255+
vision_support=False,
256+
),
257+
"azure/o3-mini-high-2025-01-31": AzureModelArgs(
258+
model_name="o3-mini",
259+
max_total_tokens=200_000,
260+
max_input_tokens=200_000,
261+
max_new_tokens=100_000,
262+
temperature=1,
263+
vision_support=False,
264+
reasoning_effort="high",
265+
),
203266
# ---------------- Anthropic ----------------#
204267
"anthropic/claude-3-7-sonnet-20250219": AnthropicModelArgs(
205268
model_name="claude-3-7-sonnet-20250219",

0 commit comments

Comments
 (0)