Skip to content

Commit 9412597

Browse files
committed
Add ref to seq transformer wrapper
1 parent 46fbfdd commit 9412597

8 files changed

Lines changed: 647 additions & 37 deletions

File tree

notebooks/tst.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
from typing import cast, Any
2+
3+
import pandas as pd
4+
from pandas import DataFrame, Series
5+
6+
from pasteur.transform import SeqTransformer, TransformerFactory, Transformer
7+
from pasteur.module import ModuleFactory, get_module_dict, Module
8+
from pasteur.attribute import Attribute, Attributes, SeqValue, get_dtype, SeqAttribute, GenAttribute
9+
from pasteur.extras.transformers import DatetimeTransformer
10+
11+
from project.settings import PASTEUR_MODULES as modules
12+
13+
14+
def _backref_cols(
15+
ids: pd.DataFrame, seq: pd.Series, data: pd.DataFrame | pd.Series, parent: str
16+
):
17+
# Ref is calculated by mapping each id in data_df by merging its parent
18+
# key, sequence number to parent key, and the number - 1 and finding the
19+
# corresponding id for that row. Then, a join is performed.
20+
_IDX_NAME = "_id_lkjijk"
21+
_JOIN_NAME = "_id_zdjwk"
22+
ids_seq_prev = ids.join(seq + 1).reset_index(names=_JOIN_NAME)
23+
ids_seq = ids.join(seq, how="right").reset_index(names=_IDX_NAME)
24+
# FIXME: ids become float
25+
join_ids = ids_seq.merge(ids_seq_prev, on=[parent, seq.name], how='left').set_index(_IDX_NAME)[
26+
[_JOIN_NAME]
27+
] # type: ignore
28+
ref_df = join_ids.join(data, on=_JOIN_NAME).drop(columns=_JOIN_NAME)
29+
ref_df.index.name = data.index.name
30+
if isinstance(data, pd.Series):
31+
return ref_df[data.name]
32+
return ref_df
33+
34+
35+
def _calculate_seq(data: Series, parent: str, col_seq: str):
36+
_ID_SEQ = "_id_sdfasdf"
37+
seq = (
38+
cast(
39+
pd.Series,
40+
pd.concat({parent: ids[parent], _ID_SEQ: data}, axis=1)
41+
.groupby(parent)[_ID_SEQ]
42+
.rank("first"),
43+
)
44+
- 1
45+
)
46+
max_len = int(cast(float, seq.max())) + 1
47+
return seq.astype(get_dtype(max_len + 1)).rename(col_seq)
48+
49+
50+
class SeqTransformerWrapper(SeqTransformer):
51+
name = "seqwrap"
52+
53+
def __init__(
54+
self,
55+
modules: list[Module],
56+
ctx: dict[str, Any],
57+
seq: dict[str, Any],
58+
parent: str | None = None,
59+
seq_col: str | None = None,
60+
**kwargs,
61+
) -> None:
62+
super().__init__(**kwargs)
63+
self.parent = parent
64+
self.seq_col_ref = seq_col
65+
66+
# Load transformers
67+
assert ctx and seq
68+
ctx_kwargs = ctx.copy()
69+
ctx_type = ctx_kwargs.pop("type")
70+
self.ctx = get_module_dict(TransformerFactory, modules)[
71+
cast(str, ctx_type)
72+
].build(**ctx_kwargs)
73+
assert isinstance(self.ctx, Transformer)
74+
75+
seq_kwargs = seq.copy()
76+
seq_type = seq_kwargs.pop("type")
77+
self.seq = get_module_dict(TransformerFactory, modules)[
78+
cast(str, seq_type)
79+
].build(**seq_kwargs)
80+
assert isinstance(self.seq, RefTransformer)
81+
82+
def fit(
83+
self,
84+
table: str,
85+
data: Series | DataFrame,
86+
ref: dict[str, DataFrame],
87+
ids: DataFrame,
88+
seq_val: SeqValue | None = None,
89+
seq: Series | None = None,
90+
) -> tuple[SeqValue, Series] | None:
91+
self.col = cast(str, data.name)
92+
self.table = table
93+
94+
# Grab parent from seq_val if available
95+
if seq_val is not None:
96+
self.parent = seq_val.table
97+
self.col_seq = seq_val.name
98+
else:
99+
self.col_seq = f"{table}_seq"
100+
self.col_n = f'{table}_n'
101+
102+
if not self.parent:
103+
# Infering parent through references
104+
self.parent = next(iter(ref))
105+
# Process references
106+
# if ref:
107+
# self.ref_table = next(iter(ref))
108+
# self.ref_col = cast(str, next(iter(ref[self.ref_table].keys())))
109+
110+
assert (
111+
self.parent
112+
), "Parent table not specified, use parameter 'parent' or a foreign reference."
113+
114+
# If seq was not provided
115+
self.generate_seq = False
116+
if seq is None:
117+
self.generate_seq = True
118+
if isinstance(data, DataFrame):
119+
assert self.seq_col_ref is not None, f'Multiple columns are provided as input, specify which one is used sequence the table through parameter `seq_col`.'
120+
seq_col = data[self.seq_col_ref]
121+
else:
122+
seq_col = data
123+
seq = _calculate_seq(seq_col, self.parent, self.col_seq)
124+
self.max_len = cast(int, seq.max()) + 1
125+
126+
ctx_data = (
127+
ids.join(data[seq == 0], how="right")
128+
.drop_duplicates(subset=[self.parent])
129+
.set_index(self.parent)[self.col]
130+
)
131+
if ref:
132+
ctx_ref = ids.drop_duplicates(subset=[self.parent])
133+
for name, ref_table in ref.items():
134+
ctx_ref = ctx_ref.join(ref_table, on=name, how="left")
135+
ctx_ref = ctx_ref.set_index(self.parent)
136+
137+
assert isinstance(
138+
self.ctx, RefTransformer
139+
), f"Reference found, initial transformer should be a reference transformer."
140+
self.ctx.fit(ctx_data, ctx_ref)
141+
else:
142+
self.ctx.fit(ctx_data)
143+
144+
# Data series is all rows where seq > 0 (skip initial)
145+
ref_df = _backref_cols(ids, seq, data, self.parent)
146+
self.seq.fit(data, ref_df)
147+
148+
# If a seq_val was not provided, assume seq was also none and
149+
# become the sequencer
150+
if seq_val is None:
151+
return SeqValue(self.col_seq, self.parent), cast(Series, seq)
152+
153+
def reduce(self, other: "SeqTransformerWrapper"):
154+
self.ctx.reduce(other)
155+
self.seq.reduce(other)
156+
self.max_len = max(other.max_len, self.max_len)
157+
158+
def transform(
159+
self,
160+
data: Series | DataFrame,
161+
ref: dict[str, DataFrame],
162+
ids: DataFrame,
163+
seq: Series | None = None,
164+
) -> tuple[DataFrame, dict[str, DataFrame]] | tuple[
165+
DataFrame, dict[str, DataFrame], Series
166+
]:
167+
parent = cast(str, self.parent)
168+
if self.generate_seq:
169+
if isinstance(data, DataFrame):
170+
assert self.seq_col_ref is not None, f'Multiple columns are provided as input, specify which one is used sequence the table through parameter `seq_col`.'
171+
seq_col = data[self.seq_col_ref]
172+
else:
173+
seq_col = data
174+
seq = _calculate_seq(seq_col, parent, self.col_seq)
175+
else:
176+
assert seq is not None
177+
178+
ctx_data = (
179+
ids.join(data[seq == 0], how="right")
180+
.drop_duplicates(subset=[self.parent])
181+
.set_index(self.parent)[self.col]
182+
)
183+
if ref:
184+
ctx_ref = ids.drop_duplicates(subset=[self.parent])
185+
for name, ref_table in ref.items():
186+
ctx_ref = ctx_ref.join(ref_table, on=name, how="left")
187+
ctx_ref = ctx_ref.set_index(self.parent)
188+
189+
if isinstance(ctx_ref, DataFrame) and ctx_ref.shape[1] == 1:
190+
ctx_ref = ctx_ref[next(iter(ctx_ref))]
191+
192+
assert isinstance(
193+
self.ctx, RefTransformer
194+
), f"Reference found, initial transformer should be a reference transformer."
195+
ctx = self.ctx.transform(ctx_data, ctx_ref)
196+
else:
197+
ctx = self.ctx.transform(ctx_data)
198+
199+
# Data series is all rows where seq > 0 (skip initial)
200+
ref_df = _backref_cols(ids, seq, data, parent)
201+
enc = self.seq.transform(data, ref_df)
202+
203+
if self.generate_seq:
204+
return enc, {parent: pd.concat([ctx, ids.join(seq).groupby(self.parent)[cast(str, seq.name)].max().rename(self.col_n) + 1], axis=1)}, seq
205+
return enc, {parent: ctx}
206+
207+
208+
def get_attributes(self) -> tuple[Attributes, dict[str, Attributes]]:
209+
return {
210+
self.col_seq: SeqAttribute(self.col_seq, cast(str, self.parent)),
211+
**self.seq.get_attributes(),
212+
}, {cast(str, self.parent): {**self.ctx.get_attributes(), self.col_n: GenAttribute(self.col_n, self.table, self.max_len)}}
213+
214+
215+
s = SeqTransformerWrapper(modules, {"type": "datetime", "nullable": True}, {"type": "datetime", "nullable": True})
216+
s.fit(
217+
"admissions", admissions["admittime"], {"patients": patients[["birth_year"]]}, ids
218+
)
219+
r = s.transform(admissions["admittime"], {"patients": patients[["birth_year"]]}, ids)
220+
s.max_len

src/pasteur/attribute.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def get_dtype(domain: int):
4444

4545

4646
class Grouping(list[GI]):
47-
""" An enchanced form of list that holds the type of grouping (categorical, ordinal),
47+
"""An enchanced form of list that holds the type of grouping (categorical, ordinal),
4848
and implements helper functions and an enchanced string representation."""
4949

5050
def __init__(self, type: Literal["cat", "ord"], arr: list["Grouping | Any"]):
@@ -190,19 +190,28 @@ def from_str(
190190

191191

192192
class Value:
193-
""" Base value class """
194-
name: str | tuple[str] | None = None
193+
"""Base value class"""
194+
195+
name: str
195196
common: int = 0
196197

197198

199+
class SeqValue(Value):
200+
table: str
201+
202+
def __init__(self, name: str, table: str) -> None:
203+
self.name = name
204+
self.table = table
205+
206+
198207
class CatValue(Value):
199-
""" Class for a Categorical Value.
200-
208+
"""Class for a Categorical Value.
209+
201210
Each Categorical Value is represented by an unsigned integer.
202211
It can also group its different values together based on an integer parameter
203212
named height.
204213
The implementation of this class remains abstract, and is expanded in
205-
the StratifiedValue class. """
214+
the StratifiedValue class."""
206215

207216
def get_domain(self, height: int = 0) -> int:
208217
"""Returns the domain of the attribute in the given height."""
@@ -228,7 +237,7 @@ def is_ordinal(self) -> bool:
228237
return False
229238

230239
def downsample(self, value: np.ndarray, height: int):
231-
""" Receives an array named `value` and downsamples it based on the provided
240+
"""Receives an array named `value` and downsamples it based on the provided
232241
height, by grouping certain values together. The proper implementation
233242
is provided by pasteur.hierarchy."""
234243
if height == 0:
@@ -239,7 +248,7 @@ def upsample(self, value: np.ndarray, height: int, deterministic: bool = True):
239248
"""Does the opposite of downsample. If deterministic is True, for each
240249
group at a given height one of its values is chosen arbitrarily to represent
241250
all children of the group.
242-
251+
243252
If deterministic is False, the group is sampled based on this Value's
244253
histogram (not implemented in this class; see pasteur.hierarchy)."""
245254
if height == 0:
@@ -263,12 +272,14 @@ def upsample(self, value: np.ndarray, height: int, deterministic: bool = True):
263272
def select_height(self) -> int:
264273
return 0
265274

275+
266276
IdxValue = CatValue
267277

278+
268279
class StratifiedValue(CatValue):
269-
"""A version of CategoricalValue which uses a Stratification to represent
270-
the domain knowledge of the Value.
271-
280+
"""A version of CategoricalValue which uses a Stratification to represent
281+
the domain knowledge of the Value.
282+
272283
Each unique value is mapped to a tree
273284
with nodes where the child order matters.
274285
By traversing the tree in DFS, each leaf is mapped to an integer."""
@@ -304,6 +315,14 @@ def is_ordinal(self) -> bool:
304315
def height(self):
305316
return self.head.height
306317

318+
class GenerationValue(StratifiedValue):
319+
table: str
320+
max_len: int
321+
322+
def __init__(self, table: str, max_len: int) -> None:
323+
self.table = table
324+
self.max_len = max_len
325+
super().__init__(Grouping('ord', list(range(max_len))), 0)
307326

308327
def _create_strat_value_cat(vals, na: bool = False, ukn_val: Any | None = None):
309328
arr = []
@@ -434,6 +453,16 @@ def NumAttribute(
434453
return Attribute(name, {name: NumValue(bins, min, max)}, nullable, False)
435454

436455

456+
def SeqAttribute(name: str, table: str):
457+
"""Returns an Attribute holding a single SeqValue with the provided data."""
458+
return Attribute(name, {name: SeqValue(name, table)}, False, False)
459+
460+
461+
def GenAttribute(name: str, table: str, max_len: int):
462+
"""Returns an Attribute holding a single GenerationValue with the provided data."""
463+
return Attribute(name, {name: GenerationValue(table, max_len)}, False, False)
464+
465+
437466
__all__ = [
438467
"get_dtype",
439468
"Grouping",

src/pasteur/extras/metrics/distr.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
from numpy import ndarray
1111
from scipy.special import rel_entr
1212
from scipy.stats import chisquare
13-
from pasteur.metric import Summaries
1413

14+
from pasteur.metric import Summaries
1515
from pasteur.utils import LazyDataset
1616

17-
from ...attribute import Attributes, CatValue, get_dtype
17+
from ...attribute import Attributes, CatValue, SeqValue, get_dtype
1818
from ...metric import Metric, Summaries
1919
from ...utils import LazyChunk, LazyFrame, data_to_tables
2020
from ...utils.progress import process_in_parallel
@@ -209,6 +209,8 @@ def fit(
209209
for table, attrs in meta.items():
210210
for attr in attrs.values():
211211
for name, val in attr.vals.items():
212+
if isinstance(val, SeqValue):
213+
continue
212214
assert isinstance(val, CatValue)
213215
self.domain[table][name] = val.domain
214216

src/pasteur/extras/transformers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44
import pandas as pd
5-
from pandas.api.types import is_categorical_dtype
5+
from pandas.api.types import is_categorical_dtype, is_float_dtype
66

77
from pasteur.attribute import Attributes
88
from pasteur.transform import RefTransformer, Transformer
@@ -401,7 +401,7 @@ def reverse(self, data: pd.DataFrame, ref: pd.Series | None = None) -> pd.Series
401401
na_mask |= np.any(vals[dcols] == 0, axis=1)
402402

403403
if ref is not None:
404-
na_mask |= pd.isna(ref)
404+
na_mask = pd.isna(ref) | na_mask
405405
ref = ref[~na_mask]
406406
vals = vals[~na_mask]
407407
ofs = 1
@@ -646,6 +646,9 @@ def transform(self, data: pd.Series, ref: pd.Series | None = None) -> pd.DataFra
646646
date_enc = self.dt.transform(data, ref)
647647
time_enc = self.tt.transform(data)
648648
del data, ref
649+
if self.nullable:
650+
c = date_enc[next(iter(date_enc))]
651+
time_enc[pd.isna(c) if is_float_dtype(c) else c == 0] = 0
649652
return pd.concat([date_enc, time_enc], axis=1, copy=False, join="inner")
650653

651654
def reverse(

src/pasteur/extras/views/mimic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def ingest(self, name, **tables: LazyChunk):
5454
case "admissions":
5555
return tables["core_admissions"]()
5656
case "transfers":
57-
return tables["core_transfers"]()
57+
return tables["core_transfers"]().dropna(subset=['hadm_id'])
5858
case other:
5959
assert False, f"Table {other} not part of view {self.name}"
6060

0 commit comments

Comments
 (0)