@@ -434,6 +434,60 @@ START_TEST(test_basic_scalar) {
434434}
435435END_TEST
436436
437+ START_TEST (test_basic_scalar_dtype ) {
438+ GpuArray x ;
439+ GpuArray y ;
440+ float a = 1.1f ;
441+
442+ GpuElemwise * ge ;
443+
444+ static const int32_t data1 [4 ] = {0 , 1 , 2 , 3 };
445+ static const float data2 [4 ] = {2.0 , 2.0 , 2.0 , 2.0 };
446+ float data3 [4 ];
447+
448+ size_t dims [2 ] = {2 , 2 };
449+
450+ gpuelemwise_arg args [3 ] = {{0 }};
451+ void * rargs [3 ];
452+
453+ ga_assert_ok (GpuArray_empty (& x , ctx , GA_UINT , 2 , dims , GA_C_ORDER ));
454+ ga_assert_ok (GpuArray_write (& x , data1 , sizeof (data1 )));
455+
456+ ga_assert_ok (GpuArray_empty (& y , ctx , GA_FLOAT , 2 , dims , GA_F_ORDER ));
457+ ga_assert_ok (GpuArray_write (& y , data2 , sizeof (data2 )));
458+
459+ args [0 ].name = "a" ;
460+ args [0 ].typecode = GA_FLOAT ;
461+ args [0 ].flags = GE_SCALAR ;
462+
463+ args [1 ].name = "x" ;
464+ args [1 ].typecode = GA_INT ;
465+ args [1 ].flags = GE_READ ;
466+
467+ args [2 ].name = "y" ;
468+ args [2 ].typecode = GA_FLOAT ;
469+ args [2 ].flags = GE_READ |GE_WRITE ;
470+
471+ ge = GpuElemwise_new (ctx , "" , "y = a * x + y" , 3 , args , 2 , 0 );
472+
473+ ck_assert_ptr_ne (ge , NULL );
474+
475+ rargs [0 ] = & a ;
476+ rargs [1 ] = & x ;
477+ rargs [2 ] = & y ;
478+
479+ ga_assert_ok (GpuElemwise_call (ge , rargs , 0 ));
480+
481+ ga_assert_ok (GpuArray_read (data3 , sizeof (data3 ), & y ));
482+
483+ ck_assert_float_eq (data3 [0 ], 2.0f );
484+ ck_assert_float_eq (data3 [1 ], 4.2f );
485+
486+ ck_assert_float_eq (data3 [2 ], 3.1f );
487+ ck_assert_float_eq (data3 [3 ], 5.3f );
488+ }
489+ END_TEST
490+
437491START_TEST (test_basic_remove1 ) {
438492 GpuArray a ;
439493 GpuArray b ;
@@ -820,6 +874,7 @@ Suite *get_suite(void) {
820874 tcase_add_test (tc , test_basic_simple );
821875 tcase_add_test (tc , test_basic_f16 );
822876 tcase_add_test (tc , test_basic_scalar );
877+ tcase_add_test (tc , test_basic_scalar_dtype );
823878 tcase_add_test (tc , test_basic_offset );
824879 tcase_add_test (tc , test_basic_remove1 );
825880 tcase_add_test (tc , test_basic_broadcast );
0 commit comments