@@ -858,30 +858,36 @@ def __init__(
858858 ctx : dict [str , Any ] | None = None ,
859859 parent : str | None = None ,
860860 seq_col : str | None = None ,
861+ ctx_to_ref : dict [str , str ] | None = None ,
861862 ** kwargs ,
862863 ) -> None :
863864 super ().__init__ (** kwargs )
864865 self .parent = parent
865866 self .seq_col_ref = seq_col
867+ self .ctx_to_ref = ctx_to_ref
866868
867869 # Load transformers
868870 assert seq
869- if not ctx :
870- ctx = seq
871- ctx_kwargs = ctx .copy ()
872- ctx_type = ctx_kwargs .pop ("type" )
873- self .ctx = get_module_dict (TransformerFactory , modules )[
874- cast (str , ctx_type )
875- ].build (** ctx_kwargs )
876- assert isinstance (self .ctx , Transformer )
877871
878872 seq_kwargs = seq .copy ()
879873 seq_type = seq_kwargs .pop ("type" )
874+ seq_kwargs ["nullable" ] = True
880875 self .seq = get_module_dict (TransformerFactory , modules )[
881876 cast (str , seq_type )
882877 ].build (** seq_kwargs )
883878 assert isinstance (self .seq , RefTransformer )
884879
880+ if ctx is not None :
881+ self .dual = True
882+ ctx_kwargs = ctx .copy ()
883+ ctx_type = ctx_kwargs .pop ("type" )
884+ self .ctx = get_module_dict (TransformerFactory , modules )[
885+ cast (str , ctx_type )
886+ ].build (** ctx_kwargs )
887+ assert isinstance (self .ctx , Transformer )
888+ else :
889+ self .dual = False
890+
885891 def fit (
886892 self ,
887893 table : str ,
@@ -904,10 +910,6 @@ def fit(
904910 if not self .parent :
905911 # Infering parent through references
906912 self .parent = next (iter (ref ))
907- # Process references
908- # if ref:
909- # self.ref_table = next(iter(ref))
910- # self.ref_col = cast(str, next(iter(ref[self.ref_table].keys())))
911913
912914 assert (
913915 self .parent
@@ -927,18 +929,39 @@ def fit(
927929 seq = _calculate_seq (seq_col , ids , self .parent , self .col_seq )
928930 self .max_len = cast (int , seq .max ()) + 1
929931
932+ if self .dual :
933+ self ._dual_fit (self .parent , data , ref , ids , seq )
934+ else :
935+ self ._single_fit (self .parent , data , ref , ids , seq )
936+
937+ # If a seq_val was not provided, assume seq was also none and
938+ # become the sequencer
939+ if seq_val is None :
940+ return SeqValue (self .col_seq , self .parent ), cast (pd .Series , seq )
941+
942+ def _dual_fit (
943+ self ,
944+ parent : str ,
945+ data : pd .Series | pd .DataFrame ,
946+ ref : dict [str , pd .DataFrame ],
947+ ids : pd .DataFrame ,
948+ seq : pd .Series ,
949+ ):
930950 ctx_data = (
931- ids .join (data [seq == 0 ], how = "right" )
932- .drop_duplicates (subset = [self .parent ])
933- .set_index (self .parent )
951+ ids [[parent ]]
952+ .join (data [seq == 0 ], how = "right" )
953+ .drop_duplicates (subset = [parent ])
954+ .set_index (parent )
934955 )
935956 if isinstance (data , pd .Series ):
936957 ctx_data = ctx_data [next (iter (ctx_data ))]
937958 if ref :
938- ctx_ref = ids .drop_duplicates (subset = [self . parent ])
959+ ctx_ref = ids .drop_duplicates (subset = [parent ])
939960 for name , ref_table in ref .items ():
940961 ctx_ref = ctx_ref .join (ref_table , on = name , how = "left" )
941- ctx_ref = ctx_ref .set_index (self .parent )
962+ ctx_ref = ctx_ref .set_index (parent ).drop (
963+ columns = [d for d in ids .columns if d != parent ]
964+ )
942965
943966 if ctx_ref .shape [1 ] == 1 :
944967 ctx_ref = ctx_ref [next (iter (ctx_ref ))]
@@ -951,13 +974,41 @@ def fit(
951974 self .ctx .fit (ctx_data )
952975
953976 # Data series is all rows where seq > 0 (skip initial)
954- ref_df = _backref_cols (ids , seq , data , self . parent )
955- self .seq .fit (data , ref_df )
977+ ref_df = _backref_cols (ids , seq , data , parent )
978+ self .seq .fit (data [ seq > 0 ] , ref_df )
956979
957- # If a seq_val was not provided, assume seq was also none and
958- # become the sequencer
959- if seq_val is None :
960- return SeqValue (self .col_seq , self .parent ), cast (pd .Series , seq )
980+ def _single_fit (
981+ self ,
982+ parent : str ,
983+ data : pd .Series | pd .DataFrame ,
984+ ref : dict [str , pd .DataFrame ],
985+ ids : pd .DataFrame ,
986+ seq : pd .Series ,
987+ ):
988+ ref_df = _backref_cols (ids , seq , data , parent )
989+ if ref :
990+ ctx_ref = ids [seq == 0 ].drop_duplicates (subset = [self .parent ])
991+ for name , ref_table in ref .items ():
992+ ctx_ref = ctx_ref .join (ref_table , on = name , how = "left" )
993+ ctx_ref = ctx_ref .drop (columns = ids .columns )
994+
995+ if ctx_ref .shape [1 ] == 1 :
996+ ctx_ref = ctx_ref [next (iter (ctx_ref ))]
997+
998+ if isinstance (ref_df , pd .Series ) and isinstance (ctx_ref , pd .Series ):
999+ ref_df = pd .concat ([ctx_ref , ref_df ])
1000+ elif isinstance (ref_df , pd .DataFrame ) and isinstance (ctx_ref , pd .DataFrame ):
1001+ if self .ctx_to_ref :
1002+ ctx_ref = ctx_ref .rename (columns = self .ctx_to_ref )
1003+ ref_df = pd .concat ([ctx_ref , ref_df ], axis = 0 )
1004+ assert (
1005+ ref_df .shape [1 ] == ctx_ref .shape [1 ]
1006+ ), f"Parent columns not joined correctly to reference ones. If they have different names, pass in `ctx_to_ref` with names mapping them to parents"
1007+ else :
1008+ assert (
1009+ False
1010+ ), "fixme: mismatched reference column counts. If single column transformer, both should be series, otherwise both should be dataframes"
1011+ self .seq .fit (data , ref_df )
9611012
9621013 def reduce (self , other : "SeqTransformerWrapper" ):
9631014 self .ctx .reduce (other )
@@ -986,6 +1037,39 @@ def transform(
9861037 else :
9871038 assert seq is not None
9881039
1040+ if self .dual :
1041+ enc , ctx = self ._dual_trn (parent , data , ref , ids , seq )
1042+ else :
1043+ enc , ctx = self ._single_trn (parent , data , ref , ids , seq )
1044+
1045+ if self .generate_seq :
1046+ return (
1047+ pd .concat ([enc , seq ], axis = 1 ),
1048+ {
1049+ parent : pd .concat (
1050+ [
1051+ ctx ,
1052+ ids .join (seq )
1053+ .groupby (self .parent )[cast (str , seq .name )]
1054+ .max ()
1055+ .rename (self .col_n )
1056+ + 1 ,
1057+ ],
1058+ axis = 1 ,
1059+ )
1060+ },
1061+ seq ,
1062+ )
1063+ return enc , {parent : ctx }
1064+
1065+ def _dual_trn (
1066+ self ,
1067+ parent : str ,
1068+ data : pd .Series | pd .DataFrame ,
1069+ ref : dict [str , pd .DataFrame ],
1070+ ids : pd .DataFrame ,
1071+ seq : pd .Series ,
1072+ ):
9891073 ctx_data = (
9901074 ids [[parent ]]
9911075 .join (data [seq == 0 ], how = "right" )
@@ -995,10 +1079,12 @@ def transform(
9951079 if ctx_data .shape [1 ] == 1 :
9961080 ctx_data = ctx_data [next (iter (ctx_data ))]
9971081 if ref :
998- ctx_ref = ids [[ parent ]] .drop_duplicates (subset = [parent ])
1082+ ctx_ref = ids .drop_duplicates (subset = [parent ])
9991083 for name , ref_table in ref .items ():
10001084 ctx_ref = ctx_ref .join (ref_table , on = name , how = "left" )
1001- ctx_ref = ctx_ref .set_index (parent )
1085+ ctx_ref = ctx_ref .set_index (parent ).drop (
1086+ columns = [d for d in ids .columns if d != parent ]
1087+ )
10021088
10031089 if ctx_ref .shape [1 ] == 1 :
10041090 ctx_ref = ctx_ref [next (iter (ctx_ref ))]
@@ -1020,27 +1106,97 @@ def transform(
10201106 if is_float_dtype (d ):
10211107 enc .loc [seq == 0 , k ] = np .nan
10221108
1023- if self .generate_seq :
1024- return (
1025- pd .concat ([enc , seq ], axis = 1 ),
1026- {
1027- parent : pd .concat (
1028- [
1029- ctx ,
1030- ids .join (seq )
1031- .groupby (self .parent )[cast (str , seq .name )]
1032- .max ()
1033- .rename (self .col_n )
1034- + 1 ,
1035- ],
1036- axis = 1 ,
1109+ return enc , ctx
1110+
1111+ def _single_trn (
1112+ self ,
1113+ parent : str ,
1114+ data : pd .Series | pd .DataFrame ,
1115+ ref : dict [str , pd .DataFrame ],
1116+ ids : pd .DataFrame ,
1117+ seq : pd .Series ,
1118+ ):
1119+ ref_df = _backref_cols (ids , seq , data , parent )
1120+ if ref :
1121+ ctx_ref = ids [seq == 0 ].drop_duplicates (subset = [self .parent ])
1122+ for name , ref_table in ref .items ():
1123+ ctx_ref = ctx_ref .join (ref_table , on = name , how = "left" )
1124+ ctx_ref = ctx_ref .drop (columns = ids .columns )
1125+
1126+ if ctx_ref .shape [1 ] == 1 :
1127+ ctx_ref = ctx_ref [next (iter (ctx_ref ))]
1128+
1129+ if isinstance (ref_df , pd .Series ) and isinstance (ctx_ref , pd .Series ):
1130+ ref_df = pd .concat ([ctx_ref , ref_df ])
1131+ elif isinstance (ref_df , pd .DataFrame ) and isinstance (ctx_ref , pd .DataFrame ):
1132+ if self .ctx_to_ref :
1133+ ctx_ref = ctx_ref .rename (columns = self .ctx_to_ref )
1134+ ref_df = pd .concat ([ctx_ref , ref_df ], axis = 0 )
1135+ assert (
1136+ ref_df .shape [1 ] == ctx_ref .shape [1 ]
1137+ ), f"Parent columns not joined correctly to reference ones. If they have different names, pass in `ctx_to_ref` with names mapping them to parents"
1138+ else :
1139+ assert (
1140+ False
1141+ ), "fixme: mismatched reference column counts. If single column transformer, both should be series, otherwise both should be dataframes"
1142+
1143+ return self .seq .transform (data , ref_df ), pd .DataFrame ()
1144+
1145+ def _single_reverse (
1146+ self ,
1147+ data : pd .DataFrame ,
1148+ ctx : dict [str , pd .DataFrame ],
1149+ ref : dict [str , pd .DataFrame ],
1150+ ids : pd .DataFrame ,
1151+ ) -> pd .DataFrame :
1152+ seq = data [self .col_seq ]
1153+ parent = cast (str , self .parent )
1154+
1155+ if ref :
1156+ ctx_ref = ids [seq == 0 ].drop_duplicates (subset = [self .parent ])
1157+ for name , ref_table in ref .items ():
1158+ ctx_ref = ctx_ref .join (ref_table , on = name , how = "left" )
1159+ ctx_ref = ctx_ref .drop (columns = ids .columns )
1160+
1161+ if self .ctx_to_ref :
1162+ ctx_ref = ctx_ref .rename (columns = self .ctx_to_ref )
1163+
1164+ if ctx_ref .shape [1 ] == 1 :
1165+ ctx_ref = ctx_ref [next (iter (ctx_ref ))]
1166+ else :
1167+ ctx_ref = None
1168+
1169+ # Data series is all rows where seq > 0 (skip initial)
1170+ out = []
1171+ for i in range (self .max_len ):
1172+ seq_mask = seq == i
1173+ data_df = data [seq_mask ]
1174+ if not len (data_df ):
1175+ break
1176+
1177+ if i > 0 :
1178+ ref_df = (
1179+ ids .loc [data_df .index ]
1180+ .join (
1181+ ids .join (out [- 1 ], how = "right" ).set_index (parent ),
1182+ on = parent ,
1183+ how = "left" ,
10371184 )
1038- },
1039- seq ,
1040- )
1041- return enc , { parent : ctx }
1185+ . drop ( columns = parent )
1186+ )
1187+ if ref_df . shape [ 1 ] == 1 :
1188+ ref_df = ref_df [ next ( iter ( ref_df ))]
10421189
1043- def reverse (
1190+ assert len (ref_df ) == len (
1191+ data_df
1192+ ), "fixme: experimental, there is a join error."
1193+ else :
1194+ ref_df = ctx_ref
1195+ out .append (pd .DataFrame (self .seq .reverse (data_df , ref_df )))
1196+
1197+ return pd .concat (out , axis = 0 )
1198+
1199+ def _dual_reverse (
10441200 self ,
10451201 data : pd .DataFrame ,
10461202 ctx : dict [str , pd .DataFrame ],
@@ -1050,15 +1206,17 @@ def reverse(
10501206 seq = data [self .col_seq ]
10511207 parent = cast (str , self .parent )
10521208
1053- ctx_data = ids .drop_duplicates (subset = [self . parent ])
1209+ ctx_data = ids .drop_duplicates (subset = [parent ])
10541210 for name , ctx_table in ctx .items ():
10551211 ctx_data = ctx_data .join (ctx_table , on = name , how = "left" )
1056- ctx_data = ctx_data .set_index (self . parent )
1212+ ctx_data = ctx_data .set_index (parent )
10571213 if ref :
1058- ctx_ref = ids .drop_duplicates (subset = [self . parent ])
1214+ ctx_ref = ids .drop_duplicates (subset = [parent ])
10591215 for name , ref_table in ref .items ():
10601216 ctx_ref = ctx_ref .join (ref_table , on = name , how = "left" )
1061- ctx_ref = ctx_ref .set_index (self .parent )
1217+ ctx_ref = ctx_ref .set_index (parent ).drop (
1218+ columns = [d for d in ids .columns if d != parent ]
1219+ )
10621220
10631221 if ctx_ref .shape [1 ] == 1 :
10641222 ctx_ref = ctx_ref [next (iter (ctx_ref ))]
@@ -1100,13 +1258,25 @@ def reverse(
11001258
11011259 return pd .concat (out , axis = 0 )
11021260
1261+ def reverse (
1262+ self ,
1263+ data : pd .DataFrame ,
1264+ ctx : dict [str , pd .DataFrame ],
1265+ ref : dict [str , pd .DataFrame ],
1266+ ids : pd .DataFrame ,
1267+ ) -> pd .DataFrame :
1268+ if self .dual :
1269+ return self ._dual_reverse (data , ctx , ref , ids )
1270+ else :
1271+ return self ._single_reverse (data , ctx , ref , ids )
1272+
11031273 def get_attributes (self ) -> tuple [Attributes , dict [str , Attributes ]]:
11041274 return {
11051275 self .col_seq : SeqAttribute (self .col_seq , cast (str , self .parent )),
11061276 ** self .seq .get_attributes (),
11071277 }, {
11081278 cast (str , self .parent ): {
1109- ** self .ctx .get_attributes (),
1279+ ** ( self .ctx .get_attributes () if self . dual else {} ),
11101280 self .col_n : GenAttribute (self .col_n , self .table , self .max_len ),
11111281 }
11121282 }
0 commit comments