Skip to content

Commit ff85027

Browse files
author
echo_stone
committed
Fix: stabilize HCCL communicator initialization
- pass device_id through comm_init so communicator setup binds the correct physical device for each rank - align a2a3 HCCL init with the pto-isa sequence by linking hcomm, running rank-0 rtSetDevice, and synchronizing rootinfo exchange - move distributed worker device setup after comm init and keep the A5/sim backends signature-compatible with the new API Made-with: Cursor
1 parent 9478e0f commit ff85027

9 files changed

Lines changed: 72 additions & 75 deletions

File tree

examples/scripts/distributed_worker.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,12 @@ def buf_bytes(b):
105105

106106
lib_path = artifact_dir / "libhost_runtime.so"
107107
Runtime = bind_host_binary(str(lib_path))
108-
set_device(args.device_id)
109-
110-
sys.stderr.write(f"[rank {args.rank}] Library loaded, device {args.device_id} set\n")
108+
sys.stderr.write(f"[rank {args.rank}] Library loaded\n")
111109

112110
# ----------------------------------------------------------------
113111
# 2. Comm init + alloc windows
114112
# ----------------------------------------------------------------
115-
comm = comm_init(args.rank, args.nranks, args.rootinfo_file)
113+
comm = comm_init(args.rank, args.nranks, args.device_id, args.rootinfo_file)
116114

117115
total_win = args.win_sync_prefix
118116
for b in buffers:
@@ -124,6 +122,9 @@ def buf_bytes(b):
124122

125123
sys.stderr.write(f"[rank {args.rank}] Comm initialized, local_base=0x{local_base:x}\n")
126124

125+
set_device(args.device_id)
126+
sys.stderr.write(f"[rank {args.rank}] Device {args.device_id} set for runtime\n")
127+
127128
# ----------------------------------------------------------------
128129
# 3. Allocate buffers
129130
# ----------------------------------------------------------------

python/bindings.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _setup_functions(self):
181181
self.lib.enable_runtime_profiling.restype = c_int
182182

183183
# --- Distributed communication API (comm_*) ---
184-
self.lib.comm_init.argtypes = [c_int, c_int, c_char_p]
184+
self.lib.comm_init.argtypes = [c_int, c_int, c_int, c_char_p]
185185
self.lib.comm_init.restype = c_void_p
186186

187187
self.lib.comm_alloc_windows.argtypes = [c_void_p, c_size_t, POINTER(c_uint64)]
@@ -543,13 +543,14 @@ def launch_runtime(
543543
# ============================================================================
544544

545545

546-
def comm_init(rank: int, nranks: int, rootinfo_path: str) -> int:
546+
def comm_init(rank: int, nranks: int, device_id: int, rootinfo_path: str) -> int:
547547
"""
548548
Initialize a distributed communicator for the given rank.
549549
550550
Args:
551551
rank: This process's rank (0-based)
552552
nranks: Total number of ranks
553+
device_id: Physical device ID used by this process
553554
rootinfo_path: Filesystem path for root info exchange
554555
555556
Returns:
@@ -562,7 +563,7 @@ def comm_init(rank: int, nranks: int, rootinfo_path: str) -> int:
562563
if _lib is None:
563564
raise RuntimeError("Runtime not loaded. Call bind_host_binary() first.")
564565

565-
handle = _lib.comm_init(rank, nranks, rootinfo_path.encode('utf-8'))
566+
handle = _lib.comm_init(rank, nranks, device_id, rootinfo_path.encode('utf-8'))
566567
if not handle:
567568
raise RuntimeError(f"comm_init failed for rank {rank}")
568569
return handle

src/a2a3/platform/include/host/comm.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ typedef struct CommHandle_* CommHandle;
3232
*
3333
* @param rank This process's rank (0-based).
3434
* @param nranks Total number of ranks.
35+
* @param device_id Physical device ID used by this process.
3536
* @param rootinfo_path Filesystem path used to exchange root info between
3637
* ranks (rank 0 writes, others read).
3738
* @return Opaque handle, or NULL on failure.
3839
*/
39-
CommHandle comm_init(int rank, int nranks, const char* rootinfo_path);
40+
CommHandle comm_init(int rank, int nranks, int device_id, const char* rootinfo_path);
4041

4142
/**
4243
* Allocate RDMA / shared-memory windows and populate the device context.

src/a2a3/platform/onboard/host/CMakeLists.txt

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,15 @@ target_link_directories(host_runtime
8181
${ASCEND_HOME_PATH}/runtime/lib64
8282
)
8383

84-
# Detect HCCL library version: CANN 9.x ships hccl_v2 instead of hccl
85-
# Prioritize hccl_v2 since CANN 9.x may have both hccl.so and hccl_v2.so,
86-
# and hccl_v2 is the complete, actively-maintained library.
87-
find_library(HCCL_V2_LIB NAMES hccl_v2 PATHS ${ASCEND_HOME_PATH}/lib64 NO_DEFAULT_PATH)
88-
find_library(HCCL_LIB NAMES hccl PATHS ${ASCEND_HOME_PATH}/lib64 NO_DEFAULT_PATH)
89-
90-
if(HCCL_V2_LIB)
91-
set(HCCL_LINK_TARGETS hccl_v2 hccl_plf)
92-
target_compile_definitions(host_runtime PRIVATE HCCL_USE_V2_API=1)
93-
message(STATUS "Using HCCL library: hccl_v2 + hccl_plf (CANN 9.x)")
94-
elseif(HCCL_LIB)
95-
set(HCCL_LINK_TARGETS hccl)
96-
message(STATUS "Using HCCL library: hccl")
84+
# CANN 9.x exposes the working non-V2 HCCL entry points through libhcomm.
85+
# Link it explicitly so comm_hccl.cpp can follow the same initialization path
86+
# as the pto-isa communication tests.
87+
find_library(HCOMM_LIB NAMES hcomm PATHS ${ASCEND_HOME_PATH}/lib64 NO_DEFAULT_PATH)
88+
if(HCOMM_LIB)
89+
set(HCCL_LINK_TARGETS hcomm)
90+
message(STATUS "Using HCCL library: hcomm")
9791
else()
98-
message(WARNING "Neither hccl_v2 nor hccl found, linking against hccl and hoping for the best")
99-
set(HCCL_LINK_TARGETS hccl)
92+
message(FATAL_ERROR "libhcomm not found under ${ASCEND_HOME_PATH}/lib64")
10093
endif()
10194

10295
# Optionally link nnopbase (provides aclCreateTensor/aclDestroyTensor for SdmaWorkspaceManager)
@@ -115,7 +108,6 @@ target_link_libraries(host_runtime
115108
runtime
116109
ascendcl
117110
${HCCL_LINK_TARGETS}
118-
hccl_fwk
119111
${NNOPBASE_LINK}
120112
dl
121113
)

src/a2a3/platform/onboard/host/comm_hccl.cpp

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,45 +24,14 @@
2424

2525
using CommTopo = uint32_t;
2626

27-
// Internal HCCL APIs (not in public headers).
28-
// CANN 9.x renames all HCCL symbols with a V2 suffix and ships libhccl_v2.so.
29-
// The public header still declares weak non-V2 symbols, so we declare V2 variants
30-
// separately and dispatch via inline wrappers.
31-
#ifdef HCCL_USE_V2_API
32-
extern "C" HcclResult HcclGetRootInfoV2(HcclRootInfo* rootInfo);
33-
extern "C" HcclResult HcclCommInitRootInfoV2(uint32_t nRanks, const HcclRootInfo* rootInfo,
34-
uint32_t rank, HcclComm* comm);
35-
extern "C" HcclResult HcclGetCommNameV2(HcclComm comm, char* commName);
36-
extern "C" HcclResult HcclBarrierV2(HcclComm comm, aclrtStream stream);
37-
extern "C" HcclResult HcclCommDestroyV2(HcclComm comm);
38-
extern "C" HcclResult HcclAllocComResourceByTilingV2(HcclComm comm, void* stream,
39-
void* mc2Tiling, void** commContext);
40-
extern "C" HcclResult HcomGetCommHandleByGroupV2(const char* group, HcclComm* commHandle);
41-
extern "C" HcclResult HcomGetL0TopoTypeExV2(const char* group, CommTopo* topoType,
42-
uint32_t isSetDevice);
43-
44-
static inline HcclResult hccl_get_root_info(HcclRootInfo* ri)
45-
{ return HcclGetRootInfoV2(ri); }
46-
static inline HcclResult hccl_comm_init_root_info(uint32_t n, const HcclRootInfo* ri, uint32_t r, HcclComm* c)
47-
{ return HcclCommInitRootInfoV2(n, ri, r, c); }
48-
static inline HcclResult hccl_get_comm_name(HcclComm c, char* name)
49-
{ return HcclGetCommNameV2(c, name); }
50-
static inline HcclResult hccl_barrier(HcclComm c, aclrtStream s)
51-
{ return HcclBarrierV2(c, s); }
52-
static inline HcclResult hccl_comm_destroy(HcclComm c)
53-
{ return HcclCommDestroyV2(c); }
54-
static inline HcclResult hccl_alloc_com_resource(HcclComm c, void* s, void* t, void** ctx)
55-
{ return HcclAllocComResourceByTilingV2(c, s, t, ctx); }
56-
static inline HcclResult hccl_get_comm_handle_by_group(const char* g, HcclComm* c)
57-
{ return HcomGetCommHandleByGroupV2(g, c); }
58-
static inline HcclResult hccl_get_l0_topo_type_ex(const char* g, CommTopo* t, uint32_t f)
59-
{ return HcomGetL0TopoTypeExV2(g, t, f); }
60-
#else
27+
// Internal HCCL helpers are exported by libhcomm on CANN 9.x. The public
28+
// HCCL APIs below intentionally use the standard, non-V2 entry points to match
29+
// the working pto-isa initialization sequence.
6130
extern "C" HcclResult HcclAllocComResourceByTiling(HcclComm comm, void* stream,
62-
void* mc2Tiling, void** commContext);
31+
void* mc2Tiling, void** commContext);
6332
extern "C" HcclResult HcomGetCommHandleByGroup(const char* group, HcclComm* commHandle);
6433
extern "C" HcclResult HcomGetL0TopoTypeEx(const char* group, CommTopo* topoType,
65-
uint32_t isSetDevice);
34+
uint32_t isSetDevice);
6635

6736
static inline HcclResult hccl_get_root_info(HcclRootInfo* ri)
6837
{ return HcclGetRootInfo(ri); }
@@ -80,13 +49,13 @@ static inline HcclResult hccl_get_comm_handle_by_group(const char* g, HcclComm*
8049
{ return HcomGetCommHandleByGroup(g, c); }
8150
static inline HcclResult hccl_get_l0_topo_type_ex(const char* g, CommTopo* t, uint32_t f)
8251
{ return HcomGetL0TopoTypeEx(g, t, f); }
83-
#endif
8452

8553
static constexpr uint32_t COMM_IS_NOT_SET_DEVICE = 0;
8654
static constexpr uint32_t COMM_TOPO_MESH = 0b1u;
8755

8856
using rtStream_t = void*;
8957
static constexpr int32_t RT_STREAM_PRIORITY_DEFAULT = 0;
58+
extern "C" int32_t rtSetDevice(int32_t device);
9059
extern "C" int32_t rtStreamCreate(rtStream_t* stream, int32_t priority);
9160
extern "C" int32_t rtStreamDestroy(rtStream_t stream);
9261

@@ -325,7 +294,7 @@ static void file_barrier(const std::string& dir, int rank, int nranks, const std
325294
// API implementation
326295
// ============================================================================
327296

328-
extern "C" CommHandle comm_init(int rank, int nranks, const char* rootinfo_path) {
297+
extern "C" CommHandle comm_init(int rank, int nranks, int device_id, const char* rootinfo_path) {
329298
auto* h = new (std::nothrow) CommHandle_{};
330299
if (!h) return nullptr;
331300

@@ -342,10 +311,26 @@ extern "C" CommHandle comm_init(int rank, int nranks, const char* rootinfo_path)
342311
return nullptr;
343312
}
344313

345-
// NOTE: Do NOT call aclrtSetDevice here — the caller (distributed_worker)
346-
// already set the correct physical device via set_device(device_id).
347-
// Calling aclrtSetDevice(rank) would override the context when
348-
// rank != device_id (e.g. devices=[2,4,5,7]).
314+
if (rank == 0) {
315+
int32_t rtRet = rtSetDevice(device_id);
316+
if (rtRet != 0) {
317+
fprintf(stderr, "[comm rank %d] rtSetDevice(%d) failed: %d\n",
318+
rank, device_id, rtRet);
319+
delete h;
320+
return nullptr;
321+
}
322+
}
323+
324+
// HCCL requires an ACL runtime context bound to the physical device.
325+
// This cannot be inferred from rank because distributed runs may map
326+
// ranks to arbitrary device lists (for example devices=[2,4,5,7]).
327+
aRet = aclrtSetDevice(device_id);
328+
if (aRet != ACL_SUCCESS) {
329+
fprintf(stderr, "[comm rank %d] aclrtSetDevice(%d) failed: %d\n",
330+
rank, device_id, (int)aRet);
331+
delete h;
332+
return nullptr;
333+
}
349334

350335
// RootInfo exchange
351336
HcclRootInfo rootInfo{};
@@ -369,6 +354,13 @@ extern "C" CommHandle comm_init(int rank, int nranks, const char* rootinfo_path)
369354
fin.read(rootInfo.internal, HCCL_ROOT_INFO_BYTES);
370355
}
371356

357+
std::string barrier_dir = h->rootinfo_path;
358+
auto last_slash = barrier_dir.rfind('/');
359+
if (last_slash != std::string::npos) {
360+
barrier_dir = barrier_dir.substr(0, last_slash);
361+
}
362+
file_barrier(barrier_dir, h->rank, h->nranks, "rootinfo_ready");
363+
372364
// Create stream for HCCL operations
373365
rtStreamCreate(&h->stream, RT_STREAM_PRIORITY_DEFAULT);
374366

@@ -403,8 +395,9 @@ extern "C" int comm_alloc_windows(CommHandle h, size_t /*win_size*/, uint64_t* d
403395
// File barrier so all ranks have completed HcclCommInitRootInfo
404396
std::string barrier_dir = h->rootinfo_path;
405397
auto last_slash = barrier_dir.rfind('/');
406-
if (last_slash != std::string::npos)
398+
if (last_slash != std::string::npos) {
407399
barrier_dir = barrier_dir.substr(0, last_slash);
400+
}
408401
file_barrier(barrier_dir, h->rank, h->nranks, "hccl_init");
409402

410403
// Tiling configuration for HcclAllocComResourceByTiling

src/a2a3/platform/sim/host/comm_sim.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ struct CommHandle_ {
6868
// API implementation
6969
// ============================================================================
7070

71-
extern "C" CommHandle comm_init(int rank, int nranks, const char* rootinfo_path) {
71+
extern "C" CommHandle comm_init(int rank, int nranks, int device_id, const char* rootinfo_path) {
7272
auto* h = new (std::nothrow) CommHandle_{};
7373
if (!h) return nullptr;
74+
(void)device_id;
7475

7576
h->rank = rank;
7677
h->nranks = nranks;

src/a5/platform/include/host/comm.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ typedef struct CommHandle_* CommHandle;
3232
*
3333
* @param rank This process's rank (0-based).
3434
* @param nranks Total number of ranks.
35+
* @param device_id Physical device ID used by this process.
3536
* @param rootinfo_path Filesystem path used to exchange root info between
3637
* ranks (rank 0 writes, others read).
3738
* @return Opaque handle, or NULL on failure.
3839
*/
39-
CommHandle comm_init(int rank, int nranks, const char* rootinfo_path);
40+
CommHandle comm_init(int rank, int nranks, int device_id, const char* rootinfo_path);
4041

4142
/**
4243
* Allocate RDMA / shared-memory windows and populate the device context.

src/a5/platform/onboard/host/comm_hccl.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ static void file_barrier(const std::string& dir, int rank, int nranks, const std
274274
// API implementation
275275
// ============================================================================
276276

277-
extern "C" CommHandle comm_init(int rank, int nranks, const char* rootinfo_path) {
277+
extern "C" CommHandle comm_init(int rank, int nranks, int device_id, const char* rootinfo_path) {
278278
auto* h = new (std::nothrow) CommHandle_{};
279279
if (!h) return nullptr;
280280

@@ -291,10 +291,16 @@ extern "C" CommHandle comm_init(int rank, int nranks, const char* rootinfo_path)
291291
return nullptr;
292292
}
293293

294-
// NOTE: Do NOT call aclrtSetDevice here — the caller (distributed_worker)
295-
// already set the correct physical device via set_device(device_id).
296-
// Calling aclrtSetDevice(rank) would override the context when
297-
// rank != device_id (e.g. devices=[2,4,5,7]).
294+
// HCCL requires an ACL runtime context bound to the physical device.
295+
// This cannot be inferred from rank because distributed runs may map
296+
// ranks to arbitrary device lists (for example devices=[2,4,5,7]).
297+
aRet = aclrtSetDevice(device_id);
298+
if (aRet != ACL_SUCCESS) {
299+
fprintf(stderr, "[comm rank %d] aclrtSetDevice(%d) failed: %d\n",
300+
rank, device_id, (int)aRet);
301+
delete h;
302+
return nullptr;
303+
}
298304

299305
// RootInfo exchange
300306
HcclRootInfo rootInfo{};

src/a5/platform/sim/host/comm_sim.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ struct CommHandle_ {
6868
// API implementation
6969
// ============================================================================
7070

71-
extern "C" CommHandle comm_init(int rank, int nranks, const char* rootinfo_path) {
71+
extern "C" CommHandle comm_init(int rank, int nranks, int device_id, const char* rootinfo_path) {
7272
auto* h = new (std::nothrow) CommHandle_{};
7373
if (!h) return nullptr;
74+
(void)device_id;
7475

7576
h->rank = rank;
7677
h->nranks = nranks;

0 commit comments

Comments
 (0)