Skip to content

Commit a8dd683

Browse files
committed
document attribute
1 parent 2f8c578 commit a8dd683

2 files changed

Lines changed: 153 additions & 88 deletions

File tree

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def skip(app, what, name, obj, skip, options):
239239

240240
def setup(app):
241241
app.connect("autodoc-process-docstring", autodoc_process_docstring)
242-
app.connect("autodoc-skip-member", skip)
242+
# app.connect("autodoc-skip-member", skip)
243243
# enable rendering RST tables in Markdown
244244
app.add_config_value("recommonmark_config", {"enable_eval_rst": True}, True)
245245
app.add_transform(AutoStructify)

src/pasteur/attribute.py

Lines changed: 152 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,33 @@
1+
"""This module implements the base abstractions of Attribute and Value,
2+
which are used to encapsulate the information of complex types.
3+
4+
Values are separated into CatValues (Categorical; defined by an index) and
5+
NumValues (Numerical). In addition, we define StratifiedValue as a special
6+
CatValue that holds a Stratification.
7+
8+
The Stratification represents a tree, created through Grouping nodes.
9+
Groupings are essentially lists with special functions.
10+
We separate groupings based on whether they are categorical (child order irrelevant)
11+
or ordinal (nearby children are more similar to each other).
12+
13+
The leafs of the tree represent the values the Value can take.
14+
They take the form of strings that describe each value, but are essentially placeholders.
15+
The Value is an integer. To map leafs to integers, the Tree is searched with Depth
16+
First Search, respecting child order to be deterministic.
17+
18+
An Attribute holds multiple values and a set of common conditions. When a common
19+
condition is active, all of the Attribute's Values are expected to have the same
20+
value."""
21+
122
import numpy as np
223
from typing import Any, Literal, TypeVar
324
from copy import copy
425

526

627
def get_dtype(domain: int):
28+
"""Returns the smallest NumPy unsigned integer dtype that will fit integers
29+
up to `domain` - 1."""
30+
731
# uint16 is 2x as fast as uint32 (5ms -> 3ms), use with marginals.
832
# Marginal domain can not exceed max(uint16) size 65535 + 1
933
if domain <= 1 << 8:
@@ -15,14 +39,17 @@ def get_dtype(domain: int):
1539
return np.uint64
1640

1741

18-
LI = TypeVar("LI", "Level", int)
42+
GI = TypeVar("GI", "Grouping", int)
43+
1944

45+
class Grouping(list[GI]):
46+
""" An enchanced form of list that holds the type of grouping (categorical, ordinal),
47+
and implements helper functions and an enchanced string representation."""
2048

21-
class Level(list[LI]):
22-
def __init__(self, type: Literal["cat", "ord"], arr: list["Level" | Any]):
49+
def __init__(self, type: Literal["cat", "ord"], arr: list["Grouping" | Any]):
2350
lvls = []
2451
for a in arr:
25-
if isinstance(a, Level):
52+
if isinstance(a, Grouping):
2653
lvls.append(a)
2754
else:
2855
lvls.append(str(a))
@@ -46,19 +73,19 @@ def __repr__(self) -> str:
4673
def height(self) -> int:
4774
if not self:
4875
return 0
49-
return max(lvl.height if isinstance(lvl, Level) else 0 for lvl in self) + 1
76+
return max(g.height if isinstance(g, Grouping) else 0 for g in self) + 1
5077

5178
@property
5279
def size(self) -> int:
53-
return sum(lvl.size if isinstance(lvl, Level) else 1 for lvl in self)
80+
return sum(g.size if isinstance(g, Grouping) else 1 for g in self)
5481

5582
def get_domain(self, height: int):
5683
return len(self.get_groups(height))
5784

5885
def _get_groups_by_level(self, lvl: int, ofs: int = 0):
5986
groups: list[list | int] = []
6087
for l in self:
61-
if isinstance(l, Level):
88+
if isinstance(l, Grouping):
6289
g, ofs = l._get_groups_by_level(lvl - 1, ofs)
6390

6491
if lvl == 0:
@@ -97,14 +124,16 @@ def get_mapping(self, height: int) -> np.ndarray:
97124
def get_human_values(self) -> list[str]:
98125
out = []
99126
for lvl in self:
100-
if isinstance(lvl, Level):
127+
if isinstance(lvl, Grouping):
101128
out.extend(lvl.get_human_values())
102129
else:
103130
out.append(str(lvl))
104131
return out
105132

106133
@staticmethod
107-
def from_str(a: str, nullable: bool = False, ukn_val: Any | None = None) -> "Level":
134+
def from_str(
135+
a: str, nullable: bool = False, ukn_val: Any | None = None
136+
) -> "Grouping":
108137
stack = [[]]
109138
is_ord = [False]
110139
bracket_closed = False
@@ -141,30 +170,39 @@ def from_str(a: str, nullable: bool = False, ukn_val: Any | None = None) -> "Lev
141170
case "}":
142171
children = stack.pop()
143172
assert not is_ord.pop(), "Unmatched '[' bracket, found '}'"
144-
stack[-1].append(Level("cat", children))
173+
stack[-1].append(Grouping("cat", children))
145174
case "[":
146175
stack.append([])
147176
is_ord.append(True)
148177
case "]":
149178
children = stack.pop()
150179
assert is_ord.pop(), "Unmatched '{' bracket, found ']'"
151-
stack[-1].append(Level("ord", children))
180+
stack[-1].append(Grouping("ord", children))
152181

153182
lvl_attrs = stack[0]
154183
if len(lvl_attrs) == 1:
155184
lvl = lvl_attrs[0]
156185
else:
157-
lvl = Level("cat", lvl_attrs)
186+
lvl = Grouping("cat", lvl_attrs)
158187

159188
return lvl
160189

161190

162191
class Value:
192+
""" Base value class """
163193
name: str | None = None
164194
common: int = 0
165195

166196

167-
class IdxValue(Value):
197+
class CatValue(Value):
198+
""" Class for a Categorical Value.
199+
200+
Each Categorical Value is represented by an unsigned integer.
201+
It can also group its different values together based on an integer parameter
202+
named height.
203+
The implementation of this class remains abstract, and is expanded in
204+
the StratifiedValue class. """
205+
168206
def get_domain(self, height: int = 0) -> int:
169207
"""Returns the domain of the attribute in the given height."""
170208
raise NotImplementedError()
@@ -176,30 +214,39 @@ def get_mapping(self, height: int) -> np.ndarray:
176214

177215
@property
178216
def height(self) -> int:
179-
"""Returns the maximum height of this column."""
217+
"""Returns the maximum height of this value."""
180218
return 0
181219

182220
@property
183221
def domain(self):
184222
return self.get_domain(0)
185223

186224
def is_ordinal(self) -> bool:
187-
"""Returns whether this column is ordinal, other than for the elements
225+
"""Returns whether this value is ordinal, other than for the elements
188226
it shares in common with the other attributes."""
189227
return False
190228

191-
def downsample(self, column: np.ndarray, height: int):
229+
def downsample(self, value: np.ndarray, height: int):
230+
""" Receives an array named `value` and downsamples it based on the provided
231+
height, by grouping certain values together. The proper implementation
232+
is provided by pasteur.hierarchy."""
192233
if height == 0:
193-
return column
194-
return self.get_mapping(height)[column]
195-
196-
def upsample(self, column: np.ndarray, height: int, deterministic: bool = True):
234+
return value
235+
return self.get_mapping(height)[value]
236+
237+
def upsample(self, value: np.ndarray, height: int, deterministic: bool = True):
238+
"""Does the opposite of downsample. If deterministic is True, for each
239+
group at a given height one of its values is chosen arbitrarily to represent
240+
all children of the group.
241+
242+
If deterministic is False, the group is sampled based on this Value's
243+
histogram (not implemented in this class; see pasteur.hierarchy)."""
197244
if height == 0:
198-
return column
245+
return value
199246

200247
assert (
201248
deterministic
202-
), "Current column doesn't contain a histogram, can't upsample"
249+
), "Current value doesn't contain a histogram, can't upsample"
203250

204251
d = self.get_domain(height)
205252
mapping = self.get_mapping(height)
@@ -210,18 +257,22 @@ def upsample(self, column: np.ndarray, height: int, deterministic: bool = True):
210257
c = (mapping == i).argmax()
211258
reverse_map[i] = c
212259

213-
return reverse_map[column]
260+
return reverse_map[value]
214261

215262
def select_height(self) -> int:
216263
return 0
217264

218265

219-
class LevelValue(IdxValue):
220-
"""A specific type of IdxColumn, which contains a hierarchical attribute
221-
structure based on a tree."""
266+
class StratifiedValue(CatValue):
267+
"""A version of CategoricalValue which uses a Stratification to represent
268+
the domain knowledge of the Value.
269+
270+
Each unique value is mapped to a tree
271+
with nodes where the child order matters.
272+
By traversing the tree in DFS, each leaf is mapped to an integer."""
222273

223-
def __init__(self, lvl: Level, common: int = 0) -> None:
224-
self.head = lvl
274+
def __init__(self, head: Grouping, common: int = 0) -> None:
275+
self.head = head
225276
self.common = common
226277

227278
def __str__(self) -> str:
@@ -252,47 +303,41 @@ def height(self):
252303
return self.head.height
253304

254305

255-
class CatValue(LevelValue):
256-
"""Initializer for LevelColumn, which initializes a single level Categorical column."""
306+
def _create_strat_value_cat(vals, na: bool = False, ukn_val: Any | None = None):
307+
arr = []
308+
common = 0
309+
if na:
310+
arr.append(None)
311+
common += 1
312+
if ukn_val is not None:
313+
arr.append(ukn_val)
314+
common += 1
315+
arr.extend(vals)
316+
317+
return StratifiedValue(Grouping("cat", arr))
257318

258-
def __init__(self, vals, na: bool = False, ukn_val: Any | None = None):
319+
320+
def _create_strat_value_ord(vals, na: bool = False, ukn_val: Any | None = None):
321+
g = Grouping("ord", vals)
322+
common = 0
323+
324+
if na or ukn_val is not None:
259325
arr = []
260-
common = 0
261326
if na:
262-
arr.append(None)
263327
common += 1
328+
arr.append(None)
264329
if ukn_val is not None:
265-
arr.append(ukn_val)
266330
common += 1
267-
arr.extend(vals)
268-
269-
super().__init__(Level("cat", arr))
270-
271-
272-
class OrdValue(LevelValue):
273-
"""Initializer for LevelColumn, which initializes a single level Ordinal column, which might have common values."""
274-
275-
def __init__(self, vals, na: bool = False, ukn_val: Any | None = None):
276-
lvl = Level("ord", vals)
277-
common = 0
278-
279-
if na or ukn_val is not None:
280-
arr = []
281-
if na:
282-
common += 1
283-
arr.append(None)
284-
if ukn_val is not None:
285-
common += 1
286-
arr.append(ukn_val)
287-
arr.append(lvl)
331+
arr.append(ukn_val)
332+
arr.append(g)
288333

289-
lvl = Level("cat", arr)
334+
g = Grouping("cat", arr)
290335

291-
super().__init__(lvl, common)
336+
return StratifiedValue(g, common)
292337

293338

294339
class NumValue(Value):
295-
"""Numerical Column, its value can be represented with a number, which might be NaN.
340+
"""Numerical Value: its value can be represented with a number, which might be NaN.
296341
297342
TODO: handle multiple common values (1 is assumed to be NA), appropriately."""
298343

@@ -317,17 +362,23 @@ def __repr__(self) -> str:
317362

318363

319364
class Attribute:
365+
"""Attribute class which holds multiple values in a dictionary."""
366+
320367
def __init__(
321368
self,
322369
name: str,
323370
vals: dict[str, V],
324371
na: bool = False,
325372
ukn_val: bool = False,
373+
common: int | None = None,
326374
) -> None:
327375
self.name = name
328376
self.na = na
329377
self.ukn_val = ukn_val
330-
self.common = self.na + self.ukn_val
378+
if common is None:
379+
self.common = self.na + self.ukn_val
380+
else:
381+
self.common = common
331382

332383
self.update_vals(vals)
333384

@@ -352,31 +403,45 @@ def __getitem__(self, col: str) -> Value:
352403
Attributes = dict[str, Attribute]
353404

354405

355-
class OrdAttribute(Attribute):
356-
def __init__(
357-
self, name: str, vals: list[Any], na: bool = False, ukn_val: Any | None = None
358-
) -> None:
359-
cols = {name: OrdValue(vals, na, ukn_val)}
360-
361-
super().__init__(name, cols, na, ukn_val is not None)
362-
363-
364-
class CatAttribute(Attribute):
365-
def __init__(
366-
self, name: str, vals: list[Any], na: bool = False, ukn_val: Any | None = None
367-
) -> None:
368-
cols = {name: OrdValue(vals, na, ukn_val)}
369-
370-
super().__init__(name, cols, na, ukn_val is not None)
371-
372-
373-
class NumAttribute(Attribute):
374-
def __init__(
375-
self,
376-
name: str,
377-
bins: int,
378-
min: int | float | None,
379-
max: int | float | None,
380-
nullable: bool = False,
381-
) -> None:
382-
super().__init__(name, {name: NumValue(bins, min, max)}, nullable, False)
406+
def OrdAttribute(
407+
name: str, vals: list[Any], na: bool = False, ukn_val: Any | None = None
408+
):
409+
"""Returns an Attribute holding a single Stratified Value where its children
410+
are ordinal, based on the provided data."""
411+
cols = {name: _create_strat_value_ord(vals, na, ukn_val)}
412+
return Attribute(name, cols, na, ukn_val is not None)
413+
414+
415+
def CatAttribute(
416+
name: str, vals: list[Any], na: bool = False, ukn_val: Any | None = None
417+
):
418+
"""Returns an Attribute holding a single Stratified Value where its children
419+
are categorical, based on the provided data."""
420+
cols = {name: _create_strat_value_cat(vals, na, ukn_val)}
421+
return Attribute(name, cols, na, ukn_val is not None)
422+
423+
424+
def NumAttribute(
425+
name: str,
426+
bins: int,
427+
min: int | float | None,
428+
max: int | float | None,
429+
nullable: bool = False,
430+
):
431+
"""Returns an Attribute holding a single NumValue with the provided data."""
432+
return Attribute(name, {name: NumValue(bins, min, max)}, nullable, False)
433+
434+
435+
__all__ = [
436+
"get_dtype",
437+
"Grouping",
438+
"Value",
439+
"CatValue",
440+
"NumValue",
441+
"StratifiedValue",
442+
"Attribute",
443+
"Attributes",
444+
"OrdAttribute",
445+
"CatAttribute",
446+
"NumAttribute",
447+
]

0 commit comments

Comments
 (0)