1+ """This module implements the base abstractions of Attribute and Value,
2+ which are used to encapsulate the information of complex types.
3+
4+ Values are separated into CatValues (Categorical; defined by an index) and
5+ NumValues (Numerical). In addition, we define StratifiedValue as a special
6+ CatValue that holds a Stratification.
7+
8+ The Stratification represents a tree, created through Grouping nodes.
9+ Groupings are essentially lists with special functions.
10+ We separate groupings based on whether they are categorical (child order irrelevant)
11+ or ordinal (nearby children are more similar to each other).
12+
13+ The leafs of the tree represent the values the Value can take.
14+ They take the form of strings that describe each value, but are essentially placeholders.
15+ The Value is an integer. To map leafs to integers, the Tree is searched with Depth
16+ First Search, respecting child order to be deterministic.
17+
18+ An Attribute holds multiple values and a set of common conditions. When a common
19+ condition is active, all of the Attribute's Values are expected to have the same
20+ value."""
21+
122import numpy as np
223from typing import Any , Literal , TypeVar
324from copy import copy
425
526
627def get_dtype (domain : int ):
28+ """Returns the smallest NumPy unsigned integer dtype that will fit integers
29+ up to `domain` - 1."""
30+
731 # uint16 is 2x as fast as uint32 (5ms -> 3ms), use with marginals.
832 # Marginal domain can not exceed max(uint16) size 65535 + 1
933 if domain <= 1 << 8 :
@@ -15,14 +39,17 @@ def get_dtype(domain: int):
1539 return np .uint64
1640
1741
18- LI = TypeVar ("LI" , "Level" , int )
42+ GI = TypeVar ("GI" , "Grouping" , int )
43+
1944
45+ class Grouping (list [GI ]):
46+ """ An enchanced form of list that holds the type of grouping (categorical, ordinal),
47+ and implements helper functions and an enchanced string representation."""
2048
21- class Level (list [LI ]):
22- def __init__ (self , type : Literal ["cat" , "ord" ], arr : list ["Level" | Any ]):
49+ def __init__ (self , type : Literal ["cat" , "ord" ], arr : list ["Grouping" | Any ]):
2350 lvls = []
2451 for a in arr :
25- if isinstance (a , Level ):
52+ if isinstance (a , Grouping ):
2653 lvls .append (a )
2754 else :
2855 lvls .append (str (a ))
@@ -46,19 +73,19 @@ def __repr__(self) -> str:
4673 def height (self ) -> int :
4774 if not self :
4875 return 0
49- return max (lvl .height if isinstance (lvl , Level ) else 0 for lvl in self ) + 1
76+ return max (g .height if isinstance (g , Grouping ) else 0 for g in self ) + 1
5077
5178 @property
5279 def size (self ) -> int :
53- return sum (lvl .size if isinstance (lvl , Level ) else 1 for lvl in self )
80+ return sum (g .size if isinstance (g , Grouping ) else 1 for g in self )
5481
5582 def get_domain (self , height : int ):
5683 return len (self .get_groups (height ))
5784
5885 def _get_groups_by_level (self , lvl : int , ofs : int = 0 ):
5986 groups : list [list | int ] = []
6087 for l in self :
61- if isinstance (l , Level ):
88+ if isinstance (l , Grouping ):
6289 g , ofs = l ._get_groups_by_level (lvl - 1 , ofs )
6390
6491 if lvl == 0 :
@@ -97,14 +124,16 @@ def get_mapping(self, height: int) -> np.ndarray:
97124 def get_human_values (self ) -> list [str ]:
98125 out = []
99126 for lvl in self :
100- if isinstance (lvl , Level ):
127+ if isinstance (lvl , Grouping ):
101128 out .extend (lvl .get_human_values ())
102129 else :
103130 out .append (str (lvl ))
104131 return out
105132
106133 @staticmethod
107- def from_str (a : str , nullable : bool = False , ukn_val : Any | None = None ) -> "Level" :
134+ def from_str (
135+ a : str , nullable : bool = False , ukn_val : Any | None = None
136+ ) -> "Grouping" :
108137 stack = [[]]
109138 is_ord = [False ]
110139 bracket_closed = False
@@ -141,30 +170,39 @@ def from_str(a: str, nullable: bool = False, ukn_val: Any | None = None) -> "Lev
141170 case "}" :
142171 children = stack .pop ()
143172 assert not is_ord .pop (), "Unmatched '[' bracket, found '}'"
144- stack [- 1 ].append (Level ("cat" , children ))
173+ stack [- 1 ].append (Grouping ("cat" , children ))
145174 case "[" :
146175 stack .append ([])
147176 is_ord .append (True )
148177 case "]" :
149178 children = stack .pop ()
150179 assert is_ord .pop (), "Unmatched '{' bracket, found ']'"
151- stack [- 1 ].append (Level ("ord" , children ))
180+ stack [- 1 ].append (Grouping ("ord" , children ))
152181
153182 lvl_attrs = stack [0 ]
154183 if len (lvl_attrs ) == 1 :
155184 lvl = lvl_attrs [0 ]
156185 else :
157- lvl = Level ("cat" , lvl_attrs )
186+ lvl = Grouping ("cat" , lvl_attrs )
158187
159188 return lvl
160189
161190
162191class Value :
192+ """ Base value class """
163193 name : str | None = None
164194 common : int = 0
165195
166196
167- class IdxValue (Value ):
197+ class CatValue (Value ):
198+ """ Class for a Categorical Value.
199+
200+ Each Categorical Value is represented by an unsigned integer.
201+ It can also group its different values together based on an integer parameter
202+ named height.
203+ The implementation of this class remains abstract, and is expanded in
204+ the StratifiedValue class. """
205+
168206 def get_domain (self , height : int = 0 ) -> int :
169207 """Returns the domain of the attribute in the given height."""
170208 raise NotImplementedError ()
@@ -176,30 +214,39 @@ def get_mapping(self, height: int) -> np.ndarray:
176214
177215 @property
178216 def height (self ) -> int :
179- """Returns the maximum height of this column ."""
217+ """Returns the maximum height of this value ."""
180218 return 0
181219
182220 @property
183221 def domain (self ):
184222 return self .get_domain (0 )
185223
186224 def is_ordinal (self ) -> bool :
187- """Returns whether this column is ordinal, other than for the elements
225+ """Returns whether this value is ordinal, other than for the elements
188226 it shares in common with the other attributes."""
189227 return False
190228
191- def downsample (self , column : np .ndarray , height : int ):
229+ def downsample (self , value : np .ndarray , height : int ):
230+ """ Receives an array named `value` and downsamples it based on the provided
231+ height, by grouping certain values together. The proper implementation
232+ is provided by pasteur.hierarchy."""
192233 if height == 0 :
193- return column
194- return self .get_mapping (height )[column ]
195-
196- def upsample (self , column : np .ndarray , height : int , deterministic : bool = True ):
234+ return value
235+ return self .get_mapping (height )[value ]
236+
237+ def upsample (self , value : np .ndarray , height : int , deterministic : bool = True ):
238+ """Does the opposite of downsample. If deterministic is True, for each
239+ group at a given height one of its values is chosen arbitrarily to represent
240+ all children of the group.
241+
242+ If deterministic is False, the group is sampled based on this Value's
243+ histogram (not implemented in this class; see pasteur.hierarchy)."""
197244 if height == 0 :
198- return column
245+ return value
199246
200247 assert (
201248 deterministic
202- ), "Current column doesn't contain a histogram, can't upsample"
249+ ), "Current value doesn't contain a histogram, can't upsample"
203250
204251 d = self .get_domain (height )
205252 mapping = self .get_mapping (height )
@@ -210,18 +257,22 @@ def upsample(self, column: np.ndarray, height: int, deterministic: bool = True):
210257 c = (mapping == i ).argmax ()
211258 reverse_map [i ] = c
212259
213- return reverse_map [column ]
260+ return reverse_map [value ]
214261
215262 def select_height (self ) -> int :
216263 return 0
217264
218265
219- class LevelValue (IdxValue ):
220- """A specific type of IdxColumn, which contains a hierarchical attribute
221- structure based on a tree."""
266+ class StratifiedValue (CatValue ):
267+ """A version of CategoricalValue which uses a Stratification to represent
268+ the domain knowledge of the Value.
269+
270+ Each unique value is mapped to a tree
271+ with nodes where the child order matters.
272+ By traversing the tree in DFS, each leaf is mapped to an integer."""
222273
223- def __init__ (self , lvl : Level , common : int = 0 ) -> None :
224- self .head = lvl
274+ def __init__ (self , head : Grouping , common : int = 0 ) -> None :
275+ self .head = head
225276 self .common = common
226277
227278 def __str__ (self ) -> str :
@@ -252,47 +303,41 @@ def height(self):
252303 return self .head .height
253304
254305
255- class CatValue (LevelValue ):
256- """Initializer for LevelColumn, which initializes a single level Categorical column."""
306+ def _create_strat_value_cat (vals , na : bool = False , ukn_val : Any | None = None ):
307+ arr = []
308+ common = 0
309+ if na :
310+ arr .append (None )
311+ common += 1
312+ if ukn_val is not None :
313+ arr .append (ukn_val )
314+ common += 1
315+ arr .extend (vals )
316+
317+ return StratifiedValue (Grouping ("cat" , arr ))
257318
258- def __init__ (self , vals , na : bool = False , ukn_val : Any | None = None ):
319+
320+ def _create_strat_value_ord (vals , na : bool = False , ukn_val : Any | None = None ):
321+ g = Grouping ("ord" , vals )
322+ common = 0
323+
324+ if na or ukn_val is not None :
259325 arr = []
260- common = 0
261326 if na :
262- arr .append (None )
263327 common += 1
328+ arr .append (None )
264329 if ukn_val is not None :
265- arr .append (ukn_val )
266330 common += 1
267- arr .extend (vals )
268-
269- super ().__init__ (Level ("cat" , arr ))
270-
271-
272- class OrdValue (LevelValue ):
273- """Initializer for LevelColumn, which initializes a single level Ordinal column, which might have common values."""
274-
275- def __init__ (self , vals , na : bool = False , ukn_val : Any | None = None ):
276- lvl = Level ("ord" , vals )
277- common = 0
278-
279- if na or ukn_val is not None :
280- arr = []
281- if na :
282- common += 1
283- arr .append (None )
284- if ukn_val is not None :
285- common += 1
286- arr .append (ukn_val )
287- arr .append (lvl )
331+ arr .append (ukn_val )
332+ arr .append (g )
288333
289- lvl = Level ("cat" , arr )
334+ g = Grouping ("cat" , arr )
290335
291- super (). __init__ ( lvl , common )
336+ return StratifiedValue ( g , common )
292337
293338
294339class NumValue (Value ):
295- """Numerical Column, its value can be represented with a number, which might be NaN.
340+ """Numerical Value: its value can be represented with a number, which might be NaN.
296341
297342 TODO: handle multiple common values (1 is assumed to be NA), appropriately."""
298343
@@ -317,17 +362,23 @@ def __repr__(self) -> str:
317362
318363
319364class Attribute :
365+ """Attribute class which holds multiple values in a dictionary."""
366+
320367 def __init__ (
321368 self ,
322369 name : str ,
323370 vals : dict [str , V ],
324371 na : bool = False ,
325372 ukn_val : bool = False ,
373+ common : int | None = None ,
326374 ) -> None :
327375 self .name = name
328376 self .na = na
329377 self .ukn_val = ukn_val
330- self .common = self .na + self .ukn_val
378+ if common is None :
379+ self .common = self .na + self .ukn_val
380+ else :
381+ self .common = common
331382
332383 self .update_vals (vals )
333384
@@ -352,31 +403,45 @@ def __getitem__(self, col: str) -> Value:
352403Attributes = dict [str , Attribute ]
353404
354405
355- class OrdAttribute (Attribute ):
356- def __init__ (
357- self , name : str , vals : list [Any ], na : bool = False , ukn_val : Any | None = None
358- ) -> None :
359- cols = {name : OrdValue (vals , na , ukn_val )}
360-
361- super ().__init__ (name , cols , na , ukn_val is not None )
362-
363-
364- class CatAttribute (Attribute ):
365- def __init__ (
366- self , name : str , vals : list [Any ], na : bool = False , ukn_val : Any | None = None
367- ) -> None :
368- cols = {name : OrdValue (vals , na , ukn_val )}
369-
370- super ().__init__ (name , cols , na , ukn_val is not None )
371-
372-
373- class NumAttribute (Attribute ):
374- def __init__ (
375- self ,
376- name : str ,
377- bins : int ,
378- min : int | float | None ,
379- max : int | float | None ,
380- nullable : bool = False ,
381- ) -> None :
382- super ().__init__ (name , {name : NumValue (bins , min , max )}, nullable , False )
406+ def OrdAttribute (
407+ name : str , vals : list [Any ], na : bool = False , ukn_val : Any | None = None
408+ ):
409+ """Returns an Attribute holding a single Stratified Value where its children
410+ are ordinal, based on the provided data."""
411+ cols = {name : _create_strat_value_ord (vals , na , ukn_val )}
412+ return Attribute (name , cols , na , ukn_val is not None )
413+
414+
415+ def CatAttribute (
416+ name : str , vals : list [Any ], na : bool = False , ukn_val : Any | None = None
417+ ):
418+ """Returns an Attribute holding a single Stratified Value where its children
419+ are categorical, based on the provided data."""
420+ cols = {name : _create_strat_value_cat (vals , na , ukn_val )}
421+ return Attribute (name , cols , na , ukn_val is not None )
422+
423+
424+ def NumAttribute (
425+ name : str ,
426+ bins : int ,
427+ min : int | float | None ,
428+ max : int | float | None ,
429+ nullable : bool = False ,
430+ ):
431+ """Returns an Attribute holding a single NumValue with the provided data."""
432+ return Attribute (name , {name : NumValue (bins , min , max )}, nullable , False )
433+
434+
435+ __all__ = [
436+ "get_dtype" ,
437+ "Grouping" ,
438+ "Value" ,
439+ "CatValue" ,
440+ "NumValue" ,
441+ "StratifiedValue" ,
442+ "Attribute" ,
443+ "Attributes" ,
444+ "OrdAttribute" ,
445+ "CatAttribute" ,
446+ "NumAttribute" ,
447+ ]
0 commit comments