Skip to content

Commit 43b182a

Browse files
committed
temp
1 parent 52eca96 commit 43b182a

16 files changed

Lines changed: 325 additions & 798 deletions

File tree

example/gpt2/checkpoint_loader.cc

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "infini_train/include/nn/modules/normalization.h"
1818
#include "infini_train/include/nn/modules/sparse.h"
1919
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
20-
#include "infini_train/include/nn/modules/transformer/layer_specs.h"
2120
#include "infini_train/include/nn/modules/transformer/mlp.h"
2221
#include "infini_train/include/nn/modules/transformer/transformer.h"
2322
#include "infini_train/include/nn/parallel/global.h"
@@ -96,10 +95,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
9695
gpt2_config.n_layer = n_layer;
9796
gpt2_config.n_head = n_head;
9897
gpt2_config.n_embd = n_embd;
99-
auto local_gpt2 = std::make_shared<nn::TransformerModel>(
100-
gpt2_config,
101-
nn::BuildTransformerSpec(gpt2_config, nn::BuildFirstStageSpec(gpt2_config),
102-
nn::BuildTransformerLayerSpec(gpt2_config), nn::BuildLastStageSpec(gpt2_config)));
98+
auto local_gpt2 = std::make_shared<nn::TransformerModel>(gpt2_config);
10399

104100
LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size
105101
<< " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head
@@ -140,6 +136,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
140136

141137
auto state_dict = local_gpt2->StateDict();
142138

139+
printf("===============Model Config:===============\n");
143140
// transformer.wte.weight (also transformer.lm_head.weight)
144141
// full: (model_vocab_size, n_embd)
145142
// local: (vocab_size_per_partition, n_embd)
@@ -158,7 +155,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
158155
size_t wte_bytes = model_vocab_size * n_embd * sizeof(float);
159156
ifs.seekg(wte_bytes, std::ios::cur);
160157
}
161-
158+
printf("Loading wte.weight...\n");
162159
if (tp_size == 1) {
163160
// Skip padded vocab part when TP is not enabled
164161
ifs.ignore((padded_vocab_size - model_vocab_size) * n_embd * sizeof(float));
@@ -174,7 +171,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
174171
size_t wpe_bytes = block_size * n_embd * sizeof(float);
175172
ifs.seekg(wpe_bytes, std::ios::cur);
176173
}
177-
174+
printf("Loading wpe.weight...\n");
178175
// transformer.h.{i}.ln_1.weight
179176
int local_layer_index = 0;
180177
for (int idx = 0; idx < n_layer; ++idx) {
@@ -190,7 +187,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
190187
ifs.seekg(ln_1_w_bytes, std::ios::cur);
191188
}
192189
}
193-
190+
printf("Loading ln_1.weight...\n");
194191
// transformer.h.{i}.ln_1.bias
195192
local_layer_index = 0;
196193
for (int idx = 0; idx < n_layer; ++idx) {
@@ -205,7 +202,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
205202
ifs.seekg(ln_1_b_bytes, std::ios::cur);
206203
}
207204
}
208-
205+
printf("Loading ln_1.bias...\n");
209206
// transformer.h.{i}.attn.c_attn.weight (ColumnParallelLinear, but actually applies on "rows")
210207
local_layer_index = 0;
211208
for (int idx = 0; idx < n_layer; ++idx) {
@@ -248,7 +245,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
248245
ifs.seekg(c_attn_w_bytes, std::ios::cur);
249246
}
250247
}
251-
248+
printf("Loading c_attn.weight...\n");
252249
// transformer.h.{i}.attn.c_attn.bias (ColumnParallelLinear)
253250
local_layer_index = 0;
254251
for (int idx = 0; idx < n_layer; ++idx) {
@@ -290,7 +287,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
290287
ifs.seekg(c_attn_b_bytes, std::ios::cur);
291288
}
292289
}
293-
290+
printf("Loading c_attn.bias...\n");
294291
// transformer.h.{i}.attn.c_proj.weight (RowParallelLinear, but actually applies on "columns")
295292
local_layer_index = 0;
296293
for (int idx = 0; idx < n_layer; ++idx) {
@@ -307,7 +304,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
307304
ifs.seekg(c_proj_w_bytes, std::ios::cur);
308305
}
309306
}
310-
307+
printf("Loading c_proj.weight...\n");
311308
// transformer.h.{i}.attn.c_proj.bias (RowParallelLinear, no shard on bias)
312309
local_layer_index = 0;
313310
for (int idx = 0; idx < n_layer; ++idx) {
@@ -323,7 +320,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
323320
ifs.seekg(c_proj_b_bytes, std::ios::cur);
324321
}
325322
}
326-
323+
printf("Loading ln_2.weight...\n");
327324
// transformer.h.{i}.ln_2.weight
328325
local_layer_index = 0;
329326
for (int idx = 0; idx < n_layer; ++idx) {
@@ -339,7 +336,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
339336
ifs.seekg(ln_2_w_bytes, std::ios::cur);
340337
}
341338
}
342-
339+
printf("Loading ln_2.bias...\n");
343340
// transformer.h.{i}.ln_2.bias
344341
local_layer_index = 0;
345342
for (int idx = 0; idx < n_layer; ++idx) {
@@ -354,7 +351,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
354351
ifs.seekg(ln_2_b_bytes, std::ios::cur);
355352
}
356353
}
357-
354+
printf("Loading mlp.c_fc.weight...\n");
358355
// transformer.h.{i}.mlp.c_fc.weight (ColumnParallelLinear, but actually applies on "rows")
359356
local_layer_index = 0;
360357
for (int idx = 0; idx < n_layer; ++idx) {
@@ -370,7 +367,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
370367
ifs.seekg(c_fc_w_bytes, std::ios::cur);
371368
}
372369
}
373-
370+
printf("Loading mlp.c_fc.bias...\n");
374371
// transformer.h.{i}.mlp.c_fc.bias (ColumnParallelLinear)
375372
local_layer_index = 0;
376373
for (int idx = 0; idx < n_layer; ++idx) {
@@ -386,7 +383,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
386383
ifs.seekg(c_fc_b_bytes, std::ios::cur);
387384
}
388385
}
389-
386+
printf("Loading mlp.c_proj.weight...\n");
390387
// transformer.h.{i}.mlp.c_proj.weight (RowParallelLinear, but actually applies on "columns")
391388
local_layer_index = 0;
392389
for (int idx = 0; idx < n_layer; ++idx) {
@@ -403,7 +400,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
403400
ifs.seekg(c_proj_w_bytes, std::ios::cur);
404401
}
405402
}
406-
403+
printf("Loading mlp.c_proj.bias...\n");
407404
// transformer.h.{i}.mlp.c_proj.bias (RowParallelLinear, no shard on bias)
408405
local_layer_index = 0;
409406
for (int idx = 0; idx < n_layer; ++idx) {
@@ -420,6 +417,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
420417
}
421418
}
422419

420+
printf("Loading is_last_stage...\n");
423421
if (is_last_stage) {
424422
// transformer.ln_f.weight
425423
auto &transformer_ln_f_weight
@@ -436,6 +434,8 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
436434
size_t ln_f_b_bytes = n_embd * sizeof(float);
437435
ifs.seekg(ln_f_w_bytes + ln_f_b_bytes, std::ios::cur);
438436
}
437+
438+
printf("Finished loading checkpoint from %s\n", filepath.c_str());
439439
return local_gpt2;
440440
}
441441
} // namespace gpt2

example/gpt2/main.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ void Train(const nn::parallel::Rank &rank) {
188188

189189
if (!FLAGS_llmc_filepath.empty()) {
190190
model = gpt2::LoadFromLLMC(FLAGS_llmc_filepath);
191+
printf("Loaded model from LLMC checkpoint: %s\n", FLAGS_llmc_filepath.c_str());
191192
} else if (kModelToConfigs.count(FLAGS_model)) {
192193
model_config = kModelToConfigs.at(FLAGS_model);
193194
model = std::make_shared<nn::TransformerModel>(model_config);
@@ -370,6 +371,7 @@ void Train(const nn::parallel::Rank &rank) {
370371
y = std::make_shared<Tensor>(y->To(device));
371372

372373
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward";
374+
373375
// (bs, seq_len, vocab_size)
374376
auto logits = (*model)({x, y})[0];
375377
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward";

example/llama3/checkpoint_loader.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "example/llama3/config.h"
1717
#include "infini_train/include/nn/modules/normalization.h"
1818
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
19-
#include "infini_train/include/nn/modules/transformer/layer_specs.h"
2019
#include "infini_train/include/nn/modules/transformer/mlp.h"
2120
#include "infini_train/include/nn/modules/transformer/transformer.h"
2221
#include "infini_train/include/nn/parallel/global.h"
@@ -90,10 +89,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
9089
llama3_config.use_scaled_rope = static_cast<bool>(use_scaled_rope);
9190
llama3_config.norm_eps = norm_eps;
9291
llama3_config.max_gen_batch_size = max_gen_bs;
93-
auto llama3 = std::make_shared<nn::TransformerModel>(
94-
llama3_config,
95-
nn::BuildTransformerSpec(llama3_config, nn::BuildFirstStageSpec(llama3_config),
96-
nn::BuildTransformerLayerSpec(llama3_config), nn::BuildLastStageSpec(llama3_config)));
92+
auto llama3 = std::make_shared<nn::TransformerModel>(llama3_config);
9793

9894
// ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========
9995
int pp_size = nn::parallel::global::GetPipelineParallelSize();

infini_train/include/nn/modules/transformer/causal_self_attention.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include <vector>
66

77
#include "infini_train/include/nn/modules/module.h"
8-
#include "infini_train/include/nn/modules/transformer/spec_utils.h"
98
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
109

1110
namespace infini_train::nn {
@@ -18,7 +17,7 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule<CausalSelfA
1817

1918
static constexpr char kParamBiasName[] = "bias";
2019

21-
explicit CausalSelfAttention(const TransformerConfig &config, const ModuleSpec &spec = {});
20+
explicit CausalSelfAttention(const TransformerConfig &config);
2221

2322
std::vector<std::shared_ptr<infini_train::Tensor>>
2423
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

infini_train/include/nn/modules/transformer/layer_specs.h

Lines changed: 0 additions & 55 deletions
This file was deleted.

infini_train/include/nn/modules/transformer/mlp.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <vector>
44

55
#include "infini_train/include/nn/modules/module.h"
6-
#include "infini_train/include/nn/modules/transformer/spec_utils.h"
76
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
87

98
namespace infini_train::nn {
@@ -17,9 +16,12 @@ class MLP : public infini_train::nn::CloneableModule<MLP> {
1716
static constexpr char kCFc2LayerName[] = "c_fc2";
1817
static constexpr char kSiluLayerName[] = "silu";
1918

20-
explicit MLP(const TransformerConfig &config, const ModuleSpec &spec = {});
19+
explicit MLP(const TransformerConfig &config);
2120

2221
std::vector<std::shared_ptr<infini_train::Tensor>>
2322
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
23+
24+
private:
25+
int64_t hidden_dim_ = 0;
2426
};
2527
} // namespace infini_train::nn

infini_train/include/nn/modules/transformer/spec_utils.h

Lines changed: 0 additions & 93 deletions
This file was deleted.

0 commit comments

Comments
 (0)