Skip to content

Commit 146bd1d

Browse files
committed
feat: checkpoint save & load
1 parent d278062 commit 146bd1d

78 files changed

Lines changed: 2554 additions & 840 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ endif()
4848
# Framework core sources (*.cc), excluding cpu kernels (they are built separately)
4949
file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc)
5050
list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*")
51+
if(NOT USE_NCCL)
52+
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/cuda/.*")
53+
endif()
5154

5255
# CPU kernels (*.cc)
5356
file(GLOB_RECURSE CPU_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/kernels/cpu/*.cc)

example/gpt2/main.cc

Lines changed: 137 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
#include <chrono>
22
#include <cstdlib>
3+
#include <filesystem>
34
#include <format>
5+
#include <limits>
46
#include <memory>
57
#include <optional>
8+
#include <algorithm>
69
#include <unordered_map>
710
#include <unordered_set>
811

912
#include "gflags/gflags.h"
1013
#include "glog/logging.h"
1114

1215
#include "infini_train/include/autocast.h"
13-
#include "infini_train/include/core/device_guard.h"
16+
#include "infini_train/include/core/runtime/device_guard.h"
17+
#include "infini_train/include/checkpoint.h"
1418
#include "infini_train/include/dataloader.h"
1519
#include "infini_train/include/device.h"
1620
#include "infini_train/include/nn/modules/loss.h"
@@ -74,6 +78,14 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
7478

7579
// precision
7680
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
81+
DEFINE_uint32(save_steps, 0, "save checkpoint every N steps; 0 disables saving");
82+
DEFINE_string(resume_from, "", "checkpoint directory to resume from");
83+
DEFINE_string(checkpoint_dir, "./checkpoints", "root directory used to store checkpoints");
84+
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
85+
DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints");
86+
DEFINE_string(checkpoint_format, "bin", "checkpoint format: bin|pth");
87+
DEFINE_bool(use_llmc_checkpoint_io, false,
88+
"whether to use GPT2 LLMC model.bin callback for checkpoint save/load when format=bin");
7789
// precision check
7890
DEFINE_string(
7991
precision_check, "",
@@ -140,24 +152,25 @@ void Train(const nn::parallel::Rank &rank) {
140152

141153
if (rank.IsParallel()) {
142154
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());
155+
auto *pg_factory = ProcessGroupFactory::Instance(device.type());
143156

144157
if (ddp_world_size > 1) {
145-
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
146-
GetDataParallelGroupRanks(rank.GlobalRank()));
158+
ddp_pg = pg_factory->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
159+
GetDataParallelGroupRanks(rank.GlobalRank()));
147160
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
148161
}
149162

150163
if (tp_world_size > 1) {
151-
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
152-
GetTensorParallelGroupRanks(rank.GlobalRank()));
164+
tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
165+
GetTensorParallelGroupRanks(rank.GlobalRank()));
153166
tp_rank = tp_pg->GetGroupRank(rank.GlobalRank());
154167
// NOTE(zbl): Reserved for VocabParallelEmbedding
155168
nn::parallel::tp_rank = tp_rank;
156169
}
157170

158171
if (pp_world_size > 1) {
159-
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
160-
GetPipelineParallelGroupRanks(rank.GlobalRank()));
172+
pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
173+
GetPipelineParallelGroupRanks(rank.GlobalRank()));
161174
pp_rank = pp_pg->GetGroupRank(rank.GlobalRank());
162175

163176
nn::parallel::pp_rank = pp_rank;
@@ -187,6 +200,8 @@ void Train(const nn::parallel::Rank &rank) {
187200
} else {
188201
model = GPT2::FromPretrained(kStrToModelType.at(FLAGS_model));
189202
}
203+
auto llmc_model = std::dynamic_pointer_cast<GPT2>(model);
204+
CHECK(llmc_model != nullptr) << "Failed to cast model to GPT2 for LLMC checkpoint I/O.";
190205

191206
model->To(device);
192207

@@ -219,8 +234,8 @@ void Train(const nn::parallel::Rank &rank) {
219234
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
220235
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
221236
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
222-
(*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id),
223-
rank.thread_rank(), ddp_config);
237+
(*mutable_chunks)[chunk_id]
238+
= std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id), rank, ddp_config);
224239
}
225240
}
226241
} else if (ddp_world_size > 1) {
@@ -229,7 +244,7 @@ void Train(const nn::parallel::Rank &rank) {
229244
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
230245
// are created during the conversion.
231246
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
232-
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
247+
model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
233248
}
234249

235250
DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
@@ -268,6 +283,7 @@ void Train(const nn::parallel::Rank &rank) {
268283
}
269284

270285
auto train_iter = train_loader.begin();
286+
size_t saved_data_batch_idx = train_iter.BatchIndex();
271287
std::shared_ptr<nn::Module> loss_fn
272288
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
273289
std::make_shared<VocabParallelCrossEntropyLoss>(model_config.original_vocab_size))
@@ -277,9 +293,100 @@ void Train(const nn::parallel::Rank &rank) {
277293

278294
auto impl = core::GetDeviceGuardImpl(device.type());
279295

296+
int start_step = 0;
297+
float best_loss = std::numeric_limits<float>::infinity();
298+
if (!FLAGS_resume_from.empty()) {
299+
std::filesystem::path resume_dir = FLAGS_resume_from;
300+
if (rank.IsParallel()) {
301+
const auto rank_dir = resume_dir / std::format("rank_{:06d}", rank.GlobalRank());
302+
if (std::filesystem::exists(rank_dir)) {
303+
resume_dir = rank_dir;
304+
}
305+
}
306+
307+
TrainerState state;
308+
CheckpointLoadOptions load_options;
309+
load_options.load_optimizer_state = true;
310+
if (FLAGS_use_llmc_checkpoint_io) {
311+
load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
312+
auto loaded_model = GPT2::FromLLMC(model_path.string());
313+
target_model->LoadStateDict(loaded_model->StateDict());
314+
};
315+
}
316+
Checkpoint::Load(resume_dir, model.get(), optimizer.get(), &state, load_options);
317+
start_step = static_cast<int>(state.global_step);
318+
best_loss = state.best_loss;
319+
if (state.data_batch_stride != static_cast<int64_t>(ddp_world_size) && rank.IsMainRank()) {
320+
LOG(WARNING) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. "
321+
"Proceeding with recorded data_batch_idx {}.",
322+
state.data_batch_stride, ddp_world_size, state.data_batch_idx);
323+
}
324+
saved_data_batch_idx = static_cast<size_t>(std::max<int64_t>(state.data_batch_idx, 0));
325+
train_iter = train_loader.IteratorAtBatchIndex(saved_data_batch_idx);
326+
if (rank.IsMainRank()) {
327+
LOG(INFO) << std::format(
328+
"Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}",
329+
state.global_step, state.best_loss, state.last_lr, state.data_batch_idx);
330+
LOG(INFO) << std::format("Checkpoint model I/O mode during resume: {}",
331+
FLAGS_use_llmc_checkpoint_io ? "llmc-callback" : "native-state-dict");
332+
}
333+
}
334+
280335
LOG(INFO) << "start training";
281336

282-
for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
337+
auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
338+
const auto ckpt_start = std::chrono::high_resolution_clock::now();
339+
340+
TrainerState state;
341+
state.global_step = global_step;
342+
state.data_batch_idx = saved_data_batch_idx;
343+
state.data_batch_stride = ddp_world_size;
344+
state.best_loss = best_loss;
345+
state.last_lr = FLAGS_learning_rate;
346+
state.optimizer_type = "SGD";
347+
state.checkpoint_format = FLAGS_checkpoint_format;
348+
state.ddp_size = ddp_world_size;
349+
state.tp_size = tp_world_size;
350+
state.sp_size = sp_world_size;
351+
state.pp_size = pp_world_size;
352+
353+
CheckpointOptions options;
354+
options.format = FLAGS_checkpoint_format;
355+
options.save_optimizer_state = FLAGS_save_optimizer_state;
356+
if (FLAGS_use_llmc_checkpoint_io) {
357+
options.model_bin_writer = [&](const nn::Module &, const std::filesystem::path &model_path) {
358+
llmc_model->SaveAsLLMC(model_path.string());
359+
};
360+
}
361+
Checkpoint::Save(save_dir, *model, *optimizer, state, options);
362+
363+
const auto ckpt_end = std::chrono::high_resolution_clock::now();
364+
const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count();
365+
366+
if (rank.IsMainRank()) {
367+
LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", save_dir.string(), ckpt_ms);
368+
369+
if (prune_step_checkpoints) {
370+
std::vector<std::filesystem::path> ckpts;
371+
const auto root = std::filesystem::path(FLAGS_checkpoint_dir);
372+
if (std::filesystem::exists(root)) {
373+
for (const auto &entry : std::filesystem::directory_iterator(root)) {
374+
if (entry.is_directory()
375+
&& entry.path().filename().string().starts_with("checkpoint_step_")) {
376+
ckpts.push_back(entry.path());
377+
}
378+
}
379+
std::sort(ckpts.begin(), ckpts.end());
380+
while (ckpts.size() > FLAGS_max_checkpoint_keep) {
381+
std::filesystem::remove_all(ckpts.front());
382+
ckpts.erase(ckpts.begin());
383+
}
384+
}
385+
}
386+
}
387+
};
388+
389+
for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) {
283390
// Reset precision check counters at start of each iteration for file overwrite
284391
utils::PrecisionChecker::ResetCounters();
285392

@@ -329,6 +436,7 @@ void Train(const nn::parallel::Rank &rank) {
329436
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
330437
// TODO(dcj): support dataloader.reset() later
331438
++train_iter;
439+
saved_data_batch_idx = train_iter.BatchIndex();
332440
x = std::make_shared<Tensor>(x->To(device));
333441
y = std::make_shared<Tensor>(y->To(device));
334442

@@ -358,6 +466,7 @@ void Train(const nn::parallel::Rank &rank) {
358466
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
359467
// TODO(dcj): support dataloader.reset() later
360468
++train_iter;
469+
saved_data_batch_idx = train_iter.BatchIndex();
361470
x = std::make_shared<Tensor>(x->To(device));
362471
y = std::make_shared<Tensor>(y->To(device));
363472

@@ -370,6 +479,8 @@ void Train(const nn::parallel::Rank &rank) {
370479
lossf = static_cast<const float *>(lossf_tensor->To(Device()).DataPtr())[0];
371480
}
372481

482+
best_loss = std::min(best_loss, lossf);
483+
373484
const auto iter_end = std::chrono::high_resolution_clock::now();
374485
const double duration_us = std::chrono::duration<double, std::micro>(iter_end - iter_start).count();
375486
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);
@@ -392,7 +503,22 @@ void Train(const nn::parallel::Rank &rank) {
392503
}
393504
}
394505
}
506+
507+
if (FLAGS_save_steps > 0 && (step + 1) % FLAGS_save_steps == 0) {
508+
std::filesystem::path step_dir
509+
= std::filesystem::path(FLAGS_checkpoint_dir) / std::format("checkpoint_step_{:06d}", step + 1);
510+
if (rank.IsParallel()) {
511+
step_dir /= std::format("rank_{:06d}", rank.GlobalRank());
512+
}
513+
save_checkpoint(step_dir, step + 1, true);
514+
}
515+
}
516+
517+
std::filesystem::path final_dir = std::filesystem::path(FLAGS_checkpoint_dir) / "checkpoint_final";
518+
if (rank.IsParallel()) {
519+
final_dir /= std::format("rank_{:06d}", rank.GlobalRank());
395520
}
521+
save_checkpoint(final_dir, FLAGS_num_iteration, false);
396522
#ifdef PROFILE_MODE
397523
Profiler::Instance().Report("gpt2.report", Profiler::SortBy::DeviceTimePercentage);
398524
Profiler::Instance().PrintRecords("gpt2.records.log");

0 commit comments

Comments
 (0)