1+ import json
2+ import os
13from typing import Optional , Union , List , Dict , Iterator
24from pydantic import BaseModel , Field
35import openai
@@ -66,10 +68,17 @@ def complete_response(self, prompt: OpenAIPrompt) -> str:
6668 config_params ['function_call' ] = 'auto'
6769 config_params ['stream' ] = False
6870
69- response = openai .ChatCompletion .create (
71+ client = openai .OpenAI (
72+ api_key = os .environ .get ("OPENAI_API_KEY" , None ),
73+ base_url = os .environ .get ("OPENAI_API_BASE" , None )
74+ )
75+
76+ response = client .chat .completions .create (
7077 messages = prompt .messages ,
7178 ** config_params
7279 )
80+ if isinstance (response , openai .types .chat .chat_completion .ChatCompletion ):
81+ return json .dumps (response .dict ())
7382 return str (response )
7483
7584 def stream_response (self , prompt : OpenAIPrompt ) -> Iterator :
@@ -80,8 +89,14 @@ def stream_response(self, prompt: OpenAIPrompt) -> Iterator:
8089 config_params ['function_call' ] = 'auto'
8190 config_params ['stream' ] = True
8291
83- response = openai .ChatCompletion .create (
92+ client = openai .OpenAI (
93+ api_key = os .environ .get ("OPENAI_API_KEY" , None ),
94+ base_url = os .environ .get ("OPENAI_API_BASE" , None )
95+ )
96+
97+ response = client .chat .completions .create (
8498 messages = prompt .messages ,
85- ** config_params
99+ ** config_params ,
100+ timeout = 8
86101 )
87102 return response
0 commit comments