Skip to content

Commit 3c87636

Browse files
Arsh ZahedArsh Zahed
authored andcommitted
Port cosine lr scheduler init
1 parent 084d2d1 commit 3c87636

4 files changed

Lines changed: 85 additions & 10 deletions

File tree

src/together/cli/api/finetune.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,24 @@ def fine_tuning(ctx: click.Context) -> None:
7171
)
7272
@click.option("--batch-size", type=INT_WITH_MAX, default="max", help="Train batch size")
7373
@click.option("--learning-rate", type=float, default=1e-5, help="Learning rate")
74+
@click.option(
75+
"--lr-scheduler-type",
76+
type=click.Choice(["linear", "cosine"]),
77+
default="linear",
78+
help="Learning rate scheduler type",
79+
)
7480
@click.option(
7581
"--min-lr-ratio",
7682
type=float,
7783
default=0.0,
7884
help="The ratio of the final learning rate to the peak learning rate",
7985
)
86+
@click.option(
87+
"--num-cycles",
88+
type=float,
89+
default=0.5,
90+
help="Number of cycles for cosine learning rate scheduler.",
91+
)
8092
@click.option(
8193
"--warmup-ratio",
8294
type=float,
@@ -162,7 +174,9 @@ def create(
162174
n_checkpoints: int,
163175
batch_size: int | Literal["max"],
164176
learning_rate: float,
177+
lr_scheduler_type: Literal["linear", "cosine"],
165178
min_lr_ratio: float,
179+
num_cycles: float,
166180
warmup_ratio: float,
167181
max_grad_norm: float,
168182
weight_decay: float,
@@ -194,7 +208,9 @@ def create(
194208
n_checkpoints=n_checkpoints,
195209
batch_size=batch_size,
196210
learning_rate=learning_rate,
211+
lr_scheduler_type=lr_scheduler_type,
197212
min_lr_ratio=min_lr_ratio,
213+
num_cycles=num_cycles,
198214
warmup_ratio=warmup_ratio,
199215
max_grad_norm=max_grad_norm,
200216
weight_decay=weight_decay,

src/together/resources/finetune.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
TrainingType,
2424
FinetuneLRScheduler,
2525
FinetuneLinearLRSchedulerArgs,
26+
FinetuneCosineLRSchedulerArgs,
2627
TrainingMethodDPO,
2728
TrainingMethodSFT,
2829
FinetuneCheckpoint,
@@ -57,7 +58,9 @@ def createFinetuneRequest(
5758
n_checkpoints: int | None = 1,
5859
batch_size: int | Literal["max"] = "max",
5960
learning_rate: float | None = 0.00001,
61+
lr_scheduler_type: Literal["linear", "cosine"] = "linear",
6062
min_lr_ratio: float = 0.0,
63+
num_cycles: float = 0.5,
6164
warmup_ratio: float = 0.0,
6265
max_grad_norm: float = 1.0,
6366
weight_decay: float = 0.0,
@@ -129,10 +132,21 @@ def createFinetuneRequest(
129132
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
130133
)
131134

132-
lrScheduler = FinetuneLRScheduler(
133-
lr_scheduler_type="linear",
134-
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
135-
)
135+
if lr_scheduler_type == "cosine":
136+
if num_cycles <= 0.0:
137+
raise ValueError("Number of cycles should be greater than 0")
138+
139+
lrScheduler = FinetuneLRScheduler(
140+
lr_scheduler_type="cosine",
141+
lr_scheduler_args=FinetuneCosineLRSchedulerArgs(
142+
min_lr_ratio=min_lr_ratio, num_cycles=num_cycles
143+
),
144+
)
145+
else:
146+
lrScheduler = FinetuneLRScheduler(
147+
lr_scheduler_type="linear",
148+
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
149+
)
136150

137151
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
138152
if training_method == "dpo":
@@ -244,7 +258,9 @@ def create(
244258
n_checkpoints: int | None = 1,
245259
batch_size: int | Literal["max"] = "max",
246260
learning_rate: float | None = 0.00001,
261+
lr_scheduler_type: Literal["linear", "cosine"] = "linear",
247262
min_lr_ratio: float = 0.0,
263+
num_cycles: float = 0.5,
248264
warmup_ratio: float = 0.0,
249265
max_grad_norm: float = 1.0,
250266
weight_decay: float = 0.0,
@@ -279,8 +295,10 @@ def create(
279295
batch_size (int or "max"): Batch size for fine-tuning. Defaults to max.
280296
learning_rate (float, optional): Learning rate multiplier to use for training
281297
Defaults to 0.00001.
298+
lr_scheduler_type (Literal["linear", "cosine"]): Learning rate scheduler type. Defaults to "linear".
282299
min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for
283300
the learning rate scheduler. Defaults to 0.0.
301+
num_cycles (float, optional): Number of cycles for cosine learning rate scheduler. Defaults to 0.5.
284302
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
285303
max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable.
286304
weight_decay (float, optional): Weight decay. Defaults to 0.0.
@@ -336,7 +354,9 @@ def create(
336354
n_checkpoints=n_checkpoints,
337355
batch_size=batch_size,
338356
learning_rate=learning_rate,
357+
lr_scheduler_type=lr_scheduler_type,
339358
min_lr_ratio=min_lr_ratio,
359+
num_cycles=num_cycles,
340360
warmup_ratio=warmup_ratio,
341361
max_grad_norm=max_grad_norm,
342362
weight_decay=weight_decay,
@@ -617,7 +637,9 @@ async def create(
617637
n_checkpoints: int | None = 1,
618638
batch_size: int | Literal["max"] = "max",
619639
learning_rate: float | None = 0.00001,
640+
lr_scheduler_type: Literal["linear", "cosine"] = "linear",
620641
min_lr_ratio: float = 0.0,
642+
num_cycles: float = 0.5,
621643
warmup_ratio: float = 0.0,
622644
max_grad_norm: float = 1.0,
623645
weight_decay: float = 0.0,
@@ -652,8 +674,10 @@ async def create(
652674
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
653675
learning_rate (float, optional): Learning rate multiplier to use for training
654676
Defaults to 0.00001.
677+
lr_scheduler_type (Literal["linear", "cosine"]): Learning rate scheduler type. Defaults to "linear".
655678
min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for
656679
the learning rate scheduler. Defaults to 0.0.
680+
num_cycles (float, optional): Number of cycles for cosine learning rate scheduler. Defaults to 0.5.
657681
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
658682
max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable.
659683
weight_decay (float, optional): Weight decay. Defaults to 0.0.
@@ -710,7 +734,9 @@ async def create(
710734
n_checkpoints=n_checkpoints,
711735
batch_size=batch_size,
712736
learning_rate=learning_rate,
737+
lr_scheduler_type=lr_scheduler_type,
713738
min_lr_ratio=min_lr_ratio,
739+
num_cycles=num_cycles,
714740
warmup_ratio=warmup_ratio,
715741
max_grad_norm=max_grad_norm,
716742
weight_decay=weight_decay,

src/together/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
TrainingMethodDPO,
3535
TrainingMethodSFT,
3636
FinetuneCheckpoint,
37+
FinetuneCosineLRSchedulerArgs,
3738
FinetuneDownloadResult,
3839
FinetuneLinearLRSchedulerArgs,
3940
FinetuneList,
@@ -70,6 +71,7 @@
7071
"FinetuneDownloadResult",
7172
"FinetuneLRScheduler",
7273
"FinetuneLinearLRSchedulerArgs",
74+
"FinetuneCosineLRSchedulerArgs",
7375
"FileRequest",
7476
"FileResponse",
7577
"FileList",

src/together/types/finetune.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from typing import List, Literal
4+
from typing import List, Literal, Union
55

6-
from pydantic import StrictBool, Field, validator, field_validator
6+
from pydantic import StrictBool, Field, validator, field_validator, ValidationInfo
77

88
from together.types.abstract import BaseModel
99
from together.types.common import (
@@ -345,13 +345,44 @@ class FinetuneTrainingLimits(BaseModel):
345345
lora_training: FinetuneLoraTrainingLimits | None = None
346346

347347

348-
class FinetuneLRScheduler(BaseModel):
349-
lr_scheduler_type: str
350-
lr_scheduler_args: FinetuneLinearLRSchedulerArgs | None = None
348+
class FinetuneLinearLRSchedulerArgs(BaseModel):
349+
min_lr_ratio: float | None = 0.0
351350

352351

353-
class FinetuneLinearLRSchedulerArgs(BaseModel):
352+
class FinetuneCosineLRSchedulerArgs(BaseModel):
354353
min_lr_ratio: float | None = 0.0
354+
num_cycles: float | None = 0.5
355+
356+
357+
LRSchedulerTypeToArgs = {
358+
"linear": FinetuneLinearLRSchedulerArgs,
359+
"cosine": FinetuneCosineLRSchedulerArgs,
360+
}
361+
362+
FinetuneLRSchedulerArgs = Union[
363+
FinetuneLinearLRSchedulerArgs, FinetuneCosineLRSchedulerArgs, None
364+
]
365+
366+
367+
class FinetuneLRScheduler(BaseModel):
368+
lr_scheduler_type: Literal["linear", "cosine"]
369+
lr_scheduler_args: FinetuneLRSchedulerArgs | None = None
370+
371+
@field_validator("lr_scheduler_args")
372+
@classmethod
373+
def validate_scheduler_args(
374+
cls, v: FinetuneLRSchedulerArgs, info: ValidationInfo
375+
) -> FinetuneLRSchedulerArgs:
376+
scheduler_type = info.data.get("lr_scheduler_type")
377+
378+
if v is None:
379+
return v
380+
381+
expected_type = LRSchedulerTypeToArgs[str(scheduler_type)]
382+
if not isinstance(v, expected_type):
383+
raise ValueError(f"Expected {expected_type}, got {type(v)}")
384+
385+
return v
355386

356387

357388
class FinetuneCheckpoint(BaseModel):

0 commit comments

Comments
 (0)