|
| 1 | +from typing import cast, Any |
| 2 | + |
| 3 | +import pandas as pd |
| 4 | +from pandas import DataFrame, Series |
| 5 | + |
| 6 | +from pasteur.transform import SeqTransformer, TransformerFactory, Transformer |
| 7 | +from pasteur.module import ModuleFactory, get_module_dict, Module |
| 8 | +from pasteur.attribute import Attribute, Attributes, SeqValue, get_dtype, SeqAttribute, GenAttribute |
| 9 | +from pasteur.extras.transformers import DatetimeTransformer |
| 10 | + |
| 11 | +from project.settings import PASTEUR_MODULES as modules |
| 12 | + |
| 13 | + |
| 14 | +def _backref_cols( |
| 15 | + ids: pd.DataFrame, seq: pd.Series, data: pd.DataFrame | pd.Series, parent: str |
| 16 | +): |
| 17 | + # Ref is calculated by mapping each id in data_df by merging its parent |
| 18 | + # key, sequence number to parent key, and the number - 1 and finding the |
| 19 | + # corresponding id for that row. Then, a join is performed. |
| 20 | + _IDX_NAME = "_id_lkjijk" |
| 21 | + _JOIN_NAME = "_id_zdjwk" |
| 22 | + ids_seq_prev = ids.join(seq + 1).reset_index(names=_JOIN_NAME) |
| 23 | + ids_seq = ids.join(seq, how="right").reset_index(names=_IDX_NAME) |
| 24 | + # FIXME: ids become float |
| 25 | + join_ids = ids_seq.merge(ids_seq_prev, on=[parent, seq.name], how='left').set_index(_IDX_NAME)[ |
| 26 | + [_JOIN_NAME] |
| 27 | + ] # type: ignore |
| 28 | + ref_df = join_ids.join(data, on=_JOIN_NAME).drop(columns=_JOIN_NAME) |
| 29 | + ref_df.index.name = data.index.name |
| 30 | + if isinstance(data, pd.Series): |
| 31 | + return ref_df[data.name] |
| 32 | + return ref_df |
| 33 | + |
| 34 | + |
| 35 | +def _calculate_seq(data: Series, parent: str, col_seq: str): |
| 36 | + _ID_SEQ = "_id_sdfasdf" |
| 37 | + seq = ( |
| 38 | + cast( |
| 39 | + pd.Series, |
| 40 | + pd.concat({parent: ids[parent], _ID_SEQ: data}, axis=1) |
| 41 | + .groupby(parent)[_ID_SEQ] |
| 42 | + .rank("first"), |
| 43 | + ) |
| 44 | + - 1 |
| 45 | + ) |
| 46 | + max_len = int(cast(float, seq.max())) + 1 |
| 47 | + return seq.astype(get_dtype(max_len + 1)).rename(col_seq) |
| 48 | + |
| 49 | + |
| 50 | +class SeqTransformerWrapper(SeqTransformer): |
| 51 | + name = "seqwrap" |
| 52 | + |
| 53 | + def __init__( |
| 54 | + self, |
| 55 | + modules: list[Module], |
| 56 | + ctx: dict[str, Any], |
| 57 | + seq: dict[str, Any], |
| 58 | + parent: str | None = None, |
| 59 | + seq_col: str | None = None, |
| 60 | + **kwargs, |
| 61 | + ) -> None: |
| 62 | + super().__init__(**kwargs) |
| 63 | + self.parent = parent |
| 64 | + self.seq_col_ref = seq_col |
| 65 | + |
| 66 | + # Load transformers |
| 67 | + assert ctx and seq |
| 68 | + ctx_kwargs = ctx.copy() |
| 69 | + ctx_type = ctx_kwargs.pop("type") |
| 70 | + self.ctx = get_module_dict(TransformerFactory, modules)[ |
| 71 | + cast(str, ctx_type) |
| 72 | + ].build(**ctx_kwargs) |
| 73 | + assert isinstance(self.ctx, Transformer) |
| 74 | + |
| 75 | + seq_kwargs = seq.copy() |
| 76 | + seq_type = seq_kwargs.pop("type") |
| 77 | + self.seq = get_module_dict(TransformerFactory, modules)[ |
| 78 | + cast(str, seq_type) |
| 79 | + ].build(**seq_kwargs) |
| 80 | + assert isinstance(self.seq, RefTransformer) |
| 81 | + |
| 82 | + def fit( |
| 83 | + self, |
| 84 | + table: str, |
| 85 | + data: Series | DataFrame, |
| 86 | + ref: dict[str, DataFrame], |
| 87 | + ids: DataFrame, |
| 88 | + seq_val: SeqValue | None = None, |
| 89 | + seq: Series | None = None, |
| 90 | + ) -> tuple[SeqValue, Series] | None: |
| 91 | + self.col = cast(str, data.name) |
| 92 | + self.table = table |
| 93 | + |
| 94 | + # Grab parent from seq_val if available |
| 95 | + if seq_val is not None: |
| 96 | + self.parent = seq_val.table |
| 97 | + self.col_seq = seq_val.name |
| 98 | + else: |
| 99 | + self.col_seq = f"{table}_seq" |
| 100 | + self.col_n = f'{table}_n' |
| 101 | + |
| 102 | + if not self.parent: |
| 103 | + # Infering parent through references |
| 104 | + self.parent = next(iter(ref)) |
| 105 | + # Process references |
| 106 | + # if ref: |
| 107 | + # self.ref_table = next(iter(ref)) |
| 108 | + # self.ref_col = cast(str, next(iter(ref[self.ref_table].keys()))) |
| 109 | + |
| 110 | + assert ( |
| 111 | + self.parent |
| 112 | + ), "Parent table not specified, use parameter 'parent' or a foreign reference." |
| 113 | + |
| 114 | + # If seq was not provided |
| 115 | + self.generate_seq = False |
| 116 | + if seq is None: |
| 117 | + self.generate_seq = True |
| 118 | + if isinstance(data, DataFrame): |
| 119 | + assert self.seq_col_ref is not None, f'Multiple columns are provided as input, specify which one is used sequence the table through parameter `seq_col`.' |
| 120 | + seq_col = data[self.seq_col_ref] |
| 121 | + else: |
| 122 | + seq_col = data |
| 123 | + seq = _calculate_seq(seq_col, self.parent, self.col_seq) |
| 124 | + self.max_len = cast(int, seq.max()) + 1 |
| 125 | + |
| 126 | + ctx_data = ( |
| 127 | + ids.join(data[seq == 0], how="right") |
| 128 | + .drop_duplicates(subset=[self.parent]) |
| 129 | + .set_index(self.parent)[self.col] |
| 130 | + ) |
| 131 | + if ref: |
| 132 | + ctx_ref = ids.drop_duplicates(subset=[self.parent]) |
| 133 | + for name, ref_table in ref.items(): |
| 134 | + ctx_ref = ctx_ref.join(ref_table, on=name, how="left") |
| 135 | + ctx_ref = ctx_ref.set_index(self.parent) |
| 136 | + |
| 137 | + assert isinstance( |
| 138 | + self.ctx, RefTransformer |
| 139 | + ), f"Reference found, initial transformer should be a reference transformer." |
| 140 | + self.ctx.fit(ctx_data, ctx_ref) |
| 141 | + else: |
| 142 | + self.ctx.fit(ctx_data) |
| 143 | + |
| 144 | + # Data series is all rows where seq > 0 (skip initial) |
| 145 | + ref_df = _backref_cols(ids, seq, data, self.parent) |
| 146 | + self.seq.fit(data, ref_df) |
| 147 | + |
| 148 | + # If a seq_val was not provided, assume seq was also none and |
| 149 | + # become the sequencer |
| 150 | + if seq_val is None: |
| 151 | + return SeqValue(self.col_seq, self.parent), cast(Series, seq) |
| 152 | + |
| 153 | + def reduce(self, other: "SeqTransformerWrapper"): |
| 154 | + self.ctx.reduce(other) |
| 155 | + self.seq.reduce(other) |
| 156 | + self.max_len = max(other.max_len, self.max_len) |
| 157 | + |
| 158 | + def transform( |
| 159 | + self, |
| 160 | + data: Series | DataFrame, |
| 161 | + ref: dict[str, DataFrame], |
| 162 | + ids: DataFrame, |
| 163 | + seq: Series | None = None, |
| 164 | + ) -> tuple[DataFrame, dict[str, DataFrame]] | tuple[ |
| 165 | + DataFrame, dict[str, DataFrame], Series |
| 166 | + ]: |
| 167 | + parent = cast(str, self.parent) |
| 168 | + if self.generate_seq: |
| 169 | + if isinstance(data, DataFrame): |
| 170 | + assert self.seq_col_ref is not None, f'Multiple columns are provided as input, specify which one is used sequence the table through parameter `seq_col`.' |
| 171 | + seq_col = data[self.seq_col_ref] |
| 172 | + else: |
| 173 | + seq_col = data |
| 174 | + seq = _calculate_seq(seq_col, parent, self.col_seq) |
| 175 | + else: |
| 176 | + assert seq is not None |
| 177 | + |
| 178 | + ctx_data = ( |
| 179 | + ids.join(data[seq == 0], how="right") |
| 180 | + .drop_duplicates(subset=[self.parent]) |
| 181 | + .set_index(self.parent)[self.col] |
| 182 | + ) |
| 183 | + if ref: |
| 184 | + ctx_ref = ids.drop_duplicates(subset=[self.parent]) |
| 185 | + for name, ref_table in ref.items(): |
| 186 | + ctx_ref = ctx_ref.join(ref_table, on=name, how="left") |
| 187 | + ctx_ref = ctx_ref.set_index(self.parent) |
| 188 | + |
| 189 | + if isinstance(ctx_ref, DataFrame) and ctx_ref.shape[1] == 1: |
| 190 | + ctx_ref = ctx_ref[next(iter(ctx_ref))] |
| 191 | + |
| 192 | + assert isinstance( |
| 193 | + self.ctx, RefTransformer |
| 194 | + ), f"Reference found, initial transformer should be a reference transformer." |
| 195 | + ctx = self.ctx.transform(ctx_data, ctx_ref) |
| 196 | + else: |
| 197 | + ctx = self.ctx.transform(ctx_data) |
| 198 | + |
| 199 | + # Data series is all rows where seq > 0 (skip initial) |
| 200 | + ref_df = _backref_cols(ids, seq, data, parent) |
| 201 | + enc = self.seq.transform(data, ref_df) |
| 202 | + |
| 203 | + if self.generate_seq: |
| 204 | + return enc, {parent: pd.concat([ctx, ids.join(seq).groupby(self.parent)[cast(str, seq.name)].max().rename(self.col_n) + 1], axis=1)}, seq |
| 205 | + return enc, {parent: ctx} |
| 206 | + |
| 207 | + |
| 208 | + def get_attributes(self) -> tuple[Attributes, dict[str, Attributes]]: |
| 209 | + return { |
| 210 | + self.col_seq: SeqAttribute(self.col_seq, cast(str, self.parent)), |
| 211 | + **self.seq.get_attributes(), |
| 212 | + }, {cast(str, self.parent): {**self.ctx.get_attributes(), self.col_n: GenAttribute(self.col_n, self.table, self.max_len)}} |
| 213 | + |
| 214 | + |
| 215 | +s = SeqTransformerWrapper(modules, {"type": "datetime", "nullable": True}, {"type": "datetime", "nullable": True}) |
| 216 | +s.fit( |
| 217 | + "admissions", admissions["admittime"], {"patients": patients[["birth_year"]]}, ids |
| 218 | +) |
| 219 | +r = s.transform(admissions["admittime"], {"patients": patients[["birth_year"]]}, ids) |
| 220 | +s.max_len |
0 commit comments