Skip to content

Commit f33fc2b

Browse files
Chamberlain0w0kilinchange
authored andcommitted
feat: integrate runtime_common, and modify ProcessGroup related apis
1 parent e281dea commit f33fc2b

57 files changed

Lines changed: 326 additions & 314 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

example/gpt2/main.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include "glog/logging.h"
1111

1212
#include "infini_train/include/autocast.h"
13-
#include "infini_train/include/core/device_guard.h"
13+
#include "infini_train/include/core/runtime/device_guard.h"
1414
#include "infini_train/include/dataloader.h"
1515
#include "infini_train/include/device.h"
1616
#include "infini_train/include/nn/modules/loss.h"
@@ -140,24 +140,25 @@ void Train(const nn::parallel::Rank &rank) {
140140

141141
if (rank.IsParallel()) {
142142
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());
143+
auto *pg_factory = ProcessGroupFactory::Instance(device.type());
143144

144145
if (ddp_world_size > 1) {
145-
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
146-
GetDataParallelGroupRanks(rank.GlobalRank()));
146+
ddp_pg = pg_factory->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
147+
GetDataParallelGroupRanks(rank.GlobalRank()));
147148
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
148149
}
149150

150151
if (tp_world_size > 1) {
151-
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
152-
GetTensorParallelGroupRanks(rank.GlobalRank()));
152+
tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
153+
GetTensorParallelGroupRanks(rank.GlobalRank()));
153154
tp_rank = tp_pg->GetGroupRank(rank.GlobalRank());
154155
// NOTE(zbl): Reserved for VocabParallelEmbedding
155156
nn::parallel::tp_rank = tp_rank;
156157
}
157158

158159
if (pp_world_size > 1) {
159-
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
160-
GetPipelineParallelGroupRanks(rank.GlobalRank()));
160+
pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
161+
GetPipelineParallelGroupRanks(rank.GlobalRank()));
161162
pp_rank = pp_pg->GetGroupRank(rank.GlobalRank());
162163

163164
nn::parallel::pp_rank = pp_rank;

example/gpt2/net.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ GPT2FirstStage::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>>
198198
auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled();
199199
int tp_rank = 0;
200200
if (tp_world_size > 1) {
201-
auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get(
202-
nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank()));
201+
auto tp_group = nn::parallel::ProcessGroupFactory::Instance(device.type())
202+
->Get(nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank()));
203203
tp_rank = tp_group->GetGroupRank(device.Rank().GlobalRank());
204204
}
205205
int64_t t_local = sequence_parallel_enabled ? x1->Dims()[1] / tp_world_size : x1->Dims()[1];

example/llama3/main.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include "glog/logging.h"
99

1010
#include "infini_train/include/autocast.h"
11-
#include "infini_train/include/core/device_guard.h"
11+
#include "infini_train/include/core/runtime/device_guard.h"
1212
#include "infini_train/include/dataloader.h"
1313
#include "infini_train/include/device.h"
1414
#include "infini_train/include/nn/modules/loss.h"
@@ -121,24 +121,25 @@ void Train(const nn::parallel::Rank &rank) {
121121

122122
if (rank.IsParallel()) {
123123
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());
124+
auto *pg_factory = ProcessGroupFactory::Instance(device.type());
124125

125126
if (ddp_world_size > 1) {
126-
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
127-
GetDataParallelGroupRanks(rank.GlobalRank()));
127+
ddp_pg = pg_factory->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
128+
GetDataParallelGroupRanks(rank.GlobalRank()));
128129
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
129130
}
130131

131132
if (tp_world_size > 1) {
132-
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
133-
GetTensorParallelGroupRanks(rank.GlobalRank()));
133+
tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
134+
GetTensorParallelGroupRanks(rank.GlobalRank()));
134135
tp_rank = tp_pg->GetGroupRank(rank.GlobalRank());
135136
// NOTE(zbl): Reserved for VocabParallelEmbedding
136137
nn::parallel::tp_rank = tp_rank;
137138
}
138139

139140
if (pp_world_size > 1) {
140-
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
141-
GetPipelineParallelGroupRanks(rank.GlobalRank()));
141+
pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
142+
GetPipelineParallelGroupRanks(rank.GlobalRank()));
142143
pp_rank = pp_pg->GetGroupRank(rank.GlobalRank());
143144

144145
nn::parallel::pp_rank = pp_rank;

infini_train/include/core/blas_handle.h

Lines changed: 0 additions & 11 deletions
This file was deleted.

infini_train/include/core/ccl/ccl.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,15 @@ class CclImpl {
6969
virtual void Recv(void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const;
7070
};
7171

72-
class Ccl {
72+
class CclGroupGuard {
7373
public:
74-
explicit Ccl(Device::DeviceType type);
75-
~Ccl();
74+
explicit CclGroupGuard(Device::DeviceType type);
75+
~CclGroupGuard();
7676

77-
Ccl(const Ccl &) = delete;
78-
Ccl &operator=(const Ccl &) = delete;
79-
Ccl(Ccl &&) = delete;
80-
Ccl &operator=(Ccl &&) = delete;
77+
CclGroupGuard(const CclGroupGuard &) = delete;
78+
CclGroupGuard &operator=(const CclGroupGuard &) = delete;
79+
CclGroupGuard(CclGroupGuard &&) = delete;
80+
CclGroupGuard &operator=(CclGroupGuard &&) = delete;
8181

8282
private:
8383
CclImpl *impl_ = nullptr;

infini_train/include/core/event.h

Lines changed: 0 additions & 11 deletions
This file was deleted.

infini_train/include/core/device_guard.h renamed to infini_train/include/core/runtime/device_guard.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
#include <memory>
55
#include <unordered_map>
66

7-
#include "infini_train/include/core/event.h"
8-
#include "infini_train/include/core/runtime_status.h"
7+
#include "infini_train/include/core/runtime/runtime_common.h"
98
#include "infini_train/include/device.h"
109

1110
namespace infini_train::core {
@@ -87,6 +86,8 @@ class DeviceGuardImpl {
8786

8887
virtual void DestroyStream(Stream *) const;
8988

89+
virtual void GetStreamPriorityRange(int *low, int *high) const;
90+
9091
// ----------------------------------------------------------------------
9192
// Event management
9293
// ----------------------------------------------------------------------

infini_train/include/core/runtime_status.h renamed to infini_train/include/core/runtime/runtime_common.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,24 @@
44

55
namespace infini_train::core {
66

7+
class BlasHandle {
8+
public:
9+
BlasHandle() = default;
10+
virtual ~BlasHandle() = default;
11+
};
12+
13+
class Event {
14+
public:
15+
Event() = default;
16+
virtual ~Event() = default;
17+
};
18+
19+
class Stream {
20+
public:
21+
Stream() = default;
22+
virtual ~Stream() = default;
23+
};
24+
725
// Generic runtime status for backend-agnostic control flow.
826
#define INFINI_TRAIN_RUNTIME_STATUS_LIST(X) \
927
X(kSuccess, 0) \

infini_train/include/core/stream.h

Lines changed: 0 additions & 11 deletions
This file was deleted.

infini_train/include/nn/parallel/process_group.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ class ProcessGroupFactory {
101101

102102
static ProcessGroupFactory *Instance();
103103

104+
static ProcessGroupFactory *Instance(Device::DeviceType backend);
105+
104106
const ProcessGroup *GetOrCreate(const std::string &name, int comm_size);
105107

106108
const ProcessGroup *GetOrCreate(const std::string &name, const std::vector<int> &device_indices);
@@ -110,7 +112,7 @@ class ProcessGroupFactory {
110112
const ProcessGroup *GetDefaultProcessGroup() const;
111113

112114
private:
113-
ProcessGroupFactory();
115+
explicit ProcessGroupFactory(Device::DeviceType backend);
114116

115117
template <typename Creator, typename = std::enable_if_t<std::is_invocable_v<Creator>>>
116118
const ProcessGroup *GetOrCreate(const std::string &name, Creator &&creator) {
@@ -135,5 +137,6 @@ class ProcessGroupFactory {
135137
mutable std::mutex mutex_;
136138
std::condition_variable cond_;
137139
std::unordered_map<std::string, std::unique_ptr<ProcessGroup>> name_to_group_;
140+
Device::DeviceType backend_ = Device::DeviceType::kInvalid;
138141
};
139142
} // namespace infini_train::nn::parallel

0 commit comments

Comments
 (0)