Skip to content

Commit 20c18ff

Browse files
authored
Fix Tuple Encoding (#568)
* Fix tuple encoding * Lint * Lint * Add more test to produce original bug * Add more test to produce original bug * Moving encoded_ty decoding back to decoder and update tests * Reverse SequenceEncodable to reflex Python's inheritance hierarchy * Add OpenAI Serialize test and Parameterized tests * Linting * Separate encodable logic for tuple, sequence, mutablesequence * Test for bare tuples * Lint * Add fixture * Update fixture * Minor * lint * Lint * more lint
1 parent 68d7645 commit 20c18ff

6 files changed

Lines changed: 374 additions & 91 deletions

effectful/handlers/llm/encoding.py

Lines changed: 170 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,16 @@ def deserialize(self, serialized_value: str) -> pydantic.BaseModel:
225225

226226
@dataclass
227227
class TupleEncodable[T](Encodable[T, typing.Any]):
228+
"""Encodes fixed-length heterogeneous tuples (e.g. ``tuple[int, str]``).
229+
230+
``model_cls`` is a dynamic pydantic model (``TupleItems``) with one field
231+
per position, producing an object JSON schema that OpenAI accepts
232+
(unlike the ``prefixItems`` schema from native tuple types).
233+
"""
234+
228235
base: type[T]
229236
enc: type[typing.Any]
237+
model_cls: type[pydantic.BaseModel]
230238
ctx: Mapping[str, Any]
231239
has_image: bool
232240
element_encoders: list[Encodable]
@@ -239,95 +247,123 @@ def encode(self, value: T) -> typing.Any:
239247
f"Tuple length {len(value)} does not match expected length {len(self.element_encoders)}"
240248
)
241249
return tuple(
242-
[enc.encode(elem) for enc, elem in zip(self.element_encoders, value)]
250+
enc.encode(elem) for enc, elem in zip(self.element_encoders, value)
243251
)
244252

245253
def decode(self, encoded_value: typing.Any) -> T:
246-
if len(encoded_value) != len(self.element_encoders):
254+
# Pydantic validation produces a TupleItems model instance;
255+
# extract the positional fields back into a sequence.
256+
if isinstance(encoded_value, pydantic.BaseModel):
257+
items = list(encoded_value.model_dump().values())
258+
else:
259+
items = list(encoded_value)
260+
if len(items) != len(self.element_encoders):
247261
raise ValueError(
248-
f"tuple length {len(encoded_value)} does not match expected length {len(self.element_encoders)}"
262+
f"tuple length {len(items)} does not match expected length {len(self.element_encoders)}"
249263
)
250-
decoded_elements: list[typing.Any] = [
251-
enc.decode(elem) for enc, elem in zip(self.element_encoders, encoded_value)
252-
]
253-
return typing.cast(T, tuple(decoded_elements))
264+
return typing.cast(
265+
T,
266+
tuple(enc.decode(elem) for enc, elem in zip(self.element_encoders, items)),
267+
)
254268

255269
def serialize(
256270
self, encoded_value: typing.Any
257271
) -> Sequence[OpenAIMessageContentListBlock]:
258272
if self.has_image:
259-
# If tuple contains images, serialize each element and flatten the results
260273
result: list[OpenAIMessageContentListBlock] = []
261-
if not isinstance(encoded_value, tuple):
262-
raise TypeError(f"Expected tuple, got {type(encoded_value)}")
263-
if len(encoded_value) != len(self.element_encoders):
264-
raise ValueError(
265-
f"Tuple length {len(encoded_value)} does not match expected length {len(self.element_encoders)}"
266-
)
267274
for enc, elem in zip(self.element_encoders, encoded_value):
268275
result.extend(enc.serialize(elem))
269276
return result
270-
else:
271-
# Use base serialization for non-image tuples
272-
adapter: pydantic.TypeAdapter[tuple] = pydantic.TypeAdapter(self.enc)
273-
json_str = adapter.dump_json(encoded_value).decode("utf-8")
274-
return [{"type": "text", "text": json_str}]
277+
model_instance = self.model_cls(
278+
**{f"item_{i}": v for i, v in enumerate(encoded_value)}
279+
)
280+
json_str = model_instance.model_dump_json()
281+
return [{"type": "text", "text": json_str}]
275282

276283
def deserialize(self, serialized_value: str) -> typing.Any:
277-
adapter: pydantic.TypeAdapter[tuple] = pydantic.TypeAdapter(self.enc)
278-
return adapter.validate_json(serialized_value)
284+
model = self.model_cls.model_validate_json(serialized_value)
285+
# Return raw field values (preserving nested pydantic models).
286+
# Use tuple to be compatible with SequenceEncodable (which also
287+
# produces tuples), ensuring encode idempotency via nested_type.
288+
return tuple(
289+
getattr(model, f"item_{i}") for i in range(len(self.element_encoders))
290+
)
279291

280292

281293
@dataclass
282294
class NamedTupleEncodable[T](TupleEncodable[T]):
283295
"""Tuple encodable that reconstructs the original NamedTuple type on decode."""
284296

285297
def decode(self, encoded_value: typing.Any) -> T:
286-
if len(encoded_value) != len(self.element_encoders):
298+
if isinstance(encoded_value, pydantic.BaseModel):
299+
items = list(encoded_value.model_dump().values())
300+
else:
301+
items = list(encoded_value)
302+
if len(items) != len(self.element_encoders):
287303
raise ValueError(
288-
f"tuple length {len(encoded_value)} does not match expected length {len(self.element_encoders)}"
304+
f"tuple length {len(items)} does not match expected length {len(self.element_encoders)}"
289305
)
290306
decoded_elements: list[typing.Any] = [
291-
enc.decode(elem) for enc, elem in zip(self.element_encoders, encoded_value)
307+
enc.decode(elem) for enc, elem in zip(self.element_encoders, items)
292308
]
293309
return typing.cast(T, self.base(*decoded_elements))
294310

295311

296312
@dataclass
297-
class MutableSequenceEncodable[T](Encodable[MutableSequence[T], typing.Any]):
298-
base: type[MutableSequence[T]]
313+
class SequenceEncodable[T](Encodable[Sequence[T], typing.Any]):
314+
"""Variable-length sequence encoded as a JSON array, decoded back to tuple."""
315+
316+
base: type[typing.Any]
299317
enc: type[typing.Any]
300318
ctx: Mapping[str, Any]
301319
has_image: bool
302320
element_encoder: Encodable[T, typing.Any]
303321

304-
def encode(self, value: MutableSequence[T]) -> typing.Any:
305-
if not isinstance(value, MutableSequence):
306-
raise TypeError(f"Expected MutableSequence, got {type(value)}")
307-
return [self.element_encoder.encode(elem) for elem in value]
322+
def encode(self, value: Sequence[T]) -> typing.Any:
323+
# Return a tuple so that nested_type routes back through the tuple
324+
# dispatcher, preserving encode idempotency.
325+
return tuple(self.element_encoder.encode(elem) for elem in value)
308326

309-
def decode(self, encoded_value: typing.Any) -> MutableSequence[T]:
310-
decoded_elements: list[T] = [
311-
self.element_encoder.decode(elem) for elem in encoded_value
312-
]
313-
return typing.cast(MutableSequence[T], decoded_elements)
327+
def decode(self, encoded_value: typing.Any) -> Sequence[T]:
328+
return typing.cast(
329+
Sequence[T],
330+
tuple(self.element_encoder.decode(elem) for elem in encoded_value),
331+
)
314332

315333
def serialize(
316334
self, encoded_value: typing.Any
317335
) -> Sequence[OpenAIMessageContentListBlock]:
318336
if self.has_image:
319-
# If list contains images, serialize each element and flatten the results
320337
result: list[OpenAIMessageContentListBlock] = []
321-
if not isinstance(encoded_value, MutableSequence):
322-
raise TypeError(f"Expected MutableSequence, got {type(encoded_value)}")
323338
for elem in encoded_value:
324339
result.extend(self.element_encoder.serialize(elem))
325340
return result
326-
else:
327-
# Use base serialization for non-image lists
328-
adapter = pydantic.TypeAdapter(self.enc)
329-
json_str = adapter.dump_json(encoded_value).decode("utf-8")
330-
return [{"type": "text", "text": json_str}]
341+
adapter = pydantic.TypeAdapter(self.enc)
342+
# Convert to list for pydantic serialization (enc is list[...])
343+
json_str = adapter.dump_json(list(encoded_value)).decode("utf-8")
344+
return [{"type": "text", "text": json_str}]
345+
346+
def deserialize(self, serialized_value: str) -> typing.Any:
347+
adapter = pydantic.TypeAdapter(self.enc)
348+
# validate_json returns a list; convert back to tuple for
349+
# compatibility with SequenceEncodable (which uses tuples).
350+
return tuple(adapter.validate_json(serialized_value))
351+
352+
353+
@dataclass
354+
class MutableSequenceEncodable[T](SequenceEncodable[T]):
355+
"""Mutable sequence (list) — same as SequenceEncodable but returns a list."""
356+
357+
def encode(self, value: Sequence[T]) -> typing.Any:
358+
if not isinstance(value, MutableSequence):
359+
raise TypeError(f"Expected MutableSequence, got {type(value)}")
360+
return [self.element_encoder.encode(elem) for elem in value]
361+
362+
def decode(self, encoded_value: typing.Any) -> MutableSequence[T]:
363+
decoded_elements: list[T] = [
364+
self.element_encoder.decode(elem) for elem in encoded_value
365+
]
366+
return typing.cast(MutableSequence[T], decoded_elements)
331367

332368
def deserialize(self, serialized_value: str) -> typing.Any:
333369
adapter = pydantic.TypeAdapter(self.enc)
@@ -804,64 +840,116 @@ def _encodable_image(
804840
def _encodable_tuple[T, U](
805841
ty: type[T], ctx: Mapping[str, Any] | None
806842
) -> Encodable[T, U]:
843+
"""Dispatch for ``tuple`` types.
844+
845+
* Bare ``tuple`` (no type params) or ``tuple[T, ...]``
846+
(variadic) to :class:`SequenceEncodable` (JSON array).
847+
* Named-tuples (subclasses with ``_fields``) → :class:`NamedTupleEncodable`.
848+
* Finitary forms (``tuple[()]``, ``tuple[T]``, ``tuple[T1, T2]``, ...)
849+
to :class:`TupleEncodable` (JSON object with positional fields).
850+
851+
https://docs.python.org/3/library/typing.html#annotating-tuples
852+
"""
853+
807854
def _is_namedtuple_type(ty: type[Any]) -> bool:
808855
return isinstance(ty, type) and issubclass(ty, tuple) and hasattr(ty, "_fields")
809856

810857
args = typing.get_args(ty)
811858
ctx = {} if ctx is None else ctx
812859

813-
# Handle plain tuple runtime type explicitly.
814-
if ty is tuple:
815-
return typing.cast(
816-
Encodable[T, U],
817-
TupleEncodable(ty, ty, ctx, False, []),
818-
)
819-
820-
# NamedTuple handling is routed through tuple logic, but decoded back into
821-
# the concrete NamedTuple class.
822-
origin = typing.get_origin(ty)
823-
is_namedtuple = origin is None and _is_namedtuple_type(ty)
824-
if origin is None:
825-
if is_namedtuple:
860+
if typing.get_origin(ty) is None:
861+
if ty is tuple:
862+
# Bare tuple — treat as tuple[Any, ...].
863+
element_encoder = Encodable.define(typing.cast(type, typing.Any), ctx)
864+
encoded_ty = typing.cast(type[typing.Any], list[element_encoder.enc]) # type: ignore
865+
return typing.cast(
866+
Encodable[T, U],
867+
SequenceEncodable(ty, encoded_ty, ctx, False, element_encoder),
868+
)
869+
if _is_namedtuple_type(ty):
870+
# NamedTuple — route through tuple logic but decode back into
871+
# the concrete NamedTuple class.
826872
hints = typing.get_type_hints(ty)
827873
tuple_field_types: list[type[Any]] = list(hints.values())
828874
if not tuple_field_types:
829875
tuple_field_types = [typing.Any] * len(getattr(ty, "_fields", ()))
830-
else:
831-
tuple_field_types = []
832-
else:
833-
tuple_field_types = list(args)
834-
835-
if not tuple_field_types:
836-
# Non-parameterized tuple subclasses still use object fallback.
837-
if not is_namedtuple:
838-
return _encodable_object(ty, ctx)
839-
# Empty namedtuple; keep tuple identity behavior.
840-
return typing.cast(Encodable[T, U], NamedTupleEncodable(ty, ty, ctx, False, []))
841-
842-
# Handle empty tuple annotation (tuple[()]).
843-
if tuple_field_types == [()] or args == ((),):
844-
return TupleEncodable(ty, ty, ctx, False, [])
845-
846-
element_encoders = [Encodable.define(arg, ctx) for arg in tuple_field_types]
847-
has_image = any(arg is Image.Image for arg in tuple_field_types)
848-
encoded_ty: type[typing.Any] = typing.cast(
849-
type[typing.Any],
850-
tuple[*(enc.enc for enc in element_encoders)], # type: ignore
851-
)
876+
if not tuple_field_types:
877+
# Empty namedtuple.
878+
empty_model = pydantic.create_model(
879+
"TupleItems", __config__={"extra": "forbid"}
880+
)
881+
return typing.cast(
882+
Encodable[T, U],
883+
NamedTupleEncodable(ty, empty_model, empty_model, ctx, False, []),
884+
)
885+
element_encoders = [Encodable.define(arg, ctx) for arg in tuple_field_types]
886+
has_image = any(arg is Image.Image for arg in tuple_field_types)
887+
model_cls = pydantic.create_model( # type: ignore[call-overload]
888+
"TupleItems",
889+
__config__={"extra": "forbid"},
890+
**{
891+
f"item_{i}": (enc.enc, ...)
892+
for i, enc in enumerate(element_encoders)
893+
},
894+
)
895+
return typing.cast(
896+
Encodable[T, U],
897+
NamedTupleEncodable(
898+
ty, model_cls, model_cls, ctx, has_image, element_encoders
899+
),
900+
)
901+
# Other tuple subclass — delegate to object.
902+
return _encodable_object(ty, ctx)
852903

853-
if is_namedtuple:
904+
# tuple[T, ...] — variable-length, encode as JSON array.
905+
if len(args) == 2 and args[1] is Ellipsis:
906+
element_ty = args[0]
907+
element_encoder = Encodable.define(element_ty, ctx)
908+
has_image = element_ty is Image.Image
909+
encoded_ty = typing.cast(type[typing.Any], list[element_encoder.enc]) # type: ignore
854910
return typing.cast(
855911
Encodable[T, U],
856-
NamedTupleEncodable(ty, encoded_ty, ctx, has_image, element_encoders),
912+
SequenceEncodable(ty, encoded_ty, ctx, has_image, element_encoder),
857913
)
858914

859-
if origin is None:
915+
# Finitary tuple — fixed-length positional struct.
916+
# Build a pydantic model with item_0, item_1, ... fields so the JSON
917+
# schema uses "properties"/"required" (accepted by OpenAI), not
918+
# "prefixItems" (rejected by OpenAI).
919+
effective_args = [] if (not args or args == ((),)) else list(args)
920+
element_encoders = [Encodable.define(arg, ctx) for arg in effective_args]
921+
has_image = any(arg is Image.Image for arg in effective_args)
922+
model_cls = pydantic.create_model( # type: ignore[call-overload]
923+
"TupleItems",
924+
__config__={"extra": "forbid"},
925+
**{f"item_{i}": (enc.enc, ...) for i, enc in enumerate(element_encoders)},
926+
)
927+
928+
return typing.cast(
929+
Encodable[T, U],
930+
TupleEncodable(ty, model_cls, model_cls, ctx, has_image, element_encoders),
931+
)
932+
933+
934+
@Encodable.define.register(Sequence)
935+
def _encodable_sequence[T, U](
936+
ty: type[Sequence[T]], ctx: Mapping[str, Any] | None
937+
) -> Encodable[T, U]:
938+
"""Dispatch for ``Sequence[T]`` — immutable variable-length sequence."""
939+
args = typing.get_args(ty)
940+
ctx = {} if ctx is None else ctx
941+
942+
if not args:
860943
return _encodable_object(ty, ctx)
861944

945+
element_ty = args[0]
946+
element_encoder = Encodable.define(element_ty, ctx)
947+
has_image = element_ty is Image.Image
948+
encoded_ty = typing.cast(type[typing.Any], list[element_encoder.enc]) # type: ignore
949+
862950
return typing.cast(
863951
Encodable[T, U],
864-
TupleEncodable(ty, encoded_ty, ctx, has_image, element_encoders),
952+
SequenceEncodable(ty, encoded_ty, ctx, has_image, element_encoder),
865953
)
866954

867955

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
{
2+
"id": "chatcmpl-bare-tuple-param-fixture",
3+
"created": 1769812600,
4+
"model": "gpt-4o-mini-2024-07-18",
5+
"object": "chat.completion",
6+
"system_fingerprint": "fp_fixture",
7+
"choices": [
8+
{
9+
"finish_reason": "stop",
10+
"index": 0,
11+
"message": {
12+
"content": "{\"value\":\"The items include apples, bananas, and cherries.\"}",
13+
"role": "assistant",
14+
"tool_calls": null,
15+
"function_call": null,
16+
"provider_specific_fields": {
17+
"refusal": null
18+
},
19+
"annotations": []
20+
},
21+
"provider_specific_fields": {}
22+
}
23+
],
24+
"usage": {
25+
"completion_tokens": 15,
26+
"prompt_tokens": 100,
27+
"total_tokens": 115,
28+
"completion_tokens_details": {
29+
"accepted_prediction_tokens": 0,
30+
"audio_tokens": 0,
31+
"reasoning_tokens": 0,
32+
"rejected_prediction_tokens": 0,
33+
"text_tokens": null,
34+
"image_tokens": null
35+
},
36+
"prompt_tokens_details": {
37+
"audio_tokens": 0,
38+
"cached_tokens": 0,
39+
"text_tokens": null,
40+
"image_tokens": null
41+
}
42+
},
43+
"service_tier": "default"
44+
}

0 commit comments

Comments
 (0)