Skip to content

Commit 3d1c382

Browse files
authored
Merge pull request #530 from abergeron/fix_fix
Fix potential race conditions
2 parents a18251f + 957301f commit 3d1c382

2 files changed

Lines changed: 24 additions & 4 deletions

File tree

src/gpuarray_blas_cuda_cublas.c

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -835,14 +835,23 @@ static int sgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
835835
return ctx->err->code;
836836
}
837837

838-
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(Ta, CUDA_WAIT_READ));
838+
if (cuda_wait(Ta, CUDA_WAIT_READ) != GA_NO_ERROR) {
839+
gpudata_release(Ta);
840+
cuda_exit(ctx);
841+
return ctx->err->code;
842+
}
839843

840844
err = cublasSgemmBatched(h->h,
841845
convT(transA), convT(transB),
842846
M, N, K, &alpha,
843847
(const float **)Aa, lda,
844848
(const float **)Ba, ldb, &beta,
845849
(float **)Ca, ldc, batchCount);
850+
if (cuda_record(Ta, CUDA_WAIT_READ) != GA_NO_ERROR) {
851+
gpudata_release(Ta);
852+
cuda_exit(ctx);
853+
return ctx->err->code;
854+
}
846855
gpudata_release(Ta);
847856
if (err != CUBLAS_STATUS_SUCCESS) {
848857
cuda_exit(ctx);
@@ -964,15 +973,26 @@ static int dgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
964973
return ctx->err->code;
965974
}
966975

967-
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(Ta, CUDA_WAIT_READ));
976+
if (cuda_wait(Ta, CUDA_WAIT_READ) != GA_NO_ERROR) {
977+
gpudata_release(Ta);
978+
cuda_exit(ctx);
979+
return ctx->err->code;
980+
}
968981

969982
err = cublasDgemmBatched(h->h,
970983
convT(transA), convT(transB),
971984
M, N, K, &alpha,
972985
(const double **)Aa, lda,
973986
(const double **)Ba, ldb, &beta,
974987
(double **)Ca, ldc, batchCount);
988+
989+
if (cuda_record(Ta, CUDA_WAIT_READ) != GA_NO_ERROR) {
990+
gpudata_release(Ta);
991+
cuda_exit(ctx);
992+
return ctx->err->code;
993+
}
975994
gpudata_release(Ta);
995+
976996
if (err != CUBLAS_STATUS_SUCCESS) {
977997
cuda_exit(ctx);
978998
return error_cublas(ctx->err, "cublasDgemmBatched", err);

src/gpuarray_buffer_cuda.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -851,8 +851,8 @@ static void cuda_free(gpudata *d) {
851851
d->ptr + d->sz == next->ptr) {
852852
d->sz = d->sz + next->sz;
853853
d->next = next->next;
854-
cuda_wait(next, CUDA_WAIT_ALL);
855-
cuda_record(d, CUDA_WAIT_ALL);
854+
cuda_waits(next, CUDA_WAIT_ALL, d->ls);
855+
cuda_records(d, CUDA_WAIT_ALL, d->ls);
856856
deallocate(next);
857857
} else {
858858
d->next = next;

0 commit comments

Comments
 (0)