11from copy import copy
2+ from typing import cast
23
34import numpy as np
45import pandas as pd
56
6- from ..attribute import Attribute , IdxValue , NumValue , OrdValue , get_dtype
7+ from ..attribute import (
8+ Attribute ,
9+ CatValue ,
10+ NumValue ,
11+ _create_strat_value_ord ,
12+ get_dtype ,
13+ )
714from ..encode import Encoder
815
916
1017class DiscretizationColumnTransformer :
1118 """Converts a numerical column into an ordinal one using histograms."""
1219
13- def fit (self , attr : NumValue , data : pd .Series ) -> IdxValue :
20+ def fit (self , attr : NumValue , data : pd .Series ) -> CatValue :
1421 self .in_attr = attr
1522 assert data .name
16- self .col = data .name
23+ self .col = cast ( str , data .name )
1724
1825 rng = (
1926 (attr .min , attr .max )
@@ -26,7 +33,7 @@ def fit(self, attr: NumValue, data: pd.Series) -> IdxValue:
2633 self .vals = ((self .edges [:- 1 ] + self .edges [1 :]) / 2 ).astype (np .float32 )
2734
2835 if attr .common <= 1 :
29- self .attr = OrdValue (self .vals , na = attr .common == 1 )
36+ self .attr = _create_strat_value_ord (self .vals , na = attr .common == 1 )
3037 else :
3138 assert (
3239 False
@@ -117,7 +124,7 @@ def fit(self, attr: Attribute, data: pd.DataFrame) -> Attribute:
117124 skip_common = False
118125 if len (attr .vals ) == 1 :
119126 v = next (iter (attr .vals .values ()))
120- if isinstance (v , IdxValue ) and v .is_ordinal :
127+ if isinstance (v , CatValue ) and v .is_ordinal :
121128 skip_common = True
122129
123130 if not skip_common :
@@ -127,7 +134,7 @@ def fit(self, attr: Attribute, data: pd.DataFrame) -> Attribute:
127134 for name , col in attr .vals .items ():
128135 if isinstance (col , NumValue ):
129136 cols [name ] = col
130- elif isinstance (col , IdxValue ):
137+ elif isinstance (col , CatValue ):
131138 if col .is_ordinal ():
132139 cols [name ] = NumValue ()
133140 else :
@@ -150,14 +157,14 @@ def encode(self, data: pd.DataFrame) -> pd.DataFrame:
150157 skip_common = False
151158 if len (a .vals ) == 1 :
152159 v = next (iter (a .vals .values ()))
153- if isinstance (v , IdxValue ) and v .is_ordinal :
160+ if isinstance (v , CatValue ) and v .is_ordinal :
154161 skip_common = True
155162
156163 for i in range (a .common ) if not skip_common else []:
157164 cmn_col = pd .Series (False , index = data .index , name = f"{ a .name } _cmn_{ i } " , dtype = np .float32 )
158165
159166 for name , col in a .vals .items ():
160- if isinstance (col , IdxValue ):
167+ if isinstance (col , CatValue ):
161168 cmn_col += data [name ] == i
162169 elif isinstance (col , NumValue ) and only_has_na :
163170 # Numerical values are expected to be NA for all common values
@@ -170,7 +177,7 @@ def encode(self, data: pd.DataFrame) -> pd.DataFrame:
170177 for name , col in a .vals .items ():
171178 if isinstance (col , NumValue ):
172179 cols .append (data [name ])
173- elif isinstance (col , IdxValue ):
180+ elif isinstance (col , CatValue ):
174181 # TODO add proper encodings other than one hot
175182
176183 # Handle ordinal values
0 commit comments