Skip to content

Commit d7b1513

Browse files
committed
feat: added LinearWarmupCosineAnnealingLRScheduler
1 parent 20c678e commit d7b1513

4 files changed

Lines changed: 99 additions & 8 deletions

File tree

src/modalities/config/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,22 @@ class CosineAnnealingLRSchedulerConfig(BaseModel):
227227
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1
228228

229229

230+
class LinearWarmupCosineAnnealingLRSchedulerConfig(BaseModel):
231+
optimizer: PydanticOptimizerIFType
232+
warmup_steps: Annotated[int, Field(strict=True, gt=0)]
233+
total_steps: Annotated[int, Field(strict=True, gt=0)]
234+
initial_lr: Annotated[float, Field(strict=True, ge=0.0)]
235+
final_lr: Annotated[float, Field(strict=True, ge=0.0)]
236+
max_lr: Annotated[float, Field(strict=True, ge=0.0)]
237+
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1
238+
239+
@model_validator(mode="after")
240+
def check_total_steps_greater_than_warmup_steps(self) -> "LinearWarmupCosineAnnealingLRSchedulerConfig":
241+
if self.total_steps <= self.warmup_steps:
242+
raise ValueError("total_steps must be greater than warmup_steps.")
243+
return self
244+
245+
230246
class FSDP1CheckpointedOptimizerConfig(BaseModel):
231247
checkpoint_loading: PydanticFSDP1CheckpointLoadingIFType
232248
checkpoint_path: Path
Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,64 @@
11
import warnings
2-
from typing import Optional
32

3+
from torch import Tensor
44
from torch.optim import Optimizer
5-
from torch.optim.lr_scheduler import LRScheduler
5+
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, LRScheduler, SequentialLR
66

77

88
class DummyLRScheduler(LRScheduler):
9-
def __init__(self, optimizer: Optimizer, last_epoch: Optional[int] = -1):
9+
def __init__(self, optimizer: Optimizer, last_epoch: int = -1):
1010
super().__init__(optimizer, last_epoch)
1111

12-
def get_lr(self) -> list[float]:
12+
def get_lr(self) -> list[float | Tensor]:
1313
if not self._get_lr_called_within_step: # type error expected due to internal pytorch implementation
1414
warnings.warn(
15-
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning
15+
"To get the last learning rate computed by the scheduler, please use `get_last_lr()`.",
16+
UserWarning,
1617
)
1718

1819
return [group["lr"] for group in self.optimizer.param_groups]
1920

20-
def _get_closed_form_lr(self) -> list[float]:
21+
def _get_closed_form_lr(self) -> list[float | Tensor]:
2122
return self.base_lrs
23+
24+
25+
class LRSchedulerFactory:
26+
@staticmethod
27+
def get_linear_warmup_cosine_annealing_lr_scheduler(
28+
optimizer: Optimizer,
29+
warmup_steps: int,
30+
total_steps: int,
31+
initial_lr: float,
32+
final_lr: float,
33+
max_lr: float,
34+
last_epoch: int = -1,
35+
) -> SequentialLR:
36+
if warmup_steps <= 0:
37+
raise ValueError("warmup_steps must be greater than 0.")
38+
if total_steps <= warmup_steps:
39+
raise ValueError("total_steps must be greater than warmup_steps.")
40+
41+
if not all(base_lr == max_lr for base_lr in [group["lr"] for group in optimizer.param_groups]):
42+
raise ValueError(
43+
"All parameter groups must have the same initial_lr."
44+
"and it must be equal to the initial_lr passed to the LR scheduler factory."
45+
)
46+
47+
warmup_scheduler = LinearLR(
48+
optimizer=optimizer,
49+
start_factor=initial_lr / max_lr,
50+
end_factor=1,
51+
total_iters=warmup_steps,
52+
)
53+
cosine_scheduler = CosineAnnealingLR(
54+
optimizer=optimizer,
55+
T_max=total_steps - warmup_steps,
56+
eta_min=final_lr,
57+
)
58+
59+
return SequentialLR(
60+
optimizer=optimizer,
61+
schedulers=[warmup_scheduler, cosine_scheduler],
62+
milestones=[warmup_steps],
63+
last_epoch=last_epoch,
64+
)

src/modalities/registry/components.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
GPT2MFUCalculatorConfig,
4949
GPT2ModelTPConfig,
5050
LinearLRSchedulerConfig,
51+
LinearWarmupCosineAnnealingLRSchedulerConfig,
5152
LLMDataLoaderConfig,
5253
MemMapDatasetConfig,
5354
OneCycleLRSchedulerConfig,
@@ -108,7 +109,7 @@
108109
ComposedInitializationRoutines,
109110
ComposedModelInitializationConfig,
110111
)
111-
from modalities.optimizers.lr_schedulers import DummyLRScheduler
112+
from modalities.optimizers.lr_schedulers import DummyLRScheduler, LRSchedulerFactory
112113
from modalities.optimizers.optimizer_factory import OptimizerFactory
113114
from modalities.optimizers.optimizer_list import OptimizersList
114115
from modalities.optimizers.scheduler_list import SchedulerList
@@ -285,6 +286,12 @@ class ComponentEntity:
285286
maybe_optimizer_list(torch.optim.lr_scheduler.CosineAnnealingLR),
286287
CosineAnnealingLRSchedulerConfig,
287288
),
289+
ComponentEntity(
290+
"scheduler",
291+
"linear_warmup_cosine_annealing_lr",
292+
maybe_optimizer_list(LRSchedulerFactory.get_linear_warmup_cosine_annealing_lr_scheduler),
293+
LinearWarmupCosineAnnealingLRSchedulerConfig,
294+
),
288295
# tokenizers
289296
ComponentEntity("tokenizer", "pretrained_hf_tokenizer", PreTrainedHFTokenizer, PreTrainedHFTokenizerConfig),
290297
ComponentEntity("tokenizer", "pretrained_sp_tokenizer", PreTrainedSPTokenizer, PreTrainedSPTokenizerConfig),

tests/test_lr_scheduler.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from unittest.mock import MagicMock, call
22

33
import numpy as np
4+
import torch
45

56
from modalities.checkpointing.checkpoint_saving import CheckpointSaving
67
from modalities.checkpointing.stateful.app_state import AppState
78
from modalities.dataloader.dataloader import LLMDataLoader
89
from modalities.evaluator import Evaluator
910
from modalities.gym import Gym
1011
from modalities.loss_functions import Loss
11-
from modalities.optimizers.lr_schedulers import DummyLRScheduler
12+
from modalities.optimizers.lr_schedulers import DummyLRScheduler, LRSchedulerFactory
1213
from modalities.trainer import Trainer
1314
from tests.utility import configure_dataloader_mock
1415

@@ -76,3 +77,27 @@ def test_dummy_lr_scheduler(optimizer_with_param_groups_mock: MagicMock):
7677
assert np.allclose(scheduler.get_lr(), [0.08, 0.18, 0.28], atol=1e-6)
7778
assert scheduler._get_closed_form_lr() == [0.1, 0.2, 0.3]
7879
assert np.allclose(scheduler.get_last_lr(), [0.08, 0.18, 0.28], atol=1e-6)
80+
81+
82+
def test_linear_warmup_cosine_annealing_lr_scheduler():
83+
parameter = torch.nn.Parameter(torch.tensor([1.0]))
84+
optimizer = torch.optim.SGD([parameter], lr=1.0)
85+
scheduler = LRSchedulerFactory.get_linear_warmup_cosine_annealing_lr_scheduler(
86+
optimizer=optimizer,
87+
warmup_steps=2,
88+
total_steps=6,
89+
initial_lr=0.1,
90+
final_lr=0.2,
91+
max_lr=1.0,
92+
)
93+
94+
learning_rates = [scheduler.get_last_lr()[0]]
95+
for _ in range(6):
96+
optimizer.step()
97+
scheduler.step()
98+
learning_rates.append(scheduler.get_last_lr()[0])
99+
100+
assert learning_rates[0] < learning_rates[1] < learning_rates[2]
101+
assert np.isclose(learning_rates[2], 1.0, atol=1e-6)
102+
assert learning_rates[2] > learning_rates[3] > learning_rates[4] > learning_rates[5] > learning_rates[6]
103+
assert np.isclose(learning_rates[6], 0.2, atol=1e-6)

0 commit comments

Comments
 (0)