|
5 | 5 | import os |
6 | 6 | from concurrent.futures import ThreadPoolExecutor |
7 | 7 | from copy import deepcopy |
| 8 | +from types import SimpleNamespace |
8 | 9 | from typing import Callable, Dict, List |
9 | 10 |
|
10 | 11 | from dotenv import load_dotenv |
11 | 12 | from litellm import completion_with_retries |
12 | 13 | from termcolor import cprint |
13 | 14 | from tqdm import tqdm |
| 15 | +from vllm import LLM |
14 | 16 |
|
15 | 17 | from utils import split_batch |
16 | 18 |
|
@@ -78,10 +80,44 @@ def run_batched_inference( |
78 | 80 | batched_rows = [row_transform(row) for row in batched_rows] |
79 | 81 | print("Running batched completion for LLM judge") |
80 | 82 |
|
81 | | - if model.startswith("openai"): |
| 83 | + if model.startswith("openai/"): |
82 | 84 | kwargs.update(configure_openai_api(model)) |
83 | | - elif model.startswith("bedrock"): |
| 85 | + elif model.startswith("bedrock/"): |
84 | 86 | load_dotenv() |
| 87 | + else: |
| 88 | + model = LLM( |
| 89 | + model=model, |
| 90 | + generation_config="auto", |
| 91 | + trust_remote_code=True, |
| 92 | + tensor_parallel_size=8, |
| 93 | + ) |
| 94 | + sampling_params = model.get_default_sampling_params() |
| 95 | + sampling_params.temperature = temperature if temperature is not None else 0.0 |
| 96 | + sampling_params.max_tokens = ( |
| 97 | + max_new_tokens if max_new_tokens is not None else 2048 |
| 98 | + ) |
| 99 | + sampling_params.skip_special_tokens = True |
| 100 | + |
| 101 | + prompts = [row["messages"] for row in batched_rows] |
| 102 | + vllm_outputs = model.chat(prompts, sampling_params, use_tqdm=True) |
| 103 | + |
| 104 | + outputs = [SimpleNamespace(content=o.outputs[0].text) for o in vllm_outputs] |
| 105 | + |
| 106 | + output_rows = [] |
| 107 | + for row, ext in zip(batched_rows, outputs): |
| 108 | + row = deepcopy(row) |
| 109 | + reasoning_content = ( |
| 110 | + "<think>\n" + ext.reasoning_content + "\n</think>\n" |
| 111 | + if hasattr(ext, "reasoning_content") |
| 112 | + and ext.reasoning_content |
| 113 | + or "thinking" in kwargs |
| 114 | + else "" |
| 115 | + ) |
| 116 | + row["messages"].append( |
| 117 | + {"role": "assistant", "content": reasoning_content + ext.content} |
| 118 | + ) |
| 119 | + output_rows.append(row) |
| 120 | + return output_rows |
85 | 121 |
|
86 | 122 | parameters = { |
87 | 123 | "model": model, |
|
0 commit comments