Skip to content

Commit 1446ce0

Browse files
committed
feat: add Qwen3_8B
1 parent e7f41c9 commit 1446ce0

7 files changed

Lines changed: 1113 additions & 0 deletions

File tree

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,14 @@ add_executable(llama3
190190
)
191191
link_infini_train_exe(llama3)
192192

193+
add_executable(qwen3
194+
example/qwen3/main.cc
195+
example/common/tiny_shakespeare_dataset.cc
196+
example/common/utils.cc
197+
example/qwen3/checkpoint_loader.cc
198+
example/common/tokenizer.cc
199+
)
200+
link_infini_train_exe(qwen3)
193201
# Tools
194202
add_subdirectory(tools/infini_run)
195203
set_target_properties(infini_run PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})

example/qwen3/checkpoint_loader.cc

Lines changed: 357 additions & 0 deletions
Large diffs are not rendered by default.

example/qwen3/checkpoint_loader.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <string>
5+
6+
namespace infini_train::nn {
7+
class TransformerModel;
8+
} // namespace infini_train::nn
9+
10+
namespace qwen3 {
11+
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
12+
} // namespace qwen3

example/qwen3/config.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#pragma once
2+
3+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
4+
5+
namespace nn = infini_train::nn;
6+
namespace qwen3 {
7+
inline nn::TransformerConfig Qwen3Config() {
8+
return {.block_size = 40960,
9+
.vocab_size = 151936,
10+
.original_vocab_size = 151936,
11+
.n_layer = 36,
12+
.n_head = 32,
13+
.n_kv_head = 8,
14+
.n_embd = 4096,
15+
.attention_type = nn::AttentionType::kRoPE,
16+
.activation_type = nn::MLPType::kSwiGLU,
17+
.norm_type = nn::NormType::kRMSNorm,
18+
.add_bias_linear = false,
19+
.add_bias_lm_head = false,
20+
.tie_weights = false,
21+
.ffn_expansion_ratio = 4.5f, // 4096*4.5*2/3 = 12288
22+
.ffn_dim_multiplier = std::nullopt,
23+
.multiple_of = 1,
24+
.rope_theta = 1000000.0f,
25+
.use_scaled_rope = false,
26+
.norm_eps = 1e-6f};
27+
}
28+
} // namespace qwen3

0 commit comments

Comments
 (0)