Skip to content

Commit 2883c77

Browse files
committed
fix end-to-end multitable process
1 parent 2ea5c83 commit 2883c77

8 files changed

Lines changed: 145 additions & 54 deletions

File tree

src/pasteur/attribute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ class GenerationValue(StratifiedValue):
322322
def __init__(self, table: str, max_len: int) -> None:
323323
self.table = table
324324
self.max_len = max_len
325-
super().__init__(Grouping('ord', list(range(max_len))), 0)
325+
super().__init__(Grouping('ord', list(range(max_len + 1))), 0)
326326

327327
def _create_strat_value_cat(vals, na: bool = False, ukn_val: Any | None = None):
328328
arr = []

src/pasteur/extras/metrics/distr.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def calc_marginal_1way(
4848
mul *= domain[col]
4949

5050
counts = np.bincount(idx, minlength=x_dom)
51+
assert (
52+
len(counts) == x_dom
53+
), f"Overflow error, domain for columns `{x}` is wrong or there is a mistake in encoding."
54+
5155
margin = counts.astype("float")
5256
margin /= margin.sum()
5357
if zero_fill is not None:
@@ -59,7 +63,9 @@ def calc_marginal_1way(
5963

6064

6165
def _visualise_cs(
62-
name: str, domain: dict[str, int], data: dict[str, Summaries[dict[str, np.ndarray]]]
66+
table: str,
67+
domain: dict[str, int],
68+
data: dict[str, Summaries[dict[str, np.ndarray]]],
6369
):
6470
import mlflow
6571

@@ -104,12 +110,12 @@ def _visualise_cs(
104110
split_ref="ref",
105111
)
106112

107-
fn = f"distr/cs.html" if name == "table" else f"distr/{name}_cs.html"
113+
fn = f"distr/cs.html" if table == "table" else f"distr/{table}_cs.html"
108114
mlflow.log_text(gen_html_table(style, FONT_SIZE), fn)
109115

110116

111117
def _visualise_kl(
112-
name: str, data: dict[str, Summaries[dict[tuple[str, str], np.ndarray]]]
118+
table: str, data: dict[str, Summaries[dict[tuple[str, str], np.ndarray]]]
113119
):
114120
import mlflow
115121

@@ -159,7 +165,7 @@ def _visualise_kl(
159165
split_ref="ref",
160166
)
161167

162-
fn = f"distr/kl.html" if name == "table" else f"distr/{name}_kl.html"
168+
fn = f"distr/kl.html" if table == "table" else f"distr/{table}_kl.html"
163169
mlflow.log_text(gen_html_table(style, FONT_SIZE), fn)
164170

165171

src/pasteur/extras/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def reverse(self, data: pd.DataFrame, ref: pd.Series | None = None) -> pd.Series
403403
if ref is not None:
404404
na_mask = pd.isna(ref) | na_mask
405405
ref = ref[~na_mask]
406-
vals = vals[~na_mask]
406+
vals = vals[~na_mask.reindex(vals.index)]
407407
ofs = 1
408408
else:
409409
ofs = 0

src/pasteur/extras/views/mimic/parameters_core.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,16 @@ tables:
2424
admittime:
2525
type: seq
2626
ref: patients.birth_year
27-
seq:
27+
ctx:
2828
type: datetime
2929
span: year.hour
3030
max_len: 99
3131
bins: 32
32+
seq:
33+
type: datetime
34+
span: year.hour
35+
max_len: 10
36+
bins: 10
3237
dischtime: datetime|day.hour:admittime
3338
deathtime: datetime?|day.hour:admittime
3439
admission_type: categorical

src/pasteur/kedro/pipelines/transform.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,18 +129,20 @@ def create_fit_pipeline(
129129
outputs=f"{view}.enc.{enc}",
130130
namespace=f"{view}.enc",
131131
)
132-
for enc in encs if enc not in ('raw', 'bst')
132+
for enc in encs
133+
if enc not in ("raw", "bst")
133134
]
134135

135136
return PipelineMeta(
136137
pipeline(trn_fit_nodes + enc_fit_nodes, tags=TAGS_TRANSFORM),
137138
[
138-
D("transformers", f"{view}.trn.{t}", ["view", view, "trn", t], type="pkl")
139+
D("transformers", f"{view}.trn.{t}", ["view", view, "trn", t], type="pkl")
139140
for t in view.tables
140141
]
141142
+ [
142-
D("encoders", f"{view}.enc.{enc}", ["view", view, 'enc', enc], type="pkl")
143-
for enc in encs if enc not in ('raw', 'bst')
143+
D("encoders", f"{view}.enc.{enc}", ["view", view, "enc", enc], type="pkl")
144+
for enc in encs
145+
if enc not in ("raw", "bst")
144146
],
145147
)
146148

@@ -178,7 +180,7 @@ def create_transform_pipeline(
178180
"split_transformed",
179181
f"{view}.{split}.ctx_{t}",
180182
["view", view, split, "ctx", t],
181-
type='multi'
183+
type="multi",
182184
)
183185
)
184186
outputs.append(
@@ -221,7 +223,7 @@ def create_transform_pipeline(
221223
f"{view}.{split}.{enc}",
222224
["synth" if retransform else "view", view, split, enc],
223225
versioned=retransform,
224-
type='multi'
226+
type="multi",
225227
)
226228
)
227229

@@ -262,7 +264,9 @@ def create_reverse_pipeline(view: View, alg: str, enc: str):
262264
"table": f"{view}.{alg}.bst_{t}",
263265
"ctx": f"{view}.{alg}.ctx_{t}",
264266
"ids": f"{view}.{alg}.ids_{t}",
265-
"parents": {req: req for req in view.trn_deps.get(t, [])},
267+
"parents": {
268+
req: f"{view}.{alg}.{req}" for req in view.trn_deps.get(t, [])
269+
},
266270
},
267271
outputs=f"{view}.{alg}.{t}",
268272
namespace=f"{view}.{alg}",
@@ -293,7 +297,7 @@ def create_reverse_pipeline(view: View, alg: str, enc: str):
293297
D(
294298
"synth_reversed",
295299
f"{view}.{alg}.{t}",
296-
["synth", view, alg, 'tables', t],
300+
["synth", view, alg, "tables", t],
297301
versioned=True,
298302
),
299303
]

src/pasteur/metric.py

Lines changed: 92 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,17 @@ def _fit_column_metrics(
218218
metrics: dict[str, list[ColumnMetricFactory]],
219219
):
220220
get_table = lazy_load_tables(tables)
221+
table = get_table(name)
221222

222223
if ref.table_has_reference():
223224
ids = ref.find_foreign_ids(name, get_table)
225+
226+
if len(table.index.symmetric_difference(ids.index)):
227+
old_len = len(table)
228+
table = table.reindex(ids.index)
229+
logger.warn(
230+
f"There are missing ids for rows in {name}, dropping {old_len-len(table)}/{old_len} rows with missing ids."
231+
)
224232
else:
225233
ids = None
226234

@@ -237,22 +245,22 @@ def _fit_column_metrics(
237245
m = factory.build(**col.args)
238246

239247
if isinstance(m, ColumnMetric):
240-
m.fit(name, col_name, col, get_table(name)[col_name])
248+
m.fit(name, col_name, col, table[col_name])
241249
elif isinstance(m, RefColumnMetric):
242-
ref_col = _calc_joined_refs(name, get_table, ids, col.ref)
250+
ref_col = _calc_joined_refs(name, get_table, ids, col.ref, table)
243251
m.fit(
244252
name,
245253
col_name,
246254
col,
247-
RefColumnData(data=get_table(name)[col_name], ref=ref_col),
255+
RefColumnData(data=table[col_name], ref=ref_col),
248256
)
249257
elif isinstance(m, SeqColumnMetric):
250-
ref_col = _calc_unjoined_refs(name, get_table, col.ref)
258+
ref_col = _calc_unjoined_refs(name, get_table, col.ref, table)
251259
m.fit(
252260
name,
253261
col_name,
254262
col,
255-
SeqColumnData(data=get_table(name)[col_name], ref=ref_col, ids=ids),
263+
SeqColumnData(data=table[col_name], ref=ref_col, ids=ids),
256264
)
257265
else:
258266
assert False, f"Unknown column metric type: {type(m)}"
@@ -272,10 +280,25 @@ def _preprocess_metrics(
272280
):
273281
get_table_wrk = lazy_load_tables(tables_wrk)
274282
get_table_ref = lazy_load_tables(tables_ref)
283+
table_wrk = get_table_wrk(name)
284+
table_ref = get_table_ref(name)
275285

276286
if ref.table_has_reference():
277287
ids_wrk = ref.find_foreign_ids(name, get_table_wrk)
278288
ids_ref = ref.find_foreign_ids(name, get_table_ref)
289+
290+
if len(table_wrk.index.symmetric_difference(ids_wrk.index)):
291+
old_len = len(table_wrk)
292+
table_wrk = table_wrk.reindex(ids_wrk.index)
293+
logger.warn(
294+
f"There are missing ids for rows in {name}, dropping {old_len-len(table_wrk)}/{old_len} rows with missing ids."
295+
)
296+
if len(table_ref.index.symmetric_difference(ids_ref.index)):
297+
old_len = len(table_ref)
298+
table_ref = table_ref.reindex(ids_ref.index)
299+
logger.warn(
300+
f"There are missing ids for rows in {name}, dropping {old_len-len(table_ref)}/{old_len} rows with missing ids."
301+
)
279302
else:
280303
ids_wrk = None
281304
ids_ref = None
@@ -286,30 +309,38 @@ def _preprocess_metrics(
286309
col = meta[name][col_name]
287310
if isinstance(m, ColumnMetric):
288311
prec = m.preprocess(
289-
get_table_wrk(name)[col_name],
290-
get_table_ref(name)[col_name],
312+
table_wrk[col_name],
313+
table_ref[col_name],
291314
)
292315
elif isinstance(m, RefColumnMetric):
293316
prec = m.preprocess(
294317
RefColumnData(
295-
data=get_table_wrk(name)[col_name],
296-
ref=_calc_joined_refs(name, get_table_wrk, ids_ref, col.ref),
318+
data=table_wrk[col_name],
319+
ref=_calc_joined_refs(
320+
name, get_table_wrk, ids_ref, col.ref, table_wrk
321+
),
297322
),
298323
RefColumnData(
299-
data=get_table_ref(name)[col_name],
300-
ref=_calc_joined_refs(name, get_table_ref, ids_ref, col.ref),
324+
data=table_ref[col_name],
325+
ref=_calc_joined_refs(
326+
name, get_table_ref, ids_ref, col.ref, table_ref
327+
),
301328
),
302329
)
303330
elif isinstance(m, SeqColumnMetric):
304331
prec = m.preprocess(
305332
SeqColumnData(
306-
data=get_table_wrk(name)[col_name],
307-
ref=_calc_unjoined_refs(name, get_table_wrk, col.ref),
333+
data=table_wrk[col_name],
334+
ref=_calc_unjoined_refs(
335+
name, get_table_wrk, col.ref, table_wrk
336+
),
308337
ids=ids_wrk,
309338
),
310339
SeqColumnData(
311-
data=get_table_ref(name)[col_name],
312-
ref=_calc_unjoined_refs(name, get_table_ref, col.ref),
340+
data=table_ref[col_name],
341+
ref=_calc_unjoined_refs(
342+
name, get_table_ref, col.ref, table_ref
343+
),
313344
ids=ids_ref,
314345
),
315346
)
@@ -334,11 +365,33 @@ def _process_metrics(
334365
get_table_wrk = lazy_load_tables(tables_wrk)
335366
get_table_ref = lazy_load_tables(tables_ref)
336367
get_table_syn = lazy_load_tables(tables_syn)
368+
table_wrk = get_table_wrk(name)
369+
table_ref = get_table_ref(name)
370+
table_syn = get_table_syn(name)
337371

338372
if ref.table_has_reference():
339373
ids_wrk = ref.find_foreign_ids(name, get_table_wrk)
340374
ids_ref = ref.find_foreign_ids(name, get_table_ref)
341375
ids_syn = ref.find_foreign_ids(name, get_table_syn)
376+
377+
if len(table_wrk.index.symmetric_difference(ids_wrk.index)):
378+
old_len = len(table_wrk)
379+
table_wrk = table_wrk.reindex(ids_wrk.index)
380+
logger.warn(
381+
f"There are missing ids for rows in {name}, dropping {old_len-len(table_wrk)}/{old_len} rows with missing ids."
382+
)
383+
if len(table_ref.index.symmetric_difference(ids_ref.index)):
384+
old_len = len(table_ref)
385+
table_ref = table_ref.reindex(ids_ref.index)
386+
logger.warn(
387+
f"There are missing ids for rows in {name}, dropping {old_len-len(table_ref)}/{old_len} rows with missing ids."
388+
)
389+
if len(table_syn.index.symmetric_difference(ids_syn.index)):
390+
old_len = len(table_syn)
391+
table_syn = table_syn.reindex(ids_syn.index)
392+
logger.warn(
393+
f"There are missing ids for rows in {name}, dropping {old_len-len(table_syn)}/{old_len} rows with missing ids."
394+
)
342395
else:
343396
ids_wrk = None
344397
ids_ref = None
@@ -358,34 +411,46 @@ def _process_metrics(
358411
elif isinstance(m, RefColumnMetric):
359412
proc = m.process(
360413
RefColumnData(
361-
data=get_table_wrk(name)[col_name],
362-
ref=_calc_joined_refs(name, get_table_wrk, ids_wrk, col.ref),
414+
data=table_wrk[col_name],
415+
ref=_calc_joined_refs(
416+
name, get_table_wrk, ids_wrk, col.ref, table_wrk
417+
),
363418
),
364419
RefColumnData(
365-
data=get_table_ref(name)[col_name],
366-
ref=_calc_joined_refs(name, get_table_ref, ids_ref, col.ref),
420+
data=table_ref[col_name],
421+
ref=_calc_joined_refs(
422+
name, get_table_ref, ids_ref, col.ref, table_ref
423+
),
367424
),
368425
RefColumnData(
369-
data=get_table_syn(name)[col_name],
370-
ref=_calc_joined_refs(name, get_table_syn, ids_syn, col.ref),
426+
data=table_syn[col_name],
427+
ref=_calc_joined_refs(
428+
name, get_table_syn, ids_syn, col.ref, table_syn
429+
),
371430
),
372431
prec,
373432
)
374433
elif isinstance(m, SeqColumnMetric):
375434
proc = m.process(
376435
SeqColumnData(
377-
data=get_table_wrk(name)[col_name],
378-
ref=_calc_unjoined_refs(name, get_table_wrk, col.ref),
436+
data=table_wrk[col_name],
437+
ref=_calc_unjoined_refs(
438+
name, get_table_wrk, col.ref, table_wrk
439+
),
379440
ids=ids_wrk,
380441
),
381442
SeqColumnData(
382-
data=get_table_ref(name)[col_name],
383-
ref=_calc_unjoined_refs(name, get_table_ref, col.ref),
443+
data=table_ref[col_name],
444+
ref=_calc_unjoined_refs(
445+
name, get_table_ref, col.ref, table_ref
446+
),
384447
ids=ids_ref,
385448
),
386449
SeqColumnData(
387-
data=get_table_syn(name)[col_name],
388-
ref=_calc_unjoined_refs(name, get_table_syn, col.ref),
450+
data=table_syn[col_name],
451+
ref=_calc_unjoined_refs(
452+
name, get_table_syn, col.ref, table_syn
453+
),
389454
ids=ids_syn,
390455
),
391456
prec,

src/pasteur/synth.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from functools import partial, wraps
99
from typing import TYPE_CHECKING, Any, Generic, TypeVar
1010

11+
from pasteur.utils import LazyDataset
12+
1113
from .encode import ViewEncoder
1214
from .metadata import Metadata
1315
from .module import ModuleClass, ModuleFactory
@@ -147,7 +149,7 @@ def synth_fit(
147149

148150
tracker = PerformanceTracker.get("synth")
149151

150-
tracker.ensemble("total", "preprocess", "bake", "fit", "sample")
152+
tracker.ensemble("total", "preprocess", "bake", "fit")
151153

152154
meta = encoder.get_metadata()
153155
args = {**metadata.algs.get(factory.name, {}), **metadata.alg_override}
@@ -185,10 +187,10 @@ class IdentSynth(Synth):
185187
def preprocess(self, meta: Any, data: dict[str, LazyDataset]):
186188
pass
187189

188-
def bake(self, meta: Any, data: dict[str, LazyDataset]):
190+
def bake(self, data: dict[str, LazyDataset]):
189191
pass
190192

191-
def fit(self, meta: Any, data: dict[str, LazyDataset]):
193+
def fit(self, data: dict[str, LazyDataset]):
192194
self.data = data
193195

194196
def sample(self, n: int | None = None):

0 commit comments

Comments
 (0)