Skip to content

Commit 0330a98

Browse files
authored
CusolverMP re-org to accommodate updated version 0.7 (nv-legate#1013)
* CusolverMP re-org to accomodate ncclComm_t. * CusolverMP re-org clean-up. * CAL dependencies clean-up. * build deps clean-up. * Addressed reviews on meta.yaml and task.add_cal_com... removal. * Removed cal.h inclusion.
1 parent 5890e1d commit 0330a98

8 files changed

Lines changed: 44 additions & 77 deletions

File tree

conda/conda-build/meta.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ requirements:
161161
- libcurand-dev
162162
- libcufile-dev
163163
- cuda-version ={{ cuda_version }}
164-
- libcusolvermp-dev
165-
- libcal-dev
164+
- libcusolvermp-dev >=0.7
166165
{% endif %}
167166

168167
run:

cupynumeric/linalg/_cholesky.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def mp_potrf(
9595
task.add_scalar_arg(n, ty.int64)
9696
task.add_scalar_arg(nb, ty.int64)
9797
task.add_nccl_communicator() # for repartitioning
98-
task.add_cal_communicator()
9998
task.execute()
10099

101100

cupynumeric/linalg/_solve.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def mp_solve(
9292
task.add_scalar_arg(nrhs, ty.int64)
9393
task.add_scalar_arg(nb, ty.int64)
9494
task.add_nccl_communicator() # for repartitioning
95-
task.add_cal_communicator()
9695
task.execute()
9796

9897

src/cupynumeric/cuda_help.h

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
#include <cusolverDn.h>
2828
#if LEGATE_DEFINED(CUPYNUMERIC_USE_CUSOLVERMP)
2929
#include <cusolverMp.h>
30-
#include <cal.h>
3130
#endif
3231
#include <cuda_runtime.h>
3332
#include <cufft.h>
@@ -108,24 +107,6 @@ __host__ inline void check_cusolver(cusolverStatus_t status, const char* file, i
108107
}
109108
}
110109

111-
#if LEGATE_DEFINED(CUPYNUMERIC_USE_CUSOLVERMP)
112-
__host__ inline void check_cal(calError_t status, const char* file, int line)
113-
{
114-
if (status != CAL_OK) {
115-
fprintf(stderr,
116-
"Internal libcal failure with error code %d in file %s at line %d\n",
117-
status,
118-
file,
119-
line);
120-
#ifdef DEBUG_CUPYNUMERIC
121-
assert(false);
122-
#else
123-
exit(status);
124-
#endif
125-
}
126-
}
127-
#endif
128-
129110
__host__ inline void check_cutensor(cutensorStatus_t result, const char* file, int line)
130111
{
131112
if (result != CUTENSOR_STATUS_SUCCESS) {
@@ -179,12 +160,6 @@ __host__ inline void check_nccl(ncclResult_t error, const char* file, int line)
179160
cupynumeric::check_cusolver(__result__, __FILE__, __LINE__); \
180161
} while (false)
181162

182-
#define CHECK_CAL(expr) \
183-
do { \
184-
calError_t __result__ = (expr); \
185-
cupynumeric::check_cal(__result__, __FILE__, __LINE__); \
186-
} while (false)
187-
188163
#define CHECK_CUTENSOR(expr) \
189164
do { \
190165
cutensorStatus_t __result__ = (expr); \

src/cupynumeric/matrix/mp_potrf.cu

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ namespace cupynumeric {
2424
using namespace Legion;
2525
using namespace legate;
2626

27-
template <typename VAL>
28-
static inline void mp_potrf_template(cal_comm_t comm,
27+
template <typename VAL, typename comm_t>
28+
static inline void mp_potrf_template(comm_t comm,
2929
int nprow,
3030
int npcol,
3131
int64_t n,
@@ -92,7 +92,7 @@ static inline void mp_potrf_template(cal_comm_t comm,
9292
info.ptr(0)));
9393

9494
// TODO: We need a deferred exception to avoid this synchronization
95-
CHECK_CAL(cal_stream_sync(comm, stream));
95+
CUPYNUMERIC_CHECK_CUDA(cudaStreamSynchronize(stream));
9696
CUPYNUMERIC_CHECK_CUDA_STREAM(stream);
9797

9898
CHECK_CUSOLVER(cusolverMpDestroyMatrixDesc(desc));
@@ -108,8 +108,9 @@ struct MpPotrfImplBody<VariantKind::GPU, Type::Code::FLOAT32> {
108108
TaskContext context;
109109
explicit MpPotrfImplBody(TaskContext context) : context(context) {}
110110

111+
template <typename comm_t>
111112
void operator()(
112-
cal_comm_t comm, int nprow, int npcol, int64_t n, int64_t nb, float* array, int64_t lld)
113+
comm_t comm, int nprow, int npcol, int64_t n, int64_t nb, float* array, int64_t lld)
113114
{
114115
auto stream = context.get_task_stream();
115116
mp_potrf_template(comm, nprow, npcol, n, nb, array, lld, stream);
@@ -121,8 +122,9 @@ struct MpPotrfImplBody<VariantKind::GPU, Type::Code::FLOAT64> {
121122
TaskContext context;
122123
explicit MpPotrfImplBody(TaskContext context) : context(context) {}
123124

125+
template <typename comm_t>
124126
void operator()(
125-
cal_comm_t comm, int nprow, int npcol, int64_t n, int64_t nb, double* array, int64_t lld)
127+
comm_t comm, int nprow, int npcol, int64_t n, int64_t nb, double* array, int64_t lld)
126128
{
127129
auto stream = context.get_task_stream();
128130
mp_potrf_template(comm, nprow, npcol, n, nb, array, lld, stream);
@@ -134,13 +136,9 @@ struct MpPotrfImplBody<VariantKind::GPU, Type::Code::COMPLEX64> {
134136
TaskContext context;
135137
explicit MpPotrfImplBody(TaskContext context) : context(context) {}
136138

137-
void operator()(cal_comm_t comm,
138-
int nprow,
139-
int npcol,
140-
int64_t n,
141-
int64_t nb,
142-
complex<float>* array,
143-
int64_t lld)
139+
template <typename comm_t>
140+
void operator()(
141+
comm_t comm, int nprow, int npcol, int64_t n, int64_t nb, complex<float>* array, int64_t lld)
144142
{
145143
auto stream = context.get_task_stream();
146144
mp_potrf_template(comm, nprow, npcol, n, nb, reinterpret_cast<cuComplex*>(array), lld, stream);
@@ -152,13 +150,9 @@ struct MpPotrfImplBody<VariantKind::GPU, Type::Code::COMPLEX128> {
152150
TaskContext context;
153151
explicit MpPotrfImplBody(TaskContext context) : context(context) {}
154152

155-
void operator()(cal_comm_t comm,
156-
int nprow,
157-
int npcol,
158-
int64_t n,
159-
int64_t nb,
160-
complex<double>* array,
161-
int64_t lld)
153+
template <typename comm_t>
154+
void operator()(
155+
comm_t comm, int nprow, int npcol, int64_t n, int64_t nb, complex<double>* array, int64_t lld)
162156
{
163157
auto stream = context.get_task_stream();
164158
mp_potrf_template(

src/cupynumeric/matrix/mp_potrf_template.inl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
#include "cupynumeric/cuda_help.h"
2626
#include "cupynumeric/utilities/repartition.h"
2727

28-
#include <cal.h>
29-
3028
namespace cupynumeric {
3129

3230
using namespace Legion;
@@ -64,12 +62,13 @@ struct MpPotrfImpl {
6462
auto input_shape = input_array.shape<2>();
6563
auto output_shape = output_array.shape<2>();
6664

65+
auto* p_nccl_comm = comms[0].get<ncclComm_t*>();
6766
int rank, num_ranks;
68-
auto nccl_comm = comms[0];
69-
auto cal_comm = comms[1].get<cal_comm_t>();
70-
assert(cal_comm);
71-
CHECK_CAL(cal_comm_get_rank(cal_comm, &rank));
72-
CHECK_CAL(cal_comm_get_size(cal_comm, &num_ranks));
67+
assert(p_nccl_comm);
68+
auto nccl_comm = *p_nccl_comm;
69+
CHECK_NCCL(ncclCommUserRank(nccl_comm, &rank));
70+
CHECK_NCCL(ncclCommCount(nccl_comm, &num_ranks));
71+
7372
assert(launch_domain.get_volume() == num_ranks);
7473
assert(launch_domain.get_dim() <= 2);
7574

@@ -115,10 +114,10 @@ struct MpPotrfImpl {
115114
auto volume = num_rows * num_cols;
116115

117116
auto [buffer_2dbc, volume_2dbc, lld_2dbc] = repartition_matrix_2dbc(
118-
input_arr, volume, false, offset_r, offset_c, lld, nprow, npcol, nb, nb, nccl_comm, context);
117+
input_arr, volume, false, offset_r, offset_c, lld, nprow, npcol, nb, nb, comms[0], context);
119118

120119
MpPotrfImplBody<KIND, CODE>{context}(
121-
cal_comm, nprow, npcol, n, nb, buffer_2dbc.ptr(0), lld_2dbc);
120+
nccl_comm, nprow, npcol, n, nb, buffer_2dbc.ptr(0), lld_2dbc);
122121

123122
repartition_matrix_block(buffer_2dbc,
124123
volume_2dbc,
@@ -136,7 +135,7 @@ struct MpPotrfImpl {
136135
false,
137136
offset_r,
138137
offset_c,
139-
nccl_comm,
138+
comms[0],
140139
context);
141140
}
142141

src/cupynumeric/matrix/mp_solve.cu

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ namespace cupynumeric {
2424
using namespace Legion;
2525
using namespace legate;
2626

27-
template <typename VAL>
28-
static inline void mp_solve_template(cal_comm_t comm,
27+
template <typename VAL, typename comm_t>
28+
static inline void mp_solve_template(comm_t comm,
2929
int nprow,
3030
int npcol,
3131
int64_t n,
@@ -145,7 +145,7 @@ static inline void mp_solve_template(cal_comm_t comm,
145145
info.ptr(0)));
146146

147147
// TODO: We need a deferred exception to avoid this synchronization
148-
CHECK_CAL(cal_stream_sync(comm, stream));
148+
CUPYNUMERIC_CHECK_CUDA(cudaStreamSynchronize(stream));
149149
CUPYNUMERIC_CHECK_CUDA_STREAM(stream);
150150

151151
CHECK_CUSOLVER(cusolverMpDestroyMatrixDesc(a_desc));
@@ -163,7 +163,8 @@ struct MpSolveImplBody<VariantKind::GPU, Type::Code::FLOAT32> {
163163
TaskContext context;
164164
explicit MpSolveImplBody(TaskContext context) : context(context) {}
165165

166-
void operator()(cal_comm_t comm,
166+
template <typename comm_t>
167+
void operator()(comm_t comm,
167168
int nprow,
168169
int npcol,
169170
int64_t n,
@@ -184,7 +185,8 @@ struct MpSolveImplBody<VariantKind::GPU, Type::Code::FLOAT64> {
184185
TaskContext context;
185186
explicit MpSolveImplBody(TaskContext context) : context(context) {}
186187

187-
void operator()(cal_comm_t comm,
188+
template <typename comm_t>
189+
void operator()(comm_t comm,
188190
int nprow,
189191
int npcol,
190192
int64_t n,
@@ -205,7 +207,8 @@ struct MpSolveImplBody<VariantKind::GPU, Type::Code::COMPLEX64> {
205207
TaskContext context;
206208
explicit MpSolveImplBody(TaskContext context) : context(context) {}
207209

208-
void operator()(cal_comm_t comm,
210+
template <typename comm_t>
211+
void operator()(comm_t comm,
209212
int nprow,
210213
int npcol,
211214
int64_t n,
@@ -236,7 +239,8 @@ struct MpSolveImplBody<VariantKind::GPU, Type::Code::COMPLEX128> {
236239
TaskContext context;
237240
explicit MpSolveImplBody(TaskContext context) : context(context) {}
238241

239-
void operator()(cal_comm_t comm,
242+
template <typename comm_t>
243+
void operator()(comm_t comm,
240244
int nprow,
241245
int npcol,
242246
int64_t n,

src/cupynumeric/matrix/mp_solve_template.inl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
#include "cupynumeric/cuda_help.h"
2626
#include "cupynumeric/utilities/repartition.h"
2727

28-
#include <cal.h>
29-
3028
namespace cupynumeric {
3129

3230
using namespace Legion;
@@ -63,13 +61,13 @@ struct MpSolveImpl {
6361
{
6462
using VAL = type_of<CODE>;
6563

66-
auto nccl_comm = comms[0];
67-
auto cal_comm = comms[1].get<cal_comm_t>();
68-
64+
auto* p_nccl_comm = comms[0].get<ncclComm_t*>();
6965
int rank, num_ranks;
70-
assert(cal_comm);
71-
CHECK_CAL(cal_comm_get_rank(cal_comm, &rank));
72-
CHECK_CAL(cal_comm_get_size(cal_comm, &num_ranks));
66+
assert(p_nccl_comm);
67+
auto nccl_comm = *p_nccl_comm;
68+
CHECK_NCCL(ncclCommUserRank(nccl_comm, &rank));
69+
CHECK_NCCL(ncclCommCount(nccl_comm, &num_ranks));
70+
7371
assert(launch_domain.get_volume() == num_ranks);
7472
assert(launch_domain.get_dim() <= 2);
7573

@@ -127,7 +125,7 @@ struct MpSolveImpl {
127125
npcol,
128126
nb,
129127
nb,
130-
nccl_comm,
128+
comms[0],
131129
context);
132130

133131
auto b_offset_r = b_shape.lo[0];
@@ -144,10 +142,10 @@ struct MpSolveImpl {
144142
npcol,
145143
nb,
146144
nb,
147-
nccl_comm,
145+
comms[0],
148146
context);
149147

150-
MpSolveImplBody<KIND, CODE>{context}(cal_comm,
148+
MpSolveImplBody<KIND, CODE>{context}(nccl_comm,
151149
nprow,
152150
npcol,
153151
n,
@@ -177,7 +175,7 @@ struct MpSolveImpl {
177175
false, // x_shape is enforced col-major
178176
b_offset_r,
179177
b_offset_c,
180-
nccl_comm,
178+
comms[0],
181179
context);
182180
}
183181

0 commit comments

Comments
 (0)