-
Notifications
You must be signed in to change notification settings - Fork 63
Expand file tree
/
Copy pathcoll.h
More file actions
224 lines (177 loc) · 5.96 KB
/
coll.h
File metadata and controls
224 lines (177 loc) · 5.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
/* Copyright 2022 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#pragma once
#include <stdbool.h>
#include <stddef.h>
#include <vector>
#ifdef LEGATE_USE_GASNET
#include <mpi.h>
#else
// If we aren't using GASNet, we'll use pthread_barrier to
// construct a communicator for thread-local communication. Mac OS
// does not implement pthread barriers, so we need to include an
// implementation in case they are not defined. We also need to
// include unistd.h since that defines _POSIX_BARRIERS.
#include <unistd.h>
#if !defined(_POSIX_BARRIERS) || (_POSIX_BARRIERS < 0)
#include "pthread_barrier.h"
#endif
#endif
namespace legate {
namespace comm {
namespace coll {
#ifdef LEGATE_USE_GASNET
#define CHECK_MPI(expr) \
do { \
int result = (expr); \
check_mpi(result, __FILE__, __LINE__); \
} while (false)
struct RankMappingTable {
int* mpi_rank;
int* global_rank;
};
#else
struct ThreadComm {
pthread_barrier_t barrier;
bool ready_flag;
const void** buffers;
const int** displs;
int* buffer_ready; // use for p2p with size = comm_size*comm_size
};
#endif
enum class CollDataType : int {
CollInt8 = 0,
CollChar = 1,
CollUint8 = 2,
CollInt = 3,
CollUint32 = 4,
CollInt64 = 5,
CollUint64 = 6,
CollFloat = 7,
CollDouble = 8,
};
enum CollStatus : int {
CollSuccess = 0,
CollError = 1,
};
struct Coll_Comm {
#ifdef LEGATE_USE_GASNET
MPI_Comm comm;
RankMappingTable mapping_table;
#else
volatile ThreadComm* comm;
#endif
int mpi_rank;
int mpi_comm_size;
int mpi_comm_size_actual;
int global_rank;
int global_comm_size;
int nb_threads;
int unique_id;
bool status;
};
typedef Coll_Comm* CollComm;
int collCommCreate(CollComm global_comm,
int global_comm_size,
int global_rank,
int unique_id,
const int* mapping_table);
int collCommDestroy(CollComm global_comm);
int collAlltoallv(const void* sendbuf,
const int sendcounts[],
const int sdispls[],
void* recvbuf,
const int recvcounts[],
const int rdispls[],
CollDataType type,
CollComm global_comm);
int collAlltoall(
const void* sendbuf, void* recvbuf, int count, CollDataType type, CollComm global_comm);
int collAllgather(
const void* sendbuf, void* recvbuf, int count, CollDataType type, CollComm global_comm);
int collSend(
const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm);
int collRecv(
void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm);
int collInit(int argc, char* argv[]);
int collFinalize();
int collGetUniqueId(int* id);
int collInitComm();
// The following functions should not be called by users
#ifdef LEGATE_USE_GASNET
int alltoallvMPI(const void* sendbuf,
const int sendcounts[],
const int sdispls[],
void* recvbuf,
const int recvcounts[],
const int rdispls[],
CollDataType type,
CollComm global_comm);
int alltoallMPI(
const void* sendbuf, void* recvbuf, int count, CollDataType type, CollComm global_comm);
int gatherMPI(
const void* sendbuf, void* recvbuf, int count, CollDataType type, int root, CollComm global_comm);
int allgatherMPI(
const void* sendbuf, void* recvbuf, int count, CollDataType type, CollComm global_comm);
int bcastMPI(void* buf, int count, CollDataType type, int root, CollComm global_comm);
int sendMPI(
const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm);
int recvMPI(void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm);
MPI_Datatype dtypeToMPIDtype(CollDataType dtype);
int generateAlltoallTag(int rank1, int rank2, CollComm global_comm);
int generateAlltoallvTag(int rank1, int rank2, CollComm global_comm);
int generateBcastTag(int rank, CollComm global_comm);
int generateGatherTag(int rank, CollComm global_comm);
int generateP2PTag(int user_tag);
#else
size_t getDtypeSize(CollDataType dtype);
int alltoallvLocal(const void* sendbuf,
const int sendcounts[],
const int sdispls[],
void* recvbuf,
const int recvcounts[],
const int rdispls[],
CollDataType type,
CollComm global_comm);
int alltoallLocal(
const void* sendbuf, void* recvbuf, int count, CollDataType type, CollComm global_comm);
int allgatherLocal(
const void* sendbuf, void* recvbuf, int count, CollDataType type, CollComm global_comm);
int sendLocal(
const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm);
int recvLocal(
void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm);
void resetLocalBuffer(CollComm global_comm);
void barrierLocal(CollComm global_comm);
#endif
void* allocateInplaceBuffer(const void* recvbuf, size_t size);
#ifdef LEGATE_USE_GASNET
inline void check_mpi(int error, const char* file, int line)
{
if (error != MPI_SUCCESS) {
fprintf(
stderr, "Internal MPI failure with error code %d in file %s at line %d\n", error, file, line);
#ifdef DEBUG_LEGATE
assert(false);
#else
exit(error);
#endif
}
}
#endif
} // namespace coll
} // namespace comm
} // namespace legate