Skip to content

Commit 4833eb9

Browse files
author
zhangyue
committed
feat(ascend): add GEMM kernel, NPU test infra, and example integration
- Add Ascend GEMM specialization using `aclnnAddmm`/`aclnnBaddbmm`. - Add `get_npu_stream()` helper and NPU device detection in test utils. - Add `skip_unsupported_dtype` fixture for Ascend in conftest. - Update `runtime_api.h` with Ascend backend entry.
1 parent 88a4379 commit 4833eb9

5 files changed

Lines changed: 135 additions & 11 deletions

File tree

examples/runtime_api.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#elif WITH_MOORE
2020
#include "moore/gemm/mublas.h"
2121
#include "moore/runtime_.h"
22+
#elif WITH_ASCEND
23+
#include "ascend/gemm/kernel.h"
24+
#include "ascend/runtime_.h"
2225
#elif WITH_CPU
2326
#include "cpu/gemm/gemm.h"
2427
#include "cpu/runtime_.h"
@@ -38,6 +41,8 @@ using DefaultRuntimeUtils = Runtime<Device::Type::kMetax>;
3841
using DefaultRuntimeUtils = Runtime<Device::Type::kCambricon>;
3942
#elif WITH_MOORE
4043
using DefaultRuntimeUtils = Runtime<Device::Type::kMoore>;
44+
#elif WITH_ASCEND
45+
using DefaultRuntimeUtils = Runtime<Device::Type::kAscend>;
4146
#elif WITH_CPU
4247
using DefaultRuntimeUtils = Runtime<Device::Type::kCpu>;
4348
#endif

src/ascend/gemm/kernel.h

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#ifndef INFINI_OPS_ASCEND_GEMM_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_GEMM_KERNEL_H_
3+
4+
#include "acl/acl.h"
5+
#include "aclnn/aclnn_base.h"
6+
#include "aclnnop/aclnn_addmm.h"
7+
#include "aclnnop/aclnn_baddbmm.h"
8+
#include "ascend/common.h"
9+
#include "ascend/workspace_pool_.h"
10+
#include "base/gemm.h"
11+
#include "operator.h"
12+
13+
namespace infini::ops {
14+
15+
template <>
16+
class Operator<Gemm, Device::Type::kAscend> : public Gemm {
17+
public:
18+
Operator(const Tensor a, const Tensor b, std::optional<float> alpha,
19+
std::optional<float> beta, std::optional<int> trans_a,
20+
std::optional<int> trans_b, Tensor c)
21+
: Gemm(a, b, alpha, beta, trans_a, trans_b, c),
22+
batched_{batch_count_ > 1},
23+
alpha_val_{alpha.value_or(1.0f)},
24+
beta_val_{beta.value_or(1.0f)} {
25+
alpha_scalar_ = aclCreateScalar(&alpha_val_, ACL_FLOAT);
26+
beta_scalar_ = aclCreateScalar(&beta_val_, ACL_FLOAT);
27+
}
28+
29+
~Operator() {
30+
aclDestroyScalar(alpha_scalar_);
31+
aclDestroyScalar(beta_scalar_);
32+
}
33+
34+
void operator()(const Tensor a, const Tensor b, std::optional<float> alpha,
35+
std::optional<float> beta, std::optional<int> trans_a,
36+
std::optional<int> trans_b, Tensor c) const override {
37+
auto stream = static_cast<aclrtStream>(stream_);
38+
39+
auto t_self = ascend::buildAclTensor(c);
40+
auto t_a = ascend::buildAclTensor(a, trans_a_);
41+
auto t_b = ascend::buildAclTensor(b, trans_b_);
42+
auto t_out = ascend::buildAclTensor(c);
43+
44+
uint64_t ws_needed = 0;
45+
aclOpExecutor* executor = nullptr;
46+
47+
if (batched_) {
48+
aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_,
49+
alpha_scalar_, t_out, 0, &ws_needed,
50+
&executor);
51+
} else {
52+
aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, alpha_scalar_,
53+
t_out, 0, &ws_needed, &executor);
54+
}
55+
56+
auto& arena = ascend::workspacePool().ensure(stream, ws_needed);
57+
58+
if (batched_) {
59+
aclnnBaddbmm(arena.buf, ws_needed, executor, stream);
60+
} else {
61+
aclnnAddmm(arena.buf, ws_needed, executor, stream);
62+
}
63+
64+
aclDestroyTensor(t_self);
65+
aclDestroyTensor(t_a);
66+
aclDestroyTensor(t_b);
67+
aclDestroyTensor(t_out);
68+
}
69+
70+
private:
71+
bool batched_;
72+
float alpha_val_;
73+
float beta_val_;
74+
aclScalar* alpha_scalar_ = nullptr;
75+
aclScalar* beta_scalar_ = nullptr;
76+
};
77+
78+
} // namespace infini::ops
79+
80+
#endif

tests/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,25 @@ def set_seed_per_test(request):
3838
_set_random_seed(_hash(_test_case_path_from_request(request)))
3939

4040

41+
_NPU_UNSUPPORTED_DTYPES = {torch.float64}
42+
43+
# torch_npu does not implement random number generation for uint16/uint32/uint64.
44+
for _bits in (16, 32, 64):
45+
_t = getattr(torch, f"uint{_bits}", None)
46+
if _t is not None:
47+
_NPU_UNSUPPORTED_DTYPES.add(_t)
48+
49+
50+
@pytest.fixture(autouse=True)
51+
def skip_unsupported_dtype(request):
52+
if not hasattr(request.node, "callspec"):
53+
return
54+
params = request.node.callspec.params
55+
56+
if params.get("device") == "npu" and params.get("dtype") in _NPU_UNSUPPORTED_DTYPES:
57+
pytest.skip(f"{params['dtype']} not supported on Ascend 910B")
58+
59+
4160
def _set_random_seed(seed):
4261
random.seed(seed)
4362
torch.manual_seed(seed)

tests/test_gemm.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
import torch
44

5-
from tests.utils import Payload, randn_strided
5+
from tests.utils import Payload, get_npu_stream, randn_strided
66

77

88
@pytest.mark.auto_act_and_assert
@@ -84,16 +84,22 @@ def test_gemm(
8484

8585

8686
def _gemm(a, b, alpha, beta, trans_a, trans_b, c, implementation_index=0):
87-
infini.ops.gemm(
88-
a,
89-
b,
90-
alpha,
91-
beta,
92-
trans_a,
93-
trans_b,
94-
c,
95-
implementation_index=implementation_index,
96-
)
87+
if a.device.type == "npu":
88+
infini.ops.gemm(
89+
a, b, alpha, beta, trans_a, trans_b, c,
90+
stream=get_npu_stream(a),
91+
)
92+
else:
93+
infini.ops.gemm(
94+
a,
95+
b,
96+
alpha,
97+
beta,
98+
trans_a,
99+
trans_b,
100+
c,
101+
implementation_index=implementation_index,
102+
)
97103

98104
return c
99105

tests/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,18 @@ def get_available_devices():
3232
if hasattr(torch, "musa") and torch.musa.is_available():
3333
devices.append("musa")
3434

35+
if hasattr(torch, "npu") and torch.npu.is_available():
36+
devices.append("npu")
37+
3538
return tuple(devices)
3639

3740

3841
with contextlib.suppress(ImportError, ModuleNotFoundError):
3942
import torch_mlu # noqa: F401
4043

44+
with contextlib.suppress(ImportError, ModuleNotFoundError):
45+
import torch_npu # noqa: F401
46+
4147

4248
def empty_strided(shape, strides, *, dtype=None, device=None):
4349
if strides is None:
@@ -76,6 +82,14 @@ def randint_strided(low, high, shape, strides, *, dtype=None, device=None):
7682
return output
7783

7884

85+
def get_npu_stream(tensor):
86+
"""Return the current NPU stream handle for `tensor`, or 0 on other devices."""
87+
if tensor.device.type != "npu":
88+
return 0
89+
90+
return torch.npu.current_stream().npu_stream
91+
92+
7993
def clone_strided(input):
8094
output = empty_strided(
8195
input.size(), input.stride(), dtype=input.dtype, device=input.device

0 commit comments

Comments
 (0)