Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions example/common/utils.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
#include "example/common/utils.h"

#include <algorithm>
#include <chrono>

#include "gflags/gflags.h"
#include "gflags/gflags_declare.h"
#include "glog/logging.h"
#include "infini_train/include/nn/parallel/global.h"

namespace infini_train {

float ConvertBF16ToFloat(void *ptr) {
Expand Down Expand Up @@ -61,4 +69,91 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s
ifs.seekg(base + std::streamoff(len * sizeof(float)));
}

ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args) {
ResumeFromCheckpointResult result;
int ddp_world_size = nn::parallel::global::GetDataParallelSize();

if (args.resume_root.empty()) {
LOG(INFO) << "No checkpoint specified for resume. Starting training from scratch.";
return result;
}

std::filesystem::path resume_dir = args.resume_root;
if (args.rank.IsParallel()) {
const auto rank_dir = resume_dir / std::format("rank_{:06d}", args.rank.GlobalRank());
if (std::filesystem::exists(rank_dir)) {
resume_dir = rank_dir;
}
}

Checkpoint::Load(resume_dir, args.model.get(), args.optimizer.get(), &args.state, args.load_options);

result.global_step = static_cast<int>(args.state.global_step);
result.best_loss = args.state.best_loss;
if (args.state.data_batch_stride != static_cast<int64_t>(ddp_world_size) && args.rank.IsMainRank()) {
LOG(WARNING) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. "
"Proceeding with recorded data_batch_idx {}.",
args.state.data_batch_stride, ddp_world_size, args.state.data_batch_idx);
}
result.data_batch_idx = static_cast<size_t>(std::max<int64_t>(args.state.data_batch_idx, 0));
args.train_iter = args.train_loader.IteratorAtBatchIndex(result.data_batch_idx);
if (args.rank.IsMainRank()) {
LOG(INFO) << std::format(
"Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}",
args.state.global_step, args.state.best_loss, args.state.last_lr, args.state.data_batch_idx);
}

return result;
}

void SaveCheckpoint(const SaveCheckpointArgs &args) {
const auto ckpt_start = std::chrono::high_resolution_clock::now();

TrainerState state;
state.global_step = args.global_step;
state.data_batch_idx = static_cast<int64_t>(args.data_batch_idx);
state.data_batch_stride = args.ddp_size;
state.best_loss = args.best_loss;
state.last_lr = args.last_lr;
state.optimizer_type = args.optimizer_type;
state.checkpoint_format = args.checkpoint_format;
state.ddp_size = args.ddp_size;
state.tp_size = args.tp_size;
state.sp_size = args.sp_size;
state.pp_size = args.pp_size;

CheckpointOptions options;
options.format = args.checkpoint_format;
options.save_optimizer_state = args.save_optimizer_state;
options.model_bin_writer = args.model_bin_writer;
Checkpoint::Save(args.save_dir, args.model, args.optimizer, state, options);

const auto ckpt_end = std::chrono::high_resolution_clock::now();
const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count();

if (!args.rank.IsMainRank()) {
return;
}

LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", args.save_dir.string(), ckpt_ms);

if (!args.prune_step_checkpoints) {
return;
}

std::vector<std::filesystem::path> ckpts;
if (std::filesystem::exists(args.checkpoint_root_dir)) {
for (const auto &entry : std::filesystem::directory_iterator(args.checkpoint_root_dir)) {
if (entry.is_directory() && entry.path().filename().string().starts_with("checkpoint_step_")) {
ckpts.push_back(entry.path());
}
}
std::sort(ckpts.begin(), ckpts.end());
while (ckpts.size() > args.max_checkpoint_keep) {
std::filesystem::remove_all(ckpts.front());
ckpts.erase(ckpts.begin());
}
}
}

} // namespace infini_train
55 changes: 55 additions & 0 deletions example/common/utils.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
#pragma once

#include "infini_train/include/checkpoint.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/rank.h"
#include "infini_train/include/optimizer.h"

#include "gflags/gflags.h"

#include <cstdint>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <functional>
#include <limits>
#include <string>
#include <vector>

namespace infini_train {
Expand Down Expand Up @@ -30,4 +42,47 @@ void ReadVectorAllFloat(std::ifstream &ifs, float *dst, int64_t len);

void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t start, int64_t cnt);

struct ResumeFromCheckpointArgs {
fLS::clstring resume_root;
const nn::parallel::Rank &rank;
std::shared_ptr<nn::Module> model;
std::shared_ptr<Optimizer> optimizer;
DistributedDataLoader &train_loader;
TrainerState &state;
DataLoaderIterator &train_iter;
CheckpointLoadOptions load_options;
};

struct ResumeFromCheckpointResult {
int global_step = 0;
float best_loss = std::numeric_limits<float>::infinity();
size_t data_batch_idx = 0;
};

struct SaveCheckpointArgs {
std::filesystem::path save_dir;
int64_t global_step = 0;
size_t data_batch_idx = 0;
float best_loss = std::numeric_limits<float>::infinity();
double last_lr = 0.0;
std::string optimizer_type;
std::string checkpoint_format = "bin";
int ddp_size = 1;
int tp_size = 1;
int sp_size = 1;
int pp_size = 1;
bool save_optimizer_state = true;
bool prune_step_checkpoints = false;
std::filesystem::path checkpoint_root_dir;
size_t max_checkpoint_keep = 0;
const nn::parallel::Rank &rank;
const nn::Module &model;
const Optimizer &optimizer;
std::function<void(const nn::Module &, const std::filesystem::path &)> model_bin_writer;
};

ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args);

void SaveCheckpoint(const SaveCheckpointArgs &args);

} // namespace infini_train
86 changes: 83 additions & 3 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
#include <algorithm>
#include <chrono>
#include <cstdlib>
#include <filesystem>
#include <format>
#include <limits>
#include <memory>
#include <optional>
#include <unordered_map>
#include <unordered_set>

#include "example/common/utils.h"
#include "gflags/gflags.h"
#include "glog/logging.h"

#include "infini_train/include/autocast.h"
#include "infini_train/include/checkpoint.h"
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
Expand Down Expand Up @@ -75,6 +80,12 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");

// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
DEFINE_uint32(save_steps, 0, "save checkpoint every N steps; 0 disables saving");
DEFINE_string(resume_from, "", "checkpoint directory to resume from");
DEFINE_string(checkpoint_dir, "./checkpoints", "root directory used to store checkpoints");
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints");
DEFINE_string(checkpoint_format, "bin", "checkpoint format: bin|pth");
// precision check
DEFINE_string(
precision_check, "",
Expand Down Expand Up @@ -198,6 +209,8 @@ void Train(const nn::parallel::Rank &rank) {
} else {
model = GPT2::FromPretrained(kStrToModelType.at(FLAGS_model));
}
auto llmc_model = std::dynamic_pointer_cast<GPT2>(model);
CHECK(llmc_model != nullptr) << "Failed to cast model to GPT2 for LLMC checkpoint I/O.";

model->To(device);

Expand Down Expand Up @@ -311,6 +324,7 @@ void Train(const nn::parallel::Rank &rank) {
}

auto train_iter = train_loader.begin();
size_t saved_data_batch_idx = train_iter.BatchIndex();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
std::make_shared<VocabParallelCrossEntropyLoss>(model_config.original_vocab_size))
Expand All @@ -320,9 +334,57 @@ void Train(const nn::parallel::Rank &rank) {

auto impl = core::GetDeviceGuardImpl(device.type());

LOG(INFO) << "start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
int start_step = 0;
float best_loss = std::numeric_limits<float>::infinity();
TrainerState state;
CheckpointLoadOptions load_options;
load_options.load_optimizer_state = true;
load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
auto loaded_model = GPT2::FromLLMC(model_path.string());
target_model->LoadStateDict(loaded_model->StateDict());
};
const auto resume_result = infini_train::ResumeFromCheckpoint({
.resume_root = FLAGS_resume_from,
.rank = rank,
.model = model,
.optimizer = optimizer,
.train_loader = train_loader,
.state = state,
.train_iter = train_iter,
.load_options = load_options,
});
start_step = resume_result.global_step;
best_loss = resume_result.best_loss;
saved_data_batch_idx = resume_result.data_batch_idx;

auto save_checkpoint
= [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
infini_train::SaveCheckpoint({
.save_dir = save_dir,
.global_step = global_step,
.data_batch_idx = saved_data_batch_idx,
.best_loss = best_loss,
.last_lr = FLAGS_learning_rate,
.optimizer_type = "SGD",
.checkpoint_format = FLAGS_checkpoint_format,
.ddp_size = ddp_world_size,
.tp_size = tp_world_size,
.sp_size = sp_world_size,
.pp_size = pp_world_size,
.save_optimizer_state = FLAGS_save_optimizer_state,
.prune_step_checkpoints = prune_step_checkpoints,
.checkpoint_root_dir = FLAGS_checkpoint_dir,
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
.rank = rank,
.model = *model,
.optimizer = *optimizer,
.model_bin_writer
= [&](const nn::Module &,
const std::filesystem::path &model_path) { llmc_model->SaveAsLLMC(model_path.string()); },
});
};

for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) {
// Reset precision check counters at start of each iteration for file overwrite
utils::PrecisionChecker::ResetCounters();

Expand Down Expand Up @@ -372,6 +434,7 @@ void Train(const nn::parallel::Rank &rank) {
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
// TODO(dcj): support dataloader.reset() later
++train_iter;
saved_data_batch_idx = train_iter.BatchIndex();
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));

Expand Down Expand Up @@ -401,6 +464,7 @@ void Train(const nn::parallel::Rank &rank) {
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
// TODO(dcj): support dataloader.reset() later
++train_iter;
saved_data_batch_idx = train_iter.BatchIndex();
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));

Expand All @@ -413,6 +477,8 @@ void Train(const nn::parallel::Rank &rank) {
lossf = static_cast<const float *>(lossf_tensor->To(Device()).DataPtr())[0];
}

best_loss = std::min(best_loss, lossf);

const auto iter_end = std::chrono::high_resolution_clock::now();
const double duration_us = std::chrono::duration<double, std::micro>(iter_end - iter_start).count();
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);
Expand All @@ -435,8 +501,22 @@ void Train(const nn::parallel::Rank &rank) {
}
}
}

if (FLAGS_save_steps > 0 && (step + 1) % FLAGS_save_steps == 0) {
std::filesystem::path step_dir
= std::filesystem::path(FLAGS_checkpoint_dir) / std::format("checkpoint_step_{:06d}", step + 1);
if (rank.IsParallel()) {
step_dir /= std::format("rank_{:06d}", rank.GlobalRank());
}
save_checkpoint(step_dir, step + 1, true);
}
}

std::filesystem::path final_dir = std::filesystem::path(FLAGS_checkpoint_dir) / "checkpoint_final";
if (rank.IsParallel()) {
final_dir /= std::format("rank_{:06d}", rank.GlobalRank());
}
save_checkpoint(final_dir, FLAGS_num_iteration, false);
// Save LoRA weights if enabled and path specified
if (lora_enabled && !FLAGS_lora_save_path.empty()) {
LOG(INFO) << "Saving LoRA weights to: " << FLAGS_lora_save_path;
Expand Down
Loading
Loading