|
1 | 1 | #pragma once |
| 2 | + |
2 | 3 | #include <cstdint> |
3 | 4 | #include <optional> |
4 | | -#include <string> |
5 | 5 |
|
6 | 6 | namespace infini_train::nn { |
7 | 7 |
|
| 8 | +enum class ModelType { |
| 9 | + kGPT2, // GPT-2 |
| 10 | + kLLaMA3, // LLaMA3 |
| 11 | +}; |
| 12 | + |
8 | 13 | enum class AttentionType { |
9 | | - kStandard, // Standard attention (GPT2 style, no RoPE) |
10 | | - kRoPE // Rotary Position Embedding (LLaMA3 style) |
| 14 | + kStandard, // Standard attention |
| 15 | + kRoPE // Rotary Position Embedding |
11 | 16 | }; |
12 | 17 |
|
13 | 18 | enum class MLPType { |
14 | | - kGELU, // GELU activation (GPT2 style) |
15 | | - kSwiGLU // SwiGLU activation (LLaMA3 style) |
| 19 | + kGELU, // GELU activation |
| 20 | + kSwiGLU // SwiGLU activation |
16 | 21 | }; |
17 | 22 |
|
18 | 23 | enum class NormType { |
19 | | - kLayerNorm, // LayerNorm (GPT2 style) |
20 | | - kRMSNorm // RMSNorm (LLaMA3 style) |
| 24 | + kLayerNorm, // LayerNorm |
| 25 | + kRMSNorm // RMSNorm |
21 | 26 | }; |
22 | 27 |
|
23 | 28 | struct TransformerConfig { |
24 | | - static constexpr char kGPT2Name[] = "GPT2"; |
25 | | - static constexpr char kLLaMA3Name[] = "LLaMA3"; |
26 | | - |
27 | | - std::string model_type = ""; |
| 29 | + ModelType model_type = ModelType::kGPT2; |
28 | 30 |
|
29 | 31 | int64_t block_size = 1024; // Max seq_len |
30 | 32 | int64_t vocab_size = 50304; // Vocab size |
@@ -59,42 +61,5 @@ struct TransformerConfig { |
59 | 61 | int64_t max_gen_batch_size = 4; // max batch size during inference |
60 | 62 |
|
61 | 63 | bool UseGQA() const { return n_kv_head < n_head; } |
62 | | - |
63 | | - static TransformerConfig GPT2() { |
64 | | - return {.model_type = kGPT2Name, |
65 | | - .block_size = 1024, |
66 | | - .vocab_size = 50304, |
67 | | - .original_vocab_size = 50257, |
68 | | - .n_layer = 12, |
69 | | - .n_head = 12, |
70 | | - .n_kv_head = 12, |
71 | | - .n_embd = 768, |
72 | | - .attention_type = AttentionType::kStandard, |
73 | | - .activation_type = MLPType::kGELU, |
74 | | - .norm_type = NormType::kLayerNorm, |
75 | | - .use_bias = true, |
76 | | - .tie_weights = true, |
77 | | - .ffn_expansion_ratio = 4.0f, |
78 | | - .ffn_dim_multiplier = std::nullopt, |
79 | | - .multiple_of = 1}; |
80 | | - } |
81 | | - |
82 | | - static TransformerConfig LLaMA3() { |
83 | | - return {.model_type = kLLaMA3Name, |
84 | | - .block_size = 8192, |
85 | | - .vocab_size = 128256, |
86 | | - .n_layer = 16, |
87 | | - .n_head = 32, |
88 | | - .n_kv_head = 8, |
89 | | - .n_embd = 2048, |
90 | | - .attention_type = AttentionType::kRoPE, |
91 | | - .activation_type = MLPType::kSwiGLU, |
92 | | - .norm_type = NormType::kRMSNorm, |
93 | | - .use_bias = false, |
94 | | - .tie_weights = false, |
95 | | - .ffn_expansion_ratio = 4.0f, |
96 | | - .ffn_dim_multiplier = 1.5f, |
97 | | - .multiple_of = 256}; |
98 | | - } |
99 | 64 | }; |
100 | 65 | } // namespace infini_train::nn |
0 commit comments