|
1 | | -from typing import Any, NamedTuple, TypeVar |
| 1 | +from typing import Any, NamedTuple, TypeVar, cast |
2 | 2 |
|
3 | 3 | import matplotlib.pyplot as plt |
4 | 4 | import numpy as np |
|
8 | 8 | from pandas.core.frame import DataFrame |
9 | 9 | from pandas.core.series import Series |
10 | 10 |
|
11 | | -from pasteur.metric import AbstractColumnMetric, RefColumnData, Summaries |
12 | | - |
13 | 11 | from ...metadata import ColumnMeta, Metadata |
14 | | -from ...metric import ColumnMetric, RefColumnMetric, Summaries |
| 12 | +from ...metric import ( |
| 13 | + AbstractColumnMetric, |
| 14 | + ColumnMetric, |
| 15 | + RefColumnData, |
| 16 | + RefColumnMetric, |
| 17 | + Summaries, |
| 18 | +) |
15 | 19 | from ...utils import list_unique |
16 | 20 | from ...utils.mlflow import load_matplotlib_style, mlflow_log_hists |
17 | 21 |
|
@@ -263,9 +267,9 @@ class DateData(NamedTuple): |
263 | 267 | class DateHist(RefColumnMetric[Summaries[DateData], Summaries[DateData]]): |
264 | 268 | name = "date" |
265 | 269 |
|
266 | | - def fit( |
267 | | - self, table: str, col: str, meta: ColumnMeta, data: pd.Series, ref: pd.Series |
268 | | - ): |
| 270 | + def fit(self, table: str, col: str | tuple[str], meta: ColumnMeta, data: RefColumnData): |
| 271 | + ref = data['ref'] |
| 272 | + data = data['data'] |
269 | 273 | self.table = table |
270 | 274 | self.col = col |
271 | 275 |
|
@@ -418,7 +422,7 @@ def process( |
418 | 422 | syn: RefColumnData, |
419 | 423 | pre: Summaries[DateData], |
420 | 424 | ) -> Summaries[DateData]: |
421 | | - return pre.replace(syn=self._process(syn["wrk"], syn["ref"])) # type: ignore |
| 425 | + return pre.replace(syn=self._process(syn["data"], syn["ref"])) # type: ignore |
422 | 426 |
|
423 | 427 | def combine(self, summaries: list[Summaries[DateData]]) -> Summaries[DateData]: |
424 | 428 | return Summaries( |
@@ -596,12 +600,12 @@ def __init__(self, *args, _from_factory: bool = False, **kwargs) -> None: |
596 | 600 | self.time = TimeHist(*args, _from_factory=_from_factory, **kwargs) |
597 | 601 |
|
598 | 602 | def fit( |
599 | | - self, table: str, col: str, meta: ColumnMeta, data: pd.Series, ref: pd.Series |
| 603 | + self, table: str, col: str, meta: ColumnMeta, data: RefColumnData |
600 | 604 | ): |
601 | 605 | self.table = table |
602 | 606 | self.col = col |
603 | | - self.date.fit(table=table, col=col, meta=meta, data=data, ref=ref) |
604 | | - self.time.fit(table=table, col=col, meta=meta, data=data) |
| 607 | + self.date.fit(table=table, col=col, meta=meta, data=data) |
| 608 | + self.time.fit(table=table, col=col, meta=meta, data=cast(pd.Series, data['data'])) |
605 | 609 |
|
606 | 610 | def preprocess( |
607 | 611 | self, wrk: RefColumnData, ref: RefColumnData |
|
0 commit comments