@@ -60,8 +60,6 @@ using namespace nvcuda;
6060#define MATRIX_N 16384
6161#define MATRIX_K 16384
6262
63-
64-
6563// The only dimensions currently supported by WMMA
6664const int WMMA_M = 16 ;
6765const int WMMA_N = 16 ;
@@ -105,7 +103,7 @@ __global__ void wmma_example(half *a, half *b, float *c, int M, int N, int K, fl
105103 // Load the inputs
106104 wmma::load_matrix_sync (a_frag, a + aRow + aCol * lda, lda);
107105 wmma::load_matrix_sync (b_frag, b + bRow + bCol * ldb, ldb);
108-
106+
109107 // Perform the matrix multiplication
110108 wmma::mma_sync (acc_frag, a_frag, b_frag, acc_frag);
111109
@@ -119,7 +117,7 @@ __global__ void wmma_example(half *a, half *b, float *c, int M, int N, int K, fl
119117 if (cRow < M && cCol < N) {
120118 wmma::load_matrix_sync (c_frag, c + cRow + cCol * ldc, ldc, wmma::mem_col_major);
121119
122-
120+ # pragma unroll
123121 for (int i=0 ; i < c_frag.num_elements ; i++) {
124122 c_frag.x [i] = alpha * acc_frag.x [i] + beta * c_frag.x [i];
125123 }
@@ -221,21 +219,35 @@ int main(int argc, char* argv[]) {
221219 cudaErrCheck (cudaEventRecord (startWMMA));
222220 wmma_example <<< gridDim , blockDim >>> (a_fp16, b_fp16, c_wmma, MATRIX_M, MATRIX_N, MATRIX_K, alpha, beta);
223221 cudaErrCheck (cudaEventRecord (stopWMMA));
222+ cudaErrCheck (cudaEventSynchronize (stopWMMA));
224223
225-
226-
227224 // Now using cuBLAS
228225 printf (" Running with cuBLAS...\n " );
229- cudaErrCheck ( cudaEventRecord (startcublas));
226+ // Warm up cuBLAS run starts
230227 cublasErrCheck (cublasGemmEx (cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N,
231228 MATRIX_M, MATRIX_N, MATRIX_K,
232229 &alpha,
233230 a_fp16, CUDA_R_16F, MATRIX_M,
234231 b_fp16, CUDA_R_16F, MATRIX_K,
235232 &beta,
236233 c_cublas, CUDA_R_32F, MATRIX_M,
237- CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP));
234+ CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
235+ // Warm up cuBLAS run ends
236+
237+ // reset the c_cublas buffer
238+ cudaErrCheck (cudaMemcpy (c_cublas, c, MATRIX_M * MATRIX_N * sizeof (float ), cudaMemcpyDeviceToDevice));
239+
240+ cudaErrCheck (cudaEventRecord (startcublas));
241+ cublasErrCheck (cublasGemmEx (cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N,
242+ MATRIX_M, MATRIX_N, MATRIX_K,
243+ &alpha,
244+ a_fp16, CUDA_R_16F, MATRIX_M,
245+ b_fp16, CUDA_R_16F, MATRIX_K,
246+ &beta,
247+ c_cublas, CUDA_R_32F, MATRIX_M,
248+ CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
238249 cudaErrCheck (cudaEventRecord (stopcublas));
250+ cudaErrCheck (cudaEventSynchronize (stopcublas));
239251
240252 // Error checking
241253 printf (" \n Checking results...\n " );
@@ -247,7 +259,10 @@ int main(int argc, char* argv[]) {
247259 for (int i = 0 ; i < MATRIX_M * MATRIX_N; i++) {
248260 float v1 = c_host_wmma[i];
249261 float v2 = c_host_cublas[i];
250- if (v1 / v2 > 1.0001 || v2 / v1 > 1.0001 || abs (v1 - v2) > 1e-5 ) {
262+ float diff = fabs (v1 - v2);
263+ float relative_err = diff / v2;
264+ float eps = 1e-4 ;
265+ if ((relative_err >= eps)) {
251266 errors++;
252267 if (errors < 10 ) printf (" %f %f\n " , v1, v2);
253268 }
@@ -260,8 +275,6 @@ int main(int argc, char* argv[]) {
260275 printf (" Results verified: cublas and WMMA agree.\n\n " );
261276 float wmmaTime;
262277 float cublasTime;
263- cudaErrCheck (cudaEventSynchronize (stopWMMA));
264- cudaErrCheck (cudaEventSynchronize (stopcublas));
265278 cudaErrCheck (cudaEventElapsedTime (&wmmaTime, startWMMA, stopWMMA));
266279 cudaErrCheck (cudaEventElapsedTime (&cublasTime, startcublas, stopcublas));
267280 printf (" wmma took %fms\n " , wmmaTime);
0 commit comments