22#
33# SPDX-License-Identifier: Apache-2.0
44
5+ import os
56from concurrent .futures import ThreadPoolExecutor
67from copy import deepcopy
78from typing import Callable , Dict , List
1314
1415from utils import split_batch
1516
16- load_dotenv ()
17-
1817
1918def log_costs (completions ):
2019 costs = [r ._hidden_params ["response_cost" ] for r in completions ]
@@ -57,6 +56,19 @@ def run_batched_inference(
5756 assert batched_rows and "messages" in batched_rows [0 ]
5857 batched_rows = [row_transform (row ) for row in batched_rows ]
5958 print ("Running batched completion for LLM judge" )
59+
60+ if model .startswith ("openai" ):
61+ kwargs ["api_key" ] = (
62+ os .getenv ("OPENAI_API_KEY" , "none" ) if model .count ("/" ) == 1 else "none"
63+ )
64+ kwargs ["api_base" ] = (
65+ os .getenv ("OPENAI_API_BASE" , "http://0.0.0.0:8000/v1" )
66+ if model .count ("/" ) == 1
67+ else "http://0.0.0.0:8000/v1"
68+ )
69+ elif model .startswith ("bedrock" ):
70+ load_dotenv ()
71+
6072 parameters = {
6173 "model" : model ,
6274 "parallel" : parallel ,
@@ -69,7 +81,14 @@ def run_batched_inference(
6981 assert parameters ["max_tokens" ] is None
7082 assert parameters ["temperature" ] is None
7183 else :
72- if parameters ["temperature" ] is None :
84+ if (
85+ model .startswith ("openai/o1-" )
86+ or model .startswith ("openai/o3-" )
87+ or model .startswith ("openai/o4-" )
88+ ):
89+ if "temperature" in parameters :
90+ del parameters ["temperature" ]
91+ elif parameters ["temperature" ] is None :
7392 parameters ["temperature" ] = 0.0
7493
7594 outputs = mini_batch_completion (** parameters )
0 commit comments