Skip to content

Commit bf2eae5

Browse files
Chamberlain0w0kilinchange
authored andcommitted
fix: resolve requested changes, remove unnecessary api, remove nccl macros, mv unique_id file helper functions to utils
1 parent a1bea05 commit bf2eae5

15 files changed

Lines changed: 162 additions & 144 deletions

File tree

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ endif()
4848
# Framework core sources (*.cc), excluding cpu kernels (they are built separately)
4949
file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc)
5050
list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*")
51+
if(NOT USE_NCCL)
52+
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/cuda/.*")
53+
endif()
5154

5255
# CPU kernels (*.cc)
5356
file(GLOB_RECURSE CPU_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/kernels/cpu/*.cc)

infini_train/include/core/ccl/ccl.h

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,13 @@ class CclImpl {
2626

2727
virtual void GroupEnd() const;
2828

29-
virtual void CommGetAsyncError(const CclComm *comm, CclStatus *async_error) const;
29+
virtual void GetAsyncError(const CclComm *comm, CclStatus *async_error) const;
3030

31-
virtual void CreateComm(CclComm **comm) const;
32-
33-
virtual void CreateUniqueId(CclUniqueId **unique_id) const;
34-
35-
virtual void GetUniqueId(CclUniqueId *unique_id) const;
36-
37-
virtual void WriteUniqueId(const CclUniqueId &unique_id, const std::string &pg_name) const;
38-
39-
virtual void ReadUniqueId(CclUniqueId *unique_id, const std::string &pg_name) const;
40-
41-
virtual void CleanupUniqueIdFile(const std::string &pg_name) const;
31+
virtual void GetUniqueId(CclUniqueId **unique_id) const;
4232

4333
virtual void CommInitAll(CclComm **comms, int ndev, const int *devlist) const;
4434

45-
virtual void CommInitRank(CclComm *comm, int nranks, const CclUniqueId &unique_id, int rank) const;
35+
virtual void CommInitRank(CclComm **comm, int nranks, const CclUniqueId &unique_id, int rank) const;
4636

4737
virtual void CommDestroy(CclComm *comm) const;
4838

infini_train/include/core/ccl/ccl_common.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#pragma once
22

3+
#include <cstddef>
34
#include <cstdint>
45

6+
#include "glog/logging.h"
7+
58
namespace infini_train::core {
69

710
#define INFINI_TRAIN_CCL_STATUS_LIST(X) \
@@ -29,7 +32,8 @@ inline const char *CclStatusToString(CclStatus status) {
2932
INFINI_TRAIN_CCL_STATUS_LIST(INFINI_TRAIN_CCL_STATUS_CASE)
3033
#undef INFINI_TRAIN_CCL_STATUS_CASE
3134
default:
32-
return "Unknown";
35+
LOG(FATAL) << "Unsupported RuntimeStatus type: " << static_cast<int>(status);
36+
return "";
3337
}
3438
}
3539

@@ -45,6 +49,10 @@ class CclUniqueId {
4549
public:
4650
CclUniqueId() = default;
4751
virtual ~CclUniqueId() = default;
52+
53+
virtual size_t Size() const = 0;
54+
virtual const void *Data() const = 0;
55+
virtual void Load(const void *src, size_t size) = 0;
4856
};
4957

5058
} // namespace infini_train::core
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include <string>
4+
5+
#include "infini_train/include/core/ccl/ccl_common.h"
6+
7+
namespace infini_train::core {
8+
9+
void WriteUniqueIdFile(const CclUniqueId &unique_id, const std::string &pg_name);
10+
11+
void ReadUniqueIdFile(CclUniqueId *unique_id, const std::string &pg_name);
12+
13+
void CleanupUniqueIdFile(const std::string &pg_name);
14+
15+
} // namespace infini_train::core

infini_train/include/core/runtime/runtime_common.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include <cstdint>
44

5+
#include "glog/logging.h"
6+
57
namespace infini_train::core {
68

79
class BlasHandle {
@@ -58,7 +60,8 @@ inline const char *RuntimeStatusToString(RuntimeStatus s) {
5860
INFINI_TRAIN_RUNTIME_STATUS_LIST(INFINI_TRAIN_RUNTIME_STATUS_CASE)
5961
#undef INFINI_TRAIN_RUNTIME_STATUS_CASE
6062
default:
61-
return "Unknown";
63+
LOG(FATAL) << "Unsupported RuntimeStatus type: " << static_cast<int>(s);
64+
return "";
6265
}
6366
}
6467

infini_train/src/core/ccl/ccl.cc

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,17 @@ void CclImpl::GroupStart() const { LOG(FATAL) << "CclImpl::GroupStart is not imp
1212

1313
void CclImpl::GroupEnd() const { LOG(FATAL) << "CclImpl::GroupEnd is not implemented."; }
1414

15-
void CclImpl::CommGetAsyncError(const CclComm *comm, CclStatus *async_error) const {
16-
LOG(FATAL) << "CclImpl::CommGetAsyncError is not implemented.";
15+
void CclImpl::GetAsyncError(const CclComm *comm, CclStatus *async_error) const {
16+
LOG(FATAL) << "CclImpl::GetAsyncError is not implemented.";
1717
}
1818

19-
void CclImpl::CreateComm(CclComm **comm) const { LOG(FATAL) << "CclImpl::CreateComm is not implemented."; }
20-
21-
void CclImpl::CreateUniqueId(CclUniqueId **unique_id) const {
22-
LOG(FATAL) << "CclImpl::CreateUniqueId is not implemented.";
23-
}
24-
25-
void CclImpl::GetUniqueId(CclUniqueId *unique_id) const { LOG(FATAL) << "CclImpl::GetUniqueId is not implemented."; }
26-
27-
void CclImpl::WriteUniqueId(const CclUniqueId &unique_id, const std::string &pg_name) const {
28-
LOG(FATAL) << "CclImpl::WriteUniqueId is not implemented.";
29-
}
30-
31-
void CclImpl::ReadUniqueId(CclUniqueId *unique_id, const std::string &pg_name) const {
32-
LOG(FATAL) << "CclImpl::ReadUniqueId is not implemented.";
33-
}
34-
35-
void CclImpl::CleanupUniqueIdFile(const std::string &pg_name) const {
36-
LOG(FATAL) << "CclImpl::CleanupUniqueIdFile is not implemented.";
37-
}
19+
void CclImpl::GetUniqueId(CclUniqueId **unique_id) const { LOG(FATAL) << "CclImpl::GetUniqueId is not implemented."; }
3820

3921
void CclImpl::CommInitAll(CclComm **comms, int ndev, const int *devlist) const {
4022
LOG(FATAL) << "CclImpl::CommInitAll is not implemented.";
4123
}
4224

43-
void CclImpl::CommInitRank(CclComm *comm, int nranks, const CclUniqueId &unique_id, int rank) const {
25+
void CclImpl::CommInitRank(CclComm **comm, int nranks, const CclUniqueId &unique_id, int rank) const {
4426
LOG(FATAL) << "CclImpl::CommInitRank is not implemented.";
4527
}
4628

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "infini_train/include/core/ccl/ccl_utils.h"
2+
3+
#include <chrono>
4+
#include <cstdio>
5+
#include <filesystem>
6+
#include <fstream>
7+
#include <iterator>
8+
#include <thread>
9+
10+
#include "glog/logging.h"
11+
12+
namespace infini_train::core {
13+
namespace {
14+
std::string UniqueIdFileName(const std::string &name, bool tmp = false) {
15+
return "cclUniqueId_" + name + (tmp ? ".tmp" : ".bin");
16+
}
17+
} // namespace
18+
19+
void WriteUniqueIdFile(const CclUniqueId &unique_id, const std::string &pg_name) {
20+
const std::string tmp_path = UniqueIdFileName(pg_name, true);
21+
22+
std::ofstream ofs(tmp_path, std::ios::binary);
23+
CHECK(ofs.good()) << "Failed to open unique_id tmp file for write: " << tmp_path;
24+
const size_t size = unique_id.Size();
25+
ofs.write(reinterpret_cast<const char *>(unique_id.Data()), static_cast<std::streamsize>(size));
26+
ofs.close();
27+
28+
std::rename(tmp_path.c_str(), UniqueIdFileName(pg_name).c_str());
29+
}
30+
31+
void ReadUniqueIdFile(CclUniqueId *unique_id, const std::string &pg_name) {
32+
CHECK_NOTNULL(unique_id);
33+
const std::string file_path = UniqueIdFileName(pg_name);
34+
35+
while (!std::filesystem::exists(file_path)) { std::this_thread::sleep_for(std::chrono::microseconds(1000)); }
36+
37+
std::ifstream ifs(file_path, std::ios::binary);
38+
CHECK(ifs.good()) << "Failed to open unique_id file for read: " << file_path;
39+
40+
std::string bytes((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
41+
ifs.close();
42+
43+
CHECK_EQ(bytes.size(), unique_id->Size())
44+
<< "Mismatched unique_id size in file. expected=" << unique_id->Size() << ", got=" << bytes.size();
45+
unique_id->Load(bytes.data(), bytes.size());
46+
}
47+
48+
void CleanupUniqueIdFile(const std::string &pg_name) {
49+
const std::string file_path = UniqueIdFileName(pg_name);
50+
if (std::filesystem::exists(file_path)) {
51+
std::filesystem::remove(file_path);
52+
}
53+
}
54+
55+
} // namespace infini_train::core
Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,35 @@
11
#include "infini_train/src/core/ccl/cuda/nccl_common.h"
22

3+
#include <cstring>
4+
5+
#include "glog/logging.h"
6+
37
namespace infini_train::core {
48

5-
#ifdef USE_NCCL
69
NcclComm::NcclComm() = default;
710

8-
NcclComm::NcclComm(ncclComm_t comm) : comm_(comm) {}
11+
NcclComm::NcclComm(ncclComm_t comm) : nccl_comm_(comm) {}
912

10-
ncclComm_t NcclComm::nccl_comm() const { return comm_; }
13+
ncclComm_t NcclComm::nccl_comm() const { return nccl_comm_; }
1114

12-
void NcclComm::set_nccl_comm(ncclComm_t comm) { comm_ = comm; }
15+
void NcclComm::set_nccl_comm(ncclComm_t comm) { nccl_comm_ = comm; }
1316

1417
NcclUniqueId::NcclUniqueId() = default;
1518

1619
NcclUniqueId::NcclUniqueId(const ncclUniqueId &id) : id_(id) {}
1720

21+
size_t NcclUniqueId::Size() const { return sizeof(id_); }
22+
23+
const void *NcclUniqueId::Data() const { return &id_; }
24+
25+
void NcclUniqueId::Load(const void *src, size_t size) {
26+
CHECK_NOTNULL(src);
27+
CHECK_EQ(size, sizeof(id_));
28+
std::memcpy(&id_, src, sizeof(id_));
29+
}
30+
1831
ncclUniqueId *NcclUniqueId::nccl_unique_id() { return &id_; }
1932

2033
const ncclUniqueId *NcclUniqueId::nccl_unique_id() const { return &id_; }
21-
#endif
2234

2335
} // namespace infini_train::core

infini_train/src/core/ccl/cuda/nccl_common.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#pragma once
22

3-
#ifdef USE_NCCL
43
#include <nccl.h>
54

65
#include "infini_train/include/core/ccl/ccl_common.h"
@@ -16,20 +15,23 @@ class NcclComm final : public CclComm {
1615
void set_nccl_comm(ncclComm_t comm);
1716

1817
private:
19-
ncclComm_t comm_ = nullptr;
18+
ncclComm_t nccl_comm_ = nullptr;
2019
};
2120

2221
class NcclUniqueId final : public CclUniqueId {
2322
public:
2423
NcclUniqueId();
2524
explicit NcclUniqueId(const ncclUniqueId &id);
2625

26+
size_t Size() const override;
27+
const void *Data() const override;
28+
void Load(const void *src, size_t size) override;
29+
2730
ncclUniqueId *nccl_unique_id();
2831
const ncclUniqueId *nccl_unique_id() const;
2932

3033
private:
3134
ncclUniqueId id_;
3235
};
33-
#endif
3436

3537
} // namespace infini_train::core

0 commit comments

Comments
 (0)