-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathconfig.py
More file actions
488 lines (366 loc) · 16.1 KB
/
config.py
File metadata and controls
488 lines (366 loc) · 16.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
import os
from functools import partial
from pathlib import Path
from typing import Annotated, Any, Callable, Literal, Optional
import torch
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict, Field, FilePath, PositiveInt, field_validator, model_validator
from torch.distributed.fsdp import ShardingStrategy
from transformers import GPT2TokenizerFast
from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast
from typing_extensions import deprecated
from modalities.config.lookup_enum import LookupEnum
from modalities.config.pydantic_if_types import (
PydanticAppStateType,
PydanticCheckpointSavingExecutionIFType,
PydanticCheckpointSavingStrategyIFType,
PydanticCollatorIFType,
PydanticDatasetIFType,
PydanticDeviceMeshIFType,
PydanticFSDP1CheckpointLoadingIFType,
PydanticFSDP1ModuleType,
PydanticFSDP2ModuleType,
PydanticLLMDataLoaderIFType,
PydanticLRSchedulerIFType,
PydanticModelInitializationIFType,
PydanticOptimizerIFType,
PydanticPytorchDeviceType,
PydanticPytorchModuleType,
PydanticSamplerIFType,
PydanticTokenizerIFType,
)
from modalities.config.utils import parse_torch_device
from modalities.running_env.env_utils import (
FSDP2MixedPrecisionSettings,
MixedPrecisionSettings,
PyTorchDtypes,
has_bfloat_support,
)
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees
from modalities.util import parse_enum_by_name
class ProcessGroupBackendType(LookupEnum):
nccl = "nccl"
class TokenizerTypes(LookupEnum):
GPT2TokenizerFast = GPT2TokenizerFast
LlamaTokenizerFast = LlamaTokenizerFast
class PassType(LookupEnum):
BY_VALUE = "by_value"
BY_REFERENCE = "by_reference"
class WandbMode(LookupEnum):
ONLINE = "ONLINE"
OFFLINE = "OFFLINE"
DISABLED = "DISABLED"
class PrecisionEnum(LookupEnum):
FP32 = torch.float32
FP16 = torch.float16
BF16 = torch.bfloat16
class ReferenceConfig(BaseModel):
instance_key: str
pass_type: PassType
class CLMCrossEntropyLossConfig(BaseModel):
target_key: str
prediction_key: str
# Checkpointing
class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel):
k: PositiveInt
class SaveKMostRecentCheckpointsStrategyConfig(BaseModel):
k: Annotated[int, Field(strict=True, ge=-1)]
class TorchCheckpointLoadingConfig(BaseModel):
device: PydanticPytorchDeviceType
precision: Optional[PrecisionEnum] = None
@field_validator("device", mode="before")
def parse_device(cls, device) -> PydanticPytorchDeviceType:
return parse_torch_device(device)
class FSDP1CheckpointLoadingConfig(BaseModel):
global_rank: Annotated[int, Field(strict=True, ge=0)]
block_names: list[str]
mixed_precision_settings: MixedPrecisionSettings
sharding_strategy: ShardingStrategy
@field_validator("mixed_precision_settings", mode="before")
def parse_mixed_precision_setting_by_name(cls, name):
mixed_precision_settings: MixedPrecisionSettings = parse_enum_by_name(
name=name, enum_type=MixedPrecisionSettings
)
if not has_bfloat_support() and (
mixed_precision_settings == MixedPrecisionSettings.BF_16
or mixed_precision_settings == MixedPrecisionSettings.BF_16_WORKING
):
raise ValueError("BF16 not supported in the current environment")
return mixed_precision_settings
@field_validator("sharding_strategy", mode="before")
def parse_sharding_strategy_by_name(cls, name):
return parse_enum_by_name(name=name, enum_type=ShardingStrategy)
class DCPCheckpointLoadingConfig(BaseModel):
global_rank: Annotated[int, Field(strict=True, ge=0)]
class FSDP1CheckpointSavingConfig(BaseModel):
checkpoint_path: Path
global_rank: Annotated[int, Field(strict=True, ge=0)]
experiment_id: str
class DCPCheckpointSavingConfig(BaseModel):
checkpoint_path: Path
global_rank: Annotated[int, Field(strict=True, ge=0)]
experiment_id: str
class CheckpointSavingConfig(BaseModel):
checkpoint_saving_strategy: PydanticCheckpointSavingStrategyIFType
checkpoint_saving_execution: PydanticCheckpointSavingExecutionIFType
class AdamOptimizerConfig(BaseModel):
lr: float
wrapped_model: PydanticPytorchModuleType
betas: tuple[float, float]
eps: float
weight_decay: float
weight_decay_groups_excluded: list[str]
class AdamWOptimizerConfig(BaseModel):
lr: float
wrapped_model: PydanticPytorchModuleType
betas: tuple[float, float]
eps: float
weight_decay: float
weight_decay_groups_excluded: list[str]
class DummyLRSchedulerConfig(BaseModel):
optimizer: PydanticOptimizerIFType
class StepLRSchedulerConfig(BaseModel):
optimizer: PydanticOptimizerIFType
step_size: Annotated[int, Field(strict=True, gt=0)]
gamma: Annotated[float, Field(strict=True, ge=0.0)]
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1
class OneCycleLRSchedulerConfig(BaseModel):
optimizer: PydanticOptimizerIFType
max_lr: Annotated[float, Field(strict=True, gt=0.0)] | list[Annotated[float, Field(strict=True, gt=0.0)]]
total_steps: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
epochs: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
steps_per_epoch: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
pct_start: Annotated[float, Field(strict=True, gt=0.0, le=1.0)]
anneal_strategy: str
cycle_momentum: bool = True
base_momentum: Annotated[float, Field(strict=True, gt=0)] | list[
Annotated[float, Field(strict=True, gt=0.0)]
] = 0.85
max_momentum: Annotated[float, Field(strict=True, gt=0.0)] | list[
Annotated[float, Field(strict=True, gt=0.0)]
] = 0.95
div_factor: Annotated[float, Field(strict=True, gt=0.0)]
final_div_factor: Annotated[float, Field(strict=True, gt=0.0)]
three_phase: bool = False
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1
@model_validator(mode="after")
def check_totals_steps_and_epchs(self) -> "OneCycleLRSchedulerConfig":
if self.total_steps is None and (self.epochs is None or self.steps_per_epoch is None):
raise ValueError("Please define total_steps or (epochs and steps_per_epoch).")
return self
class ConstantLRSchedulerConfig(BaseModel):
optimizer: PydanticOptimizerIFType
factor: Annotated[float, Field(strict=True, ge=0.0, le=1.0)]
total_iters: Annotated[int, Field(strict=True, gt=0)]
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1
class LinearLRSchedulerConfig(BaseModel):
optimizer: PydanticOptimizerIFType
start_factor: Annotated[float, Field(strict=True, gt=0.0, le=1.0)]
end_factor: Annotated[float, Field(strict=True, ge=0.0, le=1.0)]
total_iters: Annotated[int, Field(strict=True, gt=0)]
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1
class CosineAnnealingLRSchedulerConfig(BaseModel):
optimizer: PydanticOptimizerIFType
t_max: Annotated[int, Field(strict=True, gt=0)]
eta_min: Annotated[float, Field(strict=True, ge=0.0)]
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1
class FSDP1CheckpointedOptimizerConfig(BaseModel):
checkpoint_loading: PydanticFSDP1CheckpointLoadingIFType
checkpoint_path: Path
wrapped_model: PydanticPytorchModuleType
optimizer: PydanticOptimizerIFType
class FSDP1CheckpointedModelConfig(BaseModel):
checkpoint_loading: PydanticFSDP1CheckpointLoadingIFType
checkpoint_path: Path
model: PydanticPytorchModuleType
@deprecated(
"With version 0.4, we upgraded FSDP to FSDP 2.0. Use get_fsdp_2_wrapped_model(...) "
"and FSDP2WrappedModelConfig instead.",
category=FutureWarning,
)
class FSDPWrappedModelConfig(BaseModel):
model: PydanticPytorchModuleType
sync_module_states: bool
mixed_precision_settings: MixedPrecisionSettings
sharding_strategy: ShardingStrategy
block_names: list[str]
@field_validator("mixed_precision_settings", mode="before")
def parse_mixed_precision_setting_by_name(cls, name):
mixed_precision_settings: MixedPrecisionSettings = parse_enum_by_name(
name=name, enum_type=MixedPrecisionSettings
)
if not has_bfloat_support() and (
mixed_precision_settings == MixedPrecisionSettings.BF_16
or mixed_precision_settings == MixedPrecisionSettings.BF_16_WORKING
):
raise ValueError("BF16 not supported in the current environment")
return mixed_precision_settings
@field_validator("sharding_strategy", mode="before")
def parse_sharding_strategy_by_name(cls, name):
return parse_enum_by_name(name=name, enum_type=ShardingStrategy)
class FSDP2WrappedModelConfig(BaseModel):
model: PydanticPytorchModuleType
block_names: list[str]
mixed_precision_settings: FSDP2MixedPrecisionSettings
reshard_after_forward: bool = True
device_mesh: PydanticDeviceMeshIFType
@model_validator(mode="after")
def validate_mixed_precision_settings(self):
if not has_bfloat_support() and (
self.mixed_precision_settings.reduce_dtype == PyTorchDtypes.BF_16
or self.mixed_precision_settings.param_dtype == PyTorchDtypes.BF_16
):
raise ValueError("BF16 not supported in the current environment")
return self
@model_validator(mode="after")
def validate_dp_mesh_existence(self):
if ParallelismDegrees.DP_SHARD.value not in self.device_mesh.mesh_dim_names:
raise ValueError(f"Data parallelism key '{ParallelismDegrees.DP_SHARD.value}' not in {self.device_mesh=}")
return self
class CompiledModelConfig(BaseModel):
model: PydanticPytorchModuleType
block_names: list[str]
fullgraph: Optional[bool] = True
debug: Optional[bool] = False
class WeightInitializedModelConfig(BaseModel):
model: PydanticPytorchModuleType
model_initializer: PydanticModelInitializationIFType
# avoid warning about protected namespace 'model_', see
# https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces
model_config = ConfigDict(protected_namespaces=())
class ActivationCheckpointedModelConfig(BaseModel):
model: PydanticFSDP1ModuleType
activation_checkpointing_modules: Optional[list[str]] = Field(default_factory=list)
class RawAppStateConfig(BaseModel):
model: PydanticPytorchModuleType
optimizer: PydanticOptimizerIFType
lr_scheduler: Optional[PydanticLRSchedulerIFType] = None
class DCPAppStateConfig(BaseModel):
raw_app_state: PydanticAppStateType
checkpoint_dir_path: Path
load_model_checkpoint: bool = True
load_optimizer_checkpoint: bool = True
load_lr_scheduler_checkpoint: bool = True
class PreTrainedHFTokenizerConfig(BaseModel):
pretrained_model_name_or_path: str
max_length: Optional[Annotated[int, Field(strict=True, ge=0)]] = None
truncation: bool = False
padding: bool | str = False
special_tokens: Optional[dict[str, str | list | tuple]] = None
class PreTrainedSPTokenizerConfig(BaseModel):
tokenizer_model_file: str
class SequentialSamplerConfig(BaseModel):
data_source: PydanticDatasetIFType
class DistributedSamplerConfig(BaseModel):
rank: Annotated[int, Field(strict=True, ge=0)]
num_replicas: Annotated[int, Field(strict=True, ge=0)]
shuffle: bool
dataset: PydanticDatasetIFType
seed: Optional[int] = 0
drop_last: Literal[True] = True
class ResumableDistributedSamplerConfig(BaseModel):
dataset: PydanticDatasetIFType
rank: Annotated[int, Field(strict=True, ge=0)]
num_replicas: Annotated[int, Field(strict=True, ge=0)] = None
epoch: Annotated[int, Field(strict=True, ge=0)] = 0
shuffle: Optional[bool] = False
seed: Optional[int] = 0
drop_last: Literal[True] = True
skip_num_global_samples: Annotated[int, Field(strict=True, ge=0)] = 0
class MemMapDatasetConfig(BaseModel):
raw_data_path: FilePath
index_path: Optional[FilePath] = None
tokenizer: PydanticTokenizerIFType
jq_pattern: str
sample_key: str
class PackedMemMapDatasetContinuousConfig(BaseModel):
raw_data_path: Path
sequence_length: Annotated[int, Field(strict=True, gt=1)]
sample_key: str
reuse_last_target: bool = Field(default=True)
class MemMapDatasetIterativeConfig(BaseModel):
raw_data_path: Path
sample_key: str
class PackedMemMapDatasetMegatronConfig(BaseModel):
raw_data_path: Path
block_size: Annotated[int, Field(strict=True, gt=1)]
sample_key: str
class CombinedDatasetConfig(BaseModel):
datasets: list[PydanticDatasetIFType]
class BatchSamplerConfig(BaseModel):
sampler: PydanticSamplerIFType
batch_size: Annotated[int, Field(strict=True, gt=0)]
drop_last: Literal[True] = True
class LLMDataLoaderConfig(BaseModel):
dataloader_tag: str
dataset: PydanticDatasetIFType
batch_sampler: PydanticSamplerIFType
collator: Optional[PydanticCollatorIFType] = None
num_workers: Annotated[int, Field(strict=True, ge=0)]
pin_memory: bool
class DummyProgressSubscriberConfig(BaseModel):
pass
class RichProgressSubscriberConfig(BaseModel):
eval_dataloaders: Optional[list[PydanticLLMDataLoaderIFType]] = Field(default_factory=list)
train_dataloader_tag: str
num_seen_steps: Annotated[int, Field(strict=True, ge=0)]
num_target_steps: Annotated[int, Field(strict=True, gt=0)]
global_rank: Annotated[int, Field(strict=True, ge=0)]
class DummyResultSubscriberConfig(BaseModel):
pass
class WandBEvaluationResultSubscriberConfig(BaseModel):
global_rank: int
project: str
experiment_id: str
mode: WandbMode
directory: Path
config_file_path: Path
class RichResultSubscriberConfig(BaseModel):
num_ranks: int
global_rank: int
class GPT2MFUCalculatorConfig(BaseModel):
n_layer: Annotated[int, Field(strict=True, gt=0)]
sequence_length: Annotated[int, Field(strict=True, gt=0)]
n_embd: Annotated[int, Field(strict=True, gt=0)]
world_size: Annotated[int, Field(strict=True, gt=0)]
wrapped_model: PydanticFSDP1ModuleType | PydanticFSDP2ModuleType
def load_app_config_dict(
config_file_path: Path,
experiment_id: Optional[str] = None,
additional_resolver_funs: Optional[dict[str, Callable]] = None,
) -> dict:
"""Load the application configuration from the given YAML file.
The function defines custom resolvers for the OmegaConf library to resolve environment variables and
Modalities-specific variables.
Args:
config_file_path (Path): YAML config file.
experiment_id (str, optional): The experiment_id of the current run. Defaults to None.
additional_resolver_funs (dict[str, Callable], optional): Additional resolver functions. Defaults to None.
Returns:
dict: Dictionary representation of the config file.
"""
def cuda_env_resolver_fun(var_name: str) -> int:
int_env_variable_names = ["LOCAL_RANK", "WORLD_SIZE", "RANK"]
return int(os.getenv(var_name)) if var_name in int_env_variable_names else os.getenv(var_name)
def modalities_env_resolver_fun(var_name: str, kwargs: dict[str, Any]) -> str | Path:
if var_name in kwargs:
return kwargs[var_name]
else:
raise ValueError(f"Unknown modalities_env variable: {var_name}.")
def node_env_resolver_fun(var_name: str) -> int:
if var_name == "num_cpus":
return os.cpu_count()
OmegaConf.register_new_resolver("cuda_env", cuda_env_resolver_fun, replace=True)
modalities_env_kwargs = {"config_file_path": config_file_path}
if experiment_id is not None:
modalities_env_kwargs["experiment_id"] = experiment_id
OmegaConf.register_new_resolver(
"modalities_env", partial(modalities_env_resolver_fun, kwargs=modalities_env_kwargs), replace=True
)
OmegaConf.register_new_resolver("node_env", node_env_resolver_fun, replace=True)
if additional_resolver_funs is not None:
for resolver_name, resolver_fun in additional_resolver_funs.items():
OmegaConf.register_new_resolver(resolver_name, resolver_fun, replace=True)
cfg = OmegaConf.load(config_file_path)
config_dict = OmegaConf.to_container(cfg, resolve=True)
return config_dict