1717#include " infini_train/include/nn/modules/normalization.h"
1818#include " infini_train/include/nn/modules/sparse.h"
1919#include " infini_train/include/nn/modules/transformer/causal_self_attention.h"
20- #include " infini_train/include/nn/modules/transformer/layer_specs.h"
2120#include " infini_train/include/nn/modules/transformer/mlp.h"
2221#include " infini_train/include/nn/modules/transformer/transformer.h"
2322#include " infini_train/include/nn/parallel/global.h"
@@ -96,10 +95,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
9695 gpt2_config.n_layer = n_layer;
9796 gpt2_config.n_head = n_head;
9897 gpt2_config.n_embd = n_embd;
99- auto local_gpt2 = std::make_shared<nn::TransformerModel>(
100- gpt2_config,
101- nn::BuildTransformerSpec (gpt2_config, nn::BuildFirstStageSpec (gpt2_config),
102- nn::BuildTransformerLayerSpec (gpt2_config), nn::BuildLastStageSpec (gpt2_config)));
98+ auto local_gpt2 = std::make_shared<nn::TransformerModel>(gpt2_config);
10399
104100 LOG (INFO) << " magic: " << magic << " version: " << version << " block_size: " << block_size
105101 << " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head
@@ -140,6 +136,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
140136
141137 auto state_dict = local_gpt2->StateDict ();
142138
139+ printf (" ===============Model Config:===============\n " );
143140 // transformer.wte.weight (also transformer.lm_head.weight)
144141 // full: (model_vocab_size, n_embd)
145142 // local: (vocab_size_per_partition, n_embd)
@@ -158,7 +155,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
158155 size_t wte_bytes = model_vocab_size * n_embd * sizeof (float );
159156 ifs.seekg (wte_bytes, std::ios::cur);
160157 }
161-
158+ printf ( " Loading wte.weight... \n " );
162159 if (tp_size == 1 ) {
163160 // Skip padded vocab part when TP is not enabled
164161 ifs.ignore ((padded_vocab_size - model_vocab_size) * n_embd * sizeof (float ));
@@ -174,7 +171,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
174171 size_t wpe_bytes = block_size * n_embd * sizeof (float );
175172 ifs.seekg (wpe_bytes, std::ios::cur);
176173 }
177-
174+ printf ( " Loading wpe.weight... \n " );
178175 // transformer.h.{i}.ln_1.weight
179176 int local_layer_index = 0 ;
180177 for (int idx = 0 ; idx < n_layer; ++idx) {
@@ -190,7 +187,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
190187 ifs.seekg (ln_1_w_bytes, std::ios::cur);
191188 }
192189 }
193-
190+ printf ( " Loading ln_1.weight... \n " );
194191 // transformer.h.{i}.ln_1.bias
195192 local_layer_index = 0 ;
196193 for (int idx = 0 ; idx < n_layer; ++idx) {
@@ -205,7 +202,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
205202 ifs.seekg (ln_1_b_bytes, std::ios::cur);
206203 }
207204 }
208-
205+ printf ( " Loading ln_1.bias... \n " );
209206 // transformer.h.{i}.attn.c_attn.weight (ColumnParallelLinear, but actually applies on "rows")
210207 local_layer_index = 0 ;
211208 for (int idx = 0 ; idx < n_layer; ++idx) {
@@ -248,7 +245,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
248245 ifs.seekg (c_attn_w_bytes, std::ios::cur);
249246 }
250247 }
251-
248+ printf ( " Loading c_attn.weight... \n " );
252249 // transformer.h.{i}.attn.c_attn.bias (ColumnParallelLinear)
253250 local_layer_index = 0 ;
254251 for (int idx = 0 ; idx < n_layer; ++idx) {
@@ -290,7 +287,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
290287 ifs.seekg (c_attn_b_bytes, std::ios::cur);
291288 }
292289 }
293-
290+ printf ( " Loading c_attn.bias... \n " );
294291 // transformer.h.{i}.attn.c_proj.weight (RowParallelLinear, but actually applies on "columns")
295292 local_layer_index = 0 ;
296293 for (int idx = 0 ; idx < n_layer; ++idx) {
@@ -307,7 +304,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
307304 ifs.seekg (c_proj_w_bytes, std::ios::cur);
308305 }
309306 }
310-
307+ printf ( " Loading c_proj.weight... \n " );
311308 // transformer.h.{i}.attn.c_proj.bias (RowParallelLinear, no shard on bias)
312309 local_layer_index = 0 ;
313310 for (int idx = 0 ; idx < n_layer; ++idx) {
@@ -323,7 +320,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
323320 ifs.seekg (c_proj_b_bytes, std::ios::cur);
324321 }
325322 }
326-
323+ printf ( " Loading ln_2.weight... \n " );
327324 // transformer.h.{i}.ln_2.weight
328325 local_layer_index = 0 ;
329326 for (int idx = 0 ; idx < n_layer; ++idx) {
@@ -339,7 +336,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
339336 ifs.seekg (ln_2_w_bytes, std::ios::cur);
340337 }
341338 }
342-
339+ printf ( " Loading ln_2.bias... \n " );
343340 // transformer.h.{i}.ln_2.bias
344341 local_layer_index = 0 ;
345342 for (int idx = 0 ; idx < n_layer; ++idx) {
@@ -354,7 +351,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
354351 ifs.seekg (ln_2_b_bytes, std::ios::cur);
355352 }
356353 }
357-
354+ printf ( " Loading mlp.c_fc.weight... \n " );
358355 // transformer.h.{i}.mlp.c_fc.weight (ColumnParallelLinear, but actually applies on "rows")
359356 local_layer_index = 0 ;
360357 for (int idx = 0 ; idx < n_layer; ++idx) {
@@ -370,7 +367,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
370367 ifs.seekg (c_fc_w_bytes, std::ios::cur);
371368 }
372369 }
373-
370+ printf ( " Loading mlp.c_fc.bias... \n " );
374371 // transformer.h.{i}.mlp.c_fc.bias (ColumnParallelLinear)
375372 local_layer_index = 0 ;
376373 for (int idx = 0 ; idx < n_layer; ++idx) {
@@ -386,7 +383,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
386383 ifs.seekg (c_fc_b_bytes, std::ios::cur);
387384 }
388385 }
389-
386+ printf ( " Loading mlp.c_proj.weight... \n " );
390387 // transformer.h.{i}.mlp.c_proj.weight (RowParallelLinear, but actually applies on "columns")
391388 local_layer_index = 0 ;
392389 for (int idx = 0 ; idx < n_layer; ++idx) {
@@ -403,7 +400,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
403400 ifs.seekg (c_proj_w_bytes, std::ios::cur);
404401 }
405402 }
406-
403+ printf ( " Loading mlp.c_proj.bias... \n " );
407404 // transformer.h.{i}.mlp.c_proj.bias (RowParallelLinear, no shard on bias)
408405 local_layer_index = 0 ;
409406 for (int idx = 0 ; idx < n_layer; ++idx) {
@@ -420,6 +417,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
420417 }
421418 }
422419
420+ printf (" Loading is_last_stage...\n " );
423421 if (is_last_stage) {
424422 // transformer.ln_f.weight
425423 auto &transformer_ln_f_weight
@@ -436,6 +434,8 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
436434 size_t ln_f_b_bytes = n_embd * sizeof (float );
437435 ifs.seekg (ln_f_w_bytes + ln_f_b_bytes, std::ios::cur);
438436 }
437+
438+ printf (" Finished loading checkpoint from %s\n " , filepath.c_str ());
439439 return local_gpt2;
440440}
441441} // namespace gpt2
0 commit comments