Skip to content

Commit 7b43753

Browse files
committed
add sequencer
1 parent e4f1f33 commit 7b43753

3 files changed

Lines changed: 70 additions & 7 deletions

File tree

src/pasteur/metadata.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class TableMeta:
137137

138138
def __init__(self, meta: dict):
139139
self.primary_key = meta.get("primary_key", None)
140+
self.sequencer: tuple[str] | str | None = meta.get("sequencer", None)
140141

141142
if "metrics" in meta:
142143
metrics_dict = meta["metrics"]

src/pasteur/table.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,32 @@ def fit_chunk(
268268
), "Properly formatted datasets should have their primary key as their index column"
269269
# table.reindex(meta.primary_key)
270270

271+
# Process sequencer first
272+
seq_name = meta.sequencer
273+
if seq_name:
274+
col = meta.cols[seq_name]
275+
assert (
276+
col.type in self.transformer_cls
277+
), f"Column type {col.type} not in transformers:\n{list(self.transformer_cls.keys())}"
278+
279+
# Fit transformer
280+
if "main_param" in col.args:
281+
t = self.transformer_cls[col.type].build(
282+
col.args["main_param"], **col.args
283+
)
284+
else:
285+
t = self.transformer_cls[col.type].build(**col.args)
286+
287+
assert isinstance(t, SeqTransformer), f"Sequencer must be of type 'SeqTransformer', not '{type(t)}'"
288+
289+
# Add foreign column if required
290+
ref_cols = _calc_unjoined_refs(self.name, get_table, col.ref)
291+
res = t.fit(table[seq_name], ref_cols, loaded_ids)
292+
assert res
293+
seq_attr, seq = res
294+
else:
295+
seq_attr = seq = None
296+
271297
for name, col in meta.cols.items():
272298
if col.is_id():
273299
continue
@@ -287,7 +313,7 @@ def fit_chunk(
287313
if isinstance(t, SeqTransformer):
288314
# Add foreign column if required
289315
ref_cols = _calc_unjoined_refs(self.name, get_table, col.ref)
290-
t.fit(table[name], ref_cols, loaded_ids)
316+
t.fit(table[name], ref_cols, loaded_ids, seq_attr, seq)
291317
elif isinstance(t, RefTransformer):
292318
# Add foreign column if required
293319
ref_cols = _calc_joined_refs(self.name, get_table, loaded_ids, col.ref)
@@ -327,7 +353,29 @@ def transform_chunk(
327353
tts = []
328354
ctxs = defaultdict(list)
329355

356+
# Process sequencer first
357+
seq_name = meta.sequencer
358+
if seq_name:
359+
col = meta.cols[seq_name]
360+
trn = self.transformers[seq_name]
361+
assert isinstance(trn, SeqTransformer)
362+
ref_cols = _calc_unjoined_refs(self.name, get_table, col.ref)
363+
assert loaded_ids is not None
364+
365+
res = trn.transform(table[seq_name], ref_cols, loaded_ids)
366+
assert len(res) == 3
367+
tt, ctx, seq = res
368+
369+
for n, c in ctx.items():
370+
ctxs[n].append(c)
371+
else:
372+
seq = None
373+
330374
for name, col in meta.cols.items():
375+
# Skip sequencer
376+
if seq_name == name:
377+
continue
378+
331379
if col.is_id():
332380
continue
333381

@@ -336,7 +384,9 @@ def transform_chunk(
336384
# Add foreign column if required
337385
ref_cols = _calc_unjoined_refs(self.name, get_table, col.ref)
338386
assert loaded_ids is not None
339-
tt, ctx = trn.transform(table[name], ref_cols, loaded_ids)
387+
res = trn.transform(table[name], ref_cols, loaded_ids, seq)
388+
tt = res[0]
389+
ctx = res[1]
340390

341391
for n, c in ctx.items():
342392
ctxs[n].append(c)

src/pasteur/transform.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,22 @@ class SeqTransformer(Transformer):
101101
Sequence Transformers receive unprocessed parent columns, references and the ID table.
102102
Then, it is up to them to process the data and return the encoded version.
103103
They can also push columns upstream to parents, through context tables.
104+
105+
Event-based data is sequential. The Sequential transformers may require the
106+
order of each row. For this case, the main Sequence Transformer, which is named
107+
the sequencer, is processed first and returns an additional data column and
108+
attribute during fitting. This column and attribute are fed to the other
109+
sequence transformers.
104110
"""
105111

106112
def fit(
107113
self,
108114
data: pd.Series | pd.DataFrame,
109115
ref: dict[str, pd.DataFrame] | None = None,
110116
ids: pd.DataFrame | None = None,
111-
) -> tuple[Attributes, dict[str, Attributes]] | None:
117+
seq_attr: Attribute | None = None,
118+
seq: pd.Series | None = None,
119+
) -> tuple[Attribute, pd.Series] | None:
112120
pass
113121

114122
def reduce(self, other: "SeqTransformer"):
@@ -122,16 +130,19 @@ def fit_transform(
122130
data: pd.Series | pd.DataFrame,
123131
ref: dict[str, pd.DataFrame] | None = None,
124132
ids: pd.DataFrame | None = None,
125-
) -> tuple[pd.DataFrame, dict[str, pd.DataFrame]]:
126-
self.fit(data, ref)
127-
return self.transform(data, ref)
133+
seq_attr: Attribute | None = None,
134+
seq: pd.Series | None = None,
135+
) -> tuple[pd.DataFrame, dict[str, pd.DataFrame]] | tuple[pd.DataFrame, dict[str, pd.DataFrame], pd.Series]:
136+
self.fit(data, ref, ids, seq_attr, seq)
137+
return self.transform(data, ref, ids, seq)
128138

129139
def transform(
130140
self,
131141
data: pd.Series | pd.DataFrame,
132142
ref: dict[str, pd.DataFrame] | None = None,
133143
ids: pd.DataFrame | None = None,
134-
) -> tuple[pd.DataFrame, dict[str, pd.DataFrame]]:
144+
seq: pd.Series | None = None,
145+
) -> tuple[pd.DataFrame, dict[str, pd.DataFrame]] | tuple[pd.DataFrame, dict[str, pd.DataFrame], pd.Series]:
135146
raise NotImplementedError()
136147

137148
def reverse(
@@ -140,6 +151,7 @@ def reverse(
140151
ctx: dict[str, pd.DataFrame],
141152
ref: dict[str, pd.DataFrame] | None = None,
142153
ids: pd.DataFrame | None = None,
154+
seq: pd.Series | None = None,
143155
) -> pd.DataFrame:
144156
"""When reversing, the data column contains encoded data, whereas the ref
145157
column contains decoded/original data. Therefore, the referred columns have

0 commit comments

Comments
 (0)