Skip to content

Commit a2d0dd9

Browse files
committed
Add test for the scalar dtype error.
1 parent 1c1e068 commit a2d0dd9

1 file changed

Lines changed: 55 additions & 0 deletions

File tree

tests/check_elemwise.c

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,60 @@ START_TEST(test_basic_scalar) {
434434
}
435435
END_TEST
436436

437+
START_TEST(test_basic_scalar_dtype) {
438+
GpuArray x;
439+
GpuArray y;
440+
float a = 1.1f;
441+
442+
GpuElemwise *ge;
443+
444+
static const int32_t data1[4] = {0, 1, 2, 3};
445+
static const float data2[4] = {2.0, 2.0, 2.0, 2.0};
446+
float data3[4];
447+
448+
size_t dims[2] = {2, 2};
449+
450+
gpuelemwise_arg args[3] = {{0}};
451+
void *rargs[3];
452+
453+
ga_assert_ok(GpuArray_empty(&x, ctx, GA_UINT, 2, dims, GA_C_ORDER));
454+
ga_assert_ok(GpuArray_write(&x, data1, sizeof(data1)));
455+
456+
ga_assert_ok(GpuArray_empty(&y, ctx, GA_FLOAT, 2, dims, GA_F_ORDER));
457+
ga_assert_ok(GpuArray_write(&y, data2, sizeof(data2)));
458+
459+
args[0].name = "a";
460+
args[0].typecode = GA_FLOAT;
461+
args[0].flags = GE_SCALAR;
462+
463+
args[1].name = "x";
464+
args[1].typecode = GA_INT;
465+
args[1].flags = GE_READ;
466+
467+
args[2].name = "y";
468+
args[2].typecode = GA_FLOAT;
469+
args[2].flags = GE_READ|GE_WRITE;
470+
471+
ge = GpuElemwise_new(ctx, "", "y = a * x + y", 3, args, 2, 0);
472+
473+
ck_assert_ptr_ne(ge, NULL);
474+
475+
rargs[0] = &a;
476+
rargs[1] = &x;
477+
rargs[2] = &y;
478+
479+
ga_assert_ok(GpuElemwise_call(ge, rargs, 0));
480+
481+
ga_assert_ok(GpuArray_read(data3, sizeof(data3), &y));
482+
483+
ck_assert_float_eq(data3[0], 2.0f);
484+
ck_assert_float_eq(data3[1], 4.2f);
485+
486+
ck_assert_float_eq(data3[2], 3.1f);
487+
ck_assert_float_eq(data3[3], 5.3f);
488+
}
489+
END_TEST
490+
437491
START_TEST(test_basic_remove1) {
438492
GpuArray a;
439493
GpuArray b;
@@ -820,6 +874,7 @@ Suite *get_suite(void) {
820874
tcase_add_test(tc, test_basic_simple);
821875
tcase_add_test(tc, test_basic_f16);
822876
tcase_add_test(tc, test_basic_scalar);
877+
tcase_add_test(tc, test_basic_scalar_dtype);
823878
tcase_add_test(tc, test_basic_offset);
824879
tcase_add_test(tc, test_basic_remove1);
825880
tcase_add_test(tc, test_basic_broadcast);

0 commit comments

Comments
 (0)