Skip to content

Commit a47b3a8

Browse files
committed
fix bugs for adult dataset
1 parent 36c9bdd commit a47b3a8

18 files changed

Lines changed: 247 additions & 64 deletions

File tree

src/pasteur/attribute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def from_str(
191191

192192
class Value:
193193
""" Base value class """
194-
name: str | None = None
194+
name: str | tuple[str] | None = None
195195
common: int = 0
196196

197197

src/pasteur/encode.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,24 @@
88
from .module import ModuleClass, ModuleFactory
99
from .utils import LazyFrame, LazyDataset
1010

11-
ENC = TypeVar("ENC", bound=ModuleClass)
12-
META = TypeVar("META")
13-
1411

15-
class AttributeEncoderFactory(ModuleFactory[ENC], Generic[ENC]):
12+
class AttributeEncoderFactory(ModuleFactory):
1613
"""Factory base class for encoders. Use isinstance with this class
1714
to filter the Pasteur module list into only containing Encoders."""
1815

1916
...
2017

2118

22-
class EncoderFactory(ModuleFactory[ENC], Generic[ENC]):
19+
class EncoderFactory(ModuleFactory):
2320
"""Factory base class for encoders. Use isinstance with this class
2421
to filter the Pasteur module list into only containing Encoders."""
2522

2623
...
2724

2825

26+
META = TypeVar("META")
27+
28+
2929
class AttributeEncoder(ModuleClass, Generic[META]):
3030
"""Encapsulates a special way to encode an Attribute.
3131
@@ -50,7 +50,7 @@ class AttributeEncoder(ModuleClass, Generic[META]):
5050
"""
5151

5252
name: str = ""
53-
_factory = AttributeEncoderFactory["AttributeEncoder"]
53+
_factory = AttributeEncoderFactory
5454

5555
def fit(self, attr: Attribute, data: pd.DataFrame | None):
5656
raise NotImplementedError()
@@ -70,7 +70,7 @@ def get_metadata(self) -> dict[str | tuple[str], META]:
7070

7171
class Encoder(ModuleClass, Generic[META]):
7272
name: str = ""
73-
_factory = EncoderFactory["Encoder"]
73+
_factory = EncoderFactory
7474

7575
def fit(
7676
self,

src/pasteur/extras/metrics/distr.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,15 @@ def _visualise_cs(
6868
results = {}
6969

7070
# Add ref split first
71+
zfill = lambda x: (x + 1) / np.sum(x + 1)
7172
name = "ref"
7273
res = []
7374
split = next(iter(data.values()))
7475
for col in domain:
7576
wrk, syn = split.wrk, split.ref
7677
assert syn is not None
77-
chi, p = chisquare(wrk[col], syn[col])
78+
79+
chi, p = chisquare(zfill(wrk[col]), zfill(syn[col]))
7880
res.append([col, chi, p])
7981

8082
results[name] = pd.DataFrame(res, columns=["col", "X^2", "p"])
@@ -84,7 +86,7 @@ def _visualise_cs(
8486
for col in domain:
8587
wrk, syn = split.wrk, split.syn
8688
assert syn is not None
87-
chi, p = chisquare(wrk[col], syn[col])
89+
chi, p = chisquare(zfill(wrk[col]), zfill(syn[col]))
8890
res.append([col, chi, p])
8991

9092
results[name] = pd.DataFrame(res, columns=["col", "X^2", "p"])
@@ -125,8 +127,10 @@ def _visualise_kl(
125127
res = []
126128
for key in syn:
127129
col_i, col_j = key
128-
k = wrk[key]
129-
j = syn[key]
130+
131+
zfill = lambda x: (x + KL_ZERO_FILL) / np.sum(x + KL_ZERO_FILL)
132+
k = zfill(wrk[key])
133+
j = zfill(syn[key])
130134

131135
kl = rel_entr(k, j).sum()
132136
kl_norm = 1 / (1 + kl)
@@ -168,19 +172,19 @@ def _process_marginals_chunk(
168172
):
169173
assert not expand_parents, "Expanding parents not supported yet"
170174

171-
table = tables[name]()[list(domain)].to_numpy(dtype="uint16")
175+
table = tables[name]()[list(domain[name])].to_numpy(dtype="uint16")
172176
table_domain = domain[name]
173177
domain_arr = np.array(list(table_domain.values()))
174178

175179
# One way for CS
176180
one_way: dict[str, ndarray] = {}
177-
for i, name in enumerate(domain):
181+
for i, name in enumerate(table_domain):
178182
one_way[name] = calc_marginal_1way(table, domain_arr, [i], 0)
179183

180184
# Two way for KL
181185
two_way: dict[tuple[str, str], ndarray] = {}
182-
for i, col_i in enumerate(domain):
183-
for j, col_j in enumerate(domain):
186+
for i, col_i in enumerate(table_domain):
187+
for j, col_j in enumerate(table_domain):
184188
two_way[(col_i, col_j)] = calc_marginal_1way(table, domain_arr, [i, j], 0)
185189

186190
return one_way, two_way
@@ -308,7 +312,7 @@ def process(
308312
# Intertwine results
309313
res = defaultdict(list)
310314
for meta, hist in zip(per_call_meta, out):
311-
res[meta["split"]][meta["table"]].append(hist)
315+
res[meta["table"]].append(hist)
312316

313317
ret = {}
314318
for table, table_hists in res.items():

src/pasteur/extras/synth/pgm/aim.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def __init__(
5555
def preprocess(self, meta: dict[str, Attributes], data: dict[str, LazyFrame]):
5656
self.table = next(iter(meta))
5757
self.attrs = meta
58+
self._n = data[self.table].shape[0]
59+
self._partitions = len(data[self.table])
5860

5961
@make_deterministic
6062
def bake(self, data: dict[str, LazyFrame]):

src/pasteur/extras/synth/pgm/mst.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def __init__(
8989
def preprocess(self, meta: dict[str, Attributes], data: dict[str, LazyFrame]):
9090
self.table = next(iter(meta))
9191
self.attrs = meta
92+
self._n = data[self.table].shape[0]
93+
self._partitions = len(data[self.table])
9294

9395
@make_deterministic
9496
def bake(self, data: dict[str, LazyFrame]):

src/pasteur/extras/synth/privbayes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ def preprocess(self, meta: dict[str, Attributes], data: dict[str, LazyFrame]):
6666
table = tables[table_name]
6767
table_attrs = attrs[table_name]
6868

69+
self._n = table.shape[0]
70+
self._partitions = len(table)
71+
6972
if self.rebalance:
7073
with MarginalOracle(
7174
table_attrs,

src/pasteur/extras/transformers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def fit(self, data: pd.Series):
5252
if self.max is None and self.find_edges:
5353
self.max = data.max()
5454
self.attr = NumAttribute(self.col, self.bins, self.min, self.max, self.nullable)
55-
return self.attr
55+
56+
def get_attributes(self) -> Attributes:
57+
return {self.attr.name: self.attr}
5658

5759
def transform(self, data: pd.Series) -> pd.DataFrame:
5860
return pd.DataFrame(data.clip(self.min, self.max).astype("float32"))

src/pasteur/kedro/dataset/auto.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _save_worker(
112112
w.write(pa.Table.from_pandas(p0, schema=schema))
113113
del p0
114114

115-
for p in chunk: # type: ignore
115+
for p in chunk: # type: ignore
116116
try:
117117
w.write(pa.Table.from_pandas(p, schema=schema))
118118
except Exception as e:
@@ -208,9 +208,9 @@ def _load_shape_worker(load_path: str, filesystem, *_, **__):
208208

209209
class AutoDataset(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]):
210210
"""Modified kedro parquet dataset that acts similarly to a partitioned dataset
211-
and implements lazy loading.
212-
213-
In the future, this dataset will automatically handle pickling, pyarrow
211+
and implements lazy loading.
212+
213+
In the future, this dataset will automatically handle pickling, pyarrow
214214
Tables, DataFrames, and Tensors automatically based on what is saved.
215215
216216
`save()` data can be a table, a callable, or a dictionary combination of both.
@@ -403,4 +403,4 @@ def _save(self, data: pd.DataFrame) -> None:
403403
def reset(self):
404404
save_path = get_filepath_str(self._get_save_path(), self._protocol)
405405
if self._fs.exists(save_path):
406-
self._fs.rm(save_path, recursive=True, maxdepth=1)
406+
self._fs.rm(save_path, recursive=True, maxdepth=1)

src/pasteur/kedro/dataset/multi.py

Lines changed: 145 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,152 @@
1-
from kedro.io.partitioned_dataset import PartitionedDataSet
1+
import warnings
2+
from copy import deepcopy
3+
from typing import Any, Callable
24

5+
from kedro.io.core import (
6+
VERSION_KEY,
7+
VERSIONED_FLAG_KEY,
8+
AbstractDataSet,
9+
DatasetError,
10+
parse_dataset_definition,
11+
)
12+
from kedro.io.partitioned_dataset import S3_PROTOCOLS
313

4-
class Multiset(PartitionedDataSet):
5-
"""Modified Partitioned Dataset for pasteur."""
14+
15+
from urllib.parse import urlparse
16+
class Multiset(AbstractDataSet):
17+
# noqa: too-many-instance-attributes,protected-access
18+
"""Simplified version of the partitioned dataset. Is not lazy."""
19+
20+
def __init__( # noqa: too-many-arguments
21+
self,
22+
path: str,
23+
dataset: str | type[AbstractDataSet] | dict[str, Any],
24+
filepath_arg: str = "filepath",
25+
filename_suffix: str = "",
26+
credentials: dict[str, Any] | None = None,
27+
load_args: dict[str, Any] | None = None,
28+
metadata: dict[str, Any] | None = None,
29+
):
30+
# noqa: import-outside-toplevel
31+
from fsspec.utils import infer_storage_options # for performance reasons
32+
33+
super().__init__()
34+
35+
self._path = path
36+
self._filename_suffix = filename_suffix
37+
self._protocol = infer_storage_options(self._path)["protocol"]
38+
self.metadata = metadata
39+
40+
dataset = dataset if isinstance(dataset, dict) else {"type": dataset}
41+
self._dataset_type, self._dataset_config = parse_dataset_definition(dataset)
42+
if VERSION_KEY in self._dataset_config:
43+
raise DatasetError(
44+
f"'{self.__class__.__name__}' does not support versioning of the "
45+
f"underlying dataset. Please remove '{VERSIONED_FLAG_KEY}' flag from "
46+
f"the dataset definition."
47+
)
48+
49+
self._credentials = deepcopy(credentials) or {}
50+
self._filepath_arg = filepath_arg
51+
if self._filepath_arg in self._dataset_config:
52+
warnings.warn(
53+
f"'{self._filepath_arg}' key must not be specified in the dataset "
54+
f"definition as it will be overwritten by partition path"
55+
)
56+
57+
self._load_args = deepcopy(load_args) or {}
58+
self._sep = self._filesystem.sep
59+
# since some filesystem implementations may implement a global cache
60+
self._invalidate_caches()
61+
62+
@property
63+
def _filesystem(self):
64+
# for performance reasons
65+
import fsspec # noqa: import-outside-toplevel
66+
67+
protocol = "s3" if self._protocol in S3_PROTOCOLS else self._protocol
68+
return fsspec.filesystem(protocol, **self._credentials)
69+
70+
@property
71+
def _normalized_path(self) -> str:
72+
if self._protocol in S3_PROTOCOLS:
73+
return urlparse(self._path)._replace(scheme="s3").geturl()
74+
return self._path
75+
76+
def _list_partitions(self) -> list[str]:
77+
return [
78+
path
79+
for path in self._filesystem.find(self._normalized_path, **self._load_args)
80+
if path.endswith(self._filename_suffix)
81+
]
82+
83+
def _join_protocol(self, path: str) -> str:
84+
protocol_prefix = f"{self._protocol}://"
85+
if self._path.startswith(protocol_prefix) and not path.startswith(
86+
protocol_prefix
87+
):
88+
return f"{protocol_prefix}{path}"
89+
return path
90+
91+
def _partition_to_path(self, path: str):
92+
dir_path = self._path.rstrip(self._sep)
93+
path = path.lstrip(self._sep)
94+
full_path = self._sep.join([dir_path, path]) + self._filename_suffix
95+
return full_path
96+
97+
def _path_to_partition(self, path: str) -> str:
98+
dir_path = self._filesystem._strip_protocol(self._normalized_path)
99+
path = path.split(dir_path, 1).pop().lstrip(self._sep)
100+
if self._filename_suffix and path.endswith(self._filename_suffix):
101+
path = path[: -len(self._filename_suffix)]
102+
return path
103+
104+
def _load(self) -> dict[str, Callable[[], Any]]:
105+
partitions = {}
106+
107+
for partition in self._list_partitions():
108+
kwargs = deepcopy(self._dataset_config)
109+
# join the protocol back since PySpark may rely on it
110+
kwargs[self._filepath_arg] = self._join_protocol(partition)
111+
dataset = self._dataset_type(**kwargs) # type: ignore
112+
partition_id = self._path_to_partition(partition)
113+
partitions[partition_id] = dataset.load()
114+
115+
return partitions
116+
117+
def _save(self, data: dict[str, Any]) -> None:
118+
for partition_id, partition_data in sorted(data.items()):
119+
kwargs = deepcopy(self._dataset_config)
120+
partition = self._partition_to_path(partition_id)
121+
# join the protocol back since tools like PySpark may rely on it
122+
kwargs[self._filepath_arg] = self._join_protocol(partition)
123+
dataset = self._dataset_type(**kwargs) # type: ignore
124+
if callable(partition_data):
125+
partition_data = partition_data() # noqa: redefined-loop-name
126+
dataset.save(partition_data)
127+
128+
self._invalidate_caches()
129+
130+
def _describe(self) -> dict[str, Any]:
131+
clean_dataset_config = (
132+
{k: v for k, v in self._dataset_config.items()}
133+
if isinstance(self._dataset_config, dict)
134+
else self._dataset_config
135+
)
136+
return {
137+
"path": self._path,
138+
"dataset_type": self._dataset_type.__name__,
139+
"dataset_config": clean_dataset_config,
140+
}
141+
142+
def _invalidate_caches(self):
143+
self._filesystem.invalidate_cache(self._normalized_path)
6144

7145
def reset(self):
8146
"""Removes the dataset from disk so that there are no stray partitions in subsequent runs."""
9147
if self._filesystem.exists(self._normalized_path):
10148
self._filesystem.rm(self._normalized_path, recursive=True, maxdepth=1)
149+
150+
def _release(self) -> None:
151+
super()._release()
152+
self._invalidate_caches()

src/pasteur/kedro/pipelines/meta.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@ def _flatten_outputs(
7474
if isinstance(nested, dict):
7575
assert isinstance(outputs, dict)
7676
for idx, vals in nested.items():
77-
assert idx in outputs
78-
data = _flatten_outputs(vals, outputs[idx])
79-
out.update(data)
77+
if idx in outputs:
78+
data = _flatten_outputs(vals, outputs[idx])
79+
out.update(data)
8080
else:
81-
assert isinstance(outputs, list) and isinstance(nested, list)
81+
assert (isinstance(outputs, list) or isinstance(outputs, tuple)) and (
82+
isinstance(nested, list) or isinstance(nested, tuple)
83+
)
8284
assert len(outputs) == len(nested)
8385
for vals, outs in zip(nested, outputs):
8486
data = _flatten_outputs(vals, outs)
@@ -323,6 +325,7 @@ def node(
323325
namespace=namespace,
324326
)
325327

328+
326329
# Tag each node in the pipeline based on its use
327330
TAG_VIEW = "view"
328331
TAG_DATASET = "dataset"

0 commit comments

Comments
 (0)