|
2 | 2 | #define INFINI_OPS_CAMBRICON_COMMON_H_ |
3 | 3 |
|
4 | 4 | #include <cnnl.h> |
| 5 | +#include <cnrt.h> |
5 | 6 |
|
6 | 7 | #include "data_type.h" |
| 8 | +#include "device.h" |
| 9 | + |
| 10 | +#define NRAM_MAX_SIZE (1024 * 240) |
| 11 | + |
| 12 | +#ifdef __BANG__ |
| 13 | + |
| 14 | +namespace infini::ops::reduce { |
| 15 | + |
| 16 | +constexpr int batch_size = 128 / sizeof(float); |
| 17 | + |
| 18 | +__mlu_func__ void SumInternal(float* dst, float* src, int max_batch) { |
| 19 | + const int width = max_batch / batch_size; |
| 20 | + |
| 21 | + if (width >= 4) { |
| 22 | + __bang_sumpool(dst, src, batch_size, 1, width, 1, width, 1, 1); |
| 23 | + __bang_reduce_sum(dst, dst, batch_size); |
| 24 | + } else { |
| 25 | + float sum = 0.0f; |
| 26 | + for (int i = 0; i < max_batch; ++i) { |
| 27 | + sum += src[i]; |
| 28 | + } |
| 29 | + dst[0] = sum; |
| 30 | + } |
| 31 | +} |
| 32 | + |
| 33 | +} // namespace infini::ops::reduce |
| 34 | + |
| 35 | +#endif // __BANG__ |
7 | 36 |
|
8 | 37 | namespace infini::ops::cnnl_utils { |
9 | 38 |
|
10 | 39 | inline cnnlDataType_t GetDataType(DataType dtype) { |
11 | 40 | switch (dtype) { |
| 41 | + case DataType::kInt8: |
| 42 | + return CNNL_DTYPE_INT8; |
| 43 | + case DataType::kUInt8: |
| 44 | + return CNNL_DTYPE_UINT8; |
12 | 45 | case DataType::kInt32: |
13 | 46 | return CNNL_DTYPE_INT32; |
| 47 | + case DataType::kInt64: |
| 48 | + return CNNL_DTYPE_INT64; |
14 | 49 | case DataType::kFloat16: |
15 | 50 | return CNNL_DTYPE_HALF; |
16 | 51 | case DataType::kFloat32: |
17 | 52 | return CNNL_DTYPE_FLOAT; |
| 53 | + case DataType::kBFloat16: |
| 54 | + return CNNL_DTYPE_BFLOAT16; |
| 55 | + case DataType::kFloat64: |
| 56 | + return CNNL_DTYPE_DOUBLE; |
18 | 57 | default: |
19 | 58 | return CNNL_DTYPE_INVALID; |
20 | 59 | } |
21 | 60 | } |
22 | 61 |
|
23 | 62 | } // namespace infini::ops::cnnl_utils |
24 | 63 |
|
| 64 | +namespace infini::ops::cnrt_utils { |
| 65 | + |
| 66 | +inline void GetLaunchConfig(const Device& device, int* core_per_cluster, |
| 67 | + int* cluster_count) { |
| 68 | + int device_id = device.index(); |
| 69 | + cnrtDeviceGetAttribute(cluster_count, cnrtAttrClusterCount, device_id); |
| 70 | + cnrtDeviceGetAttribute(core_per_cluster, cnrtAttrMcorePerCluster, device_id); |
| 71 | +} |
| 72 | + |
| 73 | +} // namespace infini::ops::cnrt_utils |
| 74 | + |
25 | 75 | #endif |
0 commit comments