-
Notifications
You must be signed in to change notification settings - Fork 62
Expand file tree
/
Copy pathllama_decoder_layer.hpp
More file actions
84 lines (73 loc) · 2.9 KB
/
llama_decoder_layer.hpp
File metadata and controls
84 lines (73 loc) · 2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#pragma once
#include "infinicore/device.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/tensor.hpp"
#include "llama_attention.hpp"
#include "llama_config.hpp"
#include "llama_mlp.hpp"
#include "../../engine/distributed/distributed.hpp"
#include <optional>
namespace infinilm::models::llama {
/**
* @brief Single decoder layer (transformer block) for Llama
*
* Each decoder layer consists of:
* - Input layer normalization (RMSNorm)
* - Self-attention mechanism
* - Post-attention layer normalization (RMSNorm)
* - MLP feed-forward network
*
* Residual connections are applied around both attention and MLP blocks.
*/
class LlamaDecoderLayer : public infinicore::nn::Module {
public:
/**
* @brief Construct LlamaDecoderLayer module
*
* @param config Model configuration
* @param device Device to create tensors on
* @param layer_idx Layer index for cache management and debugging
* @param dtype Optional data type for model parameters (defaults to F32)
*/
LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/**
* @brief Forward pass: process one decoder layer
*
* @param hidden_states Input tensor of shape [batch, seq_len, hidden_size]
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param kv_cache Optional KV cache for incremental decoding
* @param cache_positions Cache positions tensor
* @param residual Optional residual tensor from previous layer (for MLP residual connection)
* @return Pair of (output, residual) tensors, where residual can be reused by next layer
*/
std::pair<infinicore::Tensor, infinicore::Tensor> forward(
const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions,
const std::optional<infinicore::Tensor> &residual = std::nullopt) const;
/**
* @brief Get the layer index
*/
size_t layer_idx() const { return layer_idx_; }
void set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb) {
if (self_attn_) {
self_attn_->set_rotary_emb(rotary_emb);
}
}
protected:
// Layer normalization
INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm);
INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm);
// Attention and MLP
INFINICORE_NN_MODULE(LlamaAttention, self_attn);
INFINICORE_NN_MODULE(LlamaMLP, mlp);
engine::distributed::RankInfo rank_info_;
private:
size_t layer_idx_; // Layer index for cache management and debugging
};
} // namespace infinilm::models::llama