@@ -62,6 +62,7 @@ class TypeTag(StrEnum):
6262 TUPLE = "t"
6363 LIST = "l"
6464 DICT = "m"
65+ BATCH_RESULT = "br"
6566
6667
6768@dataclass (frozen = True )
@@ -206,7 +207,18 @@ def dispatcher(self):
206207
207208 def encode (self , obj : Any ) -> EncodedValue :
208209 """Encode container using dispatcher for recursive elements."""
210+ # Import here to avoid circular dependency
211+ from aws_durable_execution_sdk_python .concurrency import (
212+ BatchResult ,
213+ ) # noqa: PLC0415
214+
209215 match obj :
216+ case BatchResult ():
217+ # Encode BatchResult as dict with special tag
218+ return EncodedValue (
219+ TypeTag .BATCH_RESULT ,
220+ self ._wrap (obj .to_dict (), self .dispatcher ).value ,
221+ )
210222 case list ():
211223 return EncodedValue (
212224 TypeTag .LIST , [self ._wrap (v , self .dispatcher ) for v in obj ]
@@ -230,7 +242,16 @@ def encode(self, obj: Any) -> EncodedValue:
230242
231243 def decode (self , tag : TypeTag , value : Any ) -> Any :
232244 """Decode container using dispatcher for recursive elements."""
245+ # Import here to avoid circular dependency
246+ from aws_durable_execution_sdk_python .concurrency import (
247+ BatchResult ,
248+ ) # noqa: PLC0415
249+
233250 match tag :
251+ case TypeTag .BATCH_RESULT :
252+ # Decode as dict (handles all recursive unwrapping) then reconstruct
253+ decoded_dict = self .decode (TypeTag .DICT , value )
254+ return BatchResult .from_dict (decoded_dict )
234255 case TypeTag .LIST :
235256 if not isinstance (value , list ):
236257 msg = f"Expected list, got { type (value )} "
@@ -281,6 +302,9 @@ def __init__(self):
281302 self .container_codec .set_dispatcher (self )
282303
283304 def encode (self , obj : Any ) -> EncodedValue :
305+ # Import here to avoid circular dependency
306+ from aws_durable_execution_sdk_python .concurrency import BatchResult
307+
284308 match obj :
285309 case None | str () | bool () | int () | float ():
286310 return self .primitive_codec .encode (obj )
@@ -292,7 +316,7 @@ def encode(self, obj: Any) -> EncodedValue:
292316 return self .decimal_codec .encode (obj )
293317 case datetime () | date ():
294318 return self .datetime_codec .encode (obj )
295- case list () | tuple () | dict ():
319+ case BatchResult () | list () | tuple () | dict ():
296320 return self .container_codec .encode (obj )
297321 case _:
298322 msg = f"Unsupported type: { type (obj )} "
@@ -301,11 +325,7 @@ def encode(self, obj: Any) -> EncodedValue:
301325 def decode (self , tag : TypeTag , value : Any ) -> Any :
302326 match tag :
303327 case (
304- TypeTag .NONE
305- | TypeTag .STR
306- | TypeTag .BOOL
307- | TypeTag .INT
308- | TypeTag .FLOAT
328+ TypeTag .NONE | TypeTag .STR | TypeTag .BOOL | TypeTag .INT | TypeTag .FLOAT
309329 ):
310330 return self .primitive_codec .decode (tag , value )
311331 case TypeTag .BYTES :
@@ -316,7 +336,7 @@ def decode(self, tag: TypeTag, value: Any) -> Any:
316336 return self .decimal_codec .decode (tag , value )
317337 case TypeTag .DATETIME | TypeTag .DATE :
318338 return self .datetime_codec .decode (tag , value )
319- case TypeTag .LIST | TypeTag .TUPLE | TypeTag .DICT :
339+ case TypeTag .BATCH_RESULT | TypeTag . LIST | TypeTag .TUPLE | TypeTag .DICT :
320340 return self .container_codec .decode (tag , value )
321341 case _:
322342 msg = f"Unknown type tag: { tag } "
0 commit comments