@@ -142,6 +142,23 @@ static void compute_reference(
142142 }
143143}
144144
145+ template <typename DstT, typename SrcT>
146+ static void compute_reference_TN (
147+ std::vector<DstT>& C,
148+ const std::vector<SrcT>& A, const std::vector<SrcT>& B,
149+ size_t M, size_t N, size_t K)
150+ {
151+ for (size_t m = 0 ; m < M; m++) {
152+ for (size_t n = 0 ; n < N; n++) {
153+ DstT sum = 0 ;
154+ for (size_t k = 0 ; k < K; k++) {
155+ sum = A[k * K + m] * B[k * N + n] + sum;
156+ }
157+ C[m * N + n] = sum;
158+ }
159+ }
160+ }
161+
145162template <typename T>
146163void check_results (
147164 size_t M,
@@ -660,6 +677,107 @@ static void i8_dpas_blockread_vnni_tiled(
660677 }
661678}
662679
680+ static void i8_naive_TN (
681+ cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
682+ cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
683+ size_t M, size_t N, size_t K,
684+ const std::vector<int >& C_ref)
685+ {
686+ printf (" %80s: " , makeTestName (__FUNCTION__, M, N, K).c_str ()); fflush (stdout);
687+
688+ cl::Kernel kernel{program, " i8_naive_TN" };
689+ if (kernel () == nullptr ) {
690+ printf (" unsupported.\n " );
691+ } else {
692+ kernel.setArg (0 , C);
693+ kernel.setArg (1 , A);
694+ kernel.setArg (2 , B);
695+ kernel.setArg (3 , static_cast <cl_int>(K));
696+
697+ if (!skipinit) {
698+ queue.enqueueFillBuffer (C, 0 , 0 , C_ref.size () * sizeof (C_ref[0 ]));
699+ }
700+
701+ float best = 999 .0f ;
702+ for (int test = 0 ; test < testIterations; test++) {
703+ cl::Event event;
704+ auto start = test_clock::now ();
705+ queue.enqueueNDRangeKernel (kernel, cl::NullRange,
706+ cl::NDRange{N, M}, cl::NullRange, nullptr , &event);
707+ queue.finish ();
708+ auto end = test_clock::now ();
709+ std::chrono::duration<float > sw_time = end - start;
710+ auto elapsed = wallclock ? sw_time.count () : hw_time (event);
711+ best = std::min (best, elapsed);
712+ }
713+ auto gops = 2.0 * M * N * K / best / 1e9 ;
714+ printf (" Best in %f seconds (%f gops)\n " , best, gops);
715+
716+ if (validate) {
717+ printf (" Checking results... " ); fflush (stdout);
718+ std::vector<int > C_check (C_ref.size ());
719+ queue.enqueueReadBuffer (C, CL_TRUE, 0 , C_check.size () * sizeof (C_check[0 ]), C_check.data ());
720+ check_results (M, N, C_check, C_ref);
721+ printf (" done!\n " );
722+ }
723+ }
724+ }
725+
726+ template <int tM, int tN>
727+ static void i8_dpas_blockread_rowmajor_TN (
728+ cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
729+ cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
730+ size_t M, size_t N, size_t K,
731+ const std::vector<int >& C_ref)
732+ {
733+ printf (" %80s: " , makeTestName (__FUNCTION__, tM, tN, M, N, K).c_str ()); fflush (stdout);
734+
735+ std::string kernelName = " i8_dpas_blockread_rowmajor_TN" ;
736+ kernelName += " _m" + std::to_string (tM);
737+ kernelName += " _n" + std::to_string (tN);
738+ cl::Kernel kernel{program, kernelName.c_str ()};
739+ if (kernel () == nullptr ) {
740+ printf (" unsupported.\n " );
741+ } else if (K < 64 || N < 64 /4 ) {
742+ printf (" matrix pitch for block reads must be >= 64 bytes.\n " );
743+ } else {
744+ kernel.setArg (0 , C);
745+ kernel.setArg (1 , A);
746+ kernel.setArg (2 , B);
747+ kernel.setArg (3 , static_cast <cl_int>(K));
748+ if (roundRobin) {
749+ setRoundRobin (kernel);
750+ }
751+
752+ if (!skipinit) {
753+ queue.enqueueFillBuffer (C, 0 , 0 , C_ref.size () * sizeof (C_ref[0 ]));
754+ }
755+
756+ float best = 999 .0f ;
757+ for (int test = 0 ; test < testIterations; test++) {
758+ cl::Event event;
759+ auto start = test_clock::now ();
760+ queue.enqueueNDRangeKernel (kernel, cl::NullRange,
761+ cl::NDRange{N, M/tM}, cl::NullRange, nullptr , &event);
762+ queue.finish ();
763+ auto end = test_clock::now ();
764+ std::chrono::duration<float > sw_time = end - start;
765+ auto elapsed = wallclock ? sw_time.count () : hw_time (event);
766+ best = std::min (best, elapsed);
767+ }
768+ auto gops = 2.0 * M * N * K / best / 1e9 ;
769+ printf (" Best in %f seconds (%f gops)\n " , best, gops);
770+
771+ if (validate) {
772+ printf (" Checking results... " ); fflush (stdout);
773+ std::vector<int > C_check (C_ref.size ());
774+ queue.enqueueReadBuffer (C, CL_TRUE, 0 , C_check.size () * sizeof (C_check[0 ]), C_check.data ());
775+ check_results (M, N, C_check, C_ref);
776+ printf (" done!\n " );
777+ }
778+ }
779+ }
780+
663781int main (int argc, char ** argv)
664782{
665783 int platformIndex = 0 ;
@@ -784,6 +902,7 @@ int main(int argc, char** argv)
784902 std::vector<int8_t > Bvnni_vec (K * N);
785903
786904 std::vector<int > C_ref (M * N);
905+ std::vector<int > C_TN_ref (M * N);
787906
788907 printf (" Initializing source matrices...\n " );
789908 fill_matrix (A_vec, M, K);
@@ -794,6 +913,8 @@ int main(int argc, char** argv)
794913 if (validate) {
795914 printf (" Computing reference...\n " );
796915 compute_reference (C_ref, A_vec, B_vec, M, N, K);
916+ printf (" Computing transposed reference...\n " );
917+ compute_reference_TN (C_TN_ref, A_vec, B_vec, M, N, K);
797918 }
798919
799920 printf (" Creating source buffers...\n " );
@@ -910,6 +1031,11 @@ int main(int argc, char** argv)
9101031 i8_dpas_blockread_vnni_tiled<8 , 16 , 4 , 4 >(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
9111032 }
9121033
1034+ if (mask & 0x2000 ) {
1035+ // i8_naive_TN(context, program, queue, C, A, B, M, N, K, C_TN_ref);
1036+ i8_dpas_blockread_rowmajor_TN<4 , 16 >(context, program, queue, C, A, B, M, N, K, C_TN_ref);
1037+ }
1038+
9131039 printf (" Done.\n " );
9141040
9151041 return 0 ;
0 commit comments