-
Notifications
You must be signed in to change notification settings - Fork 283
Expand file tree
/
Copy pathtunable.py
More file actions
42 lines (37 loc) · 1.3 KB
/
tunable.py
File metadata and controls
42 lines (37 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from skllm.models._base.classifier import (
BaseTunableClassifier as _BaseTunableClassifier,
SingleLabelMixin as _SingleLabelMixin,
MultiLabelMixin as _MultiLabelMixin,
)
from skllm.llm.vertex.mixin import (
VertexClassifierMixin as _VertexClassifierMixin,
VertexTunableMixin as _VertexTunableMixin,
)
from typing import Optional
class _TunableClassifier(
_BaseTunableClassifier, _VertexClassifierMixin, _VertexTunableMixin
):
pass
class VertexClassifier(_TunableClassifier, _SingleLabelMixin):
def __init__(
self,
base_model: str = "gemini-1.5-flash",
n_update_steps: int = 1,
default_label: str = "Random",
):
"""
Tunable Vertex-based text classifier.
Parameters
----------
base_model : str, optional
base model to use, by default "gemini-1.5-flash"
n_update_steps : int, optional
number of epochs, by default 1
default_label : str, optional
default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
"""
self._set_hyperparameters(base_model=base_model, n_update_steps=n_update_steps)
super().__init__(
model=None,
default_label=default_label,
)