11#include < chrono>
22#include < cstdlib>
3+ #include < filesystem>
34#include < format>
5+ #include < limits>
46#include < memory>
57#include < optional>
8+ #include < algorithm>
69#include < unordered_map>
710#include < unordered_set>
811
912#include " gflags/gflags.h"
1013#include " glog/logging.h"
1114
1215#include " infini_train/include/autocast.h"
13- #include " infini_train/include/core/device_guard.h"
16+ #include " infini_train/include/core/runtime/device_guard.h"
17+ #include " infini_train/include/checkpoint.h"
1418#include " infini_train/include/dataloader.h"
1519#include " infini_train/include/device.h"
1620#include " infini_train/include/nn/modules/loss.h"
@@ -74,6 +78,14 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
7478
7579// precision
7680DEFINE_string (dtype, " float32" , " precision used in training (float32/bfloat16)" );
81+ DEFINE_uint32 (save_steps, 0 , " save checkpoint every N steps; 0 disables saving" );
82+ DEFINE_string (resume_from, " " , " checkpoint directory to resume from" );
83+ DEFINE_string (checkpoint_dir, " ./checkpoints" , " root directory used to store checkpoints" );
84+ DEFINE_uint32 (max_checkpoint_keep, 3 , " max number of checkpoint steps to keep" );
85+ DEFINE_bool (save_optimizer_state, true , " whether optimizer state is persisted in checkpoints" );
86+ DEFINE_string (checkpoint_format, " bin" , " checkpoint format: bin|pth" );
87+ DEFINE_bool (use_llmc_checkpoint_io, false ,
88+ " whether to use GPT2 LLMC model.bin callback for checkpoint save/load when format=bin" );
7789// precision check
7890DEFINE_string (
7991 precision_check, " " ,
@@ -140,24 +152,25 @@ void Train(const nn::parallel::Rank &rank) {
140152
141153 if (rank.IsParallel ()) {
142154 device = Device (Device::DeviceType::kCUDA , rank.thread_rank ());
155+ auto *pg_factory = ProcessGroupFactory::Instance (device.type ());
143156
144157 if (ddp_world_size > 1 ) {
145- ddp_pg = ProcessGroupFactory::Instance () ->GetOrCreate (GetDataParallelProcessGroupName (rank.GlobalRank ()),
146- GetDataParallelGroupRanks (rank.GlobalRank ()));
158+ ddp_pg = pg_factory ->GetOrCreate (GetDataParallelProcessGroupName (rank.GlobalRank ()),
159+ GetDataParallelGroupRanks (rank.GlobalRank ()));
147160 ddp_rank = ddp_pg->GetGroupRank (rank.GlobalRank ());
148161 }
149162
150163 if (tp_world_size > 1 ) {
151- tp_pg = ProcessGroupFactory::Instance () ->GetOrCreate (GetTensorParallelProcessGroupName (rank.GlobalRank ()),
152- GetTensorParallelGroupRanks (rank.GlobalRank ()));
164+ tp_pg = pg_factory ->GetOrCreate (GetTensorParallelProcessGroupName (rank.GlobalRank ()),
165+ GetTensorParallelGroupRanks (rank.GlobalRank ()));
153166 tp_rank = tp_pg->GetGroupRank (rank.GlobalRank ());
154167 // NOTE(zbl): Reserved for VocabParallelEmbedding
155168 nn::parallel::tp_rank = tp_rank;
156169 }
157170
158171 if (pp_world_size > 1 ) {
159- pp_pg = ProcessGroupFactory::Instance () ->GetOrCreate (GetPipelineParallelProcessGroupName (rank.GlobalRank ()),
160- GetPipelineParallelGroupRanks (rank.GlobalRank ()));
172+ pp_pg = pg_factory ->GetOrCreate (GetPipelineParallelProcessGroupName (rank.GlobalRank ()),
173+ GetPipelineParallelGroupRanks (rank.GlobalRank ()));
161174 pp_rank = pp_pg->GetGroupRank (rank.GlobalRank ());
162175
163176 nn::parallel::pp_rank = pp_rank;
@@ -187,6 +200,8 @@ void Train(const nn::parallel::Rank &rank) {
187200 } else {
188201 model = GPT2::FromPretrained (kStrToModelType .at (FLAGS_model));
189202 }
203+ auto llmc_model = std::dynamic_pointer_cast<GPT2>(model);
204+ CHECK (llmc_model != nullptr ) << " Failed to cast model to GPT2 for LLMC checkpoint I/O." ;
190205
191206 model->To (device);
192207
@@ -219,8 +234,8 @@ void Train(const nn::parallel::Rank &rank) {
219234 = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
220235 auto *mutable_chunks = dynamic_cast <nn::parallel::PipelineParallel *>(model.get ())->mutable_chunks ();
221236 for (int chunk_id = 0 ; chunk_id < mutable_chunks->size (); ++chunk_id) {
222- (*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks-> at (chunk_id),
223- rank. thread_rank () , ddp_config);
237+ (*mutable_chunks)[chunk_id]
238+ = std::make_shared<DistributedDataParallel>(mutable_chunks-> at (chunk_id), rank , ddp_config);
224239 }
225240 }
226241 } else if (ddp_world_size > 1 ) {
@@ -229,7 +244,7 @@ void Train(const nn::parallel::Rank &rank) {
229244 // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
230245 // are created during the conversion.
231246 auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
232- model = std::make_shared<DistributedDataParallel>(model, rank. thread_rank () , ddp_config);
247+ model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
233248 }
234249
235250 DistributedDataLoader train_loader (std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
@@ -268,6 +283,7 @@ void Train(const nn::parallel::Rank &rank) {
268283 }
269284
270285 auto train_iter = train_loader.begin ();
286+ size_t saved_data_batch_idx = train_iter.BatchIndex ();
271287 std::shared_ptr<nn::Module> loss_fn
272288 = (tp_world_size > 1 ) ? std::static_pointer_cast<nn::Module>(
273289 std::make_shared<VocabParallelCrossEntropyLoss>(model_config.original_vocab_size ))
@@ -277,9 +293,100 @@ void Train(const nn::parallel::Rank &rank) {
277293
278294 auto impl = core::GetDeviceGuardImpl (device.type ());
279295
296+ int start_step = 0 ;
297+ float best_loss = std::numeric_limits<float >::infinity ();
298+ if (!FLAGS_resume_from.empty ()) {
299+ std::filesystem::path resume_dir = FLAGS_resume_from;
300+ if (rank.IsParallel ()) {
301+ const auto rank_dir = resume_dir / std::format (" rank_{:06d}" , rank.GlobalRank ());
302+ if (std::filesystem::exists (rank_dir)) {
303+ resume_dir = rank_dir;
304+ }
305+ }
306+
307+ TrainerState state;
308+ CheckpointLoadOptions load_options;
309+ load_options.load_optimizer_state = true ;
310+ if (FLAGS_use_llmc_checkpoint_io) {
311+ load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
312+ auto loaded_model = GPT2::FromLLMC (model_path.string ());
313+ target_model->LoadStateDict (loaded_model->StateDict ());
314+ };
315+ }
316+ Checkpoint::Load (resume_dir, model.get (), optimizer.get (), &state, load_options);
317+ start_step = static_cast <int >(state.global_step );
318+ best_loss = state.best_loss ;
319+ if (state.data_batch_stride != static_cast <int64_t >(ddp_world_size) && rank.IsMainRank ()) {
320+ LOG (WARNING) << std::format (" Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. "
321+ " Proceeding with recorded data_batch_idx {}." ,
322+ state.data_batch_stride , ddp_world_size, state.data_batch_idx );
323+ }
324+ saved_data_batch_idx = static_cast <size_t >(std::max<int64_t >(state.data_batch_idx , 0 ));
325+ train_iter = train_loader.IteratorAtBatchIndex (saved_data_batch_idx);
326+ if (rank.IsMainRank ()) {
327+ LOG (INFO) << std::format (
328+ " Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}" ,
329+ state.global_step , state.best_loss , state.last_lr , state.data_batch_idx );
330+ LOG (INFO) << std::format (" Checkpoint model I/O mode during resume: {}" ,
331+ FLAGS_use_llmc_checkpoint_io ? " llmc-callback" : " native-state-dict" );
332+ }
333+ }
334+
280335 LOG (INFO) << " start training" ;
281336
282- for (int step = 0 ; step < FLAGS_num_iteration + 1 ; ++step) {
337+ auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
338+ const auto ckpt_start = std::chrono::high_resolution_clock::now ();
339+
340+ TrainerState state;
341+ state.global_step = global_step;
342+ state.data_batch_idx = saved_data_batch_idx;
343+ state.data_batch_stride = ddp_world_size;
344+ state.best_loss = best_loss;
345+ state.last_lr = FLAGS_learning_rate;
346+ state.optimizer_type = " SGD" ;
347+ state.checkpoint_format = FLAGS_checkpoint_format;
348+ state.ddp_size = ddp_world_size;
349+ state.tp_size = tp_world_size;
350+ state.sp_size = sp_world_size;
351+ state.pp_size = pp_world_size;
352+
353+ CheckpointOptions options;
354+ options.format = FLAGS_checkpoint_format;
355+ options.save_optimizer_state = FLAGS_save_optimizer_state;
356+ if (FLAGS_use_llmc_checkpoint_io) {
357+ options.model_bin_writer = [&](const nn::Module &, const std::filesystem::path &model_path) {
358+ llmc_model->SaveAsLLMC (model_path.string ());
359+ };
360+ }
361+ Checkpoint::Save (save_dir, *model, *optimizer, state, options);
362+
363+ const auto ckpt_end = std::chrono::high_resolution_clock::now ();
364+ const double ckpt_ms = std::chrono::duration<double , std::milli>(ckpt_end - ckpt_start).count ();
365+
366+ if (rank.IsMainRank ()) {
367+ LOG (INFO) << std::format (" Checkpoint saved at: {} ({:.2f} ms)" , save_dir.string (), ckpt_ms);
368+
369+ if (prune_step_checkpoints) {
370+ std::vector<std::filesystem::path> ckpts;
371+ const auto root = std::filesystem::path (FLAGS_checkpoint_dir);
372+ if (std::filesystem::exists (root)) {
373+ for (const auto &entry : std::filesystem::directory_iterator (root)) {
374+ if (entry.is_directory ()
375+ && entry.path ().filename ().string ().starts_with (" checkpoint_step_" )) {
376+ ckpts.push_back (entry.path ());
377+ }
378+ }
379+ std::sort (ckpts.begin (), ckpts.end ());
380+ while (ckpts.size () > FLAGS_max_checkpoint_keep) {
381+ std::filesystem::remove_all (ckpts.front ());
382+ ckpts.erase (ckpts.begin ());
383+ }
384+ }
385+ }
386+ }
387+ };
388+
389+ for (int step = start_step; step < FLAGS_num_iteration + 1 ; ++step) {
283390 // Reset precision check counters at start of each iteration for file overwrite
284391 utils::PrecisionChecker::ResetCounters ();
285392
@@ -329,6 +436,7 @@ void Train(const nn::parallel::Rank &rank) {
329436 // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
330437 // TODO(dcj): support dataloader.reset() later
331438 ++train_iter;
439+ saved_data_batch_idx = train_iter.BatchIndex ();
332440 x = std::make_shared<Tensor>(x->To (device));
333441 y = std::make_shared<Tensor>(y->To (device));
334442
@@ -358,6 +466,7 @@ void Train(const nn::parallel::Rank &rank) {
358466 // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
359467 // TODO(dcj): support dataloader.reset() later
360468 ++train_iter;
469+ saved_data_batch_idx = train_iter.BatchIndex ();
361470 x = std::make_shared<Tensor>(x->To (device));
362471 y = std::make_shared<Tensor>(y->To (device));
363472
@@ -370,6 +479,8 @@ void Train(const nn::parallel::Rank &rank) {
370479 lossf = static_cast <const float *>(lossf_tensor->To (Device ()).DataPtr ())[0 ];
371480 }
372481
482+ best_loss = std::min (best_loss, lossf);
483+
373484 const auto iter_end = std::chrono::high_resolution_clock::now ();
374485 const double duration_us = std::chrono::duration<double , std::micro>(iter_end - iter_start).count ();
375486 const double tps = FLAGS_total_batch_size / (duration_us / 1e6 );
@@ -392,7 +503,22 @@ void Train(const nn::parallel::Rank &rank) {
392503 }
393504 }
394505 }
506+
507+ if (FLAGS_save_steps > 0 && (step + 1 ) % FLAGS_save_steps == 0 ) {
508+ std::filesystem::path step_dir
509+ = std::filesystem::path (FLAGS_checkpoint_dir) / std::format (" checkpoint_step_{:06d}" , step + 1 );
510+ if (rank.IsParallel ()) {
511+ step_dir /= std::format (" rank_{:06d}" , rank.GlobalRank ());
512+ }
513+ save_checkpoint (step_dir, step + 1 , true );
514+ }
515+ }
516+
517+ std::filesystem::path final_dir = std::filesystem::path (FLAGS_checkpoint_dir) / " checkpoint_final" ;
518+ if (rank.IsParallel ()) {
519+ final_dir /= std::format (" rank_{:06d}" , rank.GlobalRank ());
395520 }
521+ save_checkpoint (final_dir, FLAGS_num_iteration, false );
396522#ifdef PROFILE_MODE
397523 Profiler::Instance ().Report (" gpt2.report" , Profiler::SortBy::DeviceTimePercentage);
398524 Profiler::Instance ().PrintRecords (" gpt2.records.log" );
0 commit comments