Skip to content

Commit a426911

Browse files
committed
fix: gemini comments
1 parent caf2cc4 commit a426911

2 files changed

Lines changed: 26 additions & 29 deletions

File tree

eval/generate.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020

2121
from utils import SYSTEM_PROMPT, split_batch
22+
from utils.litellm import configure_openai_api, is_o_series_model
2223

2324
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2425

@@ -193,26 +194,12 @@ def generate_openai(
193194
"retry_strategy": "exponential_backoff_retry",
194195
"max_tokens": max_new_tokens,
195196
"model": model,
196-
"api_key": (
197-
os.getenv("OPENAI_API_KEY", "none")
198-
if model.count("/") == 1
199-
else "none"
200-
),
201-
"api_base": (
202-
os.getenv("OPENAI_API_BASE", "http://0.0.0.0:8000/v1")
203-
if model.count("/") == 1
204-
else "http://0.0.0.0:8000/v1"
205-
),
206197
"temperature": temperature,
207198
"stop": ["<end_of_turn>"],
199+
**configure_openai_api(model),
208200
}
209201

210-
if (
211-
model.startswith("openai/o1-")
212-
or model.startswith("openai/o3-")
213-
or model.startswith("openai/o4-")
214-
):
215-
# O-series models don't support customized temperature. Only default temperature=1 is supported.
202+
if is_o_series_model(model):
216203
del kwargs["temperature"]
217204
del kwargs["stop"]
218205

utils/litellm.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,27 @@ def mini_batch_completion(messages, parallel: int = 32, **kwargs):
4444
return outputs
4545

4646

47+
def configure_openai_api(model: str) -> dict:
48+
return {
49+
"api_key": (
50+
os.getenv("OPENAI_API_KEY", "none") if model.count("/") == 1 else "none"
51+
),
52+
"api_base": (
53+
os.getenv("OPENAI_API_BASE", "http://0.0.0.0:8000/v1")
54+
if model.count("/") == 1
55+
else "http://0.0.0.0:8000/v1"
56+
),
57+
}
58+
59+
60+
def is_o_series_model(model: str) -> bool:
61+
return (
62+
model.startswith("openai/o1-")
63+
or model.startswith("openai/o3-")
64+
or model.startswith("openai/o4-")
65+
)
66+
67+
4768
def run_batched_inference(
4869
batched_rows: List, # each row includes at least "messages"
4970
row_transform: Callable[[Dict], Dict] = lambda x: x,
@@ -58,14 +79,7 @@ def run_batched_inference(
5879
print("Running batched completion for LLM judge")
5980

6081
if model.startswith("openai"):
61-
kwargs["api_key"] = (
62-
os.getenv("OPENAI_API_KEY", "none") if model.count("/") == 1 else "none"
63-
)
64-
kwargs["api_base"] = (
65-
os.getenv("OPENAI_API_BASE", "http://0.0.0.0:8000/v1")
66-
if model.count("/") == 1
67-
else "http://0.0.0.0:8000/v1"
68-
)
82+
kwargs.update(configure_openai_api(model))
6983
elif model.startswith("bedrock"):
7084
load_dotenv()
7185

@@ -81,11 +95,7 @@ def run_batched_inference(
8195
assert parameters["max_tokens"] is None
8296
assert parameters["temperature"] is None
8397
else:
84-
if (
85-
model.startswith("openai/o1-")
86-
or model.startswith("openai/o3-")
87-
or model.startswith("openai/o4-")
88-
):
98+
if is_o_series_model(model):
8999
if "temperature" in parameters:
90100
del parameters["temperature"]
91101
elif parameters["temperature"] is None:

0 commit comments

Comments
 (0)