Skip to content

Commit 794b5aa

Browse files
zhuyuegongchensu
authored andcommitted
减少cpu opertaors部分的重复代码
1 parent 9846ab5 commit 794b5aa

28 files changed

Lines changed: 183 additions & 1060 deletions

File tree

src/infiniop/elementwise/binary.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ struct BinaryOp {
227227
ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \
228228
\
229229
namespace op::OP::NAMESPACE { \
230-
using Op = op::elementwise::binary::BinaryOp<MODE>; \
230+
using Op = op::elementwise::binary::BinaryOp<MODE>; \
231231
}
232232

233233
/**
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#ifndef __INFINIOP_ELEMENTWISE_CPU_IMPL_H__
2+
#define __INFINIOP_ELEMENTWISE_CPU_IMPL_H__
3+
4+
#include "elementwise_cpu.h"
5+
#include "../../devices/cpu/common_cpu.h"
6+
#include "../../../utils/check.h"
7+
#include "../../../utils/result.hpp"
8+
9+
/**
10+
* @brief Generic implementation for elementwise CPU operators.
11+
*
12+
* This file provides a generic implementation template that can be used
13+
* by all binary and unary operators to reduce code duplication.
14+
*
15+
* Usage:
16+
* #include "elementwise_cpu_impl.h"
17+
* namespace op::pow::cpu {
18+
* using Op = op::elementwise::binary::BinaryOp<BinaryMode::Pow>;
19+
* ELEMENTWISE_CPU_IMPL_BINARY(pow)
20+
* }
21+
*
22+
* namespace op::sqrt::cpu {
23+
* using Op = op::elementwise::unary::UnaryOp<UnaryMode::Sqrt>;
24+
* ELEMENTWISE_CPU_IMPL_UNARY(sqrt)
25+
* }
26+
*/
27+
28+
29+
/**
30+
* @brief Macro to generate binary operator implementation.
31+
*
32+
* This macro generates the Descriptor destructor, create, and calculate methods
33+
* for binary operators, using the generic implementation.
34+
*
35+
* Usage:
36+
* namespace op::pow::cpu {
37+
* using Op = op::elementwise::binary::BinaryOp<BinaryMode::Pow>;
38+
* ELEMENTWISE_CPU_IMPL_BINARY(pow)
39+
* }
40+
*/
41+
#define ELEMENTWISE_CPU_IMPL_BINARY(OP) \
42+
\
43+
Descriptor::~Descriptor() = default; \
44+
\
45+
infiniStatus_t Descriptor::create( \
46+
infiniopHandle_t handle_, \
47+
Descriptor **desc_ptr, \
48+
infiniopTensorDescriptor_t out_desc, \
49+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) { \
50+
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); \
51+
auto dtype = out_desc->dtype(); \
52+
const auto &a_desc = input_desc_vec.at(0); \
53+
const auto &b_desc = input_desc_vec.at(1); \
54+
const auto &out_shape = out_desc->shape(); \
55+
const auto &a_shape = a_desc->shape(); \
56+
const auto &b_shape = b_desc->shape(); \
57+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); \
58+
CHECK_SAME_SHAPE(out_shape, a_shape, b_shape); \
59+
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); \
60+
return INFINI_STATUS_SUCCESS; \
61+
} \
62+
\
63+
infiniStatus_t Descriptor::calculate( \
64+
void *workspace, \
65+
size_t workspace_size, \
66+
void *output, \
67+
std::vector<const void *> inputs, \
68+
void *stream) const { \
69+
switch (_dtype) { \
70+
case INFINI_DTYPE_F16: \
71+
return _device_info->template calculate<Op, fp16_t>( \
72+
_info, output, inputs, stream); \
73+
case INFINI_DTYPE_F32: \
74+
return _device_info->template calculate<Op, float>( \
75+
_info, output, inputs, stream); \
76+
default: \
77+
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
78+
} \
79+
}
80+
81+
/**
82+
* @brief Macro to generate unary operator implementation.
83+
*
84+
* This macro generates the Descriptor destructor, create, and calculate methods
85+
* for unary operators, using the generic implementation.
86+
*
87+
* Usage:
88+
* namespace op::sqrt::cpu {
89+
* using Op = op::elementwise::unary::UnaryOp<UnaryMode::Sqrt>;
90+
* ELEMENTWISE_CPU_IMPL_UNARY(sqrt)
91+
* }
92+
*/
93+
#define ELEMENTWISE_CPU_IMPL_UNARY(OP) \
94+
\
95+
Descriptor::~Descriptor() = default; \
96+
\
97+
infiniStatus_t Descriptor::create( \
98+
infiniopHandle_t handle_, \
99+
Descriptor **desc_ptr, \
100+
infiniopTensorDescriptor_t out_desc, \
101+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) { \
102+
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); \
103+
auto dtype = out_desc->dtype(); \
104+
const auto &x_desc = input_desc_vec.at(0); \
105+
const auto &y_shape = out_desc->shape(); \
106+
const auto &x_shape = x_desc->shape(); \
107+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); \
108+
CHECK_SAME_SHAPE(y_shape, x_shape); \
109+
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); \
110+
return INFINI_STATUS_SUCCESS; \
111+
} \
112+
\
113+
infiniStatus_t Descriptor::calculate( \
114+
void *workspace, \
115+
size_t workspace_size, \
116+
void *output, \
117+
std::vector<const void *> inputs, \
118+
void *stream) const { \
119+
switch (_dtype) { \
120+
case INFINI_DTYPE_F16: \
121+
return _device_info->template calculate<Op, fp16_t>( \
122+
_info, output, inputs, stream); \
123+
case INFINI_DTYPE_F32: \
124+
return _device_info->template calculate<Op, float>( \
125+
_info, output, inputs, stream); \
126+
default: \
127+
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
128+
} \
129+
}
130+
131+
#endif // __INFINIOP_ELEMENTWISE_CPU_IMPL_H__

src/infiniop/elementwise/unary.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ struct UnaryOp {
508508
ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \
509509
\
510510
namespace op::OP::NAMESPACE { \
511-
using Op = op::elementwise::unary::UnaryOp<MODE>; \
511+
using Op = op::elementwise::unary::UnaryOp<MODE>; \
512512
}
513513

514514
} // namespace op::elementwise::unary
Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,8 @@
11
#include "abs_cpu.h"
2+
#include "../../../elementwise/cpu/elementwise_cpu_impl.h"
23

34
namespace op::abs::cpu {
45

5-
Descriptor::~Descriptor() = default;
6+
ELEMENTWISE_CPU_IMPL_UNARY(abs)
67

7-
infiniStatus_t Descriptor::create(
8-
infiniopHandle_t handle_,
9-
Descriptor **desc_ptr,
10-
infiniopTensorDescriptor_t out_desc,
11-
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
12-
13-
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
14-
auto dtype = out_desc->dtype();
15-
16-
const auto &x_desc = input_desc_vec.at(0);
17-
const auto &y_shape = out_desc->shape();
18-
const auto &x_shape = x_desc->shape();
19-
20-
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
21-
22-
CHECK_SAME_SHAPE(y_shape, x_shape);
23-
24-
// create CPU elementwise descriptor
25-
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
26-
27-
return INFINI_STATUS_SUCCESS;
28-
}
29-
30-
infiniStatus_t Descriptor::calculate(
31-
void *workspace,
32-
size_t workspace_size,
33-
void *output,
34-
std::vector<const void *> inputs,
35-
void *stream) const {
36-
37-
switch (_dtype) {
38-
case INFINI_DTYPE_F16:
39-
return _device_info->calculate<Op, fp16_t>(_info, output, inputs, stream);
40-
case INFINI_DTYPE_F32:
41-
return _device_info->calculate<Op, float>(_info, output, inputs, stream);
42-
default:
43-
return INFINI_STATUS_BAD_TENSOR_DTYPE;
44-
}
45-
46-
return INFINI_STATUS_SUCCESS;
47-
}
488
} // namespace op::abs::cpu
Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,8 @@
11
#include "acos_cpu.h"
2+
#include "../../../elementwise/cpu/elementwise_cpu_impl.h"
23

34
namespace op::acos::cpu {
45

5-
Descriptor::~Descriptor() = default;
6+
ELEMENTWISE_CPU_IMPL_UNARY(acos)
67

7-
infiniStatus_t Descriptor::create(
8-
infiniopHandle_t handle_,
9-
Descriptor **desc_ptr,
10-
infiniopTensorDescriptor_t out_desc,
11-
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
12-
13-
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
14-
auto dtype = out_desc->dtype();
15-
16-
const auto &x_desc = input_desc_vec.at(0);
17-
const auto &y_shape = out_desc->shape();
18-
const auto &x_shape = x_desc->shape();
19-
20-
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
21-
22-
CHECK_SAME_SHAPE(y_shape, x_shape);
23-
24-
// create CPU elementwise descriptor
25-
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
26-
27-
return INFINI_STATUS_SUCCESS;
28-
}
29-
30-
infiniStatus_t Descriptor::calculate(
31-
void *workspace,
32-
size_t workspace_size,
33-
void *output,
34-
std::vector<const void *> inputs,
35-
void *stream) const {
36-
37-
switch (_dtype) {
38-
case INFINI_DTYPE_F16:
39-
return _device_info->calculate<Op, fp16_t>(_info, output, inputs, stream);
40-
case INFINI_DTYPE_F32:
41-
return _device_info->calculate<Op, float>(_info, output, inputs, stream);
42-
default:
43-
return INFINI_STATUS_BAD_TENSOR_DTYPE;
44-
}
45-
46-
return INFINI_STATUS_SUCCESS;
47-
}
488
} // namespace op::acos::cpu
Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,8 @@
11
#include "acosh_cpu.h"
2+
#include "../../../elementwise/cpu/elementwise_cpu_impl.h"
23

34
namespace op::acosh::cpu {
45

5-
Descriptor::~Descriptor() = default;
6+
ELEMENTWISE_CPU_IMPL_UNARY(acosh)
67

7-
infiniStatus_t Descriptor::create(
8-
infiniopHandle_t handle_,
9-
Descriptor **desc_ptr,
10-
infiniopTensorDescriptor_t out_desc,
11-
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
12-
13-
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
14-
auto dtype = out_desc->dtype();
15-
16-
const auto &x_desc = input_desc_vec.at(0);
17-
const auto &y_shape = out_desc->shape();
18-
const auto &x_shape = x_desc->shape();
19-
20-
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
21-
22-
CHECK_SAME_SHAPE(y_shape, x_shape);
23-
24-
// create CPU elementwise descriptor
25-
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
26-
27-
return INFINI_STATUS_SUCCESS;
28-
}
29-
30-
infiniStatus_t Descriptor::calculate(
31-
void *workspace,
32-
size_t workspace_size,
33-
void *output,
34-
std::vector<const void *> inputs,
35-
void *stream) const {
36-
37-
switch (_dtype) {
38-
case INFINI_DTYPE_F16:
39-
return _device_info->calculate<Op, fp16_t>(_info, output, inputs, stream);
40-
case INFINI_DTYPE_F32:
41-
return _device_info->calculate<Op, float>(_info, output, inputs, stream);
42-
default:
43-
return INFINI_STATUS_BAD_TENSOR_DTYPE;
44-
}
45-
46-
return INFINI_STATUS_SUCCESS;
47-
}
488
} // namespace op::acosh::cpu
Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,8 @@
11
#include "asin_cpu.h"
2+
#include "../../../elementwise/cpu/elementwise_cpu_impl.h"
23

34
namespace op::asin::cpu {
45

5-
Descriptor::~Descriptor() = default;
6+
ELEMENTWISE_CPU_IMPL_UNARY(asin)
67

7-
infiniStatus_t Descriptor::create(
8-
infiniopHandle_t handle_,
9-
Descriptor **desc_ptr,
10-
infiniopTensorDescriptor_t out_desc,
11-
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
12-
13-
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
14-
auto dtype = out_desc->dtype();
15-
16-
const auto &x_desc = input_desc_vec.at(0);
17-
const auto &y_shape = out_desc->shape();
18-
const auto &x_shape = x_desc->shape();
19-
20-
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
21-
22-
CHECK_SAME_SHAPE(y_shape, x_shape);
23-
24-
// create CPU elementwise descriptor
25-
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
26-
27-
return INFINI_STATUS_SUCCESS;
28-
}
29-
30-
infiniStatus_t Descriptor::calculate(
31-
void *workspace,
32-
size_t workspace_size,
33-
void *output,
34-
std::vector<const void *> inputs,
35-
void *stream) const {
36-
37-
switch (_dtype) {
38-
case INFINI_DTYPE_F16:
39-
return _device_info->calculate<Op, fp16_t>(_info, output, inputs, stream);
40-
case INFINI_DTYPE_F32:
41-
return _device_info->calculate<Op, float>(_info, output, inputs, stream);
42-
default:
43-
return INFINI_STATUS_BAD_TENSOR_DTYPE;
44-
}
45-
46-
return INFINI_STATUS_SUCCESS;
47-
}
488
} // namespace op::asin::cpu

0 commit comments

Comments
 (0)