@@ -10,21 +10,22 @@ Parameter::Parameter()
1010 : Tensor() {
1111}
1212
13- inline Shape get_partipion_shape_ (const Shape &shape, Size tp_dim, Size tp_size) {
13+ inline Shape get_partition_shape_ (const Shape &shape, Size tp_dim, Size tp_size, Size num_shards ) {
1414 if (tp_size <= 1 ) {
1515 return shape;
1616 }
1717 Shape part_shape = shape;
1818 if (tp_dim < shape.size ()) {
19- if (shape[tp_dim] % tp_size != 0 ) {
20- throw std::runtime_error (" Tensor dimension " + std::to_string (tp_dim) + " with size " + std::to_string (shape[tp_dim]) + " is not divisible by tensor parallel size " + std::to_string (tp_size) + " ." );
19+ Size partition_factor = (num_shards > 0 ) ? num_shards : tp_size;
20+ if (shape[tp_dim] % partition_factor != 0 ) {
21+ throw std::runtime_error (" Tensor dimension " + std::to_string (tp_dim) + " with size " + std::to_string (shape[tp_dim]) + " is not divisible by " + (num_shards > 0 ? " num_shards " : " tp_size " ) + std::to_string (partition_factor) + " ." );
2122 }
22- part_shape[tp_dim] = shape[tp_dim] / tp_size ;
23+ part_shape[tp_dim] = shape[tp_dim] / partition_factor ;
2324 }
2425 return part_shape;
2526}
2627
27- Parameter::Parameter (const Tensor &tensor, Size tp_dim, Size tp_rank, Size tp_size) : Tensor(tensor), tp_dim_(tp_dim), tp_rank_(tp_rank), tp_size_(tp_size) {
28+ Parameter::Parameter (const Tensor &tensor, Size tp_dim, Size tp_rank, Size tp_size, Size num_shards ) : Tensor(tensor), tp_dim_(tp_dim), tp_rank_(tp_rank), tp_size_(tp_size), num_shards_(num_shards ) {
2829 if (tp_rank_ >= tp_size_) {
2930 throw std::runtime_error (" Tensor parallel rank " + std::to_string (tp_rank_) + " must be less than tensor parallel size " + std::to_string (tp_size_) + " ." );
3031 }
@@ -36,10 +37,18 @@ Parameter::Parameter(
3637 const Device &device,
3738 Size tp_dim,
3839 Size tp_rank,
39- Size tp_size)
40- : Parameter(Tensor::empty(get_partipion_shape_(shape, tp_dim, tp_size), dtype, device, false ), tp_dim, tp_rank, tp_size) {
40+ Size tp_size,
41+ Size num_shards)
42+ : Parameter(Tensor::empty(get_partition_shape_(shape, tp_dim, tp_size, num_shards), dtype, device, false ), tp_dim, tp_rank, tp_size, num_shards) {
4143}
4244
45+ Parameter::Parameter (const Parameter &other)
46+ : Tensor(other),
47+ tp_dim_ (other.tp_dim_),
48+ tp_rank_(other.tp_rank_),
49+ tp_size_(other.tp_size_),
50+ num_shards_(other.num_shards_) {}
51+
4352void Parameter::load_blob (const void *data) {
4453 Shape expected_shape = Shape (impl_->shape ());
4554 expected_shape[tp_dim_] *= tp_size_;
@@ -49,21 +58,46 @@ void Parameter::load_blob(const void *data) {
4958}
5059
5160void Parameter::load (const Tensor &tensor) {
52- Shape expected_shape = Shape (impl_->shape ());
53- expected_shape[tp_dim_] *= tp_size_;
54-
55- if (expected_shape != tensor->shape ()) {
56- throw std::runtime_error (" Shape mismatch when loading tensor into parameter. Weight: " + impl_->info () + " , Tensor: " + tensor->info () + " ." );
57- }
5861 if (impl_->dtype () != tensor->dtype ()) {
5962 throw std::runtime_error (" Dtype mismatch when loading tensor into parameter. Weight: " + impl_->info () + " , Tensor: " + tensor->info () + " ." );
6063 }
61- if (tp_size_ > 1 ) {
62- impl_->copy_from (tensor->narrow ({{tp_dim_, tp_rank_ * impl_->size (tp_dim_), impl_->size (tp_dim_)}}));
6364
65+ Shape expected_shape = Shape (impl_->shape ());
66+
67+ if (num_shards_ == 0 || num_shards_ >= tp_size_) {
68+ expected_shape[tp_dim_] *= tp_size_;
69+
70+ if (expected_shape != tensor->shape ()) {
71+ throw std::runtime_error (" Shape mismatch when loading tensor into parameter. Weight: " + impl_->info () + " , Tensor: " + tensor->info () + " ." );
72+ }
73+ if (tp_size_ > 1 ) {
74+ impl_->copy_from (tensor->narrow ({{tp_dim_, tp_rank_ * impl_->size (tp_dim_), impl_->size (tp_dim_)}}));
75+ } else {
76+ impl_->copy_from (tensor);
77+ }
6478 } else {
65- impl_->copy_from (tensor);
79+ if (num_shards_ == 0 ) {
80+ throw std::runtime_error (" num_shards_ is 0 but entered new logic branch!" );
81+ }
82+
83+ Size replica_size = tp_size_ / num_shards_;
84+ if (replica_size == 0 ) {
85+ throw std::runtime_error (" replica_size is 0! tp_size_=" + std::to_string (tp_size_) + " , num_shards_=" + std::to_string (num_shards_));
86+ }
87+
88+ Size shard_id = tp_rank_ / replica_size;
89+ Size shard_size = impl_->size (tp_dim_);
90+ Size offset = shard_id * shard_size;
91+
92+ expected_shape[tp_dim_] *= num_shards_;
93+
94+ if (offset + shard_size > tensor->shape ()[tp_dim_]) {
95+ throw std::runtime_error (" Slice out of bounds! offset=" + std::to_string (offset) + " , shard_size=" + std::to_string (shard_size) + " , tensor_dim=" + std::to_string (tensor->shape ()[tp_dim_]));
96+ }
97+
98+ impl_->copy_from (tensor->narrow ({{tp_dim_, offset, shard_size}}));
6699 }
100+
67101 infinicore::context::syncStream ();
68102}
69103} // namespace infinicore::nn
0 commit comments