Skip to content

Commit 2ea5c83

Browse files
committed
add support for single transformer in seq wrapper
1 parent 9412597 commit 2ea5c83

1 file changed

Lines changed: 219 additions & 49 deletions

File tree

src/pasteur/table.py

Lines changed: 219 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -858,30 +858,36 @@ def __init__(
858858
ctx: dict[str, Any] | None = None,
859859
parent: str | None = None,
860860
seq_col: str | None = None,
861+
ctx_to_ref: dict[str, str] | None = None,
861862
**kwargs,
862863
) -> None:
863864
super().__init__(**kwargs)
864865
self.parent = parent
865866
self.seq_col_ref = seq_col
867+
self.ctx_to_ref = ctx_to_ref
866868

867869
# Load transformers
868870
assert seq
869-
if not ctx:
870-
ctx = seq
871-
ctx_kwargs = ctx.copy()
872-
ctx_type = ctx_kwargs.pop("type")
873-
self.ctx = get_module_dict(TransformerFactory, modules)[
874-
cast(str, ctx_type)
875-
].build(**ctx_kwargs)
876-
assert isinstance(self.ctx, Transformer)
877871

878872
seq_kwargs = seq.copy()
879873
seq_type = seq_kwargs.pop("type")
874+
seq_kwargs["nullable"] = True
880875
self.seq = get_module_dict(TransformerFactory, modules)[
881876
cast(str, seq_type)
882877
].build(**seq_kwargs)
883878
assert isinstance(self.seq, RefTransformer)
884879

880+
if ctx is not None:
881+
self.dual = True
882+
ctx_kwargs = ctx.copy()
883+
ctx_type = ctx_kwargs.pop("type")
884+
self.ctx = get_module_dict(TransformerFactory, modules)[
885+
cast(str, ctx_type)
886+
].build(**ctx_kwargs)
887+
assert isinstance(self.ctx, Transformer)
888+
else:
889+
self.dual = False
890+
885891
def fit(
886892
self,
887893
table: str,
@@ -904,10 +910,6 @@ def fit(
904910
if not self.parent:
905911
# Infering parent through references
906912
self.parent = next(iter(ref))
907-
# Process references
908-
# if ref:
909-
# self.ref_table = next(iter(ref))
910-
# self.ref_col = cast(str, next(iter(ref[self.ref_table].keys())))
911913

912914
assert (
913915
self.parent
@@ -927,18 +929,39 @@ def fit(
927929
seq = _calculate_seq(seq_col, ids, self.parent, self.col_seq)
928930
self.max_len = cast(int, seq.max()) + 1
929931

932+
if self.dual:
933+
self._dual_fit(self.parent, data, ref, ids, seq)
934+
else:
935+
self._single_fit(self.parent, data, ref, ids, seq)
936+
937+
# If a seq_val was not provided, assume seq was also none and
938+
# become the sequencer
939+
if seq_val is None:
940+
return SeqValue(self.col_seq, self.parent), cast(pd.Series, seq)
941+
942+
def _dual_fit(
943+
self,
944+
parent: str,
945+
data: pd.Series | pd.DataFrame,
946+
ref: dict[str, pd.DataFrame],
947+
ids: pd.DataFrame,
948+
seq: pd.Series,
949+
):
930950
ctx_data = (
931-
ids.join(data[seq == 0], how="right")
932-
.drop_duplicates(subset=[self.parent])
933-
.set_index(self.parent)
951+
ids[[parent]]
952+
.join(data[seq == 0], how="right")
953+
.drop_duplicates(subset=[parent])
954+
.set_index(parent)
934955
)
935956
if isinstance(data, pd.Series):
936957
ctx_data = ctx_data[next(iter(ctx_data))]
937958
if ref:
938-
ctx_ref = ids.drop_duplicates(subset=[self.parent])
959+
ctx_ref = ids.drop_duplicates(subset=[parent])
939960
for name, ref_table in ref.items():
940961
ctx_ref = ctx_ref.join(ref_table, on=name, how="left")
941-
ctx_ref = ctx_ref.set_index(self.parent)
962+
ctx_ref = ctx_ref.set_index(parent).drop(
963+
columns=[d for d in ids.columns if d != parent]
964+
)
942965

943966
if ctx_ref.shape[1] == 1:
944967
ctx_ref = ctx_ref[next(iter(ctx_ref))]
@@ -951,13 +974,41 @@ def fit(
951974
self.ctx.fit(ctx_data)
952975

953976
# Data series is all rows where seq > 0 (skip initial)
954-
ref_df = _backref_cols(ids, seq, data, self.parent)
955-
self.seq.fit(data, ref_df)
977+
ref_df = _backref_cols(ids, seq, data, parent)
978+
self.seq.fit(data[seq > 0], ref_df)
956979

957-
# If a seq_val was not provided, assume seq was also none and
958-
# become the sequencer
959-
if seq_val is None:
960-
return SeqValue(self.col_seq, self.parent), cast(pd.Series, seq)
980+
def _single_fit(
981+
self,
982+
parent: str,
983+
data: pd.Series | pd.DataFrame,
984+
ref: dict[str, pd.DataFrame],
985+
ids: pd.DataFrame,
986+
seq: pd.Series,
987+
):
988+
ref_df = _backref_cols(ids, seq, data, parent)
989+
if ref:
990+
ctx_ref = ids[seq == 0].drop_duplicates(subset=[self.parent])
991+
for name, ref_table in ref.items():
992+
ctx_ref = ctx_ref.join(ref_table, on=name, how="left")
993+
ctx_ref = ctx_ref.drop(columns=ids.columns)
994+
995+
if ctx_ref.shape[1] == 1:
996+
ctx_ref = ctx_ref[next(iter(ctx_ref))]
997+
998+
if isinstance(ref_df, pd.Series) and isinstance(ctx_ref, pd.Series):
999+
ref_df = pd.concat([ctx_ref, ref_df])
1000+
elif isinstance(ref_df, pd.DataFrame) and isinstance(ctx_ref, pd.DataFrame):
1001+
if self.ctx_to_ref:
1002+
ctx_ref = ctx_ref.rename(columns=self.ctx_to_ref)
1003+
ref_df = pd.concat([ctx_ref, ref_df], axis=0)
1004+
assert (
1005+
ref_df.shape[1] == ctx_ref.shape[1]
1006+
), f"Parent columns not joined correctly to reference ones. If they have different names, pass in `ctx_to_ref` with names mapping them to parents"
1007+
else:
1008+
assert (
1009+
False
1010+
), "fixme: mismatched reference column counts. If single column transformer, both should be series, otherwise both should be dataframes"
1011+
self.seq.fit(data, ref_df)
9611012

9621013
def reduce(self, other: "SeqTransformerWrapper"):
9631014
self.ctx.reduce(other)
@@ -986,6 +1037,39 @@ def transform(
9861037
else:
9871038
assert seq is not None
9881039

1040+
if self.dual:
1041+
enc, ctx = self._dual_trn(parent, data, ref, ids, seq)
1042+
else:
1043+
enc, ctx = self._single_trn(parent, data, ref, ids, seq)
1044+
1045+
if self.generate_seq:
1046+
return (
1047+
pd.concat([enc, seq], axis=1),
1048+
{
1049+
parent: pd.concat(
1050+
[
1051+
ctx,
1052+
ids.join(seq)
1053+
.groupby(self.parent)[cast(str, seq.name)]
1054+
.max()
1055+
.rename(self.col_n)
1056+
+ 1,
1057+
],
1058+
axis=1,
1059+
)
1060+
},
1061+
seq,
1062+
)
1063+
return enc, {parent: ctx}
1064+
1065+
def _dual_trn(
1066+
self,
1067+
parent: str,
1068+
data: pd.Series | pd.DataFrame,
1069+
ref: dict[str, pd.DataFrame],
1070+
ids: pd.DataFrame,
1071+
seq: pd.Series,
1072+
):
9891073
ctx_data = (
9901074
ids[[parent]]
9911075
.join(data[seq == 0], how="right")
@@ -995,10 +1079,12 @@ def transform(
9951079
if ctx_data.shape[1] == 1:
9961080
ctx_data = ctx_data[next(iter(ctx_data))]
9971081
if ref:
998-
ctx_ref = ids[[parent]].drop_duplicates(subset=[parent])
1082+
ctx_ref = ids.drop_duplicates(subset=[parent])
9991083
for name, ref_table in ref.items():
10001084
ctx_ref = ctx_ref.join(ref_table, on=name, how="left")
1001-
ctx_ref = ctx_ref.set_index(parent)
1085+
ctx_ref = ctx_ref.set_index(parent).drop(
1086+
columns=[d for d in ids.columns if d != parent]
1087+
)
10021088

10031089
if ctx_ref.shape[1] == 1:
10041090
ctx_ref = ctx_ref[next(iter(ctx_ref))]
@@ -1020,27 +1106,97 @@ def transform(
10201106
if is_float_dtype(d):
10211107
enc.loc[seq == 0, k] = np.nan
10221108

1023-
if self.generate_seq:
1024-
return (
1025-
pd.concat([enc, seq], axis=1),
1026-
{
1027-
parent: pd.concat(
1028-
[
1029-
ctx,
1030-
ids.join(seq)
1031-
.groupby(self.parent)[cast(str, seq.name)]
1032-
.max()
1033-
.rename(self.col_n)
1034-
+ 1,
1035-
],
1036-
axis=1,
1109+
return enc, ctx
1110+
1111+
def _single_trn(
1112+
self,
1113+
parent: str,
1114+
data: pd.Series | pd.DataFrame,
1115+
ref: dict[str, pd.DataFrame],
1116+
ids: pd.DataFrame,
1117+
seq: pd.Series,
1118+
):
1119+
ref_df = _backref_cols(ids, seq, data, parent)
1120+
if ref:
1121+
ctx_ref = ids[seq == 0].drop_duplicates(subset=[self.parent])
1122+
for name, ref_table in ref.items():
1123+
ctx_ref = ctx_ref.join(ref_table, on=name, how="left")
1124+
ctx_ref = ctx_ref.drop(columns=ids.columns)
1125+
1126+
if ctx_ref.shape[1] == 1:
1127+
ctx_ref = ctx_ref[next(iter(ctx_ref))]
1128+
1129+
if isinstance(ref_df, pd.Series) and isinstance(ctx_ref, pd.Series):
1130+
ref_df = pd.concat([ctx_ref, ref_df])
1131+
elif isinstance(ref_df, pd.DataFrame) and isinstance(ctx_ref, pd.DataFrame):
1132+
if self.ctx_to_ref:
1133+
ctx_ref = ctx_ref.rename(columns=self.ctx_to_ref)
1134+
ref_df = pd.concat([ctx_ref, ref_df], axis=0)
1135+
assert (
1136+
ref_df.shape[1] == ctx_ref.shape[1]
1137+
), f"Parent columns not joined correctly to reference ones. If they have different names, pass in `ctx_to_ref` with names mapping them to parents"
1138+
else:
1139+
assert (
1140+
False
1141+
), "fixme: mismatched reference column counts. If single column transformer, both should be series, otherwise both should be dataframes"
1142+
1143+
return self.seq.transform(data, ref_df), pd.DataFrame()
1144+
1145+
def _single_reverse(
1146+
self,
1147+
data: pd.DataFrame,
1148+
ctx: dict[str, pd.DataFrame],
1149+
ref: dict[str, pd.DataFrame],
1150+
ids: pd.DataFrame,
1151+
) -> pd.DataFrame:
1152+
seq = data[self.col_seq]
1153+
parent = cast(str, self.parent)
1154+
1155+
if ref:
1156+
ctx_ref = ids[seq == 0].drop_duplicates(subset=[self.parent])
1157+
for name, ref_table in ref.items():
1158+
ctx_ref = ctx_ref.join(ref_table, on=name, how="left")
1159+
ctx_ref = ctx_ref.drop(columns=ids.columns)
1160+
1161+
if self.ctx_to_ref:
1162+
ctx_ref = ctx_ref.rename(columns=self.ctx_to_ref)
1163+
1164+
if ctx_ref.shape[1] == 1:
1165+
ctx_ref = ctx_ref[next(iter(ctx_ref))]
1166+
else:
1167+
ctx_ref = None
1168+
1169+
# Data series is all rows where seq > 0 (skip initial)
1170+
out = []
1171+
for i in range(self.max_len):
1172+
seq_mask = seq == i
1173+
data_df = data[seq_mask]
1174+
if not len(data_df):
1175+
break
1176+
1177+
if i > 0:
1178+
ref_df = (
1179+
ids.loc[data_df.index]
1180+
.join(
1181+
ids.join(out[-1], how="right").set_index(parent),
1182+
on=parent,
1183+
how="left",
10371184
)
1038-
},
1039-
seq,
1040-
)
1041-
return enc, {parent: ctx}
1185+
.drop(columns=parent)
1186+
)
1187+
if ref_df.shape[1] == 1:
1188+
ref_df = ref_df[next(iter(ref_df))]
10421189

1043-
def reverse(
1190+
assert len(ref_df) == len(
1191+
data_df
1192+
), "fixme: experimental, there is a join error."
1193+
else:
1194+
ref_df = ctx_ref
1195+
out.append(pd.DataFrame(self.seq.reverse(data_df, ref_df)))
1196+
1197+
return pd.concat(out, axis=0)
1198+
1199+
def _dual_reverse(
10441200
self,
10451201
data: pd.DataFrame,
10461202
ctx: dict[str, pd.DataFrame],
@@ -1050,15 +1206,17 @@ def reverse(
10501206
seq = data[self.col_seq]
10511207
parent = cast(str, self.parent)
10521208

1053-
ctx_data = ids.drop_duplicates(subset=[self.parent])
1209+
ctx_data = ids.drop_duplicates(subset=[parent])
10541210
for name, ctx_table in ctx.items():
10551211
ctx_data = ctx_data.join(ctx_table, on=name, how="left")
1056-
ctx_data = ctx_data.set_index(self.parent)
1212+
ctx_data = ctx_data.set_index(parent)
10571213
if ref:
1058-
ctx_ref = ids.drop_duplicates(subset=[self.parent])
1214+
ctx_ref = ids.drop_duplicates(subset=[parent])
10591215
for name, ref_table in ref.items():
10601216
ctx_ref = ctx_ref.join(ref_table, on=name, how="left")
1061-
ctx_ref = ctx_ref.set_index(self.parent)
1217+
ctx_ref = ctx_ref.set_index(parent).drop(
1218+
columns=[d for d in ids.columns if d != parent]
1219+
)
10621220

10631221
if ctx_ref.shape[1] == 1:
10641222
ctx_ref = ctx_ref[next(iter(ctx_ref))]
@@ -1100,13 +1258,25 @@ def reverse(
11001258

11011259
return pd.concat(out, axis=0)
11021260

1261+
def reverse(
1262+
self,
1263+
data: pd.DataFrame,
1264+
ctx: dict[str, pd.DataFrame],
1265+
ref: dict[str, pd.DataFrame],
1266+
ids: pd.DataFrame,
1267+
) -> pd.DataFrame:
1268+
if self.dual:
1269+
return self._dual_reverse(data, ctx, ref, ids)
1270+
else:
1271+
return self._single_reverse(data, ctx, ref, ids)
1272+
11031273
def get_attributes(self) -> tuple[Attributes, dict[str, Attributes]]:
11041274
return {
11051275
self.col_seq: SeqAttribute(self.col_seq, cast(str, self.parent)),
11061276
**self.seq.get_attributes(),
11071277
}, {
11081278
cast(str, self.parent): {
1109-
**self.ctx.get_attributes(),
1279+
**(self.ctx.get_attributes() if self.dual else {}),
11101280
self.col_n: GenAttribute(self.col_n, self.table, self.max_len),
11111281
}
11121282
}

0 commit comments

Comments
 (0)