Skip to content

Commit f0a0f5d

Browse files
committed
feat(litellm): add vllm server support
1 parent f39e6bd commit f0a0f5d

2 files changed

Lines changed: 39 additions & 3 deletions

File tree

eval/evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from eval.generate import preprocess_generation
1111

12-
DEFAULT_LLM_JUDGE = "openai/gpt-4o"
12+
DEFAULT_LLM_JUDGE = "meta-llama/Llama-3.3-70B-Instruct"
1313

1414

1515
def to_evalplus_format(generation_path: str) -> str:

utils/litellm.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import os
66
from concurrent.futures import ThreadPoolExecutor
77
from copy import deepcopy
8+
from types import SimpleNamespace
89
from typing import Callable, Dict, List
910

1011
from dotenv import load_dotenv
1112
from litellm import completion_with_retries
1213
from termcolor import cprint
1314
from tqdm import tqdm
15+
from vllm import LLM
1416

1517
from utils import split_batch
1618

@@ -78,10 +80,44 @@ def run_batched_inference(
7880
batched_rows = [row_transform(row) for row in batched_rows]
7981
print("Running batched completion for LLM judge")
8082

81-
if model.startswith("openai"):
83+
if model.startswith("openai/"):
8284
kwargs.update(configure_openai_api(model))
83-
elif model.startswith("bedrock"):
85+
elif model.startswith("bedrock/"):
8486
load_dotenv()
87+
else:
88+
model = LLM(
89+
model=model,
90+
generation_config="auto",
91+
trust_remote_code=True,
92+
tensor_parallel_size=8,
93+
)
94+
sampling_params = model.get_default_sampling_params()
95+
sampling_params.temperature = temperature if temperature is not None else 0.0
96+
sampling_params.max_tokens = (
97+
max_new_tokens if max_new_tokens is not None else 2048
98+
)
99+
sampling_params.skip_special_tokens = True
100+
101+
prompts = [row["messages"] for row in batched_rows]
102+
vllm_outputs = model.chat(prompts, sampling_params, use_tqdm=True)
103+
104+
outputs = [SimpleNamespace(content=o.outputs[0].text) for o in vllm_outputs]
105+
106+
output_rows = []
107+
for row, ext in zip(batched_rows, outputs):
108+
row = deepcopy(row)
109+
reasoning_content = (
110+
"<think>\n" + ext.reasoning_content + "\n</think>\n"
111+
if hasattr(ext, "reasoning_content")
112+
and ext.reasoning_content
113+
or "thinking" in kwargs
114+
else ""
115+
)
116+
row["messages"].append(
117+
{"role": "assistant", "content": reasoning_content + ext.content}
118+
)
119+
output_rows.append(row)
120+
return output_rows
85121

86122
parameters = {
87123
"model": model,

0 commit comments

Comments
 (0)