Skip to content

Commit d9c1a79

Browse files
committed
update pipeline with new api
1 parent 4c1f908 commit d9c1a79

13 files changed

Lines changed: 379 additions & 515 deletions

File tree

src/pasteur/encode.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,25 @@
88
from .module import ModuleClass, ModuleFactory
99
from .utils import LazyFrame, LazyDataset
1010

11-
ENC = TypeVar("ENC", bound="Encoder")
11+
ENC = TypeVar("ENC", bound=ModuleClass)
12+
META = TypeVar("META")
1213

1314

14-
class EncoderFactory(ModuleFactory[ENC], Generic[ENC]):
15+
class AttributeEncoderFactory(ModuleFactory[ENC], Generic[ENC]):
1516
"""Factory base class for encoders. Use isinstance with this class
1617
to filter the Pasteur module list into only containing Encoders."""
1718

1819
...
1920

2021

21-
META = TypeVar("META")
22-
22+
class EncoderFactory(ModuleFactory[ENC], Generic[ENC]):
23+
"""Factory base class for encoders. Use isinstance with this class
24+
to filter the Pasteur module list into only containing Encoders."""
2325

24-
class Encoder(ModuleClass, Generic[META]):
25-
def get_metadata(self) -> META:
26-
raise NotImplementedError()
26+
...
2727

2828

29-
class AttributeEncoder(Encoder[dict[str | tuple[str], META]], Generic[META]):
29+
class AttributeEncoder(ModuleClass, Generic[META]):
3030
"""Encapsulates a special way to encode an Attribute.
3131
3232
One encoder is instantiated per attribute and its `fit` function is called to
@@ -50,7 +50,7 @@ class AttributeEncoder(Encoder[dict[str | tuple[str], META]], Generic[META]):
5050
"""
5151

5252
name: str = ""
53-
_factory = EncoderFactory["AttributeEncoder"]
53+
_factory = AttributeEncoderFactory["AttributeEncoder"]
5454

5555
def fit(self, attr: Attribute, data: pd.DataFrame | None):
5656
raise NotImplementedError()
@@ -64,10 +64,13 @@ def encode(self, data: pd.DataFrame) -> pd.DataFrame:
6464
def decode(self, enc: pd.DataFrame) -> pd.DataFrame:
6565
raise NotImplementedError()
6666

67+
def get_metadata(self) -> dict[str | tuple[str], META]:
68+
raise NotImplementedError()
6769

68-
class ViewEncoder(Encoder[META], Generic[META]):
70+
71+
class Encoder(ModuleClass, Generic[META]):
6972
name: str = ""
70-
_factory = EncoderFactory["ViewEncoder"]
73+
_factory = EncoderFactory["Encoder"]
7174

7275
def fit(
7376
self,
@@ -97,5 +100,10 @@ def decode(
97100
]:
98101
raise NotImplementedError()
99102

103+
def get_metadata(self) -> META:
104+
raise NotImplementedError()
105+
106+
107+
ViewEncoder = Encoder
100108

101109
__all__ = ["EncoderFactory", "Encoder", "ViewEncoder", "AttributeEncoder"]

src/pasteur/extras/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def get_recommended_datasets() -> list[Dataset | View]:
3535

3636
def get_recommended_system_modules() -> list[Module]:
3737
from .encoders import IdxEncoder, NumEncoder
38-
from .metrics.distr import ChiSquareMetric, KullbackLeiblerMetric
38+
from .metrics.distr import DistributionMetric
3939
from .metrics.visual import (
4040
CategoricalHist,
4141
DateHist,
@@ -83,8 +83,7 @@ def get_recommended_system_modules() -> list[Module]:
8383
# AimSynth.get_factory(),
8484
# PrivMrfSynth.get_factory(),
8585
# Metrics
86-
ChiSquareMetric.get_factory(),
87-
KullbackLeiblerMetric.get_factory(),
86+
DistributionMetric.get_factory(),
8887
NumericalHist.get_factory(),
8988
OrdinalHist.get_factory(),
9089
CategoricalHist.get_factory(),

src/pasteur/extras/metrics/distr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ class DistributionMetric(
192192
Summaries[dict[str, tuple[dict[str, ndarray], dict[tuple[str, str], ndarray]]]],
193193
]
194194
):
195-
name = "cs"
195+
name = "dstr"
196196
encodings = "idx"
197197

198198
def fit(

src/pasteur/extras/synth/pgm/aim.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55

66
import pandas as pd
77

8-
from pasteur.attribute import Attribute
9-
from pasteur.utils import LazyDataset
10-
11-
from ....synth import Synth, data_to_tables, make_deterministic, tables_to_data
12-
from ....utils import LazyFrame
8+
from ....attribute import Attributes
9+
from ....synth import Synth, make_deterministic
10+
from ....utils import LazyFrame, data_to_tables, tables_to_data
1311

1412
if TYPE_CHECKING:
1513
from ....attribute import Attributes

src/pasteur/extras/synth/pgm/mst.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
import pandas as pd
77

8-
from ....synth import Synth, data_to_tables, make_deterministic, tables_to_data
9-
from ....utils import LazyFrame
8+
from ....synth import Synth, make_deterministic
9+
from ....utils import LazyFrame, data_to_tables, tables_to_data
1010

1111
if TYPE_CHECKING:
1212
from ....attribute import Attributes

src/pasteur/extras/synth/privbayes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from typing import TYPE_CHECKING, Any
66

77
from ....marginal import MarginalOracle
8-
from ....synth import Synth, data_to_tables, make_deterministic, tables_to_data
9-
from ....utils import LazyFrame
8+
from ....synth import Synth, make_deterministic
9+
from ....utils import LazyFrame, data_to_tables, tables_to_data
1010

1111
if TYPE_CHECKING:
1212
import pandas as pd

src/pasteur/kedro/pipelines/main.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@
1212
from ...view import View
1313
from .dataset import create_dataset_pipeline
1414
from .meta import DatasetMeta, PipelineMeta
15-
# from .metrics import (
16-
# create_metrics_ingest_pipeline,
17-
# create_metrics_model_pipeline,
18-
# get_metrics_types,
19-
# )
15+
16+
from .metrics import (
17+
create_metrics_ingest_pipeline,
18+
create_metrics_model_pipeline,
19+
get_metrics_types,
20+
)
2021
from .synth import create_synth_pipeline
2122
from .transform import (
23+
create_fit_pipeline,
2224
create_reverse_pipeline,
2325
create_transform_pipeline,
24-
create_transformer_pipeline,
2526
)
2627
from .utils import list_unique
2728
from .views import (
@@ -97,15 +98,8 @@ def generate_pipelines(
9798
# Wrk, ref splits are transformed to all types
9899
# Synthetic data is transformed only to syn_types (as required by metrics currently)
99100
alg_types = _get_alg_types(algs)
100-
# msr_types = get_metrics_types(modules)
101-
102-
all_types = alg_types# list_unique(alg_types, msr_types)
103-
encoders = {
104-
k: v
105-
for k, v in get_module_dict(EncoderFactory, modules).items()
106-
if k in all_types
107-
}
108-
transformers = get_module_dict(TransformerFactory, modules)
101+
msr_types = get_metrics_types(modules)
102+
all_types = list_unique(alg_types, msr_types)
109103

110104
wrk_split = WRK_SPLIT
111105
ref_split = REF_SPLIT
@@ -124,20 +118,20 @@ def generate_pipelines(
124118
# To make debugging metrics easier, it's bundled with `.measure` pipelines
125119
# as well. That way, only `.measure` needs to run when changes are made
126120
# to fit functions
127-
# pipe_metrics_fit = create_metrics_ingest_pipeline(
128-
# view, modules, wrk_split, ref_split
129-
# )
121+
pipe_metrics_fit = create_metrics_ingest_pipeline(
122+
view, modules, wrk_split, ref_split
123+
)
130124

131125
# Create view transform pipeline that can run as part of ingest
132126
pipe_transform = (
133-
create_transformer_pipeline(view, transformers, encoders, wrk_split)
127+
create_fit_pipeline(view, all_types, modules, wrk_split)
134128
+ create_transform_pipeline(
135129
view,
136130
wrk_split,
137131
all_types,
138132
)
139-
# + create_transform_pipeline(view, ref_split, msr_types)
140-
# + pipe_metrics_fit
133+
+ create_transform_pipeline(view, ref_split, msr_types)
134+
+ pipe_metrics_fit
141135
)
142136

143137
# Metadata needs to be created every time to allow for overrides
@@ -166,11 +160,11 @@ def generate_pipelines(
166160
view, wrk_split, cls
167161
) + create_reverse_pipeline(view, alg, cls.type)
168162

169-
# pipe_measure = create_transform_pipeline(
170-
# view, alg, msr_types, retransform=True
171-
# ) + create_metrics_model_pipeline(view, alg, wrk_split, ref_split, modules)
163+
pipe_measure = create_transform_pipeline(
164+
view, alg, msr_types, retransform=True
165+
) + create_metrics_model_pipeline(view, alg, wrk_split, ref_split, modules)
172166

173-
complete_pipe = pipe_ds_ingest + pipe_ingest + pipe_synth #+ pipe_measure
167+
complete_pipe = pipe_ds_ingest + pipe_ingest + pipe_synth + pipe_measure
174168

175169
if "ident" in alg:
176170
# Hide ident pipelines

0 commit comments

Comments
 (0)