-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
704 lines (614 loc) · 26.9 KB
/
train.py
File metadata and controls
704 lines (614 loc) · 26.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
import time
from datetime import timedelta
from typing import Any, Generator, Iterable, Optional
import torch
from torch.distributed.elastic.multiprocessing.errors import record
import torchtitan.protocols.train_spec as train_spec_module
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.dataloader import DataloaderExhaustedError
from torchtitan.components.ft import FTManager, maybe_semi_sync_training
from torchtitan.components.loss import rescale_accumulated_loss
from torchtitan.components.metrics import (
build_metrics_processor,
ensure_pp_loss_visible,
)
from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.protocols.model_converter import build_model_converters
from torchtitan.tools import utils
from torchtitan.tools.logging import init_logger, logger
from torchtitan.tools.profiling import (
maybe_enable_memory_snapshot,
maybe_enable_profiling,
)
class Trainer(torch.distributed.checkpoint.stateful.Stateful):
# core configs
job_config: JobConfig
parallel_dims: ParallelDims
train_spec: train_spec_module.TrainSpec
# swappable training components in TrainSpec
tokenizer: train_spec_module.BaseTokenizer | None
dataloader: train_spec_module.BaseDataLoader
model_parts: list[torch.nn.Module]
loss_fn: train_spec_module.LossFunction
optimizers: train_spec_module.OptimizersContainer
lr_schedulers: train_spec_module.LRSchedulersContainer
validator: train_spec_module.BaseValidator
metrics_processor: train_spec_module.MetricsProcessor
model_args: train_spec_module.BaseModelArgs
# non-swappable training components
checkpointer: CheckpointManager
ft_manager: FTManager
# runtime utilities
device: torch.device
gc_handler: utils.GarbageCollection
train_context: Generator[None, None, None]
gradient_accumulation_steps: int
pp_has_first_stage: bool
pp_has_last_stage: bool
# additional training states
step: int
ntokens_seen: int
# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
@record
def __init__(self, job_config: JobConfig):
torch._C._log_api_usage_once("torchtitan.train")
self.job_config = job_config
logger.info(f"Starting job: {job_config.job.description}")
if job_config.experimental.custom_import:
importlib.import_module(job_config.experimental.custom_import)
device_module, device_type = utils.device_module, utils.device_type
self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
# Device has to be set before creating TorchFT manager.
device_module.set_device(self.device)
# init distributed and build meshes
dist_utils.init_distributed(
job_config.comm,
enable_cpu_backend=job_config.training.enable_cpu_offload,
base_folder=job_config.job.dump_folder,
)
job_config.maybe_log()
world_size = int(os.environ["WORLD_SIZE"])
parallelism_config = job_config.parallelism
self.parallel_dims = parallel_dims = self._create_parallel_dims(
parallelism_config, world_size
)
world_mesh = parallel_dims.world_mesh
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0
self.ft_manager = FTManager(job_config.fault_tolerance)
dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank)
# take control of garbage collection to avoid stragglers
self.gc_handler = utils.GarbageCollection(
gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug
)
# Set random seed, and maybe enable deterministic mode
# (mainly for debugging, expect perf loss).
dist_utils.set_determinism(
world_mesh,
self.device,
job_config.training.seed,
job_config.training.deterministic,
)
self.train_spec = train_spec_module.get_train_spec(job_config.model.name)
# build tokenizer and dataloader
self.tokenizer = (
self.train_spec.build_tokenizer_fn(job_config)
if self.train_spec.build_tokenizer_fn is not None
else None
)
self.dataloader = self.train_spec.build_dataloader_fn(
dp_world_size=dp_degree,
dp_rank=dp_rank,
tokenizer=self.tokenizer,
job_config=job_config,
)
# build model (using meta init)
model_args = self.train_spec.model_args[job_config.model.flavor]
# set the model args from training job configs
model_args.update_from_config(job_config)
self.model_args = model_args
logger.info(
f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}"
)
with (
torch.device("meta"),
utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]),
):
model = self.train_spec.model_cls(model_args)
# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
model_converters.convert(model)
# metrics logging
build_metrics_processor_fn = (
build_metrics_processor
if self.train_spec.build_metrics_processor_fn is None
else self.train_spec.build_metrics_processor_fn
)
self.metrics_processor = build_metrics_processor_fn(
job_config, parallel_dims, model_args
)
color = self.metrics_processor.color
# calculate model size and flops per token
(
model_param_count,
self.metrics_processor.num_flops_per_token,
) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)
logger.info(
f"{color.blue}Model {job_config.model.name} {job_config.model.flavor} "
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
)
# move sharded model to CPU/GPU and initialize weights via DTensor
if job_config.checkpoint.create_seed_checkpoint:
init_device = "cpu"
buffer_device = None
elif job_config.training.enable_cpu_offload:
init_device = "cpu"
buffer_device = device_type
else:
init_device = device_type
buffer_device = None
self.loss_fn = self.train_spec.build_loss_fn(
job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager
)
# verify batch sizes
global_batch_size = job_config.training.global_batch_size
if global_batch_size < 0:
# This global batch size results in 1 gradient accumulation
# step.
global_batch_size = job_config.training.local_batch_size * dp_degree
assert global_batch_size > 0
assert (
global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0
), (
f"global batch size must be multiple of local batch size times "
f"data-parallel degree ({global_batch_size} "
f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)"
)
# calculate gradient accumulation steps
self.gradient_accumulation_steps = global_batch_size // (
job_config.training.local_batch_size * dp_degree
)
assert self.gradient_accumulation_steps > 0
self.loss_fn = rescale_accumulated_loss(
self.loss_fn, self.gradient_accumulation_steps
)
# apply parallelisms and initialization
if parallel_dims.pp_enabled:
if not self.train_spec.pipelining_fn:
raise RuntimeError(
f"Pipeline Parallel is enabled but {job_config.model.name} "
f"does not support pipelining"
)
# apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques
(
self.pp_schedule,
self.model_parts,
self.pp_has_first_stage,
self.pp_has_last_stage,
) = self.train_spec.pipelining_fn(
model,
parallel_dims,
job_config,
self.device,
model_args,
self.train_spec.parallelize_fn,
self.loss_fn,
)
# when PP is enabled, `model` obj is no longer used after this point,
# model_parts is used instead
del model
for m in self.model_parts:
m.to_empty(device=init_device)
with torch.no_grad():
m.init_weights(buffer_device=buffer_device)
m.train()
# confirm that user will be able to view loss metrics on the console
ensure_pp_loss_visible(parallel_dims, job_config, color)
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
model = self.train_spec.parallelize_fn(model, parallel_dims, job_config)
model.to_empty(device=init_device)
with torch.no_grad():
model.init_weights(buffer_device=buffer_device)
model.train()
self.model_parts = [model]
self.ft_manager.maybe_set_all_reduce_hook(self.model_parts)
# initialize device memory monitor and get peak flops for MFU calculation
device_memory_monitor = self.metrics_processor.device_memory_monitor
gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
logger.info(f"WORLD_SIZE: {world_size}")
logger.info(f"NUM NODES: {world_size // 4}")
logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
device_mem_stats = device_memory_monitor.get_peak_stats()
logger.info(
f"{device_type.upper()} memory usage for model: "
f"{device_mem_stats.max_reserved_gib:.2f}GiB"
f"({device_mem_stats.max_reserved_pct:.2f}%)"
)
# build optimizer after applying parallelisms to the model
self.optimizers = self.train_spec.build_optimizers_fn(
self.model_parts, job_config.optimizer, parallel_dims, self.ft_manager
)
self.lr_schedulers = self.train_spec.build_lr_schedulers_fn(
self.optimizers, job_config.lr_scheduler, job_config.training.steps
)
# Post optimizer step model converters hook.
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
# where it issues a single all-reduce for all parameters at once for better performance
self.optimizers.register_step_post_hook(
lambda *args, **kwargs: model_converters.post_optimizer_hook(
self.model_parts
)
)
self.metrics_processor.optimizers = self.optimizers
self.metrics_processor.model_parts = self.model_parts
# Initialize trainer states that will be saved in checkpoint.
# These attributes must be initialized before checkpoint loading.
self.step = 0
self.ntokens_seen = 0
self.checkpointer = CheckpointManager(
dataloader=self.dataloader,
model_parts=self.model_parts,
optimizers=self.optimizers,
lr_schedulers=self.lr_schedulers,
states={"train_state": self},
checkpoint_config=job_config.checkpoint,
sd_adapter=(
self.train_spec.state_dict_adapter(
model_args, job_config.model.hf_assets_path
)
if self.train_spec.state_dict_adapter
else None
),
base_folder=job_config.job.dump_folder,
ft_manager=self.ft_manager,
)
loss_parallel_enabled = (
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
)
self.train_context = dist_utils.get_train_context(
loss_parallel_enabled,
parallelism_config.enable_compiled_autograd,
)
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
parallel_dims,
job_config.training.mixed_precision_param,
device_type,
)
# Build validator if validation is configured
if job_config.validation.enable:
assert self.train_spec.build_validator_fn is not None
pp_schedule, pp_has_first_stage, pp_has_last_stage = (
(
self.pp_schedule,
self.pp_has_first_stage,
self.pp_has_last_stage,
)
if parallel_dims.pp_enabled
else (None, None, None)
)
self.validator = self.train_spec.build_validator_fn(
job_config=job_config,
dp_world_size=dp_degree,
dp_rank=dp_rank,
tokenizer=self.tokenizer,
parallel_dims=parallel_dims,
loss_fn=self.loss_fn,
validation_context=self.train_context,
maybe_enable_amp=self.maybe_enable_amp,
metrics_processor=self.metrics_processor,
pp_schedule=pp_schedule,
pp_has_first_stage=pp_has_first_stage,
pp_has_last_stage=pp_has_last_stage,
)
logger.info(
"Trainer is initialized with "
f"local batch size {job_config.training.local_batch_size}, "
f"global batch size {global_batch_size}, "
f"gradient accumulation steps {self.gradient_accumulation_steps}, "
f"sequence length {job_config.training.seq_len}, "
f"total steps {job_config.training.steps} "
f"(warmup {job_config.lr_scheduler.warmup_steps})"
)
def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
return ParallelDims(
dp_shard=parallelism_config.data_parallel_shard_degree,
dp_replicate=parallelism_config.data_parallel_replicate_degree,
cp=parallelism_config.context_parallel_degree,
tp=parallelism_config.tensor_parallel_degree,
pp=parallelism_config.pipeline_parallel_degree,
ep=parallelism_config.expert_parallel_degree,
etp=parallelism_config.expert_tensor_parallel_degree,
world_size=world_size,
)
def batch_generator(
self, data_iterable: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]:
"""Returns an iterator that processes batches from the data iterator."""
device_type = utils.device_type
data_iterator = iter(data_iterable)
while True:
data_load_start = time.perf_counter()
try:
batch = next(data_iterator)
except StopIteration as ex:
# If data runs out during gradient accumulation, that
# entire step will not be executed.
raise DataloaderExhaustedError() from ex
input_dict, labels = batch
ntokens_batch = labels.numel()
self.ntokens_seen += ntokens_batch
self.metrics_processor.ntokens_since_last_log += ntokens_batch
self.metrics_processor.data_loading_times.append(
time.perf_counter() - data_load_start
)
# Move tensors to the appropriate device
for k, v in input_dict.items():
if isinstance(v, torch.Tensor):
input_dict[k] = v.to(device_type)
labels = labels.to(device_type)
yield input_dict, labels
def forward_backward_step(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims
inputs = input_dict["input"]
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
# For arguments, like attention_masks, we have to put them in a separate
# dict as extra_inputs are not forwarded to other stages in PP, but
# extra_kwargs are.
extra_kwargs = {}
if getattr(self.model_args, "use_flex_attn", False):
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
input_batch=inputs,
tokenizer=self.tokenizer,
extra_inputs=extra_inputs,
)
# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
inputs = input_dict["input"]
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
cp_mesh = parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=parallel_dims.world_mesh["cp"],
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={inputs, labels},
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
)
if parallel_dims.cp_enabled
else None
)
if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
with self.train_context(optional_context_parallel_ctx):
targets, losses = (
(labels, []) if self.pp_has_last_stage else (None, None)
)
if self.pp_has_first_stage:
self.pp_schedule.step(
inputs,
**extra_inputs,
**extra_kwargs,
target=targets,
losses=losses,
return_outputs=False,
)
else:
self.pp_schedule.step(
**extra_kwargs,
target=targets,
losses=losses,
return_outputs=False,
)
# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
# using sum instead of mean because we already rescale the
# loss_fn down by a factor of n_microbatches in
# torchtitan/distributed/pipeline_parallel.py
torch.sum(torch.stack(losses)).to(self.device)
if self.pp_has_last_stage
else torch.tensor([-1.0], device=self.device)
)
else:
# Non-PP forward / backward
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
loss = self.loss_fn(pred, labels)
# need to free pred before bwd to avoid peaking memory
del pred
loss.backward()
return loss
def train_step(
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
):
self.optimizers.zero_grad()
# Save the current step learning rate for logging
lr = self.lr_schedulers.schedulers[0].get_last_lr()[0]
# Keep these variables local to shorten the code as these are
# the major variables that are used in the training loop.
parallel_dims = self.parallel_dims
accumulated_losses = []
# If data runs out during gradient accumulation, that
# entire step will not be executed.
for _microbatch in range(self.gradient_accumulation_steps):
input_dict, labels = next(data_iterator)
loss = self.forward_backward_step(input_dict, labels)
accumulated_losses.append(loss.detach())
grad_norm = dist_utils.clip_grad_norm_(
[p for m in self.model_parts for p in m.parameters()],
self.job_config.training.max_norm,
foreach=True,
pp_mesh=(
parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
),
ep_enabled=parallel_dims.ep_enabled,
)
self.checkpointer.maybe_wait_for_staging()
self.optimizers.step()
self.lr_schedulers.step()
# Reduce the data collected over gradient accumulation steps.
loss = torch.sum(torch.stack(accumulated_losses))
# log metrics
if not self.metrics_processor.should_log(self.step):
return
if parallel_dims.dp_cp_enabled:
loss = loss.detach()
ft_pg = self.ft_manager.loss_sync_pg
global_avg_loss, global_max_loss, global_ntokens_seen = (
dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg),
dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg),
# we are getting the global tokens seen from a dist sum
dist_utils.dist_sum(
torch.tensor(
self.ntokens_seen, dtype=torch.int64, device=self.device
),
parallel_dims.world_mesh["dp_cp"],
ft_pg,
),
)
else:
global_avg_loss = global_max_loss = loss.detach().item()
global_ntokens_seen = self.ntokens_seen
extra_metrics = {
"n_tokens_seen": global_ntokens_seen,
"lr": lr,
}
self.metrics_processor.log(
self.step,
global_avg_loss,
global_max_loss,
grad_norm.item(),
extra_metrics=extra_metrics,
)
@record
def train(self):
job_config = self.job_config
self.checkpointer.load(step=job_config.checkpoint.load_step)
logger.info(f"Training starts at step {self.step + 1}")
leaf_folder = (
""
if not self.ft_manager.enabled
else f"replica_{self.ft_manager.replica_id}"
)
with (
maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
leaf_folder=leaf_folder,
) as torch_profiler,
maybe_enable_memory_snapshot(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
leaf_folder=leaf_folder,
) as memory_profiler,
maybe_semi_sync_training(
job_config.fault_tolerance,
ft_manager=self.ft_manager,
model=self.model_parts[0],
n_layers=(
self.model_args.n_layers
if hasattr(self.model_args, "n_layers")
else 0
),
optimizer=self.optimizers,
fragment_fn=(
self.train_spec.fragment_fn
if hasattr(self.train_spec, "fragment_fn")
else None
),
),
):
data_iterator = self.batch_generator(self.dataloader)
while self.should_continue_training():
self.step += 1
self.gc_handler.run(self.step)
try:
self.train_step(data_iterator)
except DataloaderExhaustedError:
logger.warning("Ran out of data; last step was canceled.")
break
self.checkpointer.save(
self.step, last_step=(self.step == job_config.training.steps)
)
# Run validation if validator is available
if (
self.job_config.validation.enable
and self.validator.should_validate(self.step)
):
with self.loss_fn.no_rescale():
self.validator.validate(self.model_parts, self.step)
# signal the profiler that the next profiling step has started
if torch_profiler:
torch_profiler.step()
if memory_profiler:
memory_profiler.step()
# reduce timeout after first train step for faster signal
# (assuming lazy init and compilation are finished)
if self.step == 1:
dist_utils.set_pg_timeouts(
timeout=timedelta(
seconds=job_config.comm.train_timeout_seconds
),
world_mesh=self.parallel_dims.world_mesh,
)
if torch.distributed.get_rank() == 0:
logger.info("Sleeping 2 seconds for other ranks to complete")
time.sleep(2)
logger.info("Training completed")
def should_continue_training(self) -> bool:
return self.step < self.job_config.training.steps
def state_dict(self) -> dict[str, Any]:
return {"step": self.step, "ntokens_seen": self.ntokens_seen}
def load_state_dict(self, state_dict: dict[str, Any]):
self.step = state_dict["step"]
self.ntokens_seen = state_dict["ntokens_seen"]
def close(self) -> None:
if self.checkpointer:
self.checkpointer.close()
if self.metrics_processor:
self.metrics_processor.close()
if __name__ == "__main__":
init_logger()
config_manager = ConfigManager()
config = config_manager.parse_args()
trainer: Optional[Trainer] = None
try:
trainer = Trainer(config)
if config.checkpoint.create_seed_checkpoint:
assert (
int(os.environ["WORLD_SIZE"]) == 1
), "Must create seed checkpoint using a single device, to disable sharding."
assert (
config.checkpoint.enable
), "Must enable checkpointing when creating a seed checkpoint."
trainer.checkpointer.save(curr_step=0, last_step=True)
logger.info("Created seed checkpoint")
else:
trainer.train()
except Exception:
if trainer:
trainer.close()
raise
else:
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed")