-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathcommon.h
More file actions
56 lines (47 loc) · 1.89 KB
/
common.h
File metadata and controls
56 lines (47 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#ifndef INFINI_OPS_ASCEND_COMMON_H_
#define INFINI_OPS_ASCEND_COMMON_H_
#include <cstdint>
#include <vector>
#include "acl/acl.h"
#include "aclnn/acl_meta.h"
#include "ascend/data_type_.h"
#include "tensor.h"
namespace infini::ops::ascend {
// Build an aclTensor descriptor from an InfiniOps Tensor.
//
// When `transpose_last2` is true the last two dimensions are swapped in the
// descriptor (shape and strides) without copying data. This is used by GEMM
// and Matmul to express a transpose via the view.
inline aclTensor* buildAclTensor(const Tensor& t,
bool transpose_last2 = false) {
std::vector<int64_t> shape(t.shape().begin(), t.shape().end());
std::vector<int64_t> strides(t.strides().begin(), t.strides().end());
if (transpose_last2 && shape.size() >= 2) {
auto n = shape.size();
std::swap(shape[n - 2], shape[n - 1]);
std::swap(strides[n - 2], strides[n - 1]);
}
// Compute the minimum physical storage needed for this strided view.
// For contiguous tensors this equals numel(); for non-contiguous (gapped)
// tensors it may be larger; for broadcast (stride-0) tensors it may be
// smaller. Passing the view shape as the storage shape causes
// "ViewShape overlap" errors in ACLNN for non-contiguous inputs.
int64_t storage_elems = 1;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == 0) {
storage_elems = 0;
break;
}
if (strides[i] > 0 && shape[i] > 1) {
storage_elems += static_cast<int64_t>(shape[i] - 1) * strides[i];
}
}
std::vector<int64_t> storage_shape = {storage_elems};
return aclCreateTensor(
shape.data(), static_cast<int64_t>(shape.size()), toAclDtype(t.dtype()),
strides.data(),
/*storageOffset=*/0, ACL_FORMAT_ND, storage_shape.data(),
static_cast<int64_t>(storage_shape.size()), const_cast<void*>(t.data()));
}
} // namespace infini::ops::ascend
#endif