Skip to content

Commit 440b428

Browse files
author
zhangyue
committed
fix(ascend): adapt Memcpy/Memset arity, assert workspace alloc, remove missing include
- Wrap `aclrtMemcpy` (5-arg) and `aclrtMemset` (4-arg) in lambdas to match the generic 4-arg / 3-arg calling convention used by examples. - Assert `aclrtMalloc` return value in `WorkspacePool::ensure()`. - Remove `ascend/gemm/kernel.h` include from `runtime_api.h` (file does not exist until the kernels commit).
1 parent 08e0d6a commit 440b428

3 files changed

Lines changed: 11 additions & 4 deletions

File tree

examples/runtime_api.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "moore/gemm/mublas.h"
2121
#include "moore/runtime_.h"
2222
#elif WITH_ASCEND
23-
#include "ascend/gemm/kernel.h"
2423
#include "ascend/runtime_.h"
2524
#elif WITH_CPU
2625
#include "cpu/gemm/gemm.h"

src/ascend/runtime_.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,18 @@ struct Runtime<Device::Type::kAscend>
2323

2424
static constexpr auto Free = aclrtFree;
2525

26-
static constexpr auto Memcpy = aclrtMemcpy;
26+
static constexpr auto Memcpy = [](void* dst, const void* src, size_t count,
27+
aclrtMemcpyKind kind) {
28+
return aclrtMemcpy(dst, count, src, count, kind);
29+
};
2730

2831
static constexpr auto MemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE;
2932

3033
static constexpr auto MemcpyDeviceToHost = ACL_MEMCPY_DEVICE_TO_HOST;
3134

32-
static constexpr auto Memset = aclrtMemset;
35+
static constexpr auto Memset = [](void* ptr, int value, size_t count) {
36+
return aclrtMemset(ptr, count, value, count);
37+
};
3338
};
3439

3540
static_assert(Runtime<Device::Type::kAscend>::Validate());

src/ascend/workspace_pool_.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef INFINI_OPS_ASCEND_WORKSPACE_POOL__H_
22
#define INFINI_OPS_ASCEND_WORKSPACE_POOL__H_
33

4+
#include <cassert>
45
#include <cstdint>
56
#include <mutex>
67
#include <unordered_map>
@@ -25,7 +26,9 @@ class WorkspacePool {
2526
aclrtFree(arena.buf);
2627
}
2728
if (needed > 0) {
28-
aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY);
29+
assert(aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY) ==
30+
ACL_SUCCESS &&
31+
"`WorkspacePool`: `aclrtMalloc` failed");
2932
}
3033
arena.capacity = needed;
3134
return arena;

0 commit comments

Comments
 (0)