55from pandas .api .types import is_categorical_dtype
66
77from pasteur .attribute import Attributes
8+ from pasteur .transform import RefTransformer , Transformer
89
910from ..attribute import (
1011 Attribute ,
1112 CatAttribute ,
1213 Grouping ,
13- StratifiedValue ,
1414 NumAttribute ,
1515 NumValue ,
1616 OrdAttribute ,
17- _create_strat_value_ord as OrdValue ,
18- get_dtype ,
17+ StratifiedValue ,
1918)
20- from ..transform import RefTransformer , Transformer
19+ from ..attribute import _create_strat_value_ord as OrdValue
20+ from ..attribute import get_dtype
21+ from ..transform import RefTransformer , SeqTransformer , Transformer
2122from ..utils import list_unique
2223
2324
@@ -52,7 +53,20 @@ def fit(self, data: pd.Series):
5253 if self .max is None and self .find_edges :
5354 self .max = data .max ()
5455 self .attr = NumAttribute (self .col , self .bins , self .min , self .max , self .nullable )
55-
56+
57+ def reduce (self , other : "NumericalTransformer" ):
58+ if self .min is not None and other .min is not None :
59+ self .min = min (self .min , other .min )
60+ elif other .min is not None :
61+ self .min = other .min
62+
63+ if self .max is not None and other .max is not None :
64+ self .min = min (self .max , other .max )
65+ elif other .max is not None :
66+ self .max = other .max
67+
68+ self .attr = NumAttribute (self .col , self .bins , self .min , self .max , self .nullable )
69+
5670 def get_attributes (self ) -> Attributes :
5771 return {self .attr .name : self .attr }
5872
@@ -85,37 +99,45 @@ def __init__(self, unknown_value=None, nullable: bool = False, **_):
8599 def fit (self , data : pd .Series ):
86100 # Makes fit run out of core by storing the unique values seen previously in `raw_vals`
87101 new_vals = [v for v in data .unique () if not pd .isna (v )]
88- vals = list_unique (new_vals , self .raw_vals )
89- self .raw_vals = vals
102+ self .raw_vals = list_unique (new_vals , self .raw_vals )
103+
104+ ofs = 0
105+ if self .nullable :
106+ ofs += 1
107+ if self .unknown_value is not None :
108+ ofs += 1
109+
110+ self .ofs = ofs
111+ self .col = cast (str , data .name )
112+ self .type = data .dtype
113+ self ._finalize_props ()
114+
115+ def reduce (self , other : "IdxTransformer" ):
116+ self .raw_vals = list_unique (self .raw_vals , other .raw_vals )
117+ self ._finalize_props ()
90118
119+ def _finalize_props (self ):
91120 # Try to sort vals
121+ vals = self .raw_vals
92122 try :
93123 vals = sorted (vals )
94124 except Exception :
95125 assert not self .ordinal , "Ordinal Array is not sortable"
96126
97127 vals = list (vals )
98- ofs = 0
99- if self .nullable :
100- ofs += 1
101- if self .unknown_value is not None :
102- ofs += 1
103-
128+ ofs = self .ofs
104129 self .mapping = {val : i + ofs for i , val in enumerate (vals )}
105130 self .vals = {i + ofs : val for i , val in enumerate (vals )}
106- self .col = cast (str , data .name )
107131 self .domain = ofs + len (vals )
108- self .ofs = ofs
109132
110133 # FIXME: If a column is empty it causes problems for the algorithm
111134 # add 1 fake value as fix
112135 if not vals :
113136 vals = [7777777 ]
114137
115- self .type = data .dtype
116138 cls = OrdAttribute if self .ordinal else CatAttribute
117- self .attr = cls (cast ( str , data . name ) , vals , self .nullable , self .unknown_value )
118-
139+ self .attr = cls (self . col , vals , self .nullable , self .unknown_value )
140+
119141 def get_attributes (self ) -> Attributes :
120142 return {self .attr .name : self .attr }
121143
@@ -217,10 +239,18 @@ def fit(
217239 self .ref = data .min ()
218240 else :
219241 self .ref = min (data .min (), self .ref )
242+ self .col = cast (str , data .name )
243+ self ._finalize_props ()
220244
221- col = cast (str , data .name ) # type: ignore
222- self .col = col
245+ def reduce (self , other : "DateTransformer" ):
246+ if self .ref is not None and other .ref is not None :
247+ self .ref = min (other .ref , self .ref )
248+ elif other .ref is not None :
249+ self .ref = other .ref
250+ self ._finalize_props ()
223251
252+ def _finalize_props (self ):
253+ col = self .col
224254 # Generate constraints for columns
225255 days = [
226256 "Monday" ,
@@ -305,7 +335,7 @@ def transform(self, data: pd.Series, ref: pd.Series | None = None) -> pd.DataFra
305335 # When using a ref column accessing the date parameters is done by the dt member.
306336 # When self referencing to the minimum value, its type is a Timestamp
307337 # which doesn't have the dt member and requires direct access.
308- rf_dt = rf if isinstance (rf , pd .Timestamp ) else rf .dt
338+ rf_dt = rf if isinstance (rf , pd .Timestamp ) else cast ( pd . Series , rf ) .dt
309339
310340 iso = vals .dt .isocalendar ()
311341 iso_rf = rf_dt .isocalendar ()
@@ -390,9 +420,9 @@ def reverse(self, data: pd.DataFrame, ref: pd.Series | None = None) -> pd.Series
390420 iso_rf = rf .isocalendar ()
391421 rf_day = iso_rf .weekday # type: ignore
392422 else :
393- rf_dt = rf .dt
423+ rf_dt = cast ( pd . Series , rf ) .dt
394424 rf_year = rf_dt .year
395- iso_rf = rf . dt .isocalendar ()
425+ iso_rf = rf_dt .isocalendar ()
396426 rf_day = iso_rf ["day" ]
397427
398428 match self .span :
@@ -413,7 +443,7 @@ def reverse(self, data: pd.DataFrame, ref: pd.Series | None = None) -> pd.Series
413443 + 1
414444 ).clip (0 ),
415445 unit = "days" ,
416- )
446+ ) # type: ignore
417447 case "day" :
418448 # TODO: fix negative spans
419449 out = rf_dt .normalize () + pd .to_timedelta (
@@ -441,8 +471,11 @@ def fit(
441471 self ,
442472 data : pd .Series ,
443473 ):
474+ self .col = cast (str , data .name )
475+ self ._finalize_props ()
476+
477+ def _finalize_props (self ):
444478 span = self .span
445- self .col = data .name
446479
447480 hours = []
448481 for hour in range (24 ):
@@ -484,7 +517,9 @@ def fit(
484517 self .domain = lvl .size
485518
486519 self .attr = Attribute (
487- cast (str , data .name ), {f"{ data .name } _time" : StratifiedValue (lvl )}, self .nullable
520+ self .col ,
521+ {f"{ self .col } _time" : StratifiedValue (lvl )},
522+ self .nullable ,
488523 )
489524
490525 def get_attributes (self ) -> Attributes :
@@ -592,6 +627,14 @@ def fit(
592627 self .dt .fit (data , ref )
593628 self .tt .fit (data )
594629
630+ self ._finalize_props ()
631+
632+ def reduce (self , other : "DatetimeTransformer" ):
633+ self .dt .reduce (other .dt )
634+ self .tt .reduce (other .tt )
635+ self ._finalize_props ()
636+
637+ def _finalize_props (self ):
595638 cdt = next (iter (self .dt .get_attributes ().values ()))
596639 ctt = next (iter (self .tt .get_attributes ().values ()))
597640 self .attr = Attribute (self .col , vals = {** cdt .vals , ** ctt .vals }, na = self .nullable )
0 commit comments