Skip to content

Commit 6bb2b13

Browse files
committed
moved model and param model back to internals of completions
1 parent 43e5b78 commit 6bb2b13

2 files changed

Lines changed: 32 additions & 36 deletions

File tree

effectful/handlers/llm/completions.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ChatCompletionMessageToolCall,
1414
ChatCompletionTextObject,
1515
ChatCompletionToolMessage,
16+
ChatCompletionToolParam,
1617
OpenAIChatCompletionAssistantMessage,
1718
OpenAIChatCompletionSystemMessage,
1819
OpenAIChatCompletionUserMessage,
@@ -44,6 +45,35 @@ class DecodedToolCall[T](typing.NamedTuple):
4445
type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None]
4546

4647

48+
@functools.cache
49+
def _param_model(tool: Tool) -> type[pydantic.BaseModel]:
50+
sig = inspect.signature(tool)
51+
return pydantic.create_model(
52+
"Params",
53+
__config__={"extra": "forbid"},
54+
**{
55+
name: Encodable.define(param.annotation).enc
56+
for name, param in sig.parameters.items()
57+
}, # type: ignore
58+
)
59+
60+
61+
@functools.cache
62+
def _function_model(tool: Tool) -> ChatCompletionToolParam:
63+
response_format = litellm.utils.type_to_response_format_param(_param_model(tool))
64+
assert response_format is not None
65+
assert tool.__default__.__doc__ is not None
66+
return {
67+
"type": "function",
68+
"function": {
69+
"name": tool.__name__,
70+
"description": textwrap.dedent(tool.__default__.__doc__),
71+
"parameters": response_format["json_schema"]["schema"],
72+
"strict": True,
73+
},
74+
}
75+
76+
4777
def decode_tool_call(
4878
tool_call: ChatCompletionMessageToolCall,
4979
tools: collections.abc.Mapping[str, Tool],
@@ -56,7 +86,7 @@ def decode_tool_call(
5686
sig = inspect.signature(tool)
5787

5888
# build dict of raw encodable types U
59-
raw_args = tool.param_model.model_validate_json(json_str)
89+
raw_args = _param_model(tool).model_validate_json(json_str)
6090

6191
# use encoders to decode Us to python types T
6292
bound_sig: inspect.BoundArguments = sig.bind(
@@ -96,7 +126,7 @@ def call_assistant[T, U](
96126
observe/log requests.
97127
98128
"""
99-
tool_specs = {k: t.model for k, t in tools.items()}
129+
tool_specs = {k: _function_model(t) for k, t in tools.items()}
100130
response_model = pydantic.create_model(
101131
"Response", value=response_format.enc, __config__={"extra": "forbid"}
102132
)

effectful/handlers/llm/template.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
1-
import functools
21
import inspect
3-
import textwrap
42
import types
53
import typing
64
from collections import ChainMap
75
from collections.abc import Callable, Mapping, MutableMapping
86
from dataclasses import dataclass
97
from typing import Annotated, Any
108

11-
import litellm
12-
import pydantic
13-
from litellm import ChatCompletionToolParam
14-
15-
from effectful.handlers.llm.encoding import Encodable
169
from effectful.ops.types import INSTANCE_OP_PREFIX, Annotation, Operation
1710

1811

@@ -102,33 +95,6 @@ def __init__(
10295
signature = IsRecursive.infer_annotations(signature)
10396
super().__init__(signature, name, default)
10497

105-
@functools.cached_property
106-
def param_model(self) -> type[pydantic.BaseModel]:
107-
sig = inspect.signature(self)
108-
return pydantic.create_model(
109-
"Params",
110-
__config__={"extra": "forbid"},
111-
**{
112-
name: Encodable.define(param.annotation).enc
113-
for name, param in sig.parameters.items()
114-
}, # type: ignore
115-
)
116-
117-
@functools.cached_property
118-
def model(self) -> ChatCompletionToolParam:
119-
response_format = litellm.utils.type_to_response_format_param(self.param_model)
120-
assert response_format is not None
121-
assert self.__default__.__doc__ is not None
122-
return {
123-
"type": "function",
124-
"function": {
125-
"name": self.__name__,
126-
"description": textwrap.dedent(self.__default__.__doc__),
127-
"parameters": response_format["json_schema"]["schema"],
128-
"strict": True,
129-
},
130-
}
131-
13298
@classmethod
13399
def define(cls, *args, **kwargs) -> "Tool[P, T]":
134100
"""Define a tool.

0 commit comments

Comments
 (0)