|
25 | 25 | * } |
26 | 26 | */ |
27 | 27 |
|
| 28 | +// ========================================================================= |
| 29 | +// Internal Helpers (Private Macros to reduce duplication) |
| 30 | +// ========================================================================= |
| 31 | + |
| 32 | +/** |
| 33 | + * @brief Common Calculate Switch Cases (F16 & F32) |
| 34 | + */ |
| 35 | +#define _IMPL_CALC_CASES_COMMON \ |
| 36 | + case INFINI_DTYPE_F16: \ |
| 37 | + return _device_info->template calculate<Op, fp16_t>(_info, output, inputs, stream); \ |
| 38 | + case INFINI_DTYPE_F32: \ |
| 39 | + return _device_info->template calculate<Op, float>(_info, output, inputs, stream); |
| 40 | + |
28 | 41 | /** |
29 | | - * @brief Macro to generate binary operator implementation. |
| 42 | + * @brief Extended Calculate Switch Cases (Adds F64 & BF16) |
| 43 | + */ |
| 44 | +#define _IMPL_CALC_CASES_EXTENDED \ |
| 45 | + _IMPL_CALC_CASES_COMMON \ |
| 46 | + case INFINI_DTYPE_F64: \ |
| 47 | + return _device_info->template calculate<Op, double>(_info, output, inputs, stream); \ |
| 48 | + case INFINI_DTYPE_BF16: \ |
| 49 | + return _device_info->template calculate<Op, bf16_t>(_info, output, inputs, stream); |
| 50 | + |
| 51 | +/** |
| 52 | + * @brief Generic Template for the Calculate method |
| 53 | + * @param CASES_MACRO The macro containing the switch cases to use |
| 54 | + */ |
| 55 | +#define _IMPL_CALCULATE_METHOD(CASES_MACRO) \ |
| 56 | + infiniStatus_t Descriptor::calculate( \ |
| 57 | + void *workspace, \ |
| 58 | + size_t workspace_size, \ |
| 59 | + void *output, \ |
| 60 | + std::vector<const void *> inputs, \ |
| 61 | + void *stream) const { \ |
| 62 | + switch (_dtype) { \ |
| 63 | + CASES_MACRO \ |
| 64 | + default: \ |
| 65 | + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ |
| 66 | + } \ |
| 67 | + } |
| 68 | + |
| 69 | +/** |
| 70 | + * @brief Generic Template for the Create method |
| 71 | + * @param SHAPE_CHECK_BLOCK Code block to execute for shape checking |
| 72 | + * @param ... Variadic arguments for allowed data types in CHECK_DTYPE |
| 73 | + */ |
| 74 | +#define _IMPL_CREATE_METHOD(SHAPE_CHECK_BLOCK, ...) \ |
| 75 | + Descriptor::~Descriptor() = default; \ |
| 76 | + infiniStatus_t Descriptor::create( \ |
| 77 | + infiniopHandle_t handle_, \ |
| 78 | + Descriptor **desc_ptr, \ |
| 79 | + infiniopTensorDescriptor_t out_desc, \ |
| 80 | + std::vector<infiniopTensorDescriptor_t> input_desc_vec) { \ |
| 81 | + auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); \ |
| 82 | + auto dtype = out_desc->dtype(); \ |
| 83 | + const auto &out_shape = out_desc->shape(); \ |
| 84 | + SHAPE_CHECK_BLOCK \ |
| 85 | + CHECK_DTYPE(dtype, __VA_ARGS__); \ |
| 86 | + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); \ |
| 87 | + return INFINI_STATUS_SUCCESS; \ |
| 88 | + } |
| 89 | + |
| 90 | +// ========================================================================= |
| 91 | +// Public API Implementation Macros |
| 92 | +// ========================================================================= |
| 93 | + |
| 94 | +/** |
| 95 | + * @brief Implementation for Binary Operators (F16, F32) |
30 | 96 | * |
31 | 97 | * This macro generates the Descriptor destructor, create, and calculate methods |
32 | 98 | * for binary operators, using the generic implementation. |
|
37 | 103 | * ELEMENTWISE_CPU_IMPL_BINARY(pow) |
38 | 104 | * } |
39 | 105 | */ |
40 | | -#define ELEMENTWISE_CPU_IMPL_BINARY(OP) \ |
41 | | - \ |
42 | | - Descriptor::~Descriptor() = default; \ |
43 | | - \ |
44 | | - infiniStatus_t Descriptor::create( \ |
45 | | - infiniopHandle_t handle_, \ |
46 | | - Descriptor **desc_ptr, \ |
47 | | - infiniopTensorDescriptor_t out_desc, \ |
48 | | - std::vector<infiniopTensorDescriptor_t> input_desc_vec) { \ |
49 | | - auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); \ |
50 | | - auto dtype = out_desc->dtype(); \ |
51 | | - const auto &a_desc = input_desc_vec.at(0); \ |
52 | | - const auto &b_desc = input_desc_vec.at(1); \ |
53 | | - const auto &out_shape = out_desc->shape(); \ |
54 | | - const auto &a_shape = a_desc->shape(); \ |
55 | | - const auto &b_shape = b_desc->shape(); \ |
56 | | - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); \ |
57 | | - CHECK_SAME_SHAPE(out_shape, a_shape, b_shape); \ |
58 | | - CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); \ |
59 | | - return INFINI_STATUS_SUCCESS; \ |
60 | | - } \ |
61 | | - \ |
62 | | - infiniStatus_t Descriptor::calculate( \ |
63 | | - void *workspace, \ |
64 | | - size_t workspace_size, \ |
65 | | - void *output, \ |
66 | | - std::vector<const void *> inputs, \ |
67 | | - void *stream) const { \ |
68 | | - switch (_dtype) { \ |
69 | | - case INFINI_DTYPE_F16: \ |
70 | | - return _device_info->template calculate<Op, fp16_t>( \ |
71 | | - _info, output, inputs, stream); \ |
72 | | - case INFINI_DTYPE_F32: \ |
73 | | - return _device_info->template calculate<Op, float>( \ |
74 | | - _info, output, inputs, stream); \ |
75 | | - default: \ |
76 | | - return INFINI_STATUS_BAD_TENSOR_DTYPE; \ |
77 | | - } \ |
78 | | - } |
| 106 | +#define ELEMENTWISE_CPU_IMPL_BINARY(OP) \ |
| 107 | + _IMPL_CREATE_METHOD( \ |
| 108 | + const auto &a_desc = input_desc_vec.at(0); \ |
| 109 | + const auto &b_desc = input_desc_vec.at(1); \ |
| 110 | + const auto &a_shape = a_desc->shape(); \ |
| 111 | + const auto &b_shape = b_desc->shape(); \ |
| 112 | + CHECK_SAME_SHAPE(out_shape, a_shape, b_shape);, \ |
| 113 | + INFINI_DTYPE_F16, INFINI_DTYPE_F32 \ |
| 114 | + ) \ |
| 115 | + _IMPL_CALCULATE_METHOD(_IMPL_CALC_CASES_COMMON) |
79 | 116 |
|
80 | 117 | /** |
81 | | - * @brief Macro to generate unary operator implementation. |
| 118 | + * @brief Implementation for Unary Operators (F16, F32) |
82 | 119 | * |
83 | 120 | * This macro generates the Descriptor destructor, create, and calculate methods |
84 | 121 | * for unary operators, using the generic implementation. |
|
89 | 126 | * ELEMENTWISE_CPU_IMPL_UNARY(sqrt) |
90 | 127 | * } |
91 | 128 | */ |
92 | | -#define ELEMENTWISE_CPU_IMPL_UNARY(OP) \ |
93 | | - \ |
94 | | - Descriptor::~Descriptor() = default; \ |
95 | | - \ |
96 | | - infiniStatus_t Descriptor::create( \ |
97 | | - infiniopHandle_t handle_, \ |
98 | | - Descriptor **desc_ptr, \ |
99 | | - infiniopTensorDescriptor_t out_desc, \ |
100 | | - std::vector<infiniopTensorDescriptor_t> input_desc_vec) { \ |
101 | | - auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); \ |
102 | | - auto dtype = out_desc->dtype(); \ |
103 | | - const auto &x_desc = input_desc_vec.at(0); \ |
104 | | - const auto &y_shape = out_desc->shape(); \ |
105 | | - const auto &x_shape = x_desc->shape(); \ |
106 | | - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); \ |
107 | | - CHECK_SAME_SHAPE(y_shape, x_shape); \ |
108 | | - CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); \ |
109 | | - return INFINI_STATUS_SUCCESS; \ |
110 | | - } \ |
111 | | - \ |
112 | | - infiniStatus_t Descriptor::calculate( \ |
113 | | - void *workspace, \ |
114 | | - size_t workspace_size, \ |
115 | | - void *output, \ |
116 | | - std::vector<const void *> inputs, \ |
117 | | - void *stream) const { \ |
118 | | - switch (_dtype) { \ |
119 | | - case INFINI_DTYPE_F16: \ |
120 | | - return _device_info->template calculate<Op, fp16_t>( \ |
121 | | - _info, output, inputs, stream); \ |
122 | | - case INFINI_DTYPE_F32: \ |
123 | | - return _device_info->template calculate<Op, float>( \ |
124 | | - _info, output, inputs, stream); \ |
125 | | - default: \ |
126 | | - return INFINI_STATUS_BAD_TENSOR_DTYPE; \ |
127 | | - } \ |
128 | | - } |
| 129 | +#define ELEMENTWISE_CPU_IMPL_UNARY(OP) \ |
| 130 | + _IMPL_CREATE_METHOD( \ |
| 131 | + const auto &x_desc = input_desc_vec.at(0); \ |
| 132 | + const auto &x_shape = x_desc->shape(); \ |
| 133 | + CHECK_SAME_SHAPE(out_shape, x_shape);, \ |
| 134 | + INFINI_DTYPE_F16, INFINI_DTYPE_F32 \ |
| 135 | + ) \ |
| 136 | + _IMPL_CALCULATE_METHOD(_IMPL_CALC_CASES_COMMON) |
| 137 | + |
| 138 | +/** |
| 139 | + * @brief Implementation for Unary Operators Extended (F16, F32, F64, BF16) |
| 140 | + * |
| 141 | + * This macro generates the Descriptor destructor, create, and calculate methods |
| 142 | + * for unary operators supporting F16, F32, F64, and BF16 data types. |
| 143 | + * |
| 144 | + * Usage: |
| 145 | + * namespace op::exp::cpu { |
| 146 | + * using Op = op::elementwise::unary::UnaryOp<UnaryMode::Exp>; |
| 147 | + * ELEMENTWISE_CPU_IMPL_UNARY_EXTENDED(exp) |
| 148 | + * } |
| 149 | + */ |
| 150 | +#define ELEMENTWISE_CPU_IMPL_UNARY_EXTENDED(OP) \ |
| 151 | + _IMPL_CREATE_METHOD( \ |
| 152 | + const auto &x_desc = input_desc_vec.at(0); \ |
| 153 | + const auto &x_shape = x_desc->shape(); \ |
| 154 | + CHECK_SAME_SHAPE(out_shape, x_shape);, \ |
| 155 | + INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16 \ |
| 156 | + ) \ |
| 157 | + _IMPL_CALCULATE_METHOD(_IMPL_CALC_CASES_EXTENDED) |
129 | 158 |
|
130 | 159 | #endif // __INFINIOP_ELEMENTWISE_CPU_IMPL_H__ |
0 commit comments