Skip to content

Commit a872438

Browse files
Arsh ZahedArsh Zahed
authored andcommitted
Port cosine lr scheduler init
1 parent 3c87636 commit a872438

1 file changed

Lines changed: 12 additions & 3 deletions

File tree

src/together/types/finetune.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,20 +365,29 @@ class FinetuneCosineLRSchedulerArgs(BaseModel):
365365

366366

367367
class FinetuneLRScheduler(BaseModel):
368-
lr_scheduler_type: Literal["linear", "cosine"]
368+
lr_scheduler_type: str
369369
lr_scheduler_args: FinetuneLRSchedulerArgs | None = None
370370

371+
@field_validator("lr_scheduler_type")
372+
@classmethod
373+
def validate_scheduler_type(cls, v: str) -> str:
374+
if v not in LRSchedulerTypeToArgs:
375+
raise ValueError(
376+
f"Scheduler type must be one of: {LRSchedulerTypeToArgs.keys()}"
377+
)
378+
return v
379+
371380
@field_validator("lr_scheduler_args")
372381
@classmethod
373382
def validate_scheduler_args(
374383
cls, v: FinetuneLRSchedulerArgs, info: ValidationInfo
375384
) -> FinetuneLRSchedulerArgs:
376-
scheduler_type = info.data.get("lr_scheduler_type")
385+
scheduler_type = str(info.data.get("lr_scheduler_type"))
377386

378387
if v is None:
379388
return v
380389

381-
expected_type = LRSchedulerTypeToArgs[str(scheduler_type)]
390+
expected_type = LRSchedulerTypeToArgs[scheduler_type]
382391
if not isinstance(v, expected_type):
383392
raise ValueError(f"Expected {expected_type}, got {type(v)}")
384393

0 commit comments

Comments
 (0)