Skip to content

Commit 81aae52

Browse files
feat(chat): add parse method (#237)
Co-authored-by: Robert Craigie <robert@craigie.dev>
1 parent fa87048 commit 81aae52

3 files changed

Lines changed: 309 additions & 80 deletions

File tree

src/writerai/lib/_parsing/_completions.py

Lines changed: 59 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,25 @@
66

77
import pydantic
88

9-
from writerai.types.shared_params.tool_param import ToolParam
10-
119
from .._tools import PydanticFunctionTool
1210
from ..._types import NOT_GIVEN, NotGiven
13-
from ..._utils import is_given
11+
from ..._utils import is_dict, is_given
1412
from ..._compat import PYDANTIC_V2, model_parse_json
1513
from ..._models import construct_type_unchecked
16-
from .._pydantic import is_basemodel_type, is_dataclass_like_type
14+
from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
1715
from ..._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
1816
from ...types.parsed_chat import (
17+
ChatCompletion,
1918
ParsedChatCompletion,
19+
ChatCompletionMessage,
20+
ParsedFunctionToolCall,
2021
ParsedChatCompletionChoice,
2122
ParsedChatCompletionMessage,
2223
)
23-
from ...types.shared_params import FunctionDefinition
24-
from ...types.chat_completion import ChatCompletion
24+
from ...types.shared_params import ToolParam as ChatCompletionToolParam, FunctionDefinition
25+
from ...types.chat_chat_params import ResponseFormat as ResponseFormatParam
2526
from ...types.shared.tool_call import Function
26-
from ...types.chat_completion_message import ChatCompletionMessage
27-
from ...types.parsed_function_tool_call import (
28-
ParsedFunction,
29-
ParsedFunctionToolCall,
30-
)
27+
from ...types.parsed_function_tool_call import ParsedFunction
3128

3229
ResponseFormatT = TypeVar(
3330
"ResponseFormatT",
@@ -38,7 +35,7 @@
3835

3936

4037
def validate_input_tools(
41-
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
38+
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
4239
) -> None:
4340
if not is_given(tools):
4441
return
@@ -58,9 +55,8 @@ def validate_input_tools(
5855

5956
def parse_chat_completion(
6057
*,
61-
# response_format: type[ResponseFormatT] | chat_chat_params.ResponseFormat | NotGiven,
62-
response_format: type[ResponseFormatT] | NotGiven,
63-
input_tools: Iterable[ToolParam] | NotGiven,
58+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
59+
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
6460
chat_completion: ChatCompletion | ParsedChatCompletion[object],
6561
) -> ParsedChatCompletion[ResponseFormatT]:
6662
if is_given(input_tools):
@@ -113,7 +109,7 @@ def parse_chat_completion(
113109
response_format=response_format,
114110
message=message,
115111
),
116-
"tool_calls": tool_calls,
112+
"tool_calls": tool_calls if tool_calls else None,
117113
},
118114
},
119115
)
@@ -131,11 +127,13 @@ def parse_chat_completion(
131127
)
132128

133129

134-
def get_input_tool_by_name(*, input_tools: list[ToolParam], name: str) -> ToolParam | None:
130+
def get_input_tool_by_name(*, input_tools: list[ChatCompletionToolParam], name: str) -> ChatCompletionToolParam | None:
135131
return next((t for t in input_tools if t.get("function", {}).get("name") == name), None)
136132

137133

138-
def parse_function_tool_arguments(*, input_tools: list[ToolParam], function: Function | ParsedFunction) -> object:
134+
def parse_function_tool_arguments(
135+
*, input_tools: list[ChatCompletionToolParam], function: Function | ParsedFunction
136+
) -> object:
139137
assert function.name is not None
140138
input_tool = get_input_tool_by_name(input_tools=input_tools, name=function.name)
141139
if not input_tool:
@@ -155,8 +153,7 @@ def parse_function_tool_arguments(*, input_tools: list[ToolParam], function: Fun
155153

156154
def maybe_parse_content(
157155
*,
158-
# response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
159-
response_format: type[ResponseFormatT] | NotGiven,
156+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
160157
message: ChatCompletionMessage | ParsedChatCompletionMessage[object],
161158
) -> ResponseFormatT | None:
162159
if has_rich_response_format(response_format) and message.content and not message.refusal:
@@ -166,8 +163,7 @@ def maybe_parse_content(
166163

167164

168165
def solve_response_format_t(
169-
# response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
170-
response_format: type[ResponseFormatT] | NotGiven,
166+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
171167
) -> type[ResponseFormatT]:
172168
"""Return the runtime type for the given response format.
173169
@@ -182,9 +178,8 @@ def solve_response_format_t(
182178

183179
def has_parseable_input(
184180
*,
185-
# response_format: type | ResponseFormatParam | NotGiven,
186-
response_format: type | NotGiven,
187-
input_tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
181+
response_format: type | ResponseFormatParam | NotGiven,
182+
input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
188183
) -> bool:
189184
if has_rich_response_format(response_format):
190185
return True
@@ -197,30 +192,27 @@ def has_parseable_input(
197192

198193

199194
def has_rich_response_format(
200-
# response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
201-
response_format: type[ResponseFormatT] | NotGiven,
195+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
202196
) -> TypeGuard[type[ResponseFormatT]]:
203197
if not is_given(response_format):
204198
return False
205199

206-
# if is_response_format_param(response_format):
207-
# return False
200+
if is_response_format_param(response_format):
201+
return False
208202

209203
return True
210204

211205

212-
# def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
213-
# return is_dict(response_format)
206+
def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
207+
return is_dict(response_format)
214208

215209

216-
def is_parseable_tool(input_tool: ToolParam) -> bool:
210+
def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool:
217211
input_fn = cast(object, input_tool.get("function"))
218212
if isinstance(input_fn, PydanticFunctionTool):
219213
return True
220214

221-
# FIXME: `strict` currently missing in the schema definition
222-
# return cast(FunctionDefinition, input_fn).get("strict") or False
223-
return False
215+
return cast(FunctionDefinition, input_fn).get("strict") or False # type: ignore
224216

225217

226218
def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
@@ -236,43 +228,36 @@ def _parse_content(response_format: type[ResponseFormatT], content: str) -> Resp
236228
raise TypeError(f"Unable to automatically parse response format type {response_format}")
237229

238230

239-
# def type_to_response_format_param(
240-
# response_format: type | completion_create_params.ResponseFormat | NotGiven,
241-
# ) -> ResponseFormatParam | NotGiven:
242-
# if not is_given(response_format):
243-
# return NOT_GIVEN
244-
245-
# if is_response_format_param(response_format):
246-
# return response_format
247-
248-
# # type checkers don't narrow the negation of a `TypeGuard` as it isn't
249-
# # a safe default behaviour but we know that at this point the `response_format`
250-
# # can only be a `type`
251-
# response_format = cast(type, response_format)
252-
253-
# json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
254-
255-
# if is_basemodel_type(response_format):
256-
# name = response_format.__name__
257-
# json_schema_type = response_format
258-
# elif is_dataclass_like_type(response_format):
259-
# name = response_format.__name__
260-
# json_schema_type = pydantic.TypeAdapter(response_format)
261-
# else:
262-
# raise TypeError(f"Unsupported response_format type - {response_format}")
263-
264-
265-
# return {
266-
# "type": "json_schema",
267-
# "json_schema": {
268-
# "schema": to_strict_json_schema(json_schema_type),
269-
# "name": name,
270-
# "strict": True,
271-
# },
272-
# }
273231
def type_to_response_format_param(
274-
response_format: type | NotGiven,
275-
) -> NotGiven:
276-
if is_given(response_format):
277-
raise NotImplementedError("Support for response_format is not implemented yet")
278-
return NOT_GIVEN
232+
response_format: type | ResponseFormatParam | NotGiven,
233+
) -> ResponseFormatParam | NotGiven:
234+
if not is_given(response_format):
235+
return NOT_GIVEN
236+
237+
if is_response_format_param(response_format):
238+
return response_format
239+
240+
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
241+
# a safe default behaviour but we know that at this point the `response_format`
242+
# can only be a `type`
243+
response_format = cast(type, response_format)
244+
245+
json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
246+
247+
if is_basemodel_type(response_format):
248+
name = response_format.__name__
249+
json_schema_type = response_format
250+
elif is_dataclass_like_type(response_format):
251+
name = response_format.__name__
252+
json_schema_type = pydantic.TypeAdapter(response_format)
253+
else:
254+
raise TypeError(f"Unsupported response_format type - {response_format}")
255+
256+
return {
257+
"type": "json_schema",
258+
"json_schema": {
259+
"schema": to_strict_json_schema(json_schema_type),
260+
"name": name,
261+
"strict": True,
262+
},
263+
}

0 commit comments

Comments
 (0)