Skip to content

Commit b026e4e

Browse files
Support VLM finetuning
1 parent 367f606 commit b026e4e

5 files changed

Lines changed: 72 additions & 48 deletions

File tree

src/together/cli/api/finetune.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import json
4-
import re
54
from datetime import datetime, timezone
65
from textwrap import wrap
76
from typing import Any, Literal
@@ -14,18 +13,11 @@
1413

1514
from together import Together
1615
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX, generate_progress_bar
17-
from together.types.finetune import (
18-
DownloadCheckpointType,
19-
FinetuneEventType,
20-
FinetuneTrainingLimits,
21-
FullTrainingType,
22-
LoRATrainingType,
23-
)
16+
from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits
2417
from together.utils import (
2518
finetune_price_to_dollars,
2619
format_timestamp,
2720
log_warn,
28-
log_warn_once,
2921
parse_timestamp,
3022
)
3123

@@ -258,6 +250,7 @@ def create(
258250
lora_dropout: float,
259251
lora_alpha: float,
260252
lora_trainable_modules: str,
253+
train_vision: bool,
261254
suffix: str,
262255
wandb_api_key: str,
263256
wandb_base_url: str,
@@ -299,6 +292,7 @@ def create(
299292
lora_dropout=lora_dropout,
300293
lora_alpha=lora_alpha,
301294
lora_trainable_modules=lora_trainable_modules,
295+
train_vision=train_vision,
302296
suffix=suffix,
303297
wandb_api_key=wandb_api_key,
304298
wandb_base_url=wandb_base_url,
@@ -368,6 +362,10 @@ def create(
368362
"You have specified a number of evaluation loops but no validation file."
369363
)
370364

365+
if model_limits.supports_vision:
366+
# Don't show price estimation for multimodal models
367+
confirm = True
368+
371369
finetune_price_estimation_result = client.fine_tuning.estimate_price(
372370
training_file=training_file,
373371
validation_file=validation_file,

src/together/resources/finetune.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44
from pathlib import Path
5-
from typing import List, Dict, Literal
5+
from typing import Dict, List, Literal
66

77
from rich import print as rprint
88

@@ -18,10 +18,11 @@
1818
FinetuneList,
1919
FinetuneListEvents,
2020
FinetuneLRScheduler,
21-
FinetuneRequest,
22-
FinetuneResponse,
21+
FinetuneMultimodalParams,
2322
FinetunePriceEstimationRequest,
2423
FinetunePriceEstimationResponse,
24+
FinetuneRequest,
25+
FinetuneResponse,
2526
FinetuneTrainingLimits,
2627
FullTrainingType,
2728
LinearLRScheduler,
@@ -73,6 +74,7 @@ def create_finetune_request(
7374
lora_dropout: float | None = 0,
7475
lora_alpha: float | None = None,
7576
lora_trainable_modules: str | None = "all-linear",
77+
train_vision: bool = False,
7678
suffix: str | None = None,
7779
wandb_api_key: str | None = None,
7880
wandb_base_url: str | None = None,
@@ -252,6 +254,15 @@ def create_finetune_request(
252254
simpo_gamma=simpo_gamma,
253255
)
254256

257+
if model_limits.supports_vision:
258+
multimodal_params = FinetuneMultimodalParams(train_vision=train_vision)
259+
elif train_vision:
260+
raise ValueError(
261+
f"Vision encoder training is not supported for the non-multimodal model `{model}`"
262+
)
263+
else:
264+
multimodal_params = None
265+
255266
finetune_request = FinetuneRequest(
256267
model=model,
257268
training_file=training_file,
@@ -272,6 +283,7 @@ def create_finetune_request(
272283
wandb_project_name=wandb_project_name,
273284
wandb_name=wandb_name,
274285
training_method=training_method_cls,
286+
multimodal_params=multimodal_params,
275287
from_checkpoint=from_checkpoint,
276288
from_hf_model=from_hf_model,
277289
hf_model_revision=hf_model_revision,
@@ -342,6 +354,7 @@ def create(
342354
lora_dropout: float | None = 0,
343355
lora_alpha: float | None = None,
344356
lora_trainable_modules: str | None = "all-linear",
357+
train_vision: bool = False,
345358
suffix: str | None = None,
346359
wandb_api_key: str | None = None,
347360
wandb_base_url: str | None = None,
@@ -387,6 +400,7 @@ def create(
387400
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
388401
lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
389402
lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
403+
train_vision (bool, optional): Whether to train vision encoder in multimodal models. Defaults to False.
390404
suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
391405
Defaults to None.
392406
wandb_api_key (str, optional): API key for Weights & Biases integration.
@@ -464,6 +478,7 @@ def create(
464478
lora_dropout=lora_dropout,
465479
lora_alpha=lora_alpha,
466480
lora_trainable_modules=lora_trainable_modules,
481+
train_vision=train_vision,
467482
suffix=suffix,
468483
wandb_api_key=wandb_api_key,
469484
wandb_base_url=wandb_base_url,
@@ -906,6 +921,7 @@ async def create(
906921
lora_dropout: float | None = 0,
907922
lora_alpha: float | None = None,
908923
lora_trainable_modules: str | None = "all-linear",
924+
train_vision: bool = False,
909925
suffix: str | None = None,
910926
wandb_api_key: str | None = None,
911927
wandb_base_url: str | None = None,
@@ -951,6 +967,7 @@ async def create(
951967
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
952968
lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
953969
lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
970+
train_vision (bool, optional): Whether to train vision encoder in multimodal models. Defaults to False.
954971
suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
955972
Defaults to None.
956973
wandb_api_key (str, optional): API key for Weights & Biases integration.
@@ -1028,6 +1045,7 @@ async def create(
10281045
lora_dropout=lora_dropout,
10291046
lora_alpha=lora_alpha,
10301047
lora_trainable_modules=lora_trainable_modules,
1048+
train_vision=train_vision,
10311049
suffix=suffix,
10321050
wandb_api_key=wandb_api_key,
10331051
wandb_base_url=wandb_base_url,

src/together/types/__init__.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@
77
AudioSpeechStreamChunk,
88
AudioSpeechStreamEvent,
99
AudioSpeechStreamResponse,
10+
AudioTimestampGranularities,
1011
AudioTranscriptionRequest,
11-
AudioTranslationRequest,
1212
AudioTranscriptionResponse,
13+
AudioTranscriptionResponseFormat,
1314
AudioTranscriptionVerboseResponse,
15+
AudioTranslationRequest,
1416
AudioTranslationResponse,
1517
AudioTranslationVerboseResponse,
16-
AudioTranscriptionResponseFormat,
17-
AudioTimestampGranularities,
1818
ModelVoices,
1919
VoiceListResponse,
2020
)
21+
from together.types.batch import BatchEndpoint, BatchJob, BatchJobStatus
2122
from together.types.chat_completions import (
2223
ChatCompletionChunk,
2324
ChatCompletionRequest,
@@ -31,6 +32,19 @@
3132
)
3233
from together.types.embeddings import EmbeddingRequest, EmbeddingResponse
3334
from together.types.endpoints import Autoscaling, DedicatedEndpoint, ListEndpoint
35+
from together.types.evaluation import (
36+
ClassifyParameters,
37+
CompareParameters,
38+
EvaluationCreateResponse,
39+
EvaluationJob,
40+
EvaluationRequest,
41+
EvaluationStatus,
42+
EvaluationStatusResponse,
43+
EvaluationType,
44+
JudgeModelConfig,
45+
ModelRequest,
46+
ScoreParameters,
47+
)
3448
from together.types.files import (
3549
FileDeleteResponse,
3650
FileList,
@@ -41,49 +55,32 @@
4155
FileType,
4256
)
4357
from together.types.finetune import (
44-
TrainingMethodDPO,
45-
TrainingMethodSFT,
46-
FinetuneCheckpoint,
4758
CosineLRScheduler,
4859
CosineLRSchedulerArgs,
60+
FinetuneCheckpoint,
61+
FinetuneDeleteResponse,
4962
FinetuneDownloadResult,
50-
LinearLRScheduler,
51-
LinearLRSchedulerArgs,
52-
FinetuneLRScheduler,
5363
FinetuneList,
5464
FinetuneListEvents,
55-
FinetuneRequest,
56-
FinetuneResponse,
65+
FinetuneLRScheduler,
66+
FinetuneMultimodalParams,
5767
FinetunePriceEstimationRequest,
5868
FinetunePriceEstimationResponse,
59-
FinetuneDeleteResponse,
69+
FinetuneRequest,
70+
FinetuneResponse,
6071
FinetuneTrainingLimits,
6172
FullTrainingType,
73+
LinearLRScheduler,
74+
LinearLRSchedulerArgs,
6275
LoRATrainingType,
76+
TrainingMethodDPO,
77+
TrainingMethodSFT,
6378
TrainingType,
6479
)
6580
from together.types.images import ImageRequest, ImageResponse
6681
from together.types.models import ModelObject, ModelUploadRequest, ModelUploadResponse
6782
from together.types.rerank import RerankRequest, RerankResponse
68-
from together.types.batch import BatchJob, BatchJobStatus, BatchEndpoint
69-
from together.types.evaluation import (
70-
EvaluationType,
71-
EvaluationStatus,
72-
JudgeModelConfig,
73-
ModelRequest,
74-
ClassifyParameters,
75-
ScoreParameters,
76-
CompareParameters,
77-
EvaluationRequest,
78-
EvaluationCreateResponse,
79-
EvaluationJob,
80-
EvaluationStatusResponse,
81-
)
82-
from together.types.videos import (
83-
CreateVideoBody,
84-
CreateVideoResponse,
85-
VideoJob,
86-
)
83+
from together.types.videos import CreateVideoBody, CreateVideoResponse, VideoJob
8784

8885

8986
__all__ = [
@@ -131,6 +128,7 @@
131128
"RerankRequest",
132129
"RerankResponse",
133130
"FinetuneTrainingLimits",
131+
"FinetuneMultimodalParams",
134132
"AudioSpeechRequest",
135133
"AudioResponseFormat",
136134
"AudioLanguage",

src/together/types/finetune.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from typing import List, Literal, Any
4+
from typing import Any, List, Literal
55

66
from pydantic import Field, StrictBool, field_validator
77

88
from together.types.abstract import BaseModel
9-
from together.types.common import (
10-
ObjectType,
11-
)
9+
from together.types.common import ObjectType
1210

1311

1412
class FinetuneJobStatus(str, Enum):
@@ -175,6 +173,14 @@ class TrainingMethodDPO(TrainingMethod):
175173
simpo_gamma: float | None = None
176174

177175

176+
class FinetuneMultimodalParams(BaseModel):
177+
"""
178+
Multimodal parameters
179+
"""
180+
181+
train_vision: bool = False
182+
183+
178184
class FinetuneProgress(BaseModel):
179185
"""
180186
Fine-tune job progress
@@ -231,6 +237,8 @@ class FinetuneRequest(BaseModel):
231237
)
232238
# from step
233239
from_checkpoint: str | None = None
240+
# multimodal parameters
241+
multimodal_params: FinetuneMultimodalParams | None = None
234242
# hf related fields
235243
hf_api_token: str | None = None
236244
hf_output_repo_name: str | None = None
@@ -409,6 +417,7 @@ class FinetuneTrainingLimits(BaseModel):
409417
min_learning_rate: float
410418
full_training: FinetuneFullTrainingLimits | None = None
411419
lora_training: FinetuneLoraTrainingLimits | None = None
420+
supports_vision: bool = False
412421

413422

414423
class LinearLRSchedulerArgs(BaseModel):

tests/unit/test_finetune_resources.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from unittest.mock import MagicMock, Mock
2+
13
import pytest
2-
from unittest.mock import MagicMock, Mock, patch
34

45
from together.client import Together
56
from together.resources.finetune import create_finetune_request

0 commit comments

Comments
 (0)