11import os
2- from typing import Iterable
3-
4- import pandas as pd
2+ from typing import Iterable , Tuple
53
64from graphgen .bases import BaseGraphStorage , BaseOperator , BaseTokenizer
75from graphgen .common import init_storage
@@ -24,7 +22,9 @@ def __init__(
2422 graph_backend : str = "kuzu" ,
2523 ** partition_kwargs ,
2624 ):
27- super ().__init__ (working_dir = working_dir , op_name = "partition" )
25+ super ().__init__ (
26+ working_dir = working_dir , kv_backend = kv_backend , op_name = "partition"
27+ )
2828 self .kg_instance : BaseGraphStorage = init_storage (
2929 backend = graph_backend ,
3030 working_dir = working_dir ,
@@ -55,7 +55,7 @@ def __init__(
5555 else :
5656 raise ValueError (f"Unsupported partition method: { method } " )
5757
58- def process (self , batch : pd . DataFrame ) -> Iterable [pd . DataFrame ]:
58+ def process (self , batch : list ) -> Tuple [ Iterable [list ], dict ]:
5959 # this operator does not consume any batch data
6060 # but for compatibility we keep the interface
6161 self .kg_instance .reload ()
@@ -64,19 +64,22 @@ def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
6464 g = self .kg_instance , ** self .method_params
6565 )
6666
67- count = 0
68- for community in communities :
69- count += 1
70- batch = self .partitioner .community2batch (community , g = self .kg_instance )
71- # batch = self._attach_additional_data_to_node(batch)
67+ def generator ():
68+ count = 0
69+ for community in communities :
70+ count += 1
71+ batch = self .partitioner .community2batch (community , g = self .kg_instance )
72+ # batch = self._attach_additional_data_to_node(batch)
73+
74+ result = {
75+ "nodes" : batch [0 ],
76+ "edges" : batch [1 ],
77+ }
78+ result ["_trace_id" ] = self .get_trace_id (result )
79+ yield result
80+ logger .info ("Total communities partitioned: %d" , count )
7281
73- result = {
74- "nodes" : batch [0 ],
75- "edges" : batch [1 ],
76- }
77- result ["_trace_id" ] = self .generate_trace_id (result )
78- yield pd .DataFrame ([result ])
79- logger .info ("Total communities partitioned: %d" , count )
82+ return generator (), {}
8083
8184 # def _attach_additional_data_to_node(self, batch: tuple) -> tuple:
8285 # """
0 commit comments