Skip to content

Commit 4e53d3f

Browse files
authored
Issue/1127: kv replica when tp_size > num_kv_heads (#1128)
1 parent fa3a233 commit 4e53d3f

2 files changed

Lines changed: 60 additions & 21 deletions

File tree

include/infinicore/nn/parameter.hpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,28 @@ class Parameter : public Tensor {
1010
Parameter(const Tensor &tensor,
1111
Size tp_dim = 0,
1212
Size tp_rank = 0,
13-
Size tp_size = 1);
13+
Size tp_size = 1,
14+
Size num_shards = 0);
1415

1516
Parameter(const Shape &shape,
1617
const DataType &dtype,
1718
const Device &device,
1819
Size tp_dim = 0,
1920
Size tp_rank = 0,
20-
Size tp_size = 1);
21+
Size tp_size = 1,
22+
Size num_shards = 0);
23+
24+
Parameter(const Parameter &other);
2125

2226
void load_blob(const void *data);
2327

2428
void load(const Tensor &tensor);
2529

2630
protected:
2731
// Tensor parallel configs
28-
Size tp_dim_; // dimension partitioned
29-
Size tp_rank_; // rank of this partition among tp group
30-
Size tp_size_; // total number of partitions
32+
Size tp_dim_; // dimension partitioned
33+
Size tp_rank_; // rank of this partition among tp group
34+
Size tp_size_; // total number of partitions
35+
Size num_shards_ = 0; // number of logical shards, used when tp_size > num_kv_head
3136
};
3237
} // namespace infinicore::nn

src/infinicore/nn/parameter.cc

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,22 @@ Parameter::Parameter()
1010
: Tensor() {
1111
}
1212

13-
inline Shape get_partipion_shape_(const Shape &shape, Size tp_dim, Size tp_size) {
13+
inline Shape get_partition_shape_(const Shape &shape, Size tp_dim, Size tp_size, Size num_shards) {
1414
if (tp_size <= 1) {
1515
return shape;
1616
}
1717
Shape part_shape = shape;
1818
if (tp_dim < shape.size()) {
19-
if (shape[tp_dim] % tp_size != 0) {
20-
throw std::runtime_error("Tensor dimension " + std::to_string(tp_dim) + " with size " + std::to_string(shape[tp_dim]) + " is not divisible by tensor parallel size " + std::to_string(tp_size) + ".");
19+
Size partition_factor = (num_shards > 0) ? num_shards : tp_size;
20+
if (shape[tp_dim] % partition_factor != 0) {
21+
throw std::runtime_error("Tensor dimension " + std::to_string(tp_dim) + " with size " + std::to_string(shape[tp_dim]) + " is not divisible by " + (num_shards > 0 ? "num_shards " : "tp_size ") + std::to_string(partition_factor) + ".");
2122
}
22-
part_shape[tp_dim] = shape[tp_dim] / tp_size;
23+
part_shape[tp_dim] = shape[tp_dim] / partition_factor;
2324
}
2425
return part_shape;
2526
}
2627

27-
Parameter::Parameter(const Tensor &tensor, Size tp_dim, Size tp_rank, Size tp_size) : Tensor(tensor), tp_dim_(tp_dim), tp_rank_(tp_rank), tp_size_(tp_size) {
28+
Parameter::Parameter(const Tensor &tensor, Size tp_dim, Size tp_rank, Size tp_size, Size num_shards) : Tensor(tensor), tp_dim_(tp_dim), tp_rank_(tp_rank), tp_size_(tp_size), num_shards_(num_shards) {
2829
if (tp_rank_ >= tp_size_) {
2930
throw std::runtime_error("Tensor parallel rank " + std::to_string(tp_rank_) + " must be less than tensor parallel size " + std::to_string(tp_size_) + ".");
3031
}
@@ -36,10 +37,18 @@ Parameter::Parameter(
3637
const Device &device,
3738
Size tp_dim,
3839
Size tp_rank,
39-
Size tp_size)
40-
: Parameter(Tensor::empty(get_partipion_shape_(shape, tp_dim, tp_size), dtype, device, false), tp_dim, tp_rank, tp_size) {
40+
Size tp_size,
41+
Size num_shards)
42+
: Parameter(Tensor::empty(get_partition_shape_(shape, tp_dim, tp_size, num_shards), dtype, device, false), tp_dim, tp_rank, tp_size, num_shards) {
4143
}
4244

45+
Parameter::Parameter(const Parameter &other)
46+
: Tensor(other),
47+
tp_dim_(other.tp_dim_),
48+
tp_rank_(other.tp_rank_),
49+
tp_size_(other.tp_size_),
50+
num_shards_(other.num_shards_) {}
51+
4352
void Parameter::load_blob(const void *data) {
4453
Shape expected_shape = Shape(impl_->shape());
4554
expected_shape[tp_dim_] *= tp_size_;
@@ -49,21 +58,46 @@ void Parameter::load_blob(const void *data) {
4958
}
5059

5160
void Parameter::load(const Tensor &tensor) {
52-
Shape expected_shape = Shape(impl_->shape());
53-
expected_shape[tp_dim_] *= tp_size_;
54-
55-
if (expected_shape != tensor->shape()) {
56-
throw std::runtime_error("Shape mismatch when loading tensor into parameter. Weight: " + impl_->info() + ", Tensor: " + tensor->info() + ".");
57-
}
5861
if (impl_->dtype() != tensor->dtype()) {
5962
throw std::runtime_error("Dtype mismatch when loading tensor into parameter. Weight: " + impl_->info() + ", Tensor: " + tensor->info() + ".");
6063
}
61-
if (tp_size_ > 1) {
62-
impl_->copy_from(tensor->narrow({{tp_dim_, tp_rank_ * impl_->size(tp_dim_), impl_->size(tp_dim_)}}));
6364

65+
Shape expected_shape = Shape(impl_->shape());
66+
67+
if (num_shards_ == 0 || num_shards_ >= tp_size_) {
68+
expected_shape[tp_dim_] *= tp_size_;
69+
70+
if (expected_shape != tensor->shape()) {
71+
throw std::runtime_error("Shape mismatch when loading tensor into parameter. Weight: " + impl_->info() + ", Tensor: " + tensor->info() + ".");
72+
}
73+
if (tp_size_ > 1) {
74+
impl_->copy_from(tensor->narrow({{tp_dim_, tp_rank_ * impl_->size(tp_dim_), impl_->size(tp_dim_)}}));
75+
} else {
76+
impl_->copy_from(tensor);
77+
}
6478
} else {
65-
impl_->copy_from(tensor);
79+
if (num_shards_ == 0) {
80+
throw std::runtime_error("num_shards_ is 0 but entered new logic branch!");
81+
}
82+
83+
Size replica_size = tp_size_ / num_shards_;
84+
if (replica_size == 0) {
85+
throw std::runtime_error("replica_size is 0! tp_size_=" + std::to_string(tp_size_) + ", num_shards_=" + std::to_string(num_shards_));
86+
}
87+
88+
Size shard_id = tp_rank_ / replica_size;
89+
Size shard_size = impl_->size(tp_dim_);
90+
Size offset = shard_id * shard_size;
91+
92+
expected_shape[tp_dim_] *= num_shards_;
93+
94+
if (offset + shard_size > tensor->shape()[tp_dim_]) {
95+
throw std::runtime_error("Slice out of bounds! offset=" + std::to_string(offset) + ", shard_size=" + std::to_string(shard_size) + ", tensor_dim=" + std::to_string(tensor->shape()[tp_dim_]));
96+
}
97+
98+
impl_->copy_from(tensor->narrow({{tp_dim_, offset, shard_size}}));
6699
}
100+
67101
infinicore::context::syncStream();
68102
}
69103
} // namespace infinicore::nn

0 commit comments

Comments
 (0)