Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions model2vec/inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
from collections.abc import Sequence
from enum import Enum
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TypeVar, cast
Expand All @@ -10,7 +11,7 @@
import numpy as np
import skops.io
from sklearn.metrics import classification_report
from sklearn.neural_network import MLPClassifier
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MultiLabelBinarizer

Expand All @@ -23,25 +24,29 @@
LabelType = TypeVar("LabelType", list[str], list[list[str]])


class HeadType(str, Enum):
CLASSIFIER = "classifier"
PROJECTOR = "projector"
MULTILABEL = "multilabel"


class StaticModelPipeline:
def __init__(self, model: StaticModel, head: Pipeline) -> None:
"""Create a pipeline with a StaticModel encoder."""
self.model = model
self.head = head
classifier = self.head[-1]
# Check if the classifier is a multilabel classifier.
# NOTE: this doesn't look robust, but it is.
# Different classifiers, such as OVR wrappers, support multilabel output natively, so we
# can just use predict.
self.multilabel = False
if isinstance(classifier, MLPClassifier):
if classifier.out_activation_ == "logistic":
self.multilabel = True

@property
def classes_(self) -> np.ndarray:
"""The classes of the classifier."""
return self.head.classes_

last_head = self.head[-1]
self.classes_: None | np.ndarray = None
if isinstance(last_head, MLPRegressor):
self.classifier_type = HeadType.PROJECTOR
elif isinstance(last_head, MLPClassifier):
activation = last_head.out_activation_
self.classifier_type = HeadType.MULTILABEL if activation == "logistic" else HeadType.CLASSIFIER
self.classes_ = self.head.classes_
else:
# Default to classifier: the assumption is the user is unlikely to use multilabel here.
self.classifier_type = HeadType.CLASSIFIER

@classmethod
def from_pretrained(
Expand Down Expand Up @@ -138,7 +143,8 @@ def predict(
multiprocessing_threshold=multiprocessing_threshold,
)

if self.multilabel:
if self.classifier_type == HeadType.MULTILABEL:
assert self.classes_ is not None
out_labels = []
proba = self.head.predict_proba(encoded)
for vector in proba:
Expand Down Expand Up @@ -166,7 +172,10 @@ def predict_proba(
:param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
:param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
:return: The predicted labels or probabilities.
:raises ValueError: If the classifier type is projector.
"""
if self.classifier_type == HeadType.PROJECTOR:
raise ValueError("You are using evaluate on a projector model. This is not supported.")
encoded = self._encode_and_coerce_to_2d(
X,
show_progress_bar=show_progress_bar,
Expand All @@ -190,7 +199,10 @@ def evaluate(
:param threshold: The threshold for multilabel classification.
:param output_dict: Whether to output the classification report as a dictionary.
:return: A classification report.
:raises ValueError: If the classifier type is projector.
"""
if self.classifier_type == HeadType.PROJECTOR:
raise ValueError("You are using evaluate on a projector model. This is not supported.")
predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold)
report = evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict)

Expand Down
9 changes: 8 additions & 1 deletion model2vec/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from model2vec.utils import get_package_extras, importable

_REQUIRED_EXTRA = "train"
Expand All @@ -6,5 +8,10 @@
importable(extra_dependency, _REQUIRED_EXTRA)

from model2vec.train.classifier import StaticModelForClassification
from model2vec.train.similarity import StaticModelForSimilarity
from model2vec.train.utils import TipFilter

__all__ = ["StaticModelForClassification", "StaticModelForSimilarity"]


__all__ = ["StaticModelForClassification"]
logging.getLogger("lightning.pytorch.utilities.rank_zero").addFilter(TipFilter())
Loading
Loading