Skip to content

Commit e4f1f33

Browse files
committed
fix out-of-core and references for mimic
1 parent a47b3a8 commit e4f1f33

15 files changed

Lines changed: 190 additions & 206 deletions

File tree

src/pasteur/extras/datasets/mimic/__init__.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from ....dataset import Dataset
99
from ....utils import (
1010
LazyChunk,
11-
LazyDataset,
1211
LazyFrame,
1312
gen_closure,
1413
get_relative_fn,
@@ -20,24 +19,31 @@
2019

2120

2221
def _split_table(
23-
chunksize: int, keys: np.ndarray, table: "Callable[..., TextFileReader]"
22+
name: str, chunksize: int, keys: np.ndarray, table: "Callable[..., TextFileReader]"
2423
):
2524
pd_keys = pd.DataFrame(index=keys)
2625
del keys
2726

2827
for chunk in table(chunksize=chunksize):
29-
yield chunk.join(pd_keys, on="subject_id", how="inner")
28+
c = chunk.join(pd_keys, on="subject_id", how="inner")
29+
30+
# Fix poe id
31+
if name == 'hosp_pharmacy':
32+
c['poe_seq'] = c['poe_id'].str[1].astype('Int16')
33+
c = c.drop(columns=['poe_id'])
34+
35+
yield c
3036

3137

3238
def _partition_table(
33-
table: Callable, patients: LazyFrame, n_partition: int, chunksize: int
39+
name: str, table: Callable, patients: LazyFrame, n_partition: int, chunksize: int
3440
):
3541
# Deterministic loading = all tables have the same split
3642
keys = patients(["subject_id"]).index.to_numpy()
3743
partitions = np.array_split(keys, n_partition)
3844

3945
return {
40-
str(i): gen_closure(_split_table, chunksize, part, table)
46+
str(i): gen_closure(_split_table, name, chunksize, part, table)
4147
for i, part in enumerate(partitions)
4248
}
4349

@@ -97,6 +103,7 @@ def ingest(self, name, **tables: LazyFrame | Callable[[], TextFileReader]):
97103
if name in self._mimic_tables_partitioned:
98104
chunksize = self._mimic_tables_partitioned[name]
99105
return _partition_table(
106+
name,
100107
cast("Callable[[], TextFileReader]", tables[name]),
101108
cast("LazyFrame", tables["core_patients"]),
102109
self._n_partitions,

src/pasteur/extras/datasets/mimic/catalog.yml

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ _mimic_in_csv: &mimic_csv
77
sep: ","
88
engine: "c"
99
header: 0
10-
infer_datetime_format: True
10+
date_format: "%Y-%m-%d %H:%M:%S"
1111

1212
_mimic_in_chunked: &mimic_chunked # Currently a placeholder
1313
<<: *mimic_csv
@@ -18,6 +18,7 @@ mimic.raw@core_patients:
1818
load_args:
1919
<<: *mimic_csv_load
2020
index_col: subject_id
21+
date_format: "%Y-%m-%d"
2122
parse_dates: [dod]
2223
dtype:
2324
subject_id: int32
@@ -255,17 +256,29 @@ mimic.raw@hosp_poe:
255256
filepath: ${location}/hosp/poe.csv.gz
256257
load_args:
257258
<<: *mimic_csv_load
258-
index_col: poe_id
259+
index_col: [subject_id, poe_seq] #poe_id
260+
usecols:
261+
- subject_id
262+
- hadm_id
263+
- poe_seq
264+
- order_status
265+
- transaction_type
266+
- order_subtype
267+
- order_type
268+
# - poe_id
269+
- discontinue_of_poe_id
270+
- discontinued_by_poe_id
271+
- ordertime
259272
parse_dates: [ordertime]
260273
dtype:
261274
subject_id: int32
262275
hadm_id: int32
263-
poe_sec: int16
276+
poe_seq: int16
264277
order_status: category
265278
transaction_type: category
266279
order_subtype: category
267280
order_type: category
268-
poe_id: object
281+
# poe_id: object
269282
discontinue_of_poe_id: object
270283
discontinued_by_poe_id: object
271284

src/pasteur/extras/metrics/distr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def preprocess(
223223
per_call_meta = []
224224
base_args = {"domain": self.domain}
225225

226-
for pid, (cwrk, cref) in LazyDataset.zip([wrk, ref]).items():
226+
for cwrk, cref in LazyDataset.zip_values([wrk, ref]):
227227
for split, split_data in [("wrk", cwrk), ("ref", cref)]:
228228
ids, tables = data_to_tables(split_data)
229229

@@ -236,7 +236,7 @@ def preprocess(
236236
"tables": tables,
237237
}
238238
)
239-
per_call_meta.append({"split": split, "table": table, "pid": pid})
239+
per_call_meta.append({"split": split, "table": table})
240240

241241
# Process marginals
242242
out = process_in_parallel(
@@ -287,7 +287,7 @@ def process(
287287
per_call_meta = []
288288
base_args = {"domain": self.domain}
289289

290-
for pid, csyn in LazyDataset.zip(syn).items():
290+
for csyn in LazyDataset.zip_values(syn):
291291
ids, tables = data_to_tables(csyn)
292292

293293
for table in self.domain:
@@ -299,7 +299,7 @@ def process(
299299
"tables": tables,
300300
}
301301
)
302-
per_call_meta.append({"table": table, "pid": pid})
302+
per_call_meta.append({"table": table})
303303

304304
# Process marginals
305305
out = process_in_parallel(

src/pasteur/extras/metrics/visual.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, NamedTuple, TypeVar
1+
from typing import Any, NamedTuple, TypeVar, cast
22

33
import matplotlib.pyplot as plt
44
import numpy as np
@@ -8,10 +8,14 @@
88
from pandas.core.frame import DataFrame
99
from pandas.core.series import Series
1010

11-
from pasteur.metric import AbstractColumnMetric, RefColumnData, Summaries
12-
1311
from ...metadata import ColumnMeta, Metadata
14-
from ...metric import ColumnMetric, RefColumnMetric, Summaries
12+
from ...metric import (
13+
AbstractColumnMetric,
14+
ColumnMetric,
15+
RefColumnData,
16+
RefColumnMetric,
17+
Summaries,
18+
)
1519
from ...utils import list_unique
1620
from ...utils.mlflow import load_matplotlib_style, mlflow_log_hists
1721

@@ -263,9 +267,9 @@ class DateData(NamedTuple):
263267
class DateHist(RefColumnMetric[Summaries[DateData], Summaries[DateData]]):
264268
name = "date"
265269

266-
def fit(
267-
self, table: str, col: str, meta: ColumnMeta, data: pd.Series, ref: pd.Series
268-
):
270+
def fit(self, table: str, col: str | tuple[str], meta: ColumnMeta, data: RefColumnData):
271+
ref = data['ref']
272+
data = data['data']
269273
self.table = table
270274
self.col = col
271275

@@ -418,7 +422,7 @@ def process(
418422
syn: RefColumnData,
419423
pre: Summaries[DateData],
420424
) -> Summaries[DateData]:
421-
return pre.replace(syn=self._process(syn["wrk"], syn["ref"])) # type: ignore
425+
return pre.replace(syn=self._process(syn["data"], syn["ref"])) # type: ignore
422426

423427
def combine(self, summaries: list[Summaries[DateData]]) -> Summaries[DateData]:
424428
return Summaries(
@@ -596,12 +600,12 @@ def __init__(self, *args, _from_factory: bool = False, **kwargs) -> None:
596600
self.time = TimeHist(*args, _from_factory=_from_factory, **kwargs)
597601

598602
def fit(
599-
self, table: str, col: str, meta: ColumnMeta, data: pd.Series, ref: pd.Series
603+
self, table: str, col: str, meta: ColumnMeta, data: RefColumnData
600604
):
601605
self.table = table
602606
self.col = col
603-
self.date.fit(table=table, col=col, meta=meta, data=data, ref=ref)
604-
self.time.fit(table=table, col=col, meta=meta, data=data)
607+
self.date.fit(table=table, col=col, meta=meta, data=data)
608+
self.time.fit(table=table, col=col, meta=meta, data=cast(pd.Series, data['data']))
605609

606610
def preprocess(
607611
self, wrk: RefColumnData, ref: RefColumnData

src/pasteur/extras/transformers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -589,8 +589,11 @@ def fit(
589589
):
590590
self.col = cast(str, data.name)
591591

592-
cdt = self.dt.fit(data, ref)
593-
ctt = self.tt.fit(data)
592+
self.dt.fit(data, ref)
593+
self.tt.fit(data)
594+
595+
cdt = next(iter(self.dt.get_attributes().values()))
596+
ctt = next(iter(self.tt.get_attributes().values()))
594597
self.attr = Attribute(self.col, vals={**cdt.vals, **ctt.vals}, na=self.nullable)
595598

596599
def get_attributes(self) -> Attributes:

src/pasteur/kedro/dataset/auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def _save_worker(
6767
dtypes = p0.dtypes
6868
for field in old_schema:
6969
if (
70-
isinstance(field.type, pa.dictionaryType)
70+
isinstance(field.type, pa.DictionaryType)
7171
and field.type.index_type.bit_width == 8
7272
):
7373
# Expand uint8 dictionaries to uint16
@@ -114,7 +114,7 @@ def _save_worker(
114114

115115
for p in chunk: # type: ignore
116116
try:
117-
w.write(pa.Table.from_pandas(p, schema=schema))
117+
w.write(pa.Table.from_pandas(p, schema=schema, preserve_index=True))
118118
except Exception as e:
119119
logger.error(f"Error writing chunk:\n{e}")
120120
else:

src/pasteur/kedro/dataset/multi.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,14 @@ def _normalized_path(self) -> str:
7474
return self._path
7575

7676
def _list_partitions(self) -> list[str]:
77+
if not self._filesystem.isdir(self._normalized_path, **self._load_args):
78+
# If the path does not exist, ie no datasets were saved before
79+
# return no partitions instead of crashing
80+
return []
7781
return [
78-
path
79-
for path in self._filesystem.find(self._normalized_path, **self._load_args)
80-
if path.endswith(self._filename_suffix)
82+
path['name']
83+
for path in self._filesystem.listdir(self._normalized_path, **self._load_args)
84+
if path['name'].endswith(self._filename_suffix)
8185
]
8286

8387
def _join_protocol(self, path: str) -> str:

src/pasteur/kedro/runner/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def run_expanded_node(
3333
node_name = node.name.split("(")[0]
3434
set_node_name(node_name)
3535
try:
36-
3736
t = PerformanceTracker.get("nodes")
3837
t.log_to_file()
3938
t.start(node_name)

src/pasteur/kedro/runner/sequential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
class SimpleSequentialRunner(AbstractRunner):
2323
"""``SimpleRunner`` is a modification of ``SequentialRunner`` that uses a TQDM
24-
loading bar. It also force enables async save of datasets.
24+
loading bar.
2525
"""
2626

2727
def __init__(

src/pasteur/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, **kwargs):
4848

4949
# Ref can be set both by the ref keyword or by extended syntax
5050
ref = type_ref[1] if len(type_ref) > 1 else None
51-
refs = kwargs.get("ref", kwargs.get("refs", ref))
51+
ref = kwargs.get("ref", kwargs.get("refs", ref))
5252

5353
# Basic type and dtype data
5454
self.type = type

0 commit comments

Comments
 (0)