-
Notifications
You must be signed in to change notification settings - Fork 283
Expand file tree
/
Copy pathtunable.py
More file actions
29 lines (26 loc) · 890 Bytes
/
tunable.py
File metadata and controls
29 lines (26 loc) · 890 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from skllm.llm.vertex.mixin import (
VertexTunableMixin as _VertexTunableMixin,
VertexTextCompletionMixin as _VertexTextCompletionMixin,
)
from skllm.models._base.text2text import (
BaseTunableText2TextModel as _BaseTunableText2TextModel,
)
class TunableVertexText2Text(
_BaseTunableText2TextModel, _VertexTextCompletionMixin, _VertexTunableMixin
):
def __init__(
self,
base_model: str = "gemini-1.5-flash",
n_update_steps: int = 1,
):
"""
Tunable Vertex-based text-to-text model.
Parameters
----------
base_model : str, optional
base model to use, by default "gemini-1.5-flash"
n_update_steps : int, optional
number of epochs, by default 1
"""
self.model = None
self._set_hyperparameters(base_model=base_model, n_update_steps=n_update_steps)