Skip to content

Commit 46fbfdd

Browse files
committed
polish transformers and add reductions
1 parent e7024a1 commit 46fbfdd

1 file changed

Lines changed: 69 additions & 26 deletions

File tree

src/pasteur/extras/transformers.py

Lines changed: 69 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,20 @@
55
from pandas.api.types import is_categorical_dtype
66

77
from pasteur.attribute import Attributes
8+
from pasteur.transform import RefTransformer, Transformer
89

910
from ..attribute import (
1011
Attribute,
1112
CatAttribute,
1213
Grouping,
13-
StratifiedValue,
1414
NumAttribute,
1515
NumValue,
1616
OrdAttribute,
17-
_create_strat_value_ord as OrdValue,
18-
get_dtype,
17+
StratifiedValue,
1918
)
20-
from ..transform import RefTransformer, Transformer
19+
from ..attribute import _create_strat_value_ord as OrdValue
20+
from ..attribute import get_dtype
21+
from ..transform import RefTransformer, SeqTransformer, Transformer
2122
from ..utils import list_unique
2223

2324

@@ -52,7 +53,20 @@ def fit(self, data: pd.Series):
5253
if self.max is None and self.find_edges:
5354
self.max = data.max()
5455
self.attr = NumAttribute(self.col, self.bins, self.min, self.max, self.nullable)
55-
56+
57+
def reduce(self, other: "NumericalTransformer"):
58+
if self.min is not None and other.min is not None:
59+
self.min = min(self.min, other.min)
60+
elif other.min is not None:
61+
self.min = other.min
62+
63+
if self.max is not None and other.max is not None:
64+
self.min = min(self.max, other.max)
65+
elif other.max is not None:
66+
self.max = other.max
67+
68+
self.attr = NumAttribute(self.col, self.bins, self.min, self.max, self.nullable)
69+
5670
def get_attributes(self) -> Attributes:
5771
return {self.attr.name: self.attr}
5872

@@ -85,37 +99,45 @@ def __init__(self, unknown_value=None, nullable: bool = False, **_):
8599
def fit(self, data: pd.Series):
86100
# Makes fit run out of core by storing the unique values seen previously in `raw_vals`
87101
new_vals = [v for v in data.unique() if not pd.isna(v)]
88-
vals = list_unique(new_vals, self.raw_vals)
89-
self.raw_vals = vals
102+
self.raw_vals = list_unique(new_vals, self.raw_vals)
103+
104+
ofs = 0
105+
if self.nullable:
106+
ofs += 1
107+
if self.unknown_value is not None:
108+
ofs += 1
109+
110+
self.ofs = ofs
111+
self.col = cast(str, data.name)
112+
self.type = data.dtype
113+
self._finalize_props()
114+
115+
def reduce(self, other: "IdxTransformer"):
116+
self.raw_vals = list_unique(self.raw_vals, other.raw_vals)
117+
self._finalize_props()
90118

119+
def _finalize_props(self):
91120
# Try to sort vals
121+
vals = self.raw_vals
92122
try:
93123
vals = sorted(vals)
94124
except Exception:
95125
assert not self.ordinal, "Ordinal Array is not sortable"
96126

97127
vals = list(vals)
98-
ofs = 0
99-
if self.nullable:
100-
ofs += 1
101-
if self.unknown_value is not None:
102-
ofs += 1
103-
128+
ofs = self.ofs
104129
self.mapping = {val: i + ofs for i, val in enumerate(vals)}
105130
self.vals = {i + ofs: val for i, val in enumerate(vals)}
106-
self.col = cast(str, data.name)
107131
self.domain = ofs + len(vals)
108-
self.ofs = ofs
109132

110133
# FIXME: If a column is empty it causes problems for the algorithm
111134
# add 1 fake value as fix
112135
if not vals:
113136
vals = [7777777]
114137

115-
self.type = data.dtype
116138
cls = OrdAttribute if self.ordinal else CatAttribute
117-
self.attr = cls(cast(str, data.name), vals, self.nullable, self.unknown_value)
118-
139+
self.attr = cls(self.col, vals, self.nullable, self.unknown_value)
140+
119141
def get_attributes(self) -> Attributes:
120142
return {self.attr.name: self.attr}
121143

@@ -217,10 +239,18 @@ def fit(
217239
self.ref = data.min()
218240
else:
219241
self.ref = min(data.min(), self.ref)
242+
self.col = cast(str, data.name)
243+
self._finalize_props()
220244

221-
col = cast(str, data.name) # type: ignore
222-
self.col = col
245+
def reduce(self, other: "DateTransformer"):
246+
if self.ref is not None and other.ref is not None:
247+
self.ref = min(other.ref, self.ref)
248+
elif other.ref is not None:
249+
self.ref = other.ref
250+
self._finalize_props()
223251

252+
def _finalize_props(self):
253+
col = self.col
224254
# Generate constraints for columns
225255
days = [
226256
"Monday",
@@ -305,7 +335,7 @@ def transform(self, data: pd.Series, ref: pd.Series | None = None) -> pd.DataFra
305335
# When using a ref column accessing the date parameters is done by the dt member.
306336
# When self referencing to the minimum value, its type is a Timestamp
307337
# which doesn't have the dt member and requires direct access.
308-
rf_dt = rf if isinstance(rf, pd.Timestamp) else rf.dt
338+
rf_dt = rf if isinstance(rf, pd.Timestamp) else cast(pd.Series, rf).dt
309339

310340
iso = vals.dt.isocalendar()
311341
iso_rf = rf_dt.isocalendar()
@@ -390,9 +420,9 @@ def reverse(self, data: pd.DataFrame, ref: pd.Series | None = None) -> pd.Series
390420
iso_rf = rf.isocalendar()
391421
rf_day = iso_rf.weekday # type: ignore
392422
else:
393-
rf_dt = rf.dt
423+
rf_dt = cast(pd.Series, rf).dt
394424
rf_year = rf_dt.year
395-
iso_rf = rf.dt.isocalendar()
425+
iso_rf = rf_dt.isocalendar()
396426
rf_day = iso_rf["day"]
397427

398428
match self.span:
@@ -413,7 +443,7 @@ def reverse(self, data: pd.DataFrame, ref: pd.Series | None = None) -> pd.Series
413443
+ 1
414444
).clip(0),
415445
unit="days",
416-
)
446+
) # type: ignore
417447
case "day":
418448
# TODO: fix negative spans
419449
out = rf_dt.normalize() + pd.to_timedelta(
@@ -441,8 +471,11 @@ def fit(
441471
self,
442472
data: pd.Series,
443473
):
474+
self.col = cast(str, data.name)
475+
self._finalize_props()
476+
477+
def _finalize_props(self):
444478
span = self.span
445-
self.col = data.name
446479

447480
hours = []
448481
for hour in range(24):
@@ -484,7 +517,9 @@ def fit(
484517
self.domain = lvl.size
485518

486519
self.attr = Attribute(
487-
cast(str, data.name), {f"{data.name}_time": StratifiedValue(lvl)}, self.nullable
520+
self.col,
521+
{f"{self.col}_time": StratifiedValue(lvl)},
522+
self.nullable,
488523
)
489524

490525
def get_attributes(self) -> Attributes:
@@ -592,6 +627,14 @@ def fit(
592627
self.dt.fit(data, ref)
593628
self.tt.fit(data)
594629

630+
self._finalize_props()
631+
632+
def reduce(self, other: "DatetimeTransformer"):
633+
self.dt.reduce(other.dt)
634+
self.tt.reduce(other.tt)
635+
self._finalize_props()
636+
637+
def _finalize_props(self):
595638
cdt = next(iter(self.dt.get_attributes().values()))
596639
ctt = next(iter(self.tt.get_attributes().values()))
597640
self.attr = Attribute(self.col, vals={**cdt.vals, **ctt.vals}, na=self.nullable)

0 commit comments

Comments
 (0)