@@ -24,8 +24,8 @@ namespace cupynumeric {
2424using namespace Legion ;
2525using namespace legate ;
2626
27- template <typename VAL>
28- static inline void mp_potrf_template (cal_comm_t comm,
27+ template <typename VAL, typename comm_t >
28+ static inline void mp_potrf_template (comm_t comm,
2929 int nprow,
3030 int npcol,
3131 int64_t n,
@@ -92,7 +92,7 @@ static inline void mp_potrf_template(cal_comm_t comm,
9292 info.ptr (0 )));
9393
9494 // TODO: We need a deferred exception to avoid this synchronization
95- CHECK_CAL ( cal_stream_sync (comm, stream));
95+ CUPYNUMERIC_CHECK_CUDA ( cudaStreamSynchronize ( stream));
9696 CUPYNUMERIC_CHECK_CUDA_STREAM (stream);
9797
9898 CHECK_CUSOLVER (cusolverMpDestroyMatrixDesc (desc));
@@ -108,8 +108,9 @@ struct MpPotrfImplBody<VariantKind::GPU, Type::Code::FLOAT32> {
108108 TaskContext context;
109109 explicit MpPotrfImplBody (TaskContext context) : context(context) {}
110110
111+ template <typename comm_t >
111112 void operator ()(
112- cal_comm_t comm, int nprow, int npcol, int64_t n, int64_t nb, float * array, int64_t lld)
113+ comm_t comm, int nprow, int npcol, int64_t n, int64_t nb, float * array, int64_t lld)
113114 {
114115 auto stream = context.get_task_stream ();
115116 mp_potrf_template (comm, nprow, npcol, n, nb, array, lld, stream);
@@ -121,8 +122,9 @@ struct MpPotrfImplBody<VariantKind::GPU, Type::Code::FLOAT64> {
121122 TaskContext context;
122123 explicit MpPotrfImplBody (TaskContext context) : context(context) {}
123124
125+ template <typename comm_t >
124126 void operator ()(
125- cal_comm_t comm, int nprow, int npcol, int64_t n, int64_t nb, double * array, int64_t lld)
127+ comm_t comm, int nprow, int npcol, int64_t n, int64_t nb, double * array, int64_t lld)
126128 {
127129 auto stream = context.get_task_stream ();
128130 mp_potrf_template (comm, nprow, npcol, n, nb, array, lld, stream);
@@ -134,13 +136,9 @@ struct MpPotrfImplBody<VariantKind::GPU, Type::Code::COMPLEX64> {
134136 TaskContext context;
135137 explicit MpPotrfImplBody (TaskContext context) : context(context) {}
136138
137- void operator ()(cal_comm_t comm,
138- int nprow,
139- int npcol,
140- int64_t n,
141- int64_t nb,
142- complex <float >* array,
143- int64_t lld)
139+ template <typename comm_t >
140+ void operator ()(
141+ comm_t comm, int nprow, int npcol, int64_t n, int64_t nb, complex <float >* array, int64_t lld)
144142 {
145143 auto stream = context.get_task_stream ();
146144 mp_potrf_template (comm, nprow, npcol, n, nb, reinterpret_cast <cuComplex*>(array), lld, stream);
@@ -152,13 +150,9 @@ struct MpPotrfImplBody<VariantKind::GPU, Type::Code::COMPLEX128> {
152150 TaskContext context;
153151 explicit MpPotrfImplBody (TaskContext context) : context(context) {}
154152
155- void operator ()(cal_comm_t comm,
156- int nprow,
157- int npcol,
158- int64_t n,
159- int64_t nb,
160- complex <double >* array,
161- int64_t lld)
153+ template <typename comm_t >
154+ void operator ()(
155+ comm_t comm, int nprow, int npcol, int64_t n, int64_t nb, complex <double >* array, int64_t lld)
162156 {
163157 auto stream = context.get_task_stream ();
164158 mp_potrf_template (
0 commit comments