|
1 | 1 | #include "example/common/utils.h" |
2 | 2 |
|
| 3 | +#include "gflags/gflags.h" |
| 4 | +#include "gflags/gflags_declare.h" |
| 5 | +#include "glog/logging.h" |
| 6 | +#include "infini_train/include/nn/parallel/global.h" |
| 7 | + |
3 | 8 | namespace infini_train { |
4 | 9 |
|
5 | 10 | float ConvertBF16ToFloat(void *ptr) { |
@@ -61,4 +66,53 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s |
61 | 66 | ifs.seekg(base + std::streamoff(len * sizeof(float))); |
62 | 67 | } |
63 | 68 |
|
| 69 | +std::tuple<int, float, size_t> ResumeFromCheckpoint( |
| 70 | + const fLS::clstring &flag_resume_root, // resume from this checkpoint directory |
| 71 | + const nn::parallel::Rank &rank, // rank info for distributed training |
| 72 | + std::shared_ptr<nn::Module> model, // model to be loaded with checkpoint state |
| 73 | + std::shared_ptr<Optimizer> optimizer, // some optimizer may not have state, but others may have |
| 74 | + DistributedDataLoader &train_loader, // distributed dataloader to be resumed |
| 75 | + TrainerState &state, // trainer state to be loaded from checkpoint |
| 76 | + DataLoaderIterator |
| 77 | + &train_iter, // dataloader iterator to be set to the correct position according to checkpoint state |
| 78 | + CheckpointLoadOptions model_bin_loader) { |
| 79 | + int global_step = 0; |
| 80 | + float best_loss = std::numeric_limits<float>::infinity(); |
| 81 | + size_t data_batch_idx = 0; |
| 82 | + |
| 83 | + int ddp_world_size = nn::parallel::global::GetDataParallelSize(); |
| 84 | + |
| 85 | + if (flag_resume_root.empty()) { |
| 86 | + LOG(INFO) << "No checkpoint specified for resume. Starting training from scratch."; |
| 87 | + return {global_step, best_loss, data_batch_idx}; |
| 88 | + } |
| 89 | + |
| 90 | + std::filesystem::path resume_dir = flag_resume_root; |
| 91 | + if (rank.IsParallel()) { |
| 92 | + const auto rank_dir = resume_dir / std::format("rank_{:06d}", rank.GlobalRank()); |
| 93 | + if (std::filesystem::exists(rank_dir)) { |
| 94 | + resume_dir = rank_dir; |
| 95 | + } |
| 96 | + } |
| 97 | + |
| 98 | + Checkpoint::Load(resume_dir, model.get(), optimizer.get(), &state, model_bin_loader); |
| 99 | + |
| 100 | + global_step = static_cast<int>(state.global_step); |
| 101 | + best_loss = state.best_loss; |
| 102 | + if (state.data_batch_stride != static_cast<int64_t>(ddp_world_size) && rank.IsMainRank()) { |
| 103 | + LOG(WARNING) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. " |
| 104 | + "Proceeding with recorded data_batch_idx {}.", |
| 105 | + state.data_batch_stride, ddp_world_size, state.data_batch_idx); |
| 106 | + } |
| 107 | + data_batch_idx = static_cast<size_t>(std::max<int64_t>(state.data_batch_idx, 0)); |
| 108 | + train_iter = train_loader.IteratorAtBatchIndex(data_batch_idx); |
| 109 | + if (rank.IsMainRank()) { |
| 110 | + LOG(INFO) << std::format( |
| 111 | + "Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}", state.global_step, |
| 112 | + state.best_loss, state.last_lr, state.data_batch_idx); |
| 113 | + } |
| 114 | + |
| 115 | + return {global_step, best_loss, data_batch_idx}; |
| 116 | +} |
| 117 | + |
64 | 118 | } // namespace infini_train |
0 commit comments