1313 ChatCompletionMessageToolCall ,
1414 ChatCompletionTextObject ,
1515 ChatCompletionToolMessage ,
16+ ChatCompletionToolParam ,
1617 OpenAIChatCompletionAssistantMessage ,
1718 OpenAIChatCompletionSystemMessage ,
1819 OpenAIChatCompletionUserMessage ,
@@ -44,6 +45,35 @@ class DecodedToolCall[T](typing.NamedTuple):
4445type MessageResult [T ] = tuple [Message , typing .Sequence [DecodedToolCall ], T | None ]
4546
4647
48+ @functools .cache
49+ def _param_model (tool : Tool ) -> type [pydantic .BaseModel ]:
50+ sig = inspect .signature (tool )
51+ return pydantic .create_model (
52+ "Params" ,
53+ __config__ = {"extra" : "forbid" },
54+ ** {
55+ name : Encodable .define (param .annotation ).enc
56+ for name , param in sig .parameters .items ()
57+ }, # type: ignore
58+ )
59+
60+
61+ @functools .cache
62+ def _function_model (tool : Tool ) -> ChatCompletionToolParam :
63+ response_format = litellm .utils .type_to_response_format_param (_param_model (tool ))
64+ assert response_format is not None
65+ assert tool .__default__ .__doc__ is not None
66+ return {
67+ "type" : "function" ,
68+ "function" : {
69+ "name" : tool .__name__ ,
70+ "description" : textwrap .dedent (tool .__default__ .__doc__ ),
71+ "parameters" : response_format ["json_schema" ]["schema" ],
72+ "strict" : True ,
73+ },
74+ }
75+
76+
4777def decode_tool_call (
4878 tool_call : ChatCompletionMessageToolCall ,
4979 tools : collections .abc .Mapping [str , Tool ],
@@ -56,7 +86,7 @@ def decode_tool_call(
5686 sig = inspect .signature (tool )
5787
5888 # build dict of raw encodable types U
59- raw_args = tool . param_model .model_validate_json (json_str )
89+ raw_args = _param_model ( tool ) .model_validate_json (json_str )
6090
6191 # use encoders to decode Us to python types T
6292 bound_sig : inspect .BoundArguments = sig .bind (
@@ -96,7 +126,7 @@ def call_assistant[T, U](
96126 observe/log requests.
97127
98128 """
99- tool_specs = {k : t . model for k , t in tools .items ()}
129+ tool_specs = {k : _function_model ( t ) for k , t in tools .items ()}
100130 response_model = pydantic .create_model (
101131 "Response" , value = response_format .enc , __config__ = {"extra" : "forbid" }
102132 )
0 commit comments