@@ -44,6 +44,27 @@ def mini_batch_completion(messages, parallel: int = 32, **kwargs):
4444 return outputs
4545
4646
47+ def configure_openai_api (model : str ) -> dict :
48+ return {
49+ "api_key" : (
50+ os .getenv ("OPENAI_API_KEY" , "none" ) if model .count ("/" ) == 1 else "none"
51+ ),
52+ "api_base" : (
53+ os .getenv ("OPENAI_API_BASE" , "http://0.0.0.0:8000/v1" )
54+ if model .count ("/" ) == 1
55+ else "http://0.0.0.0:8000/v1"
56+ ),
57+ }
58+
59+
60+ def is_o_series_model (model : str ) -> bool :
61+ return (
62+ model .startswith ("openai/o1-" )
63+ or model .startswith ("openai/o3-" )
64+ or model .startswith ("openai/o4-" )
65+ )
66+
67+
4768def run_batched_inference (
4869 batched_rows : List , # each row includes at least "messages"
4970 row_transform : Callable [[Dict ], Dict ] = lambda x : x ,
@@ -58,14 +79,7 @@ def run_batched_inference(
5879 print ("Running batched completion for LLM judge" )
5980
6081 if model .startswith ("openai" ):
61- kwargs ["api_key" ] = (
62- os .getenv ("OPENAI_API_KEY" , "none" ) if model .count ("/" ) == 1 else "none"
63- )
64- kwargs ["api_base" ] = (
65- os .getenv ("OPENAI_API_BASE" , "http://0.0.0.0:8000/v1" )
66- if model .count ("/" ) == 1
67- else "http://0.0.0.0:8000/v1"
68- )
82+ kwargs .update (configure_openai_api (model ))
6983 elif model .startswith ("bedrock" ):
7084 load_dotenv ()
7185
@@ -81,11 +95,7 @@ def run_batched_inference(
8195 assert parameters ["max_tokens" ] is None
8296 assert parameters ["temperature" ] is None
8397 else :
84- if (
85- model .startswith ("openai/o1-" )
86- or model .startswith ("openai/o3-" )
87- or model .startswith ("openai/o4-" )
88- ):
98+ if is_o_series_model (model ):
8999 if "temperature" in parameters :
90100 del parameters ["temperature" ]
91101 elif parameters ["temperature" ] is None :
0 commit comments