Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions eval/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ def generate_openai(
temperature: float = 0.0,
max_new_tokens: int = 8192,
):
assert model.startswith("openai/"), (
"If running openai backend, model name must start with 'openai/'. "
"For example, 'deepseek-ai/DeepSeek-R1' should be 'openai/deepseek-ai/DeepSeek-R1'"
)

outputs = []
with ThreadPoolExecutor(max_workers=len(messages_batch)) as executor:
futures = []
Expand All @@ -187,15 +192,29 @@ def generate_openai(
"num_retries": 16,
"retry_strategy": "exponential_backoff_retry",
"max_tokens": max_new_tokens,
"model": f"openai/{model}",
"api_key": os.getenv("OPENAI_API_KEY", "none"),
"api_base": os.getenv("OPENAI_API_BASE", "http://0.0.0.0:8000/v1"),
"model": model,
"api_key": (
os.getenv("OPENAI_API_KEY", "none")
if model.count("/") == 1
else "none"
),
"api_base": (
os.getenv("OPENAI_API_BASE", "http://0.0.0.0:8000/v1")
if model.count("/") == 1
Comment thread
zhewang2001 marked this conversation as resolved.
Outdated
else "http://0.0.0.0:8000/v1"
),
Comment thread
zhewang2001 marked this conversation as resolved.
Outdated
Comment thread
zhewang2001 marked this conversation as resolved.
Outdated
"temperature": temperature,
"stop": ["<end_of_turn>"],
}

if model != "o4-mini":
if (
model.startswith("openai/o1-")
or model.startswith("openai/o3-")
or model.startswith("openai/o4-")
):
# O-series models don't support customized temperature. Only default temperature=1 is supported.
kwargs["temperature"] = temperature
kwargs["stop"] = ["<end_of_turn>"]
del kwargs["temperature"]
del kwargs["stop"]
Comment thread
zhewang2001 marked this conversation as resolved.
Outdated

future = executor.submit(completion_with_retries, **kwargs)
futures.append(future)
Expand Down
25 changes: 22 additions & 3 deletions utils/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import os
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import Callable, Dict, List
Expand All @@ -13,8 +14,6 @@

from utils import split_batch

load_dotenv()


def log_costs(completions):
costs = [r._hidden_params["response_cost"] for r in completions]
Expand Down Expand Up @@ -57,6 +56,19 @@ def run_batched_inference(
assert batched_rows and "messages" in batched_rows[0]
batched_rows = [row_transform(row) for row in batched_rows]
print("Running batched completion for LLM judge")

if model.startswith("openai"):
kwargs["api_key"] = (
os.getenv("OPENAI_API_KEY", "none") if model.count("/") == 1 else "none"
)
kwargs["api_base"] = (
os.getenv("OPENAI_API_BASE", "http://0.0.0.0:8000/v1")
if model.count("/") == 1
Comment thread
zhewang2001 marked this conversation as resolved.
Outdated
else "http://0.0.0.0:8000/v1"
)
Comment thread
zhewang2001 marked this conversation as resolved.
Comment thread
zhewang2001 marked this conversation as resolved.
Outdated
elif model.startswith("bedrock"):
load_dotenv()

parameters = {
"model": model,
"parallel": parallel,
Expand All @@ -69,7 +81,14 @@ def run_batched_inference(
assert parameters["max_tokens"] is None
assert parameters["temperature"] is None
else:
if parameters["temperature"] is None:
if (
model.startswith("openai/o1-")
or model.startswith("openai/o3-")
or model.startswith("openai/o4-")
):
if "temperature" in parameters:
del parameters["temperature"]
Comment thread
zhewang2001 marked this conversation as resolved.
Outdated
elif parameters["temperature"] is None:
parameters["temperature"] = 0.0

outputs = mini_batch_completion(**parameters)
Expand Down