Skip to content

Commit 992e2c6

Browse files
committed
minor refactor + docs for main modules
1 parent c969d9e commit 992e2c6

17 files changed

Lines changed: 653 additions & 567 deletions

src/pasteur/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .kedro.utils import get_pasteur_modules
1+
# from .kedro.utils import get_pasteur_modules
22

33
# if get_pasteur_modules():
44
from .kedro.cli import cli

src/pasteur/dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
""" This module holds the definitions for the Dataset module, the initial entrypoint
2+
for data in Pasteur. """
3+
14
from __future__ import annotations
25

36
from typing import TYPE_CHECKING, Any, Callable
@@ -99,7 +102,7 @@ def ingest(self, name, **tables: Any) -> LazyFrame:
99102
100103
@warning: all partitioned tables should have the same partitions.
101104
Some tables may not be partitioned.
102-
105+
103106
Tip: use a `match` statement to fork based on table name to per-table functions."""
104107
raise NotImplemented()
105108

src/pasteur/encode.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,24 @@
55

66

77
class EncoderFactory(ModuleFactory["Encoder"]):
8+
""" Factory base class for encoders. Use isinstance with this class
9+
to filter the Pasteur module list into only containing Encoders. """
810
...
911

12+
1013
class Encoder(ModuleClass):
11-
"""Encapsulates a special way to encode an Attribute."""
14+
"""Encapsulates a special way to encode an Attribute.
15+
16+
One encoder is instantiated per module and its `fit` function is called to
17+
fit it to the base layer data.
18+
19+
After that, the module may be serialized, unserialized, and its encode
20+
and decode methods may be called arbitrarily from different processes to encode
21+
and decode sets of columns.
22+
23+
The `data` value may contain a superset of columns than that of the encoder.
24+
It is up to the encoder to filter it prior to processing. `data` should
25+
not be mutated."""
1226

1327
name: str
1428
attr: Attribute
@@ -21,4 +35,7 @@ def encode(self, data: pd.DataFrame) -> pd.DataFrame:
2135
raise NotImplementedError()
2236

2337
def decode(self, enc: pd.DataFrame) -> pd.DataFrame:
24-
raise NotImplementedError()
38+
raise NotImplementedError()
39+
40+
41+
__all__ = ["EncoderFactory", "Encoder"]

src/pasteur/hierarchy.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
""" Highly experimental and unpublished class for rebalancing Stratified Values
2+
with Differential Privacy.
3+
4+
@TODO: Documentation."""
5+
16
import logging
27
from itertools import combinations
38
from math import ceil, log
49
from typing import TypeVar
510

611
import numpy as np
7-
import pandas as pd
812

9-
from .attribute import Attributes, IdxValue, Level, LevelValue, get_dtype
13+
from .attribute import Attributes, CatValue, Grouping, StratifiedValue, get_dtype
1014

1115
logger = logging.getLogger(__name__)
1216

@@ -22,7 +26,7 @@ class CatNode(list):
2226

2327

2428
def create_tree(
25-
node: Level, common: int = 0, ofs: int = 0, n: int | None = None
29+
node: Grouping, common: int = 0, ofs: int = 0, n: int | None = None
2630
) -> list:
2731
"""Receives the top node of the tree of a hierarchical attribute and
2832
converts it into the same tree structure, where the leaves have been
@@ -37,7 +41,7 @@ def create_tree(
3741
for child in node:
3842
if ofs < common:
3943
out.append(None)
40-
elif isinstance(child, Level):
44+
elif isinstance(child, Grouping):
4145
out.append(create_tree(child, common, ofs, n))
4246
else:
4347
out.append(set([ofs]))
@@ -220,7 +224,7 @@ def create_node_to_group_map(tree: list, grouping: np.ndarray, ofs: int = 0):
220224
return ofs
221225

222226

223-
def make_grouping(counts: np.ndarray, head: Level, common: int = 0) -> np.ndarray:
227+
def make_grouping(counts: np.ndarray, head: Grouping, common: int = 0) -> np.ndarray:
224228
"""Converts the hierarchical attribute level tree provided into a node-to-group
225229
mapping, where `group[i][j] = z`, where `i` is the height of the mapping
226230
`j` is node `j` and `z` is the group the node is associated at that height.
@@ -305,11 +309,11 @@ def generate_domain_list(
305309
return new_domains
306310

307311

308-
class RebalancedValue(IdxValue):
312+
class RebalancedValue(CatValue):
309313
def __init__(
310314
self,
311315
counts: np.ndarray,
312-
col: LevelValue,
316+
col: StratifiedValue,
313317
reshape_domain: bool = True,
314318
u: float = 1.3,
315319
fixed: list[int] = [2, 4, 5, 8, 12],
@@ -406,7 +410,7 @@ def upsample(self, column: np.ndarray, height: int, deterministic: bool = True):
406410

407411
def rebalance_value(
408412
counts: np.ndarray,
409-
col: LevelValue,
413+
col: StratifiedValue,
410414
num_cols: int = 1,
411415
ep: float | None = None,
412416
gaussian: bool = False,
@@ -422,7 +426,7 @@ def rebalance_value(
422426
noise = np.random.laplace(scale=noise_scale, size=counts.shape)
423427
counts = counts + noise
424428

425-
assert isinstance(col, LevelValue)
429+
assert isinstance(col, StratifiedValue)
426430
return RebalancedValue(counts, col, **kwargs)
427431

428432

@@ -449,7 +453,7 @@ def rebalance_attributes(
449453
for name, attr in attrs.items():
450454
cols = {}
451455
for col_name, col in attr.vals.items():
452-
assert isinstance(col, LevelValue)
456+
assert isinstance(col, StratifiedValue)
453457
cols[col_name] = rebalance_value(
454458
counts[col_name],
455459
col,

src/pasteur/metadata.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
""" This module contains a base class `Metadata` which is used to wrap, type,
2+
and check all View parameters provided to kedro.
3+
4+
@TODO: refactor this file. """
5+
16
from __future__ import annotations
27

38
from typing import TYPE_CHECKING, NamedTuple, overload
49

510
if TYPE_CHECKING:
611
import pandas as pd
7-
from .utils import LazyFrame
812

913
import logging
1014

@@ -173,7 +177,7 @@ def __str__(self) -> str:
173177
return self.__dict__.__str__()
174178

175179

176-
class DatasetMeta:
180+
class ViewMeta:
177181
TABLE_CLS = TableMeta
178182

179183
def __init__(
@@ -226,5 +230,5 @@ def __str__(self) -> str:
226230
return self.__dict__.__str__()
227231

228232

229-
class Metadata(DatasetMeta):
233+
class Metadata(ViewMeta):
230234
pass

src/pasteur/metric.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
""" This module provides the definitions for Metric Modules.
2+
Metric modules can fit to a column, a table, or a whole View.
3+
In each case, modules are instanciated as required (for columns one is instantiated
4+
per column type, for tables one per table and View metrics are instantiated once)."""
5+
16
import logging
2-
from collections import defaultdict
3-
from typing import Generic, TypeVar, cast, TypedDict, NamedTuple, Any
7+
from typing import Generic, TypeVar, TypedDict, Any
48

59
import pandas as pd
610

@@ -9,7 +13,7 @@
913
from .module import ModuleClass, ModuleFactory
1014
from .table import TransformHolder
1115
from .utils import LazyChunk, LazyFrame
12-
from .utils.progress import process, process_in_parallel
16+
from .utils.progress import process_in_parallel
1317

1418
logger = logging.getLogger(__name__)
1519

@@ -44,9 +48,9 @@ def __init__(
4448
self.encodings = cls.encodings
4549

4650

47-
class DatasetMetricFactory(ModuleFactory["DatasetMetric"]):
51+
class ViewMetricFactory(ModuleFactory["ViewMetric"]):
4852
def __init__(
49-
self, cls: type["DatasetMetric"], *args, name: str | None = None, **kwargs
53+
self, cls: type["ViewMetric"], *args, name: str | None = None, **kwargs
5054
) -> None:
5155
super().__init__(cls, *args, name=name, **kwargs)
5256
self.encodings = cls.encodings
@@ -228,8 +232,8 @@ def fit(self, table: str, meta: Metadata, data: ColumnData):
228232
ids = data["ids"]
229233
tables = data["tables"].copy()
230234
tables["ids"] = ids
231-
part = next(iter(LazyFrame.zip_values(**tables))) # FIXME: incorrect type
232-
self._fit_chunk(table, meta, part, part["ids"]) #type: ignore
235+
part = next(iter(LazyFrame.zip_values(**tables))) # FIXME: incorrect type
236+
self._fit_chunk(table, meta, part, part["ids"]) # type: ignore
233237

234238
def _process_chunk(
235239
self,
@@ -276,12 +280,12 @@ def preprocess(self, wrk: ColumnData, ref: ColumnData) -> Summaries:
276280
wrk_sum[name] = []
277281
ref_sum[name] = []
278282
for i, metric in enumerate(metrics):
279-
wrk_sum[name].append(metric.combine(
280-
[chunk[name][i] for chunk in summaries_wrk]
281-
))
282-
ref_sum[name].append(metric.combine(
283-
[chunk[name][i] for chunk in summaries_ref]
284-
))
283+
wrk_sum[name].append(
284+
metric.combine([chunk[name][i] for chunk in summaries_wrk])
285+
)
286+
ref_sum[name].append(
287+
metric.combine([chunk[name][i] for chunk in summaries_ref])
288+
)
285289

286290
return Summaries(wrk_sum, ref_sum)
287291

@@ -299,9 +303,9 @@ def process(
299303
for name, metrics in self.metrics.items():
300304
syn_sum[name] = []
301305
for i, metric in enumerate(metrics):
302-
syn_sum[name].append(metric.combine(
303-
[chunk[name][i] for chunk in summaries]
304-
))
306+
syn_sum[name].append(
307+
metric.combine([chunk[name][i] for chunk in summaries])
308+
)
305309

306310
return pre.replace(syn=syn_sum)
307311

@@ -357,19 +361,19 @@ def unique_name(self) -> str:
357361
return f"{self.type}_{self.name}_{self.table}"
358362

359363

360-
class DatasetData(TypedDict):
364+
class ViewData(TypedDict):
361365
tables: dict[str, dict[str, LazyFrame]]
362366
ids: dict[str, LazyFrame]
363367

364368

365-
class DatasetMetric(Metric[DatasetData, _INGEST, _SUMMARY], Generic[_INGEST, _SUMMARY]):
366-
_factory = DatasetMetricFactory
369+
class ViewMetric(Metric[ViewData, _INGEST, _SUMMARY], Generic[_INGEST, _SUMMARY]):
370+
_factory = ViewMetricFactory
367371
type = "dst"
368372
table: str
369373
encodings: list[str] = ["raw"]
370374

371375
def fit(
372-
self, meta: Metadata, attrs: dict[str, dict[str, Attributes]], data: DatasetData
376+
self, meta: Metadata, attrs: dict[str, dict[str, Attributes]], data: ViewData
373377
):
374378
raise NotImplementedError()
375379

@@ -407,10 +411,10 @@ def fit_table_metric(
407411

408412

409413
def fit_dataset_metric(
410-
fs: DatasetMetricFactory,
414+
fs: ViewMetricFactory,
411415
meta: Metadata,
412416
trns: dict[str, TransformHolder],
413-
data: DatasetData,
417+
data: ViewData,
414418
):
415419
enc = fs.encodings
416420
attrs = {
@@ -432,3 +436,20 @@ def log_metric(metric: Metric[Any, Any, _SUMMARY], summary: _SUMMARY):
432436
mlflow_log_artifacts(
433437
"metrics", metric.unique_name(), metric=metric, summary=summary
434438
)
439+
440+
441+
DatasetMetric = ViewMetric
442+
DatasetMetricFactory = ViewMetricFactory
443+
444+
__all__ = [
445+
"ColumnMetricFactory",
446+
"RefColumnMetricFactory",
447+
"TableMetricFactory",
448+
"ViewMetricFactory",
449+
"Metric",
450+
"Summaries",
451+
"ColumnMetric",
452+
"RefColumnMetric",
453+
"TableMetric",
454+
"ViewMetric",
455+
]

src/pasteur/module.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
""" Contains the module definitions in Pasteur, the base classes all
2+
Pasteur modules extend from.
3+
4+
In Pasteur, all functionality is achieved through the use of modules.
5+
You should not interact with this module directly, but rather through its children."""
6+
17
from collections import defaultdict
28
from typing import Generic, TypeVar
39

@@ -17,7 +23,11 @@ class Module:
1723

1824
class ModuleClass:
1925
"""Modules which need to be instantiated multiple times extend from ModuleClass and define
20-
a Factory to act as their module"""
26+
a Factory to act as their module.
27+
28+
For the module types provided by pasteur, you can call the classmethod `get_factory()`.
29+
`get_factory()` also acts as a closure, allowing you to provide parameters to
30+
the module's init function."""
2131

2232
name: str
2333
_factory: type["ModuleFactory"]
@@ -43,7 +53,7 @@ class ModuleFactory(Module, Generic[A]):
4353
"""Some modules (such as transformers) require multiple instances in the system. In this case,
4454
it's not possible to provide a module instance for them.
4555
46-
`ModuleFactory` is used to provide a wrapper instance to that module class."""
56+
For those types, their instance is based on `ModuleFactory`."""
4757

4858
def __init__(self, cls: type[A], *args, name: str | None = None, **kwargs) -> None:
4959
self._cls = cls
@@ -82,3 +92,12 @@ def get_module_dict_multiple(
8292
if isinstance(module, parent):
8393
out[module.name].append(module)
8494
return out
95+
96+
97+
__all__ = [
98+
"Module",
99+
"ModuleClass",
100+
"ModuleFactory",
101+
"get_module_dict",
102+
"get_module_dict_multiple",
103+
]

0 commit comments

Comments
 (0)