Skip to content

Commit 05096ea

Browse files
committed
Issue/888 - Refactor: integrate exp and hardswish operators into unified unary framework.
1 parent dcea337 commit 05096ea

19 files changed

Lines changed: 337 additions & 1113 deletions

File tree

include/infiniop/ops/exp.h

Lines changed: 0 additions & 24 deletions
This file was deleted.

include/infiniop/ops/hardswish.h

Lines changed: 0 additions & 24 deletions
This file was deleted.

include/infiniop/ops/unary_ops_api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,7 @@ UNARY_OP_API_DECLARE(erf, Erf)
3535
UNARY_OP_API_DECLARE(atan, Atan)
3636
UNARY_OP_API_DECLARE(acos, Acos)
3737
UNARY_OP_API_DECLARE(ceil, Ceil)
38+
UNARY_OP_API_DECLARE(exp, Exp)
39+
UNARY_OP_API_DECLARE(hardswish, Hardswish)
3840

3941
#endif // __INFINIOP_UNARY_OPS_API_H__

src/infiniop/elementwise/cpu/elementwise_cpu_impl.h

Lines changed: 107 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,74 @@
2525
* }
2626
*/
2727

28+
// =========================================================================
29+
// Internal Helpers (Private Macros to reduce duplication)
30+
// =========================================================================
31+
32+
/**
33+
* @brief Common Calculate Switch Cases (F16 & F32)
34+
*/
35+
#define _IMPL_CALC_CASES_COMMON \
36+
case INFINI_DTYPE_F16: \
37+
return _device_info->template calculate<Op, fp16_t>(_info, output, inputs, stream); \
38+
case INFINI_DTYPE_F32: \
39+
return _device_info->template calculate<Op, float>(_info, output, inputs, stream);
40+
2841
/**
29-
* @brief Macro to generate binary operator implementation.
42+
* @brief Extended Calculate Switch Cases (Adds F64 & BF16)
43+
*/
44+
#define _IMPL_CALC_CASES_EXTENDED \
45+
_IMPL_CALC_CASES_COMMON \
46+
case INFINI_DTYPE_F64: \
47+
return _device_info->template calculate<Op, double>(_info, output, inputs, stream); \
48+
case INFINI_DTYPE_BF16: \
49+
return _device_info->template calculate<Op, bf16_t>(_info, output, inputs, stream);
50+
51+
/**
52+
* @brief Generic Template for the Calculate method
53+
* @param CASES_MACRO The macro containing the switch cases to use
54+
*/
55+
#define _IMPL_CALCULATE_METHOD(CASES_MACRO) \
56+
infiniStatus_t Descriptor::calculate( \
57+
void *workspace, \
58+
size_t workspace_size, \
59+
void *output, \
60+
std::vector<const void *> inputs, \
61+
void *stream) const { \
62+
switch (_dtype) { \
63+
CASES_MACRO \
64+
default: \
65+
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
66+
} \
67+
}
68+
69+
/**
70+
* @brief Generic Template for the Create method
71+
* @param SHAPE_CHECK_BLOCK Code block to execute for shape checking
72+
* @param ... Variadic arguments for allowed data types in CHECK_DTYPE
73+
*/
74+
#define _IMPL_CREATE_METHOD(SHAPE_CHECK_BLOCK, ...) \
75+
Descriptor::~Descriptor() = default; \
76+
infiniStatus_t Descriptor::create( \
77+
infiniopHandle_t handle_, \
78+
Descriptor **desc_ptr, \
79+
infiniopTensorDescriptor_t out_desc, \
80+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) { \
81+
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); \
82+
auto dtype = out_desc->dtype(); \
83+
const auto &out_shape = out_desc->shape(); \
84+
SHAPE_CHECK_BLOCK \
85+
CHECK_DTYPE(dtype, __VA_ARGS__); \
86+
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); \
87+
return INFINI_STATUS_SUCCESS; \
88+
}
89+
90+
// =========================================================================
91+
// Public API Implementation Macros
92+
// =========================================================================
93+
94+
/**
95+
* @brief Implementation for Binary Operators (F16, F32)
3096
*
3197
* This macro generates the Descriptor destructor, create, and calculate methods
3298
* for binary operators, using the generic implementation.
@@ -37,48 +103,19 @@
37103
* ELEMENTWISE_CPU_IMPL_BINARY(pow)
38104
* }
39105
*/
40-
#define ELEMENTWISE_CPU_IMPL_BINARY(OP) \
41-
\
42-
Descriptor::~Descriptor() = default; \
43-
\
44-
infiniStatus_t Descriptor::create( \
45-
infiniopHandle_t handle_, \
46-
Descriptor **desc_ptr, \
47-
infiniopTensorDescriptor_t out_desc, \
48-
std::vector<infiniopTensorDescriptor_t> input_desc_vec) { \
49-
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); \
50-
auto dtype = out_desc->dtype(); \
51-
const auto &a_desc = input_desc_vec.at(0); \
52-
const auto &b_desc = input_desc_vec.at(1); \
53-
const auto &out_shape = out_desc->shape(); \
54-
const auto &a_shape = a_desc->shape(); \
55-
const auto &b_shape = b_desc->shape(); \
56-
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); \
57-
CHECK_SAME_SHAPE(out_shape, a_shape, b_shape); \
58-
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); \
59-
return INFINI_STATUS_SUCCESS; \
60-
} \
61-
\
62-
infiniStatus_t Descriptor::calculate( \
63-
void *workspace, \
64-
size_t workspace_size, \
65-
void *output, \
66-
std::vector<const void *> inputs, \
67-
void *stream) const { \
68-
switch (_dtype) { \
69-
case INFINI_DTYPE_F16: \
70-
return _device_info->template calculate<Op, fp16_t>( \
71-
_info, output, inputs, stream); \
72-
case INFINI_DTYPE_F32: \
73-
return _device_info->template calculate<Op, float>( \
74-
_info, output, inputs, stream); \
75-
default: \
76-
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
77-
} \
78-
}
106+
#define ELEMENTWISE_CPU_IMPL_BINARY(OP) \
107+
_IMPL_CREATE_METHOD( \
108+
const auto &a_desc = input_desc_vec.at(0); \
109+
const auto &b_desc = input_desc_vec.at(1); \
110+
const auto &a_shape = a_desc->shape(); \
111+
const auto &b_shape = b_desc->shape(); \
112+
CHECK_SAME_SHAPE(out_shape, a_shape, b_shape);, \
113+
INFINI_DTYPE_F16, INFINI_DTYPE_F32 \
114+
) \
115+
_IMPL_CALCULATE_METHOD(_IMPL_CALC_CASES_COMMON)
79116

80117
/**
81-
* @brief Macro to generate unary operator implementation.
118+
* @brief Implementation for Unary Operators (F16, F32)
82119
*
83120
* This macro generates the Descriptor destructor, create, and calculate methods
84121
* for unary operators, using the generic implementation.
@@ -89,42 +126,34 @@
89126
* ELEMENTWISE_CPU_IMPL_UNARY(sqrt)
90127
* }
91128
*/
92-
#define ELEMENTWISE_CPU_IMPL_UNARY(OP) \
93-
\
94-
Descriptor::~Descriptor() = default; \
95-
\
96-
infiniStatus_t Descriptor::create( \
97-
infiniopHandle_t handle_, \
98-
Descriptor **desc_ptr, \
99-
infiniopTensorDescriptor_t out_desc, \
100-
std::vector<infiniopTensorDescriptor_t> input_desc_vec) { \
101-
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); \
102-
auto dtype = out_desc->dtype(); \
103-
const auto &x_desc = input_desc_vec.at(0); \
104-
const auto &y_shape = out_desc->shape(); \
105-
const auto &x_shape = x_desc->shape(); \
106-
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); \
107-
CHECK_SAME_SHAPE(y_shape, x_shape); \
108-
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); \
109-
return INFINI_STATUS_SUCCESS; \
110-
} \
111-
\
112-
infiniStatus_t Descriptor::calculate( \
113-
void *workspace, \
114-
size_t workspace_size, \
115-
void *output, \
116-
std::vector<const void *> inputs, \
117-
void *stream) const { \
118-
switch (_dtype) { \
119-
case INFINI_DTYPE_F16: \
120-
return _device_info->template calculate<Op, fp16_t>( \
121-
_info, output, inputs, stream); \
122-
case INFINI_DTYPE_F32: \
123-
return _device_info->template calculate<Op, float>( \
124-
_info, output, inputs, stream); \
125-
default: \
126-
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
127-
} \
128-
}
129+
#define ELEMENTWISE_CPU_IMPL_UNARY(OP) \
130+
_IMPL_CREATE_METHOD( \
131+
const auto &x_desc = input_desc_vec.at(0); \
132+
const auto &x_shape = x_desc->shape(); \
133+
CHECK_SAME_SHAPE(out_shape, x_shape);, \
134+
INFINI_DTYPE_F16, INFINI_DTYPE_F32 \
135+
) \
136+
_IMPL_CALCULATE_METHOD(_IMPL_CALC_CASES_COMMON)
137+
138+
/**
139+
* @brief Implementation for Unary Operators Extended (F16, F32, F64, BF16)
140+
*
141+
* This macro generates the Descriptor destructor, create, and calculate methods
142+
* for unary operators supporting F16, F32, F64, and BF16 data types.
143+
*
144+
* Usage:
145+
* namespace op::exp::cpu {
146+
* using Op = op::elementwise::unary::UnaryOp<UnaryMode::Exp>;
147+
* ELEMENTWISE_CPU_IMPL_UNARY_EXTENDED(exp)
148+
* }
149+
*/
150+
#define ELEMENTWISE_CPU_IMPL_UNARY_EXTENDED(OP) \
151+
_IMPL_CREATE_METHOD( \
152+
const auto &x_desc = input_desc_vec.at(0); \
153+
const auto &x_shape = x_desc->shape(); \
154+
CHECK_SAME_SHAPE(out_shape, x_shape);, \
155+
INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16 \
156+
) \
157+
_IMPL_CALCULATE_METHOD(_IMPL_CALC_CASES_EXTENDED)
129158

130159
#endif // __INFINIOP_ELEMENTWISE_CPU_IMPL_H__

0 commit comments

Comments
 (0)