-
Notifications
You must be signed in to change notification settings - Fork 43
【训练营】Checkpoint 读取工具 #129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
【训练营】Checkpoint 读取工具 #129
Changes from 3 commits
146bd1d
3b13af4
4b248b6
08ed56b
97bd747
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,9 @@ | ||
| #include <algorithm> | ||
| #include <chrono> | ||
| #include <cstdlib> | ||
| #include <filesystem> | ||
| #include <format> | ||
| #include <limits> | ||
| #include <memory> | ||
| #include <optional> | ||
| #include <unordered_map> | ||
|
|
@@ -10,6 +13,7 @@ | |
| #include "glog/logging.h" | ||
|
|
||
| #include "infini_train/include/autocast.h" | ||
| #include "infini_train/include/checkpoint.h" | ||
| #include "infini_train/include/core/runtime/device_guard.h" | ||
| #include "infini_train/include/dataloader.h" | ||
| #include "infini_train/include/device.h" | ||
|
|
@@ -75,6 +79,14 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); | |
|
|
||
| // precision | ||
| DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); | ||
| DEFINE_uint32(save_steps, 0, "save checkpoint every N steps; 0 disables saving"); | ||
| DEFINE_string(resume_from, "", "checkpoint directory to resume from"); | ||
| DEFINE_string(checkpoint_dir, "./checkpoints", "root directory used to store checkpoints"); | ||
| DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep"); | ||
| DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints"); | ||
| DEFINE_string(checkpoint_format, "bin", "checkpoint format: bin|pth"); | ||
| DEFINE_bool(use_llmc_checkpoint_io, false, | ||
| "whether to use GPT2 LLMC model.bin callback for checkpoint save/load when format=bin"); | ||
| // precision check | ||
| DEFINE_string( | ||
| precision_check, "", | ||
|
|
@@ -198,6 +210,8 @@ void Train(const nn::parallel::Rank &rank) { | |
| } else { | ||
| model = GPT2::FromPretrained(kStrToModelType.at(FLAGS_model)); | ||
| } | ||
| auto llmc_model = std::dynamic_pointer_cast<GPT2>(model); | ||
| CHECK(llmc_model != nullptr) << "Failed to cast model to GPT2 for LLMC checkpoint I/O."; | ||
|
|
||
| model->To(device); | ||
|
|
||
|
|
@@ -311,6 +325,7 @@ void Train(const nn::parallel::Rank &rank) { | |
| } | ||
|
|
||
| auto train_iter = train_loader.begin(); | ||
| size_t saved_data_batch_idx = train_iter.BatchIndex(); | ||
| std::shared_ptr<nn::Module> loss_fn | ||
| = (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>( | ||
| std::make_shared<VocabParallelCrossEntropyLoss>(model_config.original_vocab_size)) | ||
|
|
@@ -320,9 +335,100 @@ void Train(const nn::parallel::Rank &rank) { | |
|
|
||
| auto impl = core::GetDeviceGuardImpl(device.type()); | ||
|
|
||
| int start_step = 0; | ||
| float best_loss = std::numeric_limits<float>::infinity(); | ||
| if (!FLAGS_resume_from.empty()) { | ||
| std::filesystem::path resume_dir = FLAGS_resume_from; | ||
| if (rank.IsParallel()) { | ||
| const auto rank_dir = resume_dir / std::format("rank_{:06d}", rank.GlobalRank()); | ||
| if (std::filesystem::exists(rank_dir)) { | ||
| resume_dir = rank_dir; | ||
| } | ||
| } | ||
|
|
||
| TrainerState state; | ||
| CheckpointLoadOptions load_options; | ||
| load_options.load_optimizer_state = true; | ||
| if (FLAGS_use_llmc_checkpoint_io) { | ||
| load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) { | ||
| auto loaded_model = GPT2::FromLLMC(model_path.string()); | ||
| target_model->LoadStateDict(loaded_model->StateDict()); | ||
| }; | ||
| } | ||
| Checkpoint::Load(resume_dir, model.get(), optimizer.get(), &state, load_options); | ||
| start_step = static_cast<int>(state.global_step); | ||
| best_loss = state.best_loss; | ||
| if (state.data_batch_stride != static_cast<int64_t>(ddp_world_size) && rank.IsMainRank()) { | ||
| LOG(WARNING) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. " | ||
| "Proceeding with recorded data_batch_idx {}.", | ||
| state.data_batch_stride, ddp_world_size, state.data_batch_idx); | ||
| } | ||
| saved_data_batch_idx = static_cast<size_t>(std::max<int64_t>(state.data_batch_idx, 0)); | ||
| train_iter = train_loader.IteratorAtBatchIndex(saved_data_batch_idx); | ||
| if (rank.IsMainRank()) { | ||
| LOG(INFO) << std::format( | ||
| "Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}", | ||
| state.global_step, state.best_loss, state.last_lr, state.data_batch_idx); | ||
| LOG(INFO) << std::format("Checkpoint model I/O mode during resume: {}", | ||
| FLAGS_use_llmc_checkpoint_io ? "llmc-callback" : "native-state-dict"); | ||
| } | ||
| } | ||
|
|
||
| LOG(INFO) << "start training"; | ||
|
|
||
| for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { | ||
| auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个表达式内部也可以提一个函数(类似 SaveCheckpoint)到utils.cc,内部只构造一个参数的struct,然后调用SaveCheckpoint
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 将 save_checkpoint 的逻辑提取为 |
||
| bool prune_step_checkpoints) { | ||
| const auto ckpt_start = std::chrono::high_resolution_clock::now(); | ||
|
|
||
| TrainerState state; | ||
| state.global_step = global_step; | ||
| state.data_batch_idx = saved_data_batch_idx; | ||
| state.data_batch_stride = ddp_world_size; | ||
| state.best_loss = best_loss; | ||
| state.last_lr = FLAGS_learning_rate; | ||
| state.optimizer_type = "SGD"; | ||
| state.checkpoint_format = FLAGS_checkpoint_format; | ||
| state.ddp_size = ddp_world_size; | ||
| state.tp_size = tp_world_size; | ||
| state.sp_size = sp_world_size; | ||
| state.pp_size = pp_world_size; | ||
|
|
||
| CheckpointOptions options; | ||
| options.format = FLAGS_checkpoint_format; | ||
| options.save_optimizer_state = FLAGS_save_optimizer_state; | ||
| if (FLAGS_use_llmc_checkpoint_io) { | ||
| options.model_bin_writer = [&](const nn::Module &, const std::filesystem::path &model_path) { | ||
| llmc_model->SaveAsLLMC(model_path.string()); | ||
| }; | ||
| } | ||
| Checkpoint::Save(save_dir, *model, *optimizer, state, options); | ||
|
|
||
| const auto ckpt_end = std::chrono::high_resolution_clock::now(); | ||
| const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count(); | ||
|
|
||
| if (rank.IsMainRank()) { | ||
| LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", save_dir.string(), ckpt_ms); | ||
|
|
||
| if (prune_step_checkpoints) { | ||
| std::vector<std::filesystem::path> ckpts; | ||
| const auto root = std::filesystem::path(FLAGS_checkpoint_dir); | ||
| if (std::filesystem::exists(root)) { | ||
| for (const auto &entry : std::filesystem::directory_iterator(root)) { | ||
| if (entry.is_directory() && entry.path().filename().string().starts_with("checkpoint_step_")) { | ||
| ckpts.push_back(entry.path()); | ||
| } | ||
| } | ||
| std::sort(ckpts.begin(), ckpts.end()); | ||
| while (ckpts.size() > FLAGS_max_checkpoint_keep) { | ||
| std::filesystem::remove_all(ckpts.front()); | ||
| ckpts.erase(ckpts.begin()); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) { | ||
| // Reset precision check counters at start of each iteration for file overwrite | ||
| utils::PrecisionChecker::ResetCounters(); | ||
|
|
||
|
|
@@ -372,6 +478,7 @@ void Train(const nn::parallel::Rank &rank) { | |
| // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below | ||
| // TODO(dcj): support dataloader.reset() later | ||
| ++train_iter; | ||
| saved_data_batch_idx = train_iter.BatchIndex(); | ||
| x = std::make_shared<Tensor>(x->To(device)); | ||
| y = std::make_shared<Tensor>(y->To(device)); | ||
|
|
||
|
|
@@ -401,6 +508,7 @@ void Train(const nn::parallel::Rank &rank) { | |
| // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below | ||
| // TODO(dcj): support dataloader.reset() later | ||
| ++train_iter; | ||
| saved_data_batch_idx = train_iter.BatchIndex(); | ||
| x = std::make_shared<Tensor>(x->To(device)); | ||
| y = std::make_shared<Tensor>(y->To(device)); | ||
|
|
||
|
|
@@ -413,6 +521,8 @@ void Train(const nn::parallel::Rank &rank) { | |
| lossf = static_cast<const float *>(lossf_tensor->To(Device()).DataPtr())[0]; | ||
| } | ||
|
|
||
| best_loss = std::min(best_loss, lossf); | ||
|
|
||
| const auto iter_end = std::chrono::high_resolution_clock::now(); | ||
| const double duration_us = std::chrono::duration<double, std::micro>(iter_end - iter_start).count(); | ||
| const double tps = FLAGS_total_batch_size / (duration_us / 1e6); | ||
|
|
@@ -435,8 +545,22 @@ void Train(const nn::parallel::Rank &rank) { | |
| } | ||
| } | ||
| } | ||
|
|
||
| if (FLAGS_save_steps > 0 && (step + 1) % FLAGS_save_steps == 0) { | ||
| std::filesystem::path step_dir | ||
| = std::filesystem::path(FLAGS_checkpoint_dir) / std::format("checkpoint_step_{:06d}", step + 1); | ||
| if (rank.IsParallel()) { | ||
| step_dir /= std::format("rank_{:06d}", rank.GlobalRank()); | ||
| } | ||
| save_checkpoint(step_dir, step + 1, true); | ||
| } | ||
| } | ||
|
|
||
| std::filesystem::path final_dir = std::filesystem::path(FLAGS_checkpoint_dir) / "checkpoint_final"; | ||
| if (rank.IsParallel()) { | ||
| final_dir /= std::format("rank_{:06d}", rank.GlobalRank()); | ||
| } | ||
| save_checkpoint(final_dir, FLAGS_num_iteration, false); | ||
| // Save LoRA weights if enabled and path specified | ||
| if (lora_enabled && !FLAGS_lora_save_path.empty()) { | ||
| LOG(INFO) << "Saving LoRA weights to: " << FLAGS_lora_save_path; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议把主流程中恢复、保存、清理旧的Checkpoint提成公共函数,尽量让主流程简洁,另外各个训练入口可以复用。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参数如果太多可以用struct整合在一起
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
将
main.cc中的恢复过程提取为infini_train::ResumeFromCheckpoint(),并通过std::tie()获取 start_step 等信息,通过引用传递实现参数恢复。使用 llama3 进行简单测试,loss 可以复现.