Skip to content

Commit cf702b5

Browse files
committed
Make all-dims-reduced usecase work.
All-dims-reduced will be slow but does work now without errors. Added testcase to ensure this remains the case.
1 parent 06602be commit cf702b5

2 files changed

Lines changed: 107 additions & 15 deletions

File tree

src/gpuarray_reduction.c

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -364,24 +364,26 @@ static void maxandargmaxAppendIndexDeclarations(maxandargmax_ctx* ctx){
364364
strb_appends(&ctx->s, "\tX bd0 = LDIM_0, bd1 = LDIM_1, bd2 = LDIM_2;\n");
365365
strb_appends(&ctx->s, "\tX ti0 = LID_0, ti1 = LID_1, ti2 = LID_2;\n");
366366
strb_appends(&ctx->s, "\tX gi0 = bi0*bd0+ti0, gi1 = bi1*bd1+ti1, gi2 = bi2*bd2+ti2;\n");
367-
strb_appends(&ctx->s, "\tX ");
368-
for(i=0;i<ctx->ndh;i++){
369-
strb_appendf(&ctx->s, "ci%u = chunkSize[%u]%s",
370-
i, i, (i==ctx->ndh-1) ? ";\n" : ", ");
367+
if(ctx->ndh>0){
368+
strb_appends(&ctx->s, "\tX ");
369+
for(i=0;i<ctx->ndh;i++){
370+
strb_appendf(&ctx->s, "ci%u = chunkSize[%u]%s",
371+
i, i, (i==ctx->ndh-1) ? ";\n" : ", ");
372+
}
371373
}
372374

373375
strb_appends(&ctx->s, "\t\n");
374376
strb_appends(&ctx->s, "\t\n");
375377
strb_appends(&ctx->s, "\t/* Free indices & Reduction indices */\n");
376378

377-
appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->nds, "", ";\n");
378-
appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->nds, "Dim", ";\n");
379-
appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->nds, "Start", ";\n");
380-
appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->nds, "End", ";\n");
381-
appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->nds, "SStep", ";\n");
382-
appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->ndd, "MStep", ";\n");
383-
appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->ndd, "AStep", ";\n");
384-
appendIdxes (&ctx->s, "\tX ", "i", ctx->ndd, ctx->nds, "PDim", ";\n");
379+
if(ctx->nds > 0){appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->nds, "", ";\n");}
380+
if(ctx->nds > 0){appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->nds, "Dim", ";\n");}
381+
if(ctx->nds > 0){appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->nds, "Start", ";\n");}
382+
if(ctx->nds > 0){appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->nds, "End", ";\n");}
383+
if(ctx->nds > 0){appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->nds, "SStep", ";\n");}
384+
if(ctx->ndd > 0){appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->ndd, "MStep", ";\n");}
385+
if(ctx->ndd > 0){appendIdxes (&ctx->s, "\tX ", "i", 0, ctx->ndd, "AStep", ";\n");}
386+
if(ctx->nds > ctx->ndd){appendIdxes (&ctx->s, "\tX ", "i", ctx->ndd, ctx->nds, "PDim", ";\n");}
385387

386388
strb_appends(&ctx->s, "\t\n");
387389
strb_appends(&ctx->s, "\t\n");
@@ -725,8 +727,10 @@ static int maxandargmaxSchedule (maxandargmax_ctx* ctx){
725727
}
726728
}
727729

728-
dims[bestWarpAxis] = (dims[bestWarpAxis] + warpSize - 1)/warpSize;
729-
gaIFactorize(warpSize, 0, 0, &factBS[bestWarpAxis]);
730+
if(ctx->ndh > 0){
731+
dims[bestWarpAxis] = (dims[bestWarpAxis] + warpSize - 1)/warpSize;
732+
gaIFactorize(warpSize, 0, 0, &factBS[bestWarpAxis]);
733+
}
730734

731735
/**
732736
* Factorization job. We'll steadily increase the slack in case of failure
@@ -806,7 +810,7 @@ static int maxandargmaxInvoke (maxandargmax_ctx* ctx){
806810
ctx->dstMaxStepsGD &&
807811
ctx->dstArgmaxStepsGD){
808812
ctx->ret = GpuKernel_call(&ctx->kernel,
809-
ctx->ndh,
813+
ctx->ndh>0 ? ctx->ndh : 1,
810814
ctx->blockSize,
811815
ctx->gridSize,
812816
0,

tests/check_reduction.c

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,93 @@ START_TEST(test_veryhighrank){
348348
GpuArray_clear(&gaArgmax);
349349
}END_TEST
350350

351+
START_TEST(test_alldimsreduced){
352+
pcgSeed(1);
353+
354+
/**
355+
* We test here a reduction of some random 3D tensor on all dimensions.
356+
*/
357+
358+
size_t i,j,k;
359+
size_t dims[3] = {32,50,79};
360+
size_t prodDims = dims[0]*dims[1]*dims[2];
361+
const unsigned reduxList[] = {0,1,2};
362+
363+
float* pSrc = calloc(1, sizeof(*pSrc) * dims[0]*dims[1]*dims[2]);
364+
float* pMax = calloc(1, sizeof(*pMax) );
365+
size_t* pArgmax = calloc(1, sizeof(*pArgmax) );
366+
367+
ck_assert_ptr_ne(pSrc, NULL);
368+
ck_assert_ptr_ne(pMax, NULL);
369+
ck_assert_ptr_ne(pArgmax, NULL);
370+
371+
372+
/**
373+
* Initialize source data.
374+
*/
375+
376+
for(i=0;i<prodDims;i++){
377+
pSrc[i] = pcgRand01();
378+
}
379+
380+
381+
/**
382+
* Run the kernel.
383+
*/
384+
385+
GpuArray gaSrc;
386+
GpuArray gaMax;
387+
GpuArray gaArgmax;
388+
389+
ga_assert_ok(GpuArray_empty(&gaSrc, ctx, GA_FLOAT, 3, &dims[0], GA_C_ORDER));
390+
ga_assert_ok(GpuArray_empty(&gaMax, ctx, GA_FLOAT, 0, NULL, GA_C_ORDER));
391+
ga_assert_ok(GpuArray_empty(&gaArgmax, ctx, GA_SIZE, 0, NULL, GA_C_ORDER));
392+
393+
ga_assert_ok(GpuArray_write(&gaSrc, pSrc, sizeof(*pSrc)*prodDims));
394+
ga_assert_ok(GpuArray_memset(&gaMax, -1)); /* 0xFFFFFFFF is a qNaN. */
395+
ga_assert_ok(GpuArray_memset(&gaArgmax, -1));
396+
397+
ga_assert_ok(GpuArray_maxandargmax(&gaMax, &gaArgmax, &gaSrc, 3, reduxList));
398+
399+
ga_assert_ok(GpuArray_read(pMax, sizeof(*pMax), &gaMax));
400+
ga_assert_ok(GpuArray_read(pArgmax, sizeof(*pArgmax), &gaArgmax));
401+
402+
403+
/**
404+
* Check that the destination tensors are correct.
405+
*/
406+
407+
size_t gtArgmax = 0;
408+
float gtMax = pSrc[0];
409+
410+
for(i=0;i<dims[0];i++){
411+
for(j=0;j<dims[1];j++){
412+
for(k=0;k<dims[2];k++){
413+
float v = pSrc[(i*dims[1] + j)*dims[2] + k];
414+
415+
if(v > gtMax){
416+
gtMax = v;
417+
gtArgmax = (i*dims[1] + j)*dims[2] + k;
418+
}
419+
}
420+
}
421+
}
422+
423+
ck_assert_msg(gtMax == pMax[0], "Max value mismatch!");
424+
ck_assert_msg(gtArgmax == pArgmax[0], "Argmax value mismatch!");
425+
426+
/**
427+
* Deallocate.
428+
*/
429+
430+
free(pSrc);
431+
free(pMax);
432+
free(pArgmax);
433+
GpuArray_clear(&gaSrc);
434+
GpuArray_clear(&gaMax);
435+
GpuArray_clear(&gaArgmax);
436+
}END_TEST
437+
351438
Suite *get_suite(void) {
352439
Suite *s = suite_create("reduction");
353440
TCase *tc = tcase_create("basic");
@@ -357,6 +444,7 @@ Suite *get_suite(void) {
357444
tcase_add_test(tc, test_reduction);
358445
tcase_add_test(tc, test_idxtranspose);
359446
tcase_add_test(tc, test_veryhighrank);
447+
tcase_add_test(tc, test_alldimsreduced);
360448

361449
suite_add_tcase(s, tc);
362450
return s;

0 commit comments

Comments
 (0)