@@ -14,18 +14,100 @@ void teardown(void);
1414
1515#define ga_assert_ok (e ) ck_assert_int_eq(e, GA_NO_ERROR)
1616
17- START_TEST (test_gemmBatch_3d ) {
17+ static inline void ck_assert_fbuf_eq (const float * b , const float * r ,
18+ unsigned int n ) {
19+ unsigned int i ;
20+ for (i = 0 ; i < n ; i ++ ) {
21+ ck_assert_msg (b [i ] == r [i ], "Difference at %u: %f != %f(ref)" , i , b [i ], r [i ]);
22+ }
23+ }
24+
25+ START_TEST (test_gemmBatch_3d_C ) {
1826 GpuArray A ;
1927 GpuArray B ;
2028 GpuArray C ;
2129
22- size_t dims [3 ] = {32 , 32 , 32 };
30+ size_t dims [3 ] = {2 , 3 , 3 };
31+ float data [] = {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ,
32+ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 };
33+ const float res [] = {30 , 36 , 42 , 66 , 81 , 96 , 102 , 126 , 150 ,
34+ 30 , 36 , 42 , 66 , 81 , 96 , 102 , 126 , 150 };
2335
2436 ga_assert_ok (GpuArray_empty (& A , ctx , GA_FLOAT , 3 , dims , GA_C_ORDER ));
2537 ga_assert_ok (GpuArray_empty (& B , ctx , GA_FLOAT , 3 , dims , GA_C_ORDER ));
2638 ga_assert_ok (GpuArray_empty (& C , ctx , GA_FLOAT , 3 , dims , GA_C_ORDER ));
2739
40+ ga_assert_ok (GpuArray_write (& A , data , sizeof (data )));
41+ ga_assert_ok (GpuArray_write (& B , data , sizeof (data )));
42+
2843 ga_assert_ok (GpuArray_rgemmBatch_3d (cb_no_trans , cb_no_trans , 1 , & A , & B , 0 , & C , 1 ));
44+
45+ ga_assert_ok (GpuArray_read (data , sizeof (data ), & C ));
46+
47+ ck_assert_fbuf_eq (data , res , sizeof (res )/sizeof (float ));
48+ }
49+ END_TEST
50+
51+ START_TEST (test_gemmBatch_3d_F ) {
52+ GpuArray A ;
53+ GpuArray B ;
54+ GpuArray C ;
55+
56+ size_t dims [3 ] = {2 , 3 , 3 };
57+ float data [] = {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ,
58+ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 };
59+ const float res [] = {42 , 78 , 78 , 60 , 114 , 114 , 51 , 69 , 96 ,
60+ 66 , 39 , 111 , 54 , 54 , 90 , 78 , 78 , 132 };
61+
62+ ga_assert_ok (GpuArray_empty (& A , ctx , GA_FLOAT , 3 , dims , GA_F_ORDER ));
63+ ga_assert_ok (GpuArray_empty (& B , ctx , GA_FLOAT , 3 , dims , GA_F_ORDER ));
64+ ga_assert_ok (GpuArray_empty (& C , ctx , GA_FLOAT , 3 , dims , GA_C_ORDER ));
65+
66+ ga_assert_ok (GpuArray_write (& A , data , sizeof (data )));
67+ ga_assert_ok (GpuArray_write (& B , data , sizeof (data )));
68+
69+ ga_assert_ok (GpuArray_rgemmBatch_3d (cb_no_trans , cb_no_trans , 1 , & A , & B , 0 , & C , 0 ));
70+
71+ ga_assert_ok (GpuArray_read (data , sizeof (data ), & C ));
72+
73+ ck_assert_fbuf_eq (data , res , sizeof (res )/sizeof (float ));
74+ }
75+ END_TEST
76+
77+ START_TEST (test_gemmBatch_3d_S ) {
78+ GpuArray A ;
79+ GpuArray B ;
80+ GpuArray C ;
81+ ssize_t t ;
82+
83+ size_t dims [3 ] = {2 , 3 , 3 };
84+ float data [] = {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ,
85+ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 };
86+ const float res [] = {14 , 32 , 50 , 50 , 122 , 194 , 32 , 77 , 122 ,
87+ 26 , 62 , 98 , 17 , 53 , 89 , 44 , 107 , 170 };
88+
89+ ga_assert_ok (GpuArray_empty (& A , ctx , GA_FLOAT , 3 , dims , GA_F_ORDER ));
90+ ga_assert_ok (GpuArray_empty (& B , ctx , GA_FLOAT , 3 , dims , GA_C_ORDER ));
91+ ga_assert_ok (GpuArray_empty (& C , ctx , GA_FLOAT , 3 , dims , GA_C_ORDER ));
92+
93+ ga_assert_ok (GpuArray_write (& A , data , sizeof (data )));
94+ ga_assert_ok (GpuArray_write (& B , data , sizeof (data )));
95+
96+ A .strides [0 ] = 8 ;
97+ A .strides [1 ] = 24 ;
98+ A .strides [2 ] = 4 ;
99+ GpuArray_fix_flags (& A );
100+
101+ t = B .strides [1 ];
102+ B .strides [1 ] = B .strides [2 ];
103+ B .strides [2 ] = t ;
104+ GpuArray_fix_flags (& B );
105+
106+ ga_assert_ok (GpuArray_rgemmBatch_3d (cb_no_trans , cb_no_trans , 1 , & A , & B , 0 , & C , 1 ));
107+
108+ ga_assert_ok (GpuArray_read (data , sizeof (data ), & C ));
109+
110+ ck_assert_fbuf_eq (data , res , sizeof (res )/sizeof (float ));
29111}
30112END_TEST
31113
@@ -34,7 +116,9 @@ Suite *get_suite(void) {
34116 TCase * tc = tcase_create ("all" );
35117 tcase_add_checked_fixture (tc , setup , teardown );
36118 tcase_set_timeout (tc , 16.0 );
37- tcase_add_test (tc , test_gemmBatch_3d );
119+ tcase_add_test (tc , test_gemmBatch_3d_C );
120+ tcase_add_test (tc , test_gemmBatch_3d_F );
121+ tcase_add_test (tc , test_gemmBatch_3d_S );
38122 suite_add_tcase (s , tc );
39123 return s ;
40124}
0 commit comments