11import inspect
22import os
33from abc import ABC , abstractmethod
4- from typing import Iterable , Union , Tuple
4+ from typing import Iterable , Tuple , Union
55
6+ import numpy as np
67import pandas as pd
78import ray
89
910
11+ def convert_to_serializable (obj ):
12+ if isinstance (obj , np .ndarray ):
13+ return obj .tolist ()
14+ if isinstance (obj , np .generic ):
15+ return obj .item ()
16+ if isinstance (obj , dict ):
17+ return {k : convert_to_serializable (v ) for k , v in obj .items ()}
18+ if isinstance (obj , list ):
19+ return [convert_to_serializable (v ) for v in obj ]
20+ return obj
21+
22+
1023class BaseOperator (ABC ):
1124 def __init__ (
1225 self ,
@@ -21,6 +34,7 @@ def __init__(
2134 log_dir = os .path .join (working_dir , "logs" )
2235 self .op_name = op_name or self .__class__ .__name__
2336 self .working_dir = working_dir
37+ self .kv_backend = kv_backend
2438 self .kv_storage = init_storage (
2539 backend = kv_backend , working_dir = working_dir , namespace = self .op_name
2640 )
@@ -118,6 +132,9 @@ def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
118132 return to_process , pd .DataFrame (recovered_chunks )
119133
120134 def store (self , results : list , meta_update : dict ):
135+ results = convert_to_serializable (results )
136+ meta_update = convert_to_serializable (meta_update )
137+
121138 batch = {res ["_trace_id" ]: res for res in results }
122139 self .kv_storage .upsert (batch )
123140
0 commit comments