Skip to content

Commit 88a4379

Browse files
author
zhangyue
committed
fix(ascend): adapt Memcpy/Memset arity, assert workspace alloc, remove missing include
- Wrap `aclrtMemcpy` (5-arg) and `aclrtMemset` (4-arg) in lambdas to match the generic 4-arg / 3-arg calling convention used by examples. - Assert `aclrtMalloc` return value in `WorkspacePool::ensure()`. - Remove `ascend/gemm/kernel.h` include from `runtime_api.h` (file does not exist until the kernels commit).
1 parent 08e0d6a commit 88a4379

3 files changed

Lines changed: 11 additions & 8 deletions

File tree

examples/runtime_api.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919
#elif WITH_MOORE
2020
#include "moore/gemm/mublas.h"
2121
#include "moore/runtime_.h"
22-
#elif WITH_ASCEND
23-
#include "ascend/gemm/kernel.h"
24-
#include "ascend/runtime_.h"
2522
#elif WITH_CPU
2623
#include "cpu/gemm/gemm.h"
2724
#include "cpu/runtime_.h"
@@ -41,8 +38,6 @@ using DefaultRuntimeUtils = Runtime<Device::Type::kMetax>;
4138
using DefaultRuntimeUtils = Runtime<Device::Type::kCambricon>;
4239
#elif WITH_MOORE
4340
using DefaultRuntimeUtils = Runtime<Device::Type::kMoore>;
44-
#elif WITH_ASCEND
45-
using DefaultRuntimeUtils = Runtime<Device::Type::kAscend>;
4641
#elif WITH_CPU
4742
using DefaultRuntimeUtils = Runtime<Device::Type::kCpu>;
4843
#endif

src/ascend/runtime_.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,18 @@ struct Runtime<Device::Type::kAscend>
2323

2424
static constexpr auto Free = aclrtFree;
2525

26-
static constexpr auto Memcpy = aclrtMemcpy;
26+
static constexpr auto Memcpy = [](void* dst, const void* src, size_t count,
27+
aclrtMemcpyKind kind) {
28+
return aclrtMemcpy(dst, count, src, count, kind);
29+
};
2730

2831
static constexpr auto MemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE;
2932

3033
static constexpr auto MemcpyDeviceToHost = ACL_MEMCPY_DEVICE_TO_HOST;
3134

32-
static constexpr auto Memset = aclrtMemset;
35+
static constexpr auto Memset = [](void* ptr, int value, size_t count) {
36+
return aclrtMemset(ptr, count, value, count);
37+
};
3338
};
3439

3540
static_assert(Runtime<Device::Type::kAscend>::Validate());

src/ascend/workspace_pool_.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef INFINI_OPS_ASCEND_WORKSPACE_POOL__H_
22
#define INFINI_OPS_ASCEND_WORKSPACE_POOL__H_
33

4+
#include <cassert>
45
#include <cstdint>
56
#include <mutex>
67
#include <unordered_map>
@@ -25,7 +26,9 @@ class WorkspacePool {
2526
aclrtFree(arena.buf);
2627
}
2728
if (needed > 0) {
28-
aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY);
29+
assert(aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY) ==
30+
ACL_SUCCESS &&
31+
"`WorkspacePool`: `aclrtMalloc` failed");
2932
}
3033
arena.capacity = needed;
3134
return arena;

0 commit comments

Comments
 (0)