1+ """ This module provides the definitions for Metric Modules.
2+ Metric modules can fit to a column, a table, or a whole View.
3+ In each case, modules are instanciated as required (for columns one is instantiated
4+ per column type, for tables one per table and View metrics are instantiated once)."""
5+
16import logging
2- from collections import defaultdict
3- from typing import Generic , TypeVar , cast , TypedDict , NamedTuple , Any
7+ from typing import Generic , TypeVar , TypedDict , Any
48
59import pandas as pd
610
913from .module import ModuleClass , ModuleFactory
1014from .table import TransformHolder
1115from .utils import LazyChunk , LazyFrame
12- from .utils .progress import process , process_in_parallel
16+ from .utils .progress import process_in_parallel
1317
1418logger = logging .getLogger (__name__ )
1519
@@ -44,9 +48,9 @@ def __init__(
4448 self .encodings = cls .encodings
4549
4650
47- class DatasetMetricFactory (ModuleFactory ["DatasetMetric " ]):
51+ class ViewMetricFactory (ModuleFactory ["ViewMetric " ]):
4852 def __init__ (
49- self , cls : type ["DatasetMetric " ], * args , name : str | None = None , ** kwargs
53+ self , cls : type ["ViewMetric " ], * args , name : str | None = None , ** kwargs
5054 ) -> None :
5155 super ().__init__ (cls , * args , name = name , ** kwargs )
5256 self .encodings = cls .encodings
@@ -228,8 +232,8 @@ def fit(self, table: str, meta: Metadata, data: ColumnData):
228232 ids = data ["ids" ]
229233 tables = data ["tables" ].copy ()
230234 tables ["ids" ] = ids
231- part = next (iter (LazyFrame .zip_values (** tables ))) # FIXME: incorrect type
232- self ._fit_chunk (table , meta , part , part ["ids" ]) # type: ignore
235+ part = next (iter (LazyFrame .zip_values (** tables ))) # FIXME: incorrect type
236+ self ._fit_chunk (table , meta , part , part ["ids" ]) # type: ignore
233237
234238 def _process_chunk (
235239 self ,
@@ -276,12 +280,12 @@ def preprocess(self, wrk: ColumnData, ref: ColumnData) -> Summaries:
276280 wrk_sum [name ] = []
277281 ref_sum [name ] = []
278282 for i , metric in enumerate (metrics ):
279- wrk_sum [name ].append (metric . combine (
280- [chunk [name ][i ] for chunk in summaries_wrk ]
281- ))
282- ref_sum [name ].append (metric . combine (
283- [chunk [name ][i ] for chunk in summaries_ref ]
284- ))
283+ wrk_sum [name ].append (
284+ metric . combine ( [chunk [name ][i ] for chunk in summaries_wrk ])
285+ )
286+ ref_sum [name ].append (
287+ metric . combine ( [chunk [name ][i ] for chunk in summaries_ref ])
288+ )
285289
286290 return Summaries (wrk_sum , ref_sum )
287291
@@ -299,9 +303,9 @@ def process(
299303 for name , metrics in self .metrics .items ():
300304 syn_sum [name ] = []
301305 for i , metric in enumerate (metrics ):
302- syn_sum [name ].append (metric . combine (
303- [chunk [name ][i ] for chunk in summaries ]
304- ))
306+ syn_sum [name ].append (
307+ metric . combine ( [chunk [name ][i ] for chunk in summaries ])
308+ )
305309
306310 return pre .replace (syn = syn_sum )
307311
@@ -357,19 +361,19 @@ def unique_name(self) -> str:
357361 return f"{ self .type } _{ self .name } _{ self .table } "
358362
359363
360- class DatasetData (TypedDict ):
364+ class ViewData (TypedDict ):
361365 tables : dict [str , dict [str , LazyFrame ]]
362366 ids : dict [str , LazyFrame ]
363367
364368
365- class DatasetMetric (Metric [DatasetData , _INGEST , _SUMMARY ], Generic [_INGEST , _SUMMARY ]):
366- _factory = DatasetMetricFactory
369+ class ViewMetric (Metric [ViewData , _INGEST , _SUMMARY ], Generic [_INGEST , _SUMMARY ]):
370+ _factory = ViewMetricFactory
367371 type = "dst"
368372 table : str
369373 encodings : list [str ] = ["raw" ]
370374
371375 def fit (
372- self , meta : Metadata , attrs : dict [str , dict [str , Attributes ]], data : DatasetData
376+ self , meta : Metadata , attrs : dict [str , dict [str , Attributes ]], data : ViewData
373377 ):
374378 raise NotImplementedError ()
375379
@@ -407,10 +411,10 @@ def fit_table_metric(
407411
408412
409413def fit_dataset_metric (
410- fs : DatasetMetricFactory ,
414+ fs : ViewMetricFactory ,
411415 meta : Metadata ,
412416 trns : dict [str , TransformHolder ],
413- data : DatasetData ,
417+ data : ViewData ,
414418):
415419 enc = fs .encodings
416420 attrs = {
@@ -432,3 +436,20 @@ def log_metric(metric: Metric[Any, Any, _SUMMARY], summary: _SUMMARY):
432436 mlflow_log_artifacts (
433437 "metrics" , metric .unique_name (), metric = metric , summary = summary
434438 )
439+
440+
441+ DatasetMetric = ViewMetric
442+ DatasetMetricFactory = ViewMetricFactory
443+
444+ __all__ = [
445+ "ColumnMetricFactory" ,
446+ "RefColumnMetricFactory" ,
447+ "TableMetricFactory" ,
448+ "ViewMetricFactory" ,
449+ "Metric" ,
450+ "Summaries" ,
451+ "ColumnMetric" ,
452+ "RefColumnMetric" ,
453+ "TableMetric" ,
454+ "ViewMetric" ,
455+ ]
0 commit comments