11#pragma once
22
33#include " ../ops.hpp"
4+ #include " ../quantization.hpp"
45#include " module.hpp"
56#include < infiniccl.h>
7+ #include < optional>
68
79namespace infinicore ::nn {
810
@@ -11,6 +13,9 @@ class BaseLinear : public Module {
1113 BaseLinear (size_t in_features, size_t out_features, bool bias = true ,
1214 const DataType &dtype = DataType::F32, const Device &device = Device());
1315
16+ BaseLinear (size_t in_features, size_t out_features, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias = true ,
17+ const DataType &dtype = DataType::F32, const Device &device = Device());
18+
1419 // Forward pass: output = input @ weight.T + bias
1520 Tensor forward (Tensor &input) const ;
1621
@@ -27,12 +32,17 @@ class BaseLinear : public Module {
2732 // Accessors for parameters
2833 Tensor weight () const { return weight_; }
2934 Tensor bias () const { return bias_; }
35+ Tensor weight_scale () const { return weight_scale_; }
36+ Tensor weight_zeros () const { return weight_zeros_; }
3037
3138protected:
3239 // Parameters
3340 INFINICORE_NN_PARAMETER (weight);
3441 INFINICORE_NN_PARAMETER (bias);
3542
43+ INFINICORE_NN_PARAMETER (weight_scale);
44+ INFINICORE_NN_PARAMETER (weight_zeros);
45+
3646protected:
3747 // Helper method for common forward computation
3848 Tensor compute_linear (Tensor &input) const ;
@@ -41,6 +51,7 @@ class BaseLinear : public Module {
4151 size_t out_features_;
4252 bool has_bias_;
4353 DataType dtype_;
54+ std::shared_ptr<infinicore::quantization::BaseQuantization> quantization_ = std::make_shared<infinicore::quantization::NoneQuantization>(nullptr );
4455};
4556
4657} // namespace infinicore::nn
@@ -52,6 +63,9 @@ class Linear : public BaseLinear {
5263 Linear (size_t in_features, size_t out_features, bool bias = true ,
5364 const DataType &dtype = DataType::F32, const Device &device = Device());
5465
66+ Linear (size_t in_features, size_t out_features, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias = true ,
67+ const DataType &dtype = DataType::F32, const Device &device = Device());
68+
5569 // Forward pass: output = input @ weight.T + bias
5670 Tensor forward (Tensor &input) const ;
5771
@@ -65,6 +79,10 @@ class ColumnParallelLinear : public BaseLinear {
6579 const DataType &dtype = DataType::F32, const Device &device = Device(),
6680 Size tp_rank = 0 , Size tp_size = 1 );
6781
82+ ColumnParallelLinear (size_t in_features, size_t out_features, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias = true ,
83+ const DataType &dtype = DataType::F32, const Device &device = Device(),
84+ Size tp_rank = 0 , Size tp_size = 1 );
85+
6886 // Forward pass: output = input @ weight.T + bias
6987 Tensor forward (Tensor &input) const ;
7088
@@ -82,6 +100,10 @@ class RowParallelLinear : public BaseLinear {
82100 const DataType &dtype = DataType::F32, const Device &device = Device(),
83101 Size tp_rank = 0 , Size tp_size = 1 , infinicclComm_t communicator = nullptr );
84102
103+ RowParallelLinear (size_t in_features, size_t out_features, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias = true ,
104+ const DataType &dtype = DataType::F32, const Device &device = Device(),
105+ Size tp_rank = 0 , Size tp_size = 1 , infinicclComm_t communicator = nullptr );
106+
85107 // Forward pass: output = input @ weight.T + bias
86108 Tensor forward (Tensor &input) const ;
87109
0 commit comments