Skip to content

Commit 708ce91

Browse files
authored
fix the output error check to use rel_error, fix cublas time reporting issue due to startup time addition, add option in makefile to specify arch at build time (#51)
1 parent 2d8ba9f commit 708ce91

2 files changed

Lines changed: 26 additions & 12 deletions

File tree

posts/tensor-cores/Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27+
SM_ARCH ?= 70
2728

2829
all: simpleTensorCoreGEMM.cu
29-
nvcc -o TCGemm -arch=sm_70 -lcublas -lcurand simpleTensorCoreGEMM.cu
30+
nvcc -o TCGemm -arch=sm_$(SM_ARCH) -lcublas -lcurand simpleTensorCoreGEMM.cu
3031

3132
clean:
3233
rm -f TCGemm

posts/tensor-cores/simpleTensorCoreGEMM.cu

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ using namespace nvcuda;
6060
#define MATRIX_N 16384
6161
#define MATRIX_K 16384
6262

63-
64-
6563
// The only dimensions currently supported by WMMA
6664
const int WMMA_M = 16;
6765
const int WMMA_N = 16;
@@ -105,7 +103,7 @@ __global__ void wmma_example(half *a, half *b, float *c, int M, int N, int K, fl
105103
// Load the inputs
106104
wmma::load_matrix_sync(a_frag, a + aRow + aCol * lda, lda);
107105
wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);
108-
106+
109107
// Perform the matrix multiplication
110108
wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
111109

@@ -119,7 +117,7 @@ __global__ void wmma_example(half *a, half *b, float *c, int M, int N, int K, fl
119117
if (cRow < M && cCol < N) {
120118
wmma::load_matrix_sync(c_frag, c + cRow + cCol * ldc, ldc, wmma::mem_col_major);
121119

122-
120+
#pragma unroll
123121
for(int i=0; i < c_frag.num_elements; i++) {
124122
c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
125123
}
@@ -221,21 +219,35 @@ int main(int argc, char* argv[]) {
221219
cudaErrCheck(cudaEventRecord(startWMMA));
222220
wmma_example <<< gridDim, blockDim >>> (a_fp16, b_fp16, c_wmma, MATRIX_M, MATRIX_N, MATRIX_K, alpha, beta);
223221
cudaErrCheck(cudaEventRecord(stopWMMA));
222+
cudaErrCheck(cudaEventSynchronize(stopWMMA));
224223

225-
226-
227224
// Now using cuBLAS
228225
printf("Running with cuBLAS...\n");
229-
cudaErrCheck(cudaEventRecord(startcublas));
226+
// Warm up cuBLAS run starts
230227
cublasErrCheck(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N,
231228
MATRIX_M, MATRIX_N, MATRIX_K,
232229
&alpha,
233230
a_fp16, CUDA_R_16F, MATRIX_M,
234231
b_fp16, CUDA_R_16F, MATRIX_K,
235232
&beta,
236233
c_cublas, CUDA_R_32F, MATRIX_M,
237-
CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP));
234+
CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
235+
// Warm up cuBLAS run ends
236+
237+
// reset the c_cublas buffer
238+
cudaErrCheck(cudaMemcpy(c_cublas, c, MATRIX_M * MATRIX_N * sizeof(float), cudaMemcpyDeviceToDevice));
239+
240+
cudaErrCheck(cudaEventRecord(startcublas));
241+
cublasErrCheck(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N,
242+
MATRIX_M, MATRIX_N, MATRIX_K,
243+
&alpha,
244+
a_fp16, CUDA_R_16F, MATRIX_M,
245+
b_fp16, CUDA_R_16F, MATRIX_K,
246+
&beta,
247+
c_cublas, CUDA_R_32F, MATRIX_M,
248+
CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
238249
cudaErrCheck(cudaEventRecord(stopcublas));
250+
cudaErrCheck(cudaEventSynchronize(stopcublas));
239251

240252
// Error checking
241253
printf("\nChecking results...\n");
@@ -247,7 +259,10 @@ int main(int argc, char* argv[]) {
247259
for (int i = 0; i < MATRIX_M * MATRIX_N; i++) {
248260
float v1 = c_host_wmma[i];
249261
float v2 = c_host_cublas[i];
250-
if (v1 / v2 > 1.0001 || v2 / v1 > 1.0001 || abs(v1 - v2) > 1e-5) {
262+
float diff = fabs(v1 - v2);
263+
float relative_err = diff / v2;
264+
float eps = 1e-4;
265+
if ((relative_err >= eps)) {
251266
errors++;
252267
if (errors < 10) printf("%f %f\n", v1, v2);
253268
}
@@ -260,8 +275,6 @@ int main(int argc, char* argv[]) {
260275
printf("Results verified: cublas and WMMA agree.\n\n");
261276
float wmmaTime;
262277
float cublasTime;
263-
cudaErrCheck(cudaEventSynchronize(stopWMMA));
264-
cudaErrCheck(cudaEventSynchronize(stopcublas));
265278
cudaErrCheck(cudaEventElapsedTime(&wmmaTime, startWMMA, stopWMMA));
266279
cudaErrCheck(cudaEventElapsedTime(&cublasTime, startcublas, stopcublas));
267280
printf("wmma took %fms\n", wmmaTime);

0 commit comments

Comments
 (0)