1- // Copyright 2024 EPFL
2- // Solderpad Hardware License, Version 2.1, see LICENSE.md for details.
3- // SPDX-License-Identifier: Apache-2.0 WITH SHL-2.1
4- //
5- // Author: Danilo Cammarata
6-
7- // Output tile size: 8x8
8- #define FP32 6
9- #define CMUL 2
10- #define RMUL 2
11-
12- void __attribute__ ((noinline )) matrixMul_8x8 (float * addrA , float * addrB , float * addrC , int K , int N , int M , int shift )
13- {
14- int strideA = M * 4 ;
15- int strideB = N * 4 ;
16- int strideC = N * 4 ;
171
2+ void __attribute__ ((noinline )) FUNC_NAME (void * addrA ,void * addrB , void * addrC , int K , int N , int M , int shift )
3+ {
4+ int strideA ;
5+ int strideB ;
6+ int strideC ;
187 asm volatile (
198 // --- PROLOGUE: Save registers ---
209 "addi sp, sp, -0x30 \n\t"
@@ -32,45 +21,45 @@ void __attribute__ ((noinline)) matrixMul_8x8(float* addrA, float* addrB, float*
3221 "sw s11, 0x00(sp) \n\t"
3322
3423 // 1. Data Types Configuration
35- "mmac.dt %[dt_type ], %[dt_type ], %[dt_type ] \n\t"
24+ "mmac.dt %[dt_typeC ], %[dt_typeA ], %[dt_typeB ] \n\t"
3625
3726 "add t0, x0, %[M] \n\t" // t0 = Remaining M
3827 "add s0, x0, %[addrC] \n\t" // s0 = Current row pointer for C
3928 "add s1, x0, %[addrA] \n\t" // s1 = Current row pointer for A (LHS)
4029
41- "loopM_start_dyn: \n\t"
42- "mcfgm t3, t0, 1 \n\t" // t3 = Processed M rows
30+ "sll %[strideA], %[M], %[shift] \n\t" // Compute strideA = M * 2^shift
31+ "sll %[strideB], %[N], %[shift] \n\t" // Compute strideB = M * 2^shift
32+ "sll %[strideC], %[N], %[shift] \n\t" // Compute strideC = M * 2^shift
33+
34+ "1: \n\t"
35+ "mcfgm t3, t0, %[rmul] \n\t" // t3 = Processed M rows
4336
4437 "add t1, x0, %[N] \n\t" // t1 = Remaining N
4538 "add s2, x0, s0 \n\t" // s2 = C tile pointer
46- "add s3, x0, %[addrB] \n\t" // s3 = B (RHS) tile pointer
47-
48- "loopN_start_dyn: \n\t"
49- "mcfgn t4, t1, 1 \n\t" // t4 = Processed N columns
50-
51- // Reset accumulator
52- "mzero.a acc0 \n\t"
39+ "add s3, x0, %[addrB] \n\t" // s3 = B (RHS) tile pointer
5340
41+ "2: \n\t"
5442 "add t2, x0, %[K] \n\t" // t2 = Remaining K
43+ "mcfgk t5, t2 \n\t" // t5 = Processed K depth for a single block
44+ "mcfgn t4, t1, %[cmul] \n\t" // t4 = Processed N columns
45+ "mzero.a acc0 \n\t" // Reset accumulator
5546 "add s4, x0, s1 \n\t" // s4 = A tile pointer
5647 "add s5, x0, s3 \n\t" // s5 = B tile pointer
5748
58- "loopK_start_dyn: \n\t"
59- "mcfgk t5, t2 \n\t" // t5 = Processed K depth for a single block
49+ "3: \n\t"
6050
6151 // 2. Load First Tiles
6252 "mld.lhs m0, (s4), %[strideA] \n\t"
63- "mld.rhs m4, (s5), %[strideB] \n\t"
64- "mmacc acc0, m4, m0 \n\t"
65-
66- "mul s8, t5, %[strideA] \n\t"
53+ "mld.rhs m4, (s5), %[strideB] \n\t"
54+ "mmacc acc0, m4, m0 \n\t"
55+ "mul s8, t5, %[strideA] \n\t"
56+ "mul s9, t5, %[strideB] \n\t"
57+
6758 "add s6, s4, s8 \n\t" // s6 = Pointer to 2nd tile of A
6859 "mld.lhs m2, (s6), %[strideA] \n\t"
69-
70- "mul s9, t5, %[strideB] \n\t"
60+
7161 "add s7, s5, s9 \n\t" // s7 = Pointer to 2nd tile of B
72- "mld.rhs m6, (s7), %[strideB] \n\t"
73- "mmacc acc0, m6, m2 \n\t"
62+ "mld.rhs m6, (s7), %[strideB] \n\t"
7463
7564 // 5. Advance K pointers by TWO blocks
7665 "add t6, s8, s8 \n\t"
@@ -81,11 +70,13 @@ void __attribute__ ((noinline)) matrixMul_8x8(float* addrA, float* addrB, float*
8170
8271 // Decrease remaining K by two blocks
8372 "add t6, t5, t5 \n\t"
84- "sub t2, t2, t6 \n\t"
85- "bgtz t2, loopK_start_dyn \n\t"
73+ "sub t2, t2, t6 \n\t"
74+ "mmacc acc0, m6, m2 \n\t"
75+ "mcfgk t5, t2 \n\t" // t5 = Processed K depth for a single block
76+ "bgtz t2, 3b \n\t"
8677
8778 // 6. Transfer to MR and Store
88- "mmov.am m8, acc0 \n\t"
79+ "mmov.am m8, acc0 \n\t"
8980 "mst m8, (s2), %[strideC] \n\t"
9081
9182 // 7. Advance along N
@@ -94,7 +85,7 @@ void __attribute__ ((noinline)) matrixMul_8x8(float* addrA, float* addrB, float*
9485 "add s3, s3, t6 \n\t"
9586
9687 "sub t1, t1, t4 \n\t"
97- "bgtz t1, loopN_start_dyn \n\t"
88+ "bgtz t1, 2b \n\t"
9889
9990 // 8. Advance along M
10091 "mul t6, t3, %[strideC] \n\t"
@@ -104,7 +95,7 @@ void __attribute__ ((noinline)) matrixMul_8x8(float* addrA, float* addrB, float*
10495 "add s1, s1, t6 \n\t"
10596
10697 "sub t0, t0, t3 \n\t"
107- "bgtz t0, loopM_start_dyn \n\t"
98+ "bgtz t0, 1b \n\t"
10899
109100 // --- EPILOGUE: Restore registers ---
110101 "lw s0 , 0x2c(sp) \n\t"
@@ -124,14 +115,10 @@ void __attribute__ ((noinline)) matrixMul_8x8(float* addrA, float* addrB, float*
124115
125116 :
126117 : [addrA ] "r" (addrA ), [addrB ] "r" (addrB ), [addrC ] "r" (addrC ),
127- [M ] "r" (M ), [N ] "r" (N ), [K ] "r" (K ),
118+ [M ] "r" (M ), [N ] "r" (N ), [K ] "r" (K ), [ shift ] "r" ( shift ),
128119 [strideA ] "r" (strideA ), [strideB ] "r" (strideB ), [strideC ] "r" (strideC ),
129- [dt_type ] "i" (FP32 )
120+ [dt_typeC ] "i" (DTC ), [dt_typeA ] "i" (DTA ), [dt_typeB ] "i" (DTB ),
121+ [rmul ] "i" (RMUL_2 ), [cmul ] "i" (CMUL_2 )
130122 : "t0" , "t1" , "t2" , "t3" , "t4" , "t5" , "t6" , "s0" , "s1" , "s2" , "s3" , "s4" , "s5" , "s6" , "s7" , "s8" , "s9" , "memory"
131123 );
132- }
133-
134- void __attribute__ ((noinline )) matrixMul_16x16 (float * addrA , float * addrB , float * addrC , int K , int N , int M , int shift )
135- {
136-
137124}
0 commit comments