Skip to content

Commit 08e0d6a

Browse files
author
zhangyue
committed
style(ascend): apply clang-format to framework headers
1 parent 38e0330 commit 08e0d6a

7 files changed

Lines changed: 118 additions & 121 deletions

File tree

src/ascend/common.h

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,39 +18,37 @@ namespace infini::ops::ascend {
1818
// and Matmul to express a transpose via the view.
1919
inline aclTensor* buildAclTensor(const Tensor& t,
2020
bool transpose_last2 = false) {
21-
std::vector<int64_t> shape(t.shape().begin(), t.shape().end());
22-
std::vector<int64_t> strides(t.strides().begin(), t.strides().end());
23-
24-
if (transpose_last2 && shape.size() >= 2) {
25-
auto n = shape.size();
26-
std::swap(shape[n - 2], shape[n - 1]);
27-
std::swap(strides[n - 2], strides[n - 1]);
21+
std::vector<int64_t> shape(t.shape().begin(), t.shape().end());
22+
std::vector<int64_t> strides(t.strides().begin(), t.strides().end());
23+
24+
if (transpose_last2 && shape.size() >= 2) {
25+
auto n = shape.size();
26+
std::swap(shape[n - 2], shape[n - 1]);
27+
std::swap(strides[n - 2], strides[n - 1]);
28+
}
29+
30+
// Compute the minimum physical storage needed for this strided view.
31+
// For contiguous tensors this equals numel(); for non-contiguous (gapped)
32+
// tensors it may be larger; for broadcast (stride-0) tensors it may be
33+
// smaller. Passing the view shape as the storage shape causes
34+
// "ViewShape overlap" errors in ACLNN for non-contiguous inputs.
35+
int64_t storage_elems = 1;
36+
for (size_t i = 0; i < shape.size(); ++i) {
37+
if (shape[i] == 0) {
38+
storage_elems = 0;
39+
break;
2840
}
29-
30-
// Compute the minimum physical storage needed for this strided view.
31-
// For contiguous tensors this equals numel(); for non-contiguous (gapped)
32-
// tensors it may be larger; for broadcast (stride-0) tensors it may be
33-
// smaller. Passing the view shape as the storage shape causes
34-
// "ViewShape overlap" errors in ACLNN for non-contiguous inputs.
35-
int64_t storage_elems = 1;
36-
for (size_t i = 0; i < shape.size(); ++i) {
37-
if (shape[i] == 0) { storage_elems = 0; break; }
38-
if (strides[i] > 0 && shape[i] > 1) {
39-
storage_elems += static_cast<int64_t>(shape[i] - 1) * strides[i];
40-
}
41+
if (strides[i] > 0 && shape[i] > 1) {
42+
storage_elems += static_cast<int64_t>(shape[i] - 1) * strides[i];
4143
}
42-
std::vector<int64_t> storage_shape = {storage_elems};
43-
44-
return aclCreateTensor(
45-
shape.data(),
46-
static_cast<int64_t>(shape.size()),
47-
toAclDtype(t.dtype()),
48-
strides.data(),
49-
/*storageOffset=*/0,
50-
ACL_FORMAT_ND,
51-
storage_shape.data(),
52-
static_cast<int64_t>(storage_shape.size()),
53-
const_cast<void*>(t.data()));
44+
}
45+
std::vector<int64_t> storage_shape = {storage_elems};
46+
47+
return aclCreateTensor(
48+
shape.data(), static_cast<int64_t>(shape.size()), toAclDtype(t.dtype()),
49+
strides.data(),
50+
/*storageOffset=*/0, ACL_FORMAT_ND, storage_shape.data(),
51+
static_cast<int64_t>(storage_shape.size()), const_cast<void*>(t.data()));
5452
}
5553

5654
} // namespace infini::ops::ascend

src/ascend/data_type_.h

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,39 +10,50 @@
1010
namespace infini::ops::ascend {
1111

1212
inline aclDataType toAclDtype(DataType dt) {
13-
switch (dt) {
14-
case DataType::kFloat16: return ACL_FLOAT16;
15-
case DataType::kBFloat16: return ACL_BF16;
16-
case DataType::kFloat32: return ACL_FLOAT;
17-
case DataType::kInt8: return ACL_INT8;
18-
case DataType::kInt16: return ACL_INT16;
19-
case DataType::kInt32: return ACL_INT32;
20-
case DataType::kInt64: return ACL_INT64;
21-
case DataType::kUInt8: return ACL_UINT8;
22-
case DataType::kUInt16: return ACL_UINT16;
23-
case DataType::kUInt32: return ACL_UINT32;
24-
case DataType::kUInt64: return ACL_UINT64;
25-
default:
26-
assert(false && "unsupported dtype for Ascend backend");
27-
return ACL_DT_UNDEFINED;
28-
}
13+
switch (dt) {
14+
case DataType::kFloat16:
15+
return ACL_FLOAT16;
16+
case DataType::kBFloat16:
17+
return ACL_BF16;
18+
case DataType::kFloat32:
19+
return ACL_FLOAT;
20+
case DataType::kInt8:
21+
return ACL_INT8;
22+
case DataType::kInt16:
23+
return ACL_INT16;
24+
case DataType::kInt32:
25+
return ACL_INT32;
26+
case DataType::kInt64:
27+
return ACL_INT64;
28+
case DataType::kUInt8:
29+
return ACL_UINT8;
30+
case DataType::kUInt16:
31+
return ACL_UINT16;
32+
case DataType::kUInt32:
33+
return ACL_UINT32;
34+
case DataType::kUInt64:
35+
return ACL_UINT64;
36+
default:
37+
assert(false && "unsupported dtype for Ascend backend");
38+
return ACL_DT_UNDEFINED;
39+
}
2940
}
3041

3142
// Returns true for integer (signed or unsigned) DataType values.
3243
inline bool isIntegerDtype(DataType dt) {
33-
switch (dt) {
34-
case DataType::kInt8:
35-
case DataType::kInt16:
36-
case DataType::kInt32:
37-
case DataType::kInt64:
38-
case DataType::kUInt8:
39-
case DataType::kUInt16:
40-
case DataType::kUInt32:
41-
case DataType::kUInt64:
42-
return true;
43-
default:
44-
return false;
45-
}
44+
switch (dt) {
45+
case DataType::kInt8:
46+
case DataType::kInt16:
47+
case DataType::kInt32:
48+
case DataType::kInt64:
49+
case DataType::kUInt8:
50+
case DataType::kUInt16:
51+
case DataType::kUInt32:
52+
case DataType::kUInt64:
53+
return true;
54+
default:
55+
return false;
56+
}
4657
}
4758

4859
} // namespace infini::ops::ascend

src/ascend/workspace_pool_.h

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,42 +10,42 @@
1010
namespace infini::ops::ascend {
1111

1212
struct WorkspaceArena {
13-
void* buf = nullptr;
14-
uint64_t capacity = 0;
13+
void* buf = nullptr;
14+
uint64_t capacity = 0;
1515
};
1616

1717
class WorkspacePool {
1818
public:
19-
WorkspaceArena& ensure(aclrtStream stream, uint64_t needed) {
20-
std::lock_guard<std::mutex> lock(mutex_);
21-
auto& arena = arenas_[stream];
22-
if (needed <= arena.capacity) return arena;
23-
if (arena.capacity > 0) {
24-
aclrtSynchronizeStream(stream);
25-
aclrtFree(arena.buf);
26-
}
27-
if (needed > 0) {
28-
aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY);
29-
}
30-
arena.capacity = needed;
31-
return arena;
19+
WorkspaceArena& ensure(aclrtStream stream, uint64_t needed) {
20+
std::lock_guard<std::mutex> lock(mutex_);
21+
auto& arena = arenas_[stream];
22+
if (needed <= arena.capacity) return arena;
23+
if (arena.capacity > 0) {
24+
aclrtSynchronizeStream(stream);
25+
aclrtFree(arena.buf);
3226
}
27+
if (needed > 0) {
28+
aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY);
29+
}
30+
arena.capacity = needed;
31+
return arena;
32+
}
3333

34-
~WorkspacePool() {
35-
for (auto& [stream, arena] : arenas_) {
36-
if (arena.capacity > 0) aclrtFree(arena.buf);
37-
}
34+
~WorkspacePool() {
35+
for (auto& [stream, arena] : arenas_) {
36+
if (arena.capacity > 0) aclrtFree(arena.buf);
3837
}
38+
}
3939

4040
private:
41-
std::unordered_map<aclrtStream, WorkspaceArena> arenas_;
41+
std::unordered_map<aclrtStream, WorkspaceArena> arenas_;
4242

43-
std::mutex mutex_;
43+
std::mutex mutex_;
4444
};
4545

4646
inline WorkspacePool& workspacePool() {
47-
static WorkspacePool pool;
48-
return pool;
47+
static WorkspacePool pool;
48+
return pool;
4949
}
5050

5151
} // namespace infini::ops::ascend

src/base/add_rms_norm.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ class AddRmsNorm : public Operator<AddRmsNorm> {
2626
assert(x1.dtype() == x_out.dtype());
2727
}
2828

29-
virtual void operator()(const Tensor x1, const Tensor x2,
30-
const Tensor gamma, float eps, Tensor y_out,
31-
Tensor x_out) const = 0;
29+
virtual void operator()(const Tensor x1, const Tensor x2, const Tensor gamma,
30+
float eps, Tensor y_out, Tensor x_out) const = 0;
3231

3332
protected:
3433
Tensor::Shape input_shape_;

src/base/flash_attention.h

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,13 @@ namespace infini::ops {
1111

1212
class FlashAttention : public Operator<FlashAttention> {
1313
public:
14-
FlashAttention(
15-
const Tensor query, const Tensor key, const Tensor value,
16-
std::optional<Tensor> cu_seqlens_q,
17-
std::optional<Tensor> cu_seqlens_kv,
18-
std::optional<Tensor> block_table,
19-
int64_t num_heads, int64_t num_kv_heads, int64_t head_size,
20-
double scale,
21-
bool causal,
22-
int64_t window_left,
23-
int64_t window_right,
24-
int64_t block_size,
25-
Tensor output)
14+
FlashAttention(const Tensor query, const Tensor key, const Tensor value,
15+
std::optional<Tensor> cu_seqlens_q,
16+
std::optional<Tensor> cu_seqlens_kv,
17+
std::optional<Tensor> block_table, int64_t num_heads,
18+
int64_t num_kv_heads, int64_t head_size, double scale,
19+
bool causal, int64_t window_left, int64_t window_right,
20+
int64_t block_size, Tensor output)
2621
: num_tokens_{query.size(0)},
2722
num_heads_{num_heads},
2823
num_kv_heads_{num_kv_heads},
@@ -50,18 +45,15 @@ class FlashAttention : public Operator<FlashAttention> {
5045
"`FlashAttention` requires query to be 3D [T, N, D]");
5146
}
5247

53-
virtual void operator()(
54-
const Tensor query, const Tensor key, const Tensor value,
55-
std::optional<Tensor> cu_seqlens_q,
56-
std::optional<Tensor> cu_seqlens_kv,
57-
std::optional<Tensor> block_table,
58-
int64_t num_heads, int64_t num_kv_heads, int64_t head_size,
59-
double scale,
60-
bool causal,
61-
int64_t window_left,
62-
int64_t window_right,
63-
int64_t block_size,
64-
Tensor output) const = 0;
48+
virtual void operator()(const Tensor query, const Tensor key,
49+
const Tensor value,
50+
std::optional<Tensor> cu_seqlens_q,
51+
std::optional<Tensor> cu_seqlens_kv,
52+
std::optional<Tensor> block_table, int64_t num_heads,
53+
int64_t num_kv_heads, int64_t head_size, double scale,
54+
bool causal, int64_t window_left,
55+
int64_t window_right, int64_t block_size,
56+
Tensor output) const = 0;
6557

6658
protected:
6759
Tensor::Size num_tokens_{0};

src/base/matmul.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ class Matmul : public Operator<Matmul> {
1111
// trans_a / trans_b: if true, transpose the last two dims of a / b before
1212
// multiplying. These are constructor parameters so the CacheKey encodes
1313
// the transposition and distinct descriptors are cached for each combination.
14-
Matmul(const Tensor a, const Tensor b, Tensor c,
15-
bool trans_a, bool trans_b)
14+
Matmul(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b)
1615
: a_shape_{a.shape()},
1716
b_shape_{b.shape()},
1817
c_shape_{c.shape()},

src/base/reshape_and_cache.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@ namespace infini::ops {
1010

1111
class ReshapeAndCache : public Operator<ReshapeAndCache> {
1212
public:
13-
ReshapeAndCache(
14-
const Tensor key, const Tensor value,
15-
const Tensor kv_cache, const Tensor slot_mapping,
16-
Tensor kv_cache_out)
13+
ReshapeAndCache(const Tensor key, const Tensor value, const Tensor kv_cache,
14+
const Tensor slot_mapping, Tensor kv_cache_out)
1715
: num_tokens_{key.size(0)},
1816
num_kv_heads_{key.size(1)},
1917
head_size_{key.size(2)},
@@ -30,15 +28,15 @@ class ReshapeAndCache : public Operator<ReshapeAndCache> {
3028
assert(key.shape() == value.shape() &&
3129
"`ReshapeAndCache` requires key and value same shape");
3230
assert(kv_cache.ndim() == 5 &&
33-
"`ReshapeAndCache` requires kv_cache to be 5D [2, num_blocks, block_size, num_kv_heads, head_size]");
31+
"`ReshapeAndCache` requires kv_cache to be 5D [2, num_blocks, "
32+
"block_size, num_kv_heads, head_size]");
3433
assert(slot_mapping.ndim() == 1 &&
3534
"`ReshapeAndCache` requires slot_mapping to be 1D");
3635
}
3736

38-
virtual void operator()(
39-
const Tensor key, const Tensor value,
40-
const Tensor kv_cache, const Tensor slot_mapping,
41-
Tensor kv_cache_out) const = 0;
37+
virtual void operator()(const Tensor key, const Tensor value,
38+
const Tensor kv_cache, const Tensor slot_mapping,
39+
Tensor kv_cache_out) const = 0;
4240

4341
protected:
4442
Tensor::Size num_tokens_{0};

0 commit comments

Comments
 (0)