66
77import pydantic
88
9- from writerai .types .shared_params .tool_param import ToolParam
10-
119from .._tools import PydanticFunctionTool
1210from ..._types import NOT_GIVEN , NotGiven
13- from ..._utils import is_given
11+ from ..._utils import is_dict , is_given
1412from ..._compat import PYDANTIC_V2 , model_parse_json
1513from ..._models import construct_type_unchecked
16- from .._pydantic import is_basemodel_type , is_dataclass_like_type
14+ from .._pydantic import is_basemodel_type , to_strict_json_schema , is_dataclass_like_type
1715from ..._exceptions import LengthFinishReasonError , ContentFilterFinishReasonError
1816from ...types .parsed_chat import (
17+ ChatCompletion ,
1918 ParsedChatCompletion ,
19+ ChatCompletionMessage ,
20+ ParsedFunctionToolCall ,
2021 ParsedChatCompletionChoice ,
2122 ParsedChatCompletionMessage ,
2223)
23- from ...types .shared_params import FunctionDefinition
24- from ...types .chat_completion import ChatCompletion
24+ from ...types .shared_params import ToolParam as ChatCompletionToolParam , FunctionDefinition
25+ from ...types .chat_chat_params import ResponseFormat as ResponseFormatParam
2526from ...types .shared .tool_call import Function
26- from ...types .chat_completion_message import ChatCompletionMessage
27- from ...types .parsed_function_tool_call import (
28- ParsedFunction ,
29- ParsedFunctionToolCall ,
30- )
27+ from ...types .parsed_function_tool_call import ParsedFunction
3128
3229ResponseFormatT = TypeVar (
3330 "ResponseFormatT" ,
3835
3936
4037def validate_input_tools (
41- tools : Iterable [ToolParam ] | NotGiven = NOT_GIVEN ,
38+ tools : Iterable [ChatCompletionToolParam ] | NotGiven = NOT_GIVEN ,
4239) -> None :
4340 if not is_given (tools ):
4441 return
@@ -58,9 +55,8 @@ def validate_input_tools(
5855
5956def parse_chat_completion (
6057 * ,
61- # response_format: type[ResponseFormatT] | chat_chat_params.ResponseFormat | NotGiven,
62- response_format : type [ResponseFormatT ] | NotGiven ,
63- input_tools : Iterable [ToolParam ] | NotGiven ,
58+ response_format : type [ResponseFormatT ] | ResponseFormatParam | NotGiven ,
59+ input_tools : Iterable [ChatCompletionToolParam ] | NotGiven ,
6460 chat_completion : ChatCompletion | ParsedChatCompletion [object ],
6561) -> ParsedChatCompletion [ResponseFormatT ]:
6662 if is_given (input_tools ):
@@ -113,7 +109,7 @@ def parse_chat_completion(
113109 response_format = response_format ,
114110 message = message ,
115111 ),
116- "tool_calls" : tool_calls ,
112+ "tool_calls" : tool_calls if tool_calls else None ,
117113 },
118114 },
119115 )
@@ -131,11 +127,13 @@ def parse_chat_completion(
131127 )
132128
133129
134- def get_input_tool_by_name (* , input_tools : list [ToolParam ], name : str ) -> ToolParam | None :
130+ def get_input_tool_by_name (* , input_tools : list [ChatCompletionToolParam ], name : str ) -> ChatCompletionToolParam | None :
135131 return next ((t for t in input_tools if t .get ("function" , {}).get ("name" ) == name ), None )
136132
137133
138- def parse_function_tool_arguments (* , input_tools : list [ToolParam ], function : Function | ParsedFunction ) -> object :
134+ def parse_function_tool_arguments (
135+ * , input_tools : list [ChatCompletionToolParam ], function : Function | ParsedFunction
136+ ) -> object :
139137 assert function .name is not None
140138 input_tool = get_input_tool_by_name (input_tools = input_tools , name = function .name )
141139 if not input_tool :
@@ -155,8 +153,7 @@ def parse_function_tool_arguments(*, input_tools: list[ToolParam], function: Fun
155153
156154def maybe_parse_content (
157155 * ,
158- # response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
159- response_format : type [ResponseFormatT ] | NotGiven ,
156+ response_format : type [ResponseFormatT ] | ResponseFormatParam | NotGiven ,
160157 message : ChatCompletionMessage | ParsedChatCompletionMessage [object ],
161158) -> ResponseFormatT | None :
162159 if has_rich_response_format (response_format ) and message .content and not message .refusal :
@@ -166,8 +163,7 @@ def maybe_parse_content(
166163
167164
168165def solve_response_format_t (
169- # response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
170- response_format : type [ResponseFormatT ] | NotGiven ,
166+ response_format : type [ResponseFormatT ] | ResponseFormatParam | NotGiven ,
171167) -> type [ResponseFormatT ]:
172168 """Return the runtime type for the given response format.
173169
@@ -182,9 +178,8 @@ def solve_response_format_t(
182178
183179def has_parseable_input (
184180 * ,
185- # response_format: type | ResponseFormatParam | NotGiven,
186- response_format : type | NotGiven ,
187- input_tools : Iterable [ToolParam ] | NotGiven = NOT_GIVEN ,
181+ response_format : type | ResponseFormatParam | NotGiven ,
182+ input_tools : Iterable [ChatCompletionToolParam ] | NotGiven = NOT_GIVEN ,
188183) -> bool :
189184 if has_rich_response_format (response_format ):
190185 return True
@@ -197,30 +192,27 @@ def has_parseable_input(
197192
198193
199194def has_rich_response_format (
200- # response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
201- response_format : type [ResponseFormatT ] | NotGiven ,
195+ response_format : type [ResponseFormatT ] | ResponseFormatParam | NotGiven ,
202196) -> TypeGuard [type [ResponseFormatT ]]:
203197 if not is_given (response_format ):
204198 return False
205199
206- # if is_response_format_param(response_format):
207- # return False
200+ if is_response_format_param (response_format ):
201+ return False
208202
209203 return True
210204
211205
212- # def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
213- # return is_dict(response_format)
206+ def is_response_format_param (response_format : object ) -> TypeGuard [ResponseFormatParam ]:
207+ return is_dict (response_format )
214208
215209
216- def is_parseable_tool (input_tool : ToolParam ) -> bool :
210+ def is_parseable_tool (input_tool : ChatCompletionToolParam ) -> bool :
217211 input_fn = cast (object , input_tool .get ("function" ))
218212 if isinstance (input_fn , PydanticFunctionTool ):
219213 return True
220214
221- # FIXME: `strict` currently missing in the schema definition
222- # return cast(FunctionDefinition, input_fn).get("strict") or False
223- return False
215+ return cast (FunctionDefinition , input_fn ).get ("strict" ) or False # type: ignore
224216
225217
226218def _parse_content (response_format : type [ResponseFormatT ], content : str ) -> ResponseFormatT :
@@ -236,43 +228,36 @@ def _parse_content(response_format: type[ResponseFormatT], content: str) -> Resp
236228 raise TypeError (f"Unable to automatically parse response format type { response_format } " )
237229
238230
239- # def type_to_response_format_param(
240- # response_format: type | completion_create_params.ResponseFormat | NotGiven,
241- # ) -> ResponseFormatParam | NotGiven:
242- # if not is_given(response_format):
243- # return NOT_GIVEN
244-
245- # if is_response_format_param(response_format):
246- # return response_format
247-
248- # # type checkers don't narrow the negation of a `TypeGuard` as it isn't
249- # # a safe default behaviour but we know that at this point the `response_format`
250- # # can only be a `type`
251- # response_format = cast(type, response_format)
252-
253- # json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
254-
255- # if is_basemodel_type(response_format):
256- # name = response_format.__name__
257- # json_schema_type = response_format
258- # elif is_dataclass_like_type(response_format):
259- # name = response_format.__name__
260- # json_schema_type = pydantic.TypeAdapter(response_format)
261- # else:
262- # raise TypeError(f"Unsupported response_format type - {response_format}")
263-
264-
265- # return {
266- # "type": "json_schema",
267- # "json_schema": {
268- # "schema": to_strict_json_schema(json_schema_type),
269- # "name": name,
270- # "strict": True,
271- # },
272- # }
273231def type_to_response_format_param (
274- response_format : type | NotGiven ,
275- ) -> NotGiven :
276- if is_given (response_format ):
277- raise NotImplementedError ("Support for response_format is not implemented yet" )
278- return NOT_GIVEN
232+ response_format : type | ResponseFormatParam | NotGiven ,
233+ ) -> ResponseFormatParam | NotGiven :
234+ if not is_given (response_format ):
235+ return NOT_GIVEN
236+
237+ if is_response_format_param (response_format ):
238+ return response_format
239+
240+ # type checkers don't narrow the negation of a `TypeGuard` as it isn't
241+ # a safe default behaviour but we know that at this point the `response_format`
242+ # can only be a `type`
243+ response_format = cast (type , response_format )
244+
245+ json_schema_type : type [pydantic .BaseModel ] | pydantic .TypeAdapter [Any ] | None = None
246+
247+ if is_basemodel_type (response_format ):
248+ name = response_format .__name__
249+ json_schema_type = response_format
250+ elif is_dataclass_like_type (response_format ):
251+ name = response_format .__name__
252+ json_schema_type = pydantic .TypeAdapter (response_format )
253+ else :
254+ raise TypeError (f"Unsupported response_format type - { response_format } " )
255+
256+ return {
257+ "type" : "json_schema" ,
258+ "json_schema" : {
259+ "schema" : to_strict_json_schema (json_schema_type ),
260+ "name" : name ,
261+ "strict" : True ,
262+ },
263+ }
0 commit comments