Skip to content

Commit 9144c5e

Browse files
committed
fix refactor introduced errors
1 parent 992e2c6 commit 9144c5e

11 files changed

Lines changed: 60 additions & 34 deletions

File tree

src/pasteur/attribute.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def upsample(self, value: np.ndarray, height: int, deterministic: bool = True):
262262
def select_height(self) -> int:
263263
return 0
264264

265+
IdxValue = CatValue
265266

266267
class StratifiedValue(CatValue):
267268
"""A version of CategoricalValue which uses a Stratification to represent

src/pasteur/dataset.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def ingest(self, name, **tables: Any) -> LazyFrame:
102102
103103
@warning: all partitioned tables should have the same partitions.
104104
Some tables may not be partitioned.
105-
105+
106106
Tip: use a `match` statement to fork based on table name to per-table functions."""
107107
raise NotImplemented()
108108

@@ -149,5 +149,18 @@ def keys(self, **tables: LazyChunk) -> pd.DataFrame:
149149

150150
return tables["table"]()
151151

152+
class TypedDataset(Dataset):
153+
"""Extend from to create an intermediary step in ingestion, where the table
154+
is loaded from `<dataset>.raw@<table>` to a parquet one `<dataset>.typed.<table>.
155+
156+
Useful for multiple reads to raw tables. You can also override the `type()` function to make
157+
minor changes to the dataset. By default it's the identity.
158+
159+
Since parquet files don't support chunked loading it's unused."""
160+
161+
def type(self, table: Any):
162+
return table
163+
164+
152165

153166
__all__ = ["Dataset", "TabularDataset"]

src/pasteur/extras/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
""" This package contains reference implementations for Pasteur modules, which
2+
may be extracted to a separate package in the future."""
3+
14
from __future__ import annotations
25

36
from typing import TYPE_CHECKING

src/pasteur/extras/encoders.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
from copy import copy
2+
from typing import cast
23

34
import numpy as np
45
import pandas as pd
56

6-
from ..attribute import Attribute, IdxValue, NumValue, OrdValue, get_dtype
7+
from ..attribute import (
8+
Attribute,
9+
CatValue,
10+
NumValue,
11+
_create_strat_value_ord,
12+
get_dtype,
13+
)
714
from ..encode import Encoder
815

916

1017
class DiscretizationColumnTransformer:
1118
"""Converts a numerical column into an ordinal one using histograms."""
1219

13-
def fit(self, attr: NumValue, data: pd.Series) -> IdxValue:
20+
def fit(self, attr: NumValue, data: pd.Series) -> CatValue:
1421
self.in_attr = attr
1522
assert data.name
16-
self.col = data.name
23+
self.col = cast(str, data.name)
1724

1825
rng = (
1926
(attr.min, attr.max)
@@ -26,7 +33,7 @@ def fit(self, attr: NumValue, data: pd.Series) -> IdxValue:
2633
self.vals = ((self.edges[:-1] + self.edges[1:]) / 2).astype(np.float32)
2734

2835
if attr.common <= 1:
29-
self.attr = OrdValue(self.vals, na=attr.common == 1)
36+
self.attr = _create_strat_value_ord(self.vals, na=attr.common == 1)
3037
else:
3138
assert (
3239
False
@@ -117,7 +124,7 @@ def fit(self, attr: Attribute, data: pd.DataFrame) -> Attribute:
117124
skip_common = False
118125
if len(attr.vals) == 1:
119126
v = next(iter(attr.vals.values()))
120-
if isinstance(v, IdxValue) and v.is_ordinal:
127+
if isinstance(v, CatValue) and v.is_ordinal:
121128
skip_common = True
122129

123130
if not skip_common:
@@ -127,7 +134,7 @@ def fit(self, attr: Attribute, data: pd.DataFrame) -> Attribute:
127134
for name, col in attr.vals.items():
128135
if isinstance(col, NumValue):
129136
cols[name] = col
130-
elif isinstance(col, IdxValue):
137+
elif isinstance(col, CatValue):
131138
if col.is_ordinal():
132139
cols[name] = NumValue()
133140
else:
@@ -150,14 +157,14 @@ def encode(self, data: pd.DataFrame) -> pd.DataFrame:
150157
skip_common = False
151158
if len(a.vals) == 1:
152159
v = next(iter(a.vals.values()))
153-
if isinstance(v, IdxValue) and v.is_ordinal:
160+
if isinstance(v, CatValue) and v.is_ordinal:
154161
skip_common = True
155162

156163
for i in range(a.common) if not skip_common else []:
157164
cmn_col = pd.Series(False, index=data.index, name=f"{a.name}_cmn_{i}", dtype=np.float32)
158165

159166
for name, col in a.vals.items():
160-
if isinstance(col, IdxValue):
167+
if isinstance(col, CatValue):
161168
cmn_col += data[name] == i
162169
elif isinstance(col, NumValue) and only_has_na:
163170
# Numerical values are expected to be NA for all common values
@@ -170,7 +177,7 @@ def encode(self, data: pd.DataFrame) -> pd.DataFrame:
170177
for name, col in a.vals.items():
171178
if isinstance(col, NumValue):
172179
cols.append(data[name])
173-
elif isinstance(col, IdxValue):
180+
elif isinstance(col, CatValue):
174181
# TODO add proper encodings other than one hot
175182

176183
# Handle ordinal values

src/pasteur/extras/metrics/distr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from scipy.special import rel_entr
1010
from scipy.stats import chisquare
1111

12-
from ...attribute import Attributes, IdxValue, get_dtype
12+
from ...attribute import Attributes, CatValue, get_dtype
1313
from ...metric import Summaries, TableData, TableMetric
1414
from ...utils.progress import process_in_parallel
1515

@@ -70,7 +70,7 @@ def fit(
7070
self.domain = {}
7171
for attr in table_attrs.values():
7272
for name, val in attr.vals.items():
73-
assert isinstance(val, IdxValue)
73+
assert isinstance(val, CatValue)
7474
self.domain[name] = val.domain
7575

7676
def process_chunk(
@@ -187,7 +187,7 @@ def fit(
187187
self.domain = {}
188188
for attr in table_attrs.values():
189189
for name, val in attr.vals.items():
190-
assert isinstance(val, IdxValue)
190+
assert isinstance(val, CatValue)
191191
self.domain[name] = val.domain
192192

193193
def process_chunk(

src/pasteur/extras/synth/privbayes/implementation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import pandas as pd
77

8-
from ....attribute import Attributes, IdxValue, get_dtype
8+
from ....attribute import Attributes, CatValue, get_dtype
99
from ....marginal import (
1010
ZERO_FILL,
1111
AttrSelector,
@@ -230,7 +230,7 @@ def greedy_bayes(
230230
for i, (an, a) in enumerate(attrs.items()):
231231
group_names.append(an)
232232
for c_n, c in a.vals.items():
233-
c = cast(IdxValue, c)
233+
c = cast(CatValue, c)
234234
col_names.append(c_n)
235235
groups.append(i)
236236
heights.append(c.height)
@@ -245,7 +245,7 @@ def greedy_bayes(
245245

246246
for i, (an, a) in enumerate(attrs.items()):
247247
for c_n, c in a.vals.items():
248-
c = cast(IdxValue, c)
248+
c = cast(CatValue, c)
249249

250250
doms = []
251251
for i in range(c.height):
@@ -660,7 +660,7 @@ def sample_rows(
660660
p_partial = partial and attr_name == x_attr
661661
for i, (col_name, h) in enumerate(attr.cols.items()):
662662
col = attrs[attr_name].vals[col_name]
663-
col = cast(IdxValue, col)
663+
col = cast(CatValue, col)
664664
mapping = np.array(col.get_mapping(h), dtype=dtype)
665665
domain = col.get_domain(h)
666666

src/pasteur/extras/transformers.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from ..attribute import (
88
Attribute,
99
CatAttribute,
10-
Level,
11-
LevelValue,
10+
Grouping,
11+
CatValue,
1212
NumAttribute,
1313
NumValue,
1414
OrdAttribute,
15-
OrdValue,
15+
_create_strat_value_ord as OrdValue,
1616
get_dtype,
1717
)
1818
from ..transform import RefTransformer, Transformer
@@ -443,7 +443,7 @@ def fit(
443443
hours.append(f"{hour:02d}:00")
444444
elif span == "halfhour":
445445
hours.append(
446-
Level(
446+
Grouping(
447447
"ord",
448448
[f"{hour:02d}:00", f"{hour:02d}:30"],
449449
)
@@ -455,7 +455,7 @@ def fit(
455455
mins.append(f"{hour:02d}:{min:02d}")
456456
if span == "halfminute":
457457
mins.append(
458-
Level(
458+
Grouping(
459459
"ord",
460460
[
461461
f"{hour:02d}:{min:02d}:00",
@@ -467,17 +467,17 @@ def fit(
467467
secs = []
468468
for sec in range(60):
469469
secs.append(f"{hour:02d}:{min:02d}:{sec:02d}")
470-
mins.append(Level("ord", secs))
470+
mins.append(Grouping("ord", secs))
471471

472-
hours.append(Level("ord", mins))
473-
lvl = Level("ord", hours)
472+
hours.append(Grouping("ord", mins))
473+
lvl = Grouping("ord", hours)
474474
if self.nullable:
475-
lvl = Level("cat", [None, lvl])
475+
lvl = Grouping("cat", [None, lvl])
476476

477477
self.domain = lvl.size
478478

479479
self.attr = Attribute(
480-
cast(str, data.name), {f"{data.name}_time": LevelValue(lvl)}, self.nullable
480+
cast(str, data.name), {f"{data.name}_time": CatValue(lvl)}, self.nullable
481481
)
482482
return self.attr
483483

src/pasteur/kedro/pipelines/synth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def create_synth_pipeline(
2525
tables = view.tables
2626

2727
tags: list[str] = list(TAGS_SYNTH)
28-
if fr.gpu:
29-
tags.append(TAG_GPU)
28+
# if fr.gpu:
29+
# tags.append(TAG_GPU)
3030

3131
pipe = pipeline(
3232
[

src/pasteur/marginal/memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import numpy as np
55

6-
from ..attribute import Attributes, get_dtype, IdxValue
6+
from ..attribute import Attributes, get_dtype, CatValue
77
from ..utils import LazyFrame
88

99
class ArrayInfo(NamedTuple):
@@ -36,7 +36,7 @@ def allocate_memory(data: LazyFrame, attrs: Attributes, *, common: bool = False)
3636
continue
3737

3838
for name, col in attr.vals.items():
39-
col = cast(IdxValue, col)
39+
col = cast(CatValue, col)
4040
info[name] = []
4141
for height in range(col.height):
4242
shape = (n, )

src/pasteur/marginal/numpy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pandas as pd
55

6-
from ..attribute import Attributes, get_dtype, IdxValue
6+
from ..attribute import Attributes, get_dtype, CatValue
77

88
ZERO_FILL = 1e-24
99

@@ -63,7 +63,7 @@ def expand_table(
6363
if name not in table:
6464
continue
6565

66-
col = cast(IdxValue, col)
66+
col = cast(CatValue, col)
6767
col_hier = []
6868
col_noncommon = []
6969
col_dom = []
@@ -98,7 +98,7 @@ def get_domains(attrs: Attributes) -> dict[str, list[int]]:
9898
domains = {}
9999
for attr in attrs.values():
100100
for name, col in attr.vals.items():
101-
col = cast(IdxValue, col)
101+
col = cast(CatValue, col)
102102
col_dom = []
103103

104104
for height in range(col.height):

0 commit comments

Comments
 (0)