2424
2525using CommTopo = uint32_t ;
2626
27- // Internal HCCL APIs (not in public headers).
28- // CANN 9.x renames all HCCL symbols with a V2 suffix and ships libhccl_v2.so.
29- // The public header still declares weak non-V2 symbols, so we declare V2 variants
30- // separately and dispatch via inline wrappers.
31- #ifdef HCCL_USE_V2_API
32- extern " C" HcclResult HcclGetRootInfoV2 (HcclRootInfo* rootInfo);
33- extern " C" HcclResult HcclCommInitRootInfoV2 (uint32_t nRanks, const HcclRootInfo* rootInfo,
34- uint32_t rank, HcclComm* comm);
35- extern " C" HcclResult HcclGetCommNameV2 (HcclComm comm, char * commName);
36- extern " C" HcclResult HcclBarrierV2 (HcclComm comm, aclrtStream stream);
37- extern " C" HcclResult HcclCommDestroyV2 (HcclComm comm);
38- extern " C" HcclResult HcclAllocComResourceByTilingV2 (HcclComm comm, void * stream,
39- void * mc2Tiling, void ** commContext);
40- extern " C" HcclResult HcomGetCommHandleByGroupV2 (const char * group, HcclComm* commHandle);
41- extern " C" HcclResult HcomGetL0TopoTypeExV2 (const char * group, CommTopo* topoType,
42- uint32_t isSetDevice);
43-
44- static inline HcclResult hccl_get_root_info (HcclRootInfo* ri)
45- { return HcclGetRootInfoV2 (ri); }
46- static inline HcclResult hccl_comm_init_root_info (uint32_t n, const HcclRootInfo* ri, uint32_t r, HcclComm* c)
47- { return HcclCommInitRootInfoV2 (n, ri, r, c); }
48- static inline HcclResult hccl_get_comm_name (HcclComm c, char * name)
49- { return HcclGetCommNameV2 (c, name); }
50- static inline HcclResult hccl_barrier (HcclComm c, aclrtStream s)
51- { return HcclBarrierV2 (c, s); }
52- static inline HcclResult hccl_comm_destroy (HcclComm c)
53- { return HcclCommDestroyV2 (c); }
54- static inline HcclResult hccl_alloc_com_resource (HcclComm c, void * s, void * t, void ** ctx)
55- { return HcclAllocComResourceByTilingV2 (c, s, t, ctx); }
56- static inline HcclResult hccl_get_comm_handle_by_group (const char * g, HcclComm* c)
57- { return HcomGetCommHandleByGroupV2 (g, c); }
58- static inline HcclResult hccl_get_l0_topo_type_ex (const char * g, CommTopo* t, uint32_t f)
59- { return HcomGetL0TopoTypeExV2 (g, t, f); }
60- #else
27+ // Internal HCCL helpers are exported by libhcomm on CANN 9.x. The public
28+ // HCCL APIs below intentionally use the standard, non-V2 entry points to match
29+ // the working pto-isa initialization sequence.
6130extern " C" HcclResult HcclAllocComResourceByTiling (HcclComm comm, void * stream,
62- void * mc2Tiling, void ** commContext);
31+ void * mc2Tiling, void ** commContext);
6332extern " C" HcclResult HcomGetCommHandleByGroup (const char * group, HcclComm* commHandle);
6433extern " C" HcclResult HcomGetL0TopoTypeEx (const char * group, CommTopo* topoType,
65- uint32_t isSetDevice);
34+ uint32_t isSetDevice);
6635
6736static inline HcclResult hccl_get_root_info (HcclRootInfo* ri)
6837 { return HcclGetRootInfo (ri); }
@@ -80,13 +49,13 @@ static inline HcclResult hccl_get_comm_handle_by_group(const char* g, HcclComm*
8049 { return HcomGetCommHandleByGroup (g, c); }
8150static inline HcclResult hccl_get_l0_topo_type_ex (const char * g, CommTopo* t, uint32_t f)
8251 { return HcomGetL0TopoTypeEx (g, t, f); }
83- #endif
8452
8553static constexpr uint32_t COMM_IS_NOT_SET_DEVICE = 0 ;
8654static constexpr uint32_t COMM_TOPO_MESH = 0b1u ;
8755
8856using rtStream_t = void *;
8957static constexpr int32_t RT_STREAM_PRIORITY_DEFAULT = 0 ;
58+ extern " C" int32_t rtSetDevice (int32_t device);
9059extern " C" int32_t rtStreamCreate (rtStream_t* stream, int32_t priority);
9160extern " C" int32_t rtStreamDestroy (rtStream_t stream);
9261
@@ -325,7 +294,7 @@ static void file_barrier(const std::string& dir, int rank, int nranks, const std
325294// API implementation
326295// ============================================================================
327296
328- extern " C" CommHandle comm_init (int rank, int nranks, const char * rootinfo_path) {
297+ extern " C" CommHandle comm_init (int rank, int nranks, int device_id, const char * rootinfo_path) {
329298 auto * h = new (std::nothrow) CommHandle_{};
330299 if (!h) return nullptr ;
331300
@@ -342,10 +311,26 @@ extern "C" CommHandle comm_init(int rank, int nranks, const char* rootinfo_path)
342311 return nullptr ;
343312 }
344313
345- // NOTE: Do NOT call aclrtSetDevice here — the caller (distributed_worker)
346- // already set the correct physical device via set_device(device_id).
347- // Calling aclrtSetDevice(rank) would override the context when
348- // rank != device_id (e.g. devices=[2,4,5,7]).
314+ if (rank == 0 ) {
315+ int32_t rtRet = rtSetDevice (device_id);
316+ if (rtRet != 0 ) {
317+ fprintf (stderr, " [comm rank %d] rtSetDevice(%d) failed: %d\n " ,
318+ rank, device_id, rtRet);
319+ delete h;
320+ return nullptr ;
321+ }
322+ }
323+
324+ // HCCL requires an ACL runtime context bound to the physical device.
325+ // This cannot be inferred from rank because distributed runs may map
326+ // ranks to arbitrary device lists (for example devices=[2,4,5,7]).
327+ aRet = aclrtSetDevice (device_id);
328+ if (aRet != ACL_SUCCESS) {
329+ fprintf (stderr, " [comm rank %d] aclrtSetDevice(%d) failed: %d\n " ,
330+ rank, device_id, (int )aRet);
331+ delete h;
332+ return nullptr ;
333+ }
349334
350335 // RootInfo exchange
351336 HcclRootInfo rootInfo{};
@@ -369,6 +354,13 @@ extern "C" CommHandle comm_init(int rank, int nranks, const char* rootinfo_path)
369354 fin.read (rootInfo.internal , HCCL_ROOT_INFO_BYTES);
370355 }
371356
357+ std::string barrier_dir = h->rootinfo_path ;
358+ auto last_slash = barrier_dir.rfind (' /' );
359+ if (last_slash != std::string::npos) {
360+ barrier_dir = barrier_dir.substr (0 , last_slash);
361+ }
362+ file_barrier (barrier_dir, h->rank , h->nranks , " rootinfo_ready" );
363+
372364 // Create stream for HCCL operations
373365 rtStreamCreate (&h->stream , RT_STREAM_PRIORITY_DEFAULT);
374366
@@ -403,8 +395,9 @@ extern "C" int comm_alloc_windows(CommHandle h, size_t /*win_size*/, uint64_t* d
403395 // File barrier so all ranks have completed HcclCommInitRootInfo
404396 std::string barrier_dir = h->rootinfo_path ;
405397 auto last_slash = barrier_dir.rfind (' /' );
406- if (last_slash != std::string::npos)
398+ if (last_slash != std::string::npos) {
407399 barrier_dir = barrier_dir.substr (0 , last_slash);
400+ }
408401 file_barrier (barrier_dir, h->rank , h->nranks , " hccl_init" );
409402
410403 // Tiling configuration for HcclAllocComResourceByTiling
0 commit comments