|
23 | 23 | TrainingType, |
24 | 24 | FinetuneLRScheduler, |
25 | 25 | FinetuneLinearLRSchedulerArgs, |
| 26 | + FinetuneCosineLRSchedulerArgs, |
26 | 27 | TrainingMethodDPO, |
27 | 28 | TrainingMethodSFT, |
28 | 29 | FinetuneCheckpoint, |
@@ -57,7 +58,9 @@ def createFinetuneRequest( |
57 | 58 | n_checkpoints: int | None = 1, |
58 | 59 | batch_size: int | Literal["max"] = "max", |
59 | 60 | learning_rate: float | None = 0.00001, |
| 61 | + lr_scheduler_type: Literal["linear", "cosine"] = "linear", |
60 | 62 | min_lr_ratio: float = 0.0, |
| 63 | + num_cycles: float = 0.5, |
61 | 64 | warmup_ratio: float = 0.0, |
62 | 65 | max_grad_norm: float = 1.0, |
63 | 66 | weight_decay: float = 0.0, |
@@ -129,10 +132,21 @@ def createFinetuneRequest( |
129 | 132 | f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}" |
130 | 133 | ) |
131 | 134 |
|
132 | | - lrScheduler = FinetuneLRScheduler( |
133 | | - lr_scheduler_type="linear", |
134 | | - lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), |
135 | | - ) |
| 135 | + if lr_scheduler_type == "cosine": |
| 136 | + if num_cycles <= 0.0: |
| 137 | + raise ValueError("Number of cycles should be greater than 0") |
| 138 | + |
| 139 | + lrScheduler = FinetuneLRScheduler( |
| 140 | + lr_scheduler_type="cosine", |
| 141 | + lr_scheduler_args=FinetuneCosineLRSchedulerArgs( |
| 142 | + min_lr_ratio=min_lr_ratio, num_cycles=num_cycles |
| 143 | + ), |
| 144 | + ) |
| 145 | + else: |
| 146 | + lrScheduler = FinetuneLRScheduler( |
| 147 | + lr_scheduler_type="linear", |
| 148 | + lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), |
| 149 | + ) |
136 | 150 |
|
137 | 151 | training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT() |
138 | 152 | if training_method == "dpo": |
@@ -244,7 +258,9 @@ def create( |
244 | 258 | n_checkpoints: int | None = 1, |
245 | 259 | batch_size: int | Literal["max"] = "max", |
246 | 260 | learning_rate: float | None = 0.00001, |
| 261 | + lr_scheduler_type: Literal["linear", "cosine"] = "linear", |
247 | 262 | min_lr_ratio: float = 0.0, |
| 263 | + num_cycles: float = 0.5, |
248 | 264 | warmup_ratio: float = 0.0, |
249 | 265 | max_grad_norm: float = 1.0, |
250 | 266 | weight_decay: float = 0.0, |
@@ -279,8 +295,10 @@ def create( |
279 | 295 | batch_size (int or "max"): Batch size for fine-tuning. Defaults to max. |
280 | 296 | learning_rate (float, optional): Learning rate multiplier to use for training |
281 | 297 | Defaults to 0.00001. |
| 298 | + lr_scheduler_type (Literal["linear", "cosine"]): Learning rate scheduler type. Defaults to "linear". |
282 | 299 | min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for |
283 | 300 | the learning rate scheduler. Defaults to 0.0. |
| 301 | + num_cycles (float, optional): Number of cycles for cosine learning rate scheduler. Defaults to 0.5. |
284 | 302 | warmup_ratio (float, optional): Warmup ratio for learning rate scheduler. |
285 | 303 | max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable. |
286 | 304 | weight_decay (float, optional): Weight decay. Defaults to 0.0. |
@@ -336,7 +354,9 @@ def create( |
336 | 354 | n_checkpoints=n_checkpoints, |
337 | 355 | batch_size=batch_size, |
338 | 356 | learning_rate=learning_rate, |
| 357 | + lr_scheduler_type=lr_scheduler_type, |
339 | 358 | min_lr_ratio=min_lr_ratio, |
| 359 | + num_cycles=num_cycles, |
340 | 360 | warmup_ratio=warmup_ratio, |
341 | 361 | max_grad_norm=max_grad_norm, |
342 | 362 | weight_decay=weight_decay, |
@@ -617,7 +637,9 @@ async def create( |
617 | 637 | n_checkpoints: int | None = 1, |
618 | 638 | batch_size: int | Literal["max"] = "max", |
619 | 639 | learning_rate: float | None = 0.00001, |
| 640 | + lr_scheduler_type: Literal["linear", "cosine"] = "linear", |
620 | 641 | min_lr_ratio: float = 0.0, |
| 642 | + num_cycles: float = 0.5, |
621 | 643 | warmup_ratio: float = 0.0, |
622 | 644 | max_grad_norm: float = 1.0, |
623 | 645 | weight_decay: float = 0.0, |
@@ -652,8 +674,10 @@ async def create( |
652 | 674 | batch_size (int, optional): Batch size for fine-tuning. Defaults to max. |
653 | 675 | learning_rate (float, optional): Learning rate multiplier to use for training |
654 | 676 | Defaults to 0.00001. |
| 677 | + lr_scheduler_type (Literal["linear", "cosine"]): Learning rate scheduler type. Defaults to "linear". |
655 | 678 | min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for |
656 | 679 | the learning rate scheduler. Defaults to 0.0. |
| 680 | + num_cycles (float, optional): Number of cycles for cosine learning rate scheduler. Defaults to 0.5. |
657 | 681 | warmup_ratio (float, optional): Warmup ratio for learning rate scheduler. |
658 | 682 | max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable. |
659 | 683 | weight_decay (float, optional): Weight decay. Defaults to 0.0. |
@@ -710,7 +734,9 @@ async def create( |
710 | 734 | n_checkpoints=n_checkpoints, |
711 | 735 | batch_size=batch_size, |
712 | 736 | learning_rate=learning_rate, |
| 737 | + lr_scheduler_type=lr_scheduler_type, |
713 | 738 | min_lr_ratio=min_lr_ratio, |
| 739 | + num_cycles=num_cycles, |
714 | 740 | warmup_ratio=warmup_ratio, |
715 | 741 | max_grad_norm=max_grad_norm, |
716 | 742 | weight_decay=weight_decay, |
|
0 commit comments