Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions csrc/cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,15 @@ StaticKVCache::StaticKVCache(
: Cache(),
k_dim_(k_dim),
v_dim_(v_dim),
num_rank_k_heads_(num_k_heads / rank_info.tp_size),
num_rank_v_heads_(num_v_heads / rank_info.tp_size),
rank_batch_size_(config.max_batch_size()),
cache_len_(config.max_cache_len() == std::numeric_limits<infinicore::Size>::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()),
rank_num_layers_(num_layers),
dtype_(dtype) {

bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0);

num_rank_k_heads_ = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size);
num_rank_v_heads_ = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size);
// Allocate K cache
k_caches_ = infinicore::Tensor::empty(
{rank_num_layers_,
Expand Down Expand Up @@ -90,15 +92,20 @@ infinicore::Tensor StaticKVCache::create_layer_kv_cache(
const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info();

size_t rank_batch_size = (config.max_batch_size());
size_t num_rank_kv_heads = (num_k_heads / rank_info.tp_size);
size_t kv_dim = k_dim;

bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0);

size_t num_rank_k_heads = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size);
size_t num_rank_v_heads = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size);

size_t cache_len = (config.max_cache_len() == std::numeric_limits<infinicore::Size>::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len());

// Allocate KV cache
infinicore::Tensor kv_cache = infinicore::Tensor::empty(
{2,
rank_batch_size,
num_rank_kv_heads,
num_rank_k_heads,
cache_len,
kv_dim},
dtype,
Expand Down Expand Up @@ -186,12 +193,15 @@ PagedKVCache::PagedKVCache(
: Cache(),
k_dim_(k_dim),
v_dim_(v_dim),
num_rank_k_heads_(num_k_heads / rank_info.tp_size),
num_rank_v_heads_(num_v_heads / rank_info.tp_size),
rank_num_layers_(num_layers),
dtype_(dtype),
num_blocks_per_layer_(config.num_blocks()),
block_size_(config.block_size()) {

bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0);

num_rank_k_heads_ = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size);
num_rank_v_heads_ = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size);
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
k_caches_ = infinicore::Tensor::empty(
{rank_num_layers_,
Expand Down Expand Up @@ -224,8 +234,11 @@ infinicore::Tensor PagedKVCache::create_layer_kv_cache(

const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info();

size_t num_rank_kv_heads(num_k_heads / rank_info.tp_size);
size_t kv_dim = k_dim;
bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0);

size_t num_rank_k_heads = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size);
size_t num_rank_v_heads = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size);

size_t num_blocks_per_layer = config.num_blocks();
size_t block_size = config.block_size();
Expand All @@ -234,7 +247,7 @@ infinicore::Tensor PagedKVCache::create_layer_kv_cache(
infinicore::Tensor kv_cache = infinicore::Tensor::empty(
{2,
num_blocks_per_layer,
num_rank_kv_heads,
num_rank_k_heads,
block_size,
kv_dim},
dtype,
Expand Down
5 changes: 1 addition & 4 deletions csrc/layers/attention/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@ Attention::Attention(std::shared_ptr<infinilm::config::ModelConfig> model_config
const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info();
int tp_rank = infinilm::global_state::get_tensor_model_parallel_rank();
int tp_size = infinilm::global_state::get_tensor_model_parallel_world_size();
if ((total_num_kv_heads < tp_size) || (0 != (total_num_kv_heads % tp_size))) {
throw std::runtime_error("infinilm::layers::attention::Attention: num_key_value_heads must be divisible by tp_size");
}

num_attention_heads_ = total_num_heads / tp_size;
num_key_value_heads_ = total_num_kv_heads / tp_size;
num_key_value_heads_ = total_num_kv_heads < tp_size ? 1 : total_num_kv_heads / tp_size;

auto quant_scheme = model_config->get_quant_scheme();
auto quantization_method = model_config->get_quantization_method();
Expand Down
16 changes: 7 additions & 9 deletions csrc/layers/linear/fused_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
engine::distributed::RankInfo rank_info)
: infinicore::nn::ColumnParallelLinear(
hidden_size,
num_q_head * q_dim + num_k_head * k_dim + num_v_head * v_dim,
calculate_out_feature_size(num_q_head, q_dim, num_k_head, k_dim, num_v_head, v_dim, rank_info),
quantization,
(q_bias || k_bias || v_bias),
dtype,
Expand All @@ -110,18 +110,16 @@ QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
num_v_head_(num_v_head),
q_bias_(q_bias),
k_bias_(k_bias),
v_bias_(v_bias) {
if (num_q_head % tp_size_ != 0 || num_k_head % tp_size_ != 0 || num_v_head % tp_size_ != 0) {
throw std::runtime_error("QKVParallelLinear: num_[q|k|v]_head must be divisible by tp_size");
}
v_bias_(v_bias),
num_kv_head_replicas_(calculate_kv_replicas(num_k_head, rank_info.tp_size)) {

if ((q_bias_ != k_bias_) || (k_bias_ != v_bias_)) {
throw std::runtime_error("q_bias, k_bias, v_bias must all match");
}

q_out_size_ = num_q_head_ * q_dim_ / tp_size_;
k_out_size_ = num_k_head_ * k_dim_ / tp_size_;
v_out_size_ = num_v_head_ * v_dim_ / tp_size_;
k_out_size_ = num_kv_head_replicas_ * num_k_head_ * k_dim_ / tp_size_;
v_out_size_ = num_kv_head_replicas_ * num_v_head_ * v_dim_ / tp_size_;
}

std::tuple<infinicore::Tensor, infinicore::Tensor, infinicore::Tensor>
Expand All @@ -144,13 +142,13 @@ infinicore::nn::Parameter QKVParallelLinear::get_q_weight() const {
infinicore::nn::Parameter QKVParallelLinear::get_k_weight() const {
return infinicore::nn::Parameter(
weight_->narrow({{0, q_out_size_, k_out_size_}}),
0, tp_rank_, tp_size_);
0, tp_rank_, tp_size_, num_k_head_);
}

infinicore::nn::Parameter QKVParallelLinear::get_v_weight() const {
return infinicore::nn::Parameter(
weight_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}),
0, tp_rank_, tp_size_);
0, tp_rank_, tp_size_, num_v_head_);
}

infinicore::nn::Parameter QKVParallelLinear::get_q_weight_scale() const {
Expand Down
18 changes: 18 additions & 0 deletions csrc/layers/linear/fused_linear.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear {
bool has_k_bias() const;
bool has_v_bias() const;

private:
static size_t calculate_kv_replicas(size_t num_k_head, size_t tp_size) {
if (num_k_head % tp_size == 0) {
return 1;
}
if (tp_size % num_k_head == 0) {
return (tp_size + num_k_head - 1) / num_k_head;
}
throw std::runtime_error("Invalid KV head configuration");
}

static size_t
calculate_out_feature_size(size_t num_q_head, size_t q_dim, size_t num_k_head, size_t k_dim, size_t num_v_head, size_t v_dim, engine::distributed::RankInfo rank_info) {
return num_q_head * q_dim + num_k_head * k_dim * calculate_kv_replicas(num_k_head, rank_info.tp_size) + num_v_head * v_dim * calculate_kv_replicas(num_v_head, rank_info.tp_size);
}

private:
size_t q_dim_;
size_t k_dim_;
Expand All @@ -94,6 +110,8 @@ class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear {
size_t q_out_size_; // num_q_head * q_dim / tp_size
size_t k_out_size_; // num_k_head * k_dim / tp_size
size_t v_out_size_; // num_v_head * v_dim / tp_size

size_t num_kv_head_replicas_ = 1;
};

class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear {
Expand Down