@@ -76,22 +76,55 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #(
7676`endif
7777
7878 // -----------------------------------------------------------------------
79- // Operand data mux: WGMMA uses tile buffer, WMMA uses register file
79+ // WGMMA / WMMA abstraction layer
8080 // -----------------------------------------------------------------------
81+ // All WGMMA-vs-WMMA runtime differences are resolved here behind a
82+ // common interface. Downstream code uses only these wires and never
83+ // references tbuf_* or is_wgmma directly.
8184
8285 wire [TCU_BLOCK_CAP - 1 : 0 ][`XLEN - 1 : 0 ] rs1_data;
83- wire [TCU_BLOCK_CAP - 1 : 0 ][`XLEN - 1 : 0 ] rs2_data;
86+ wire [TCU_WG_RS2_WIDTH - 1 : 0 ][`XLEN - 1 : 0 ] rs2_data;
87+ wire exe_ready_extra; // additional ready gating (tbuf_ready)
88+ wire [3 : 0 ] last_k_steps; // final step_k value for is_last_k
89+
90+ // WMMA K-step limit (needed before the mux region for last_k_steps)
91+ `ifdef TCU_SPARSE_ENABLE
92+ wire [3 : 0 ] k_steps_val = is_sparse ? 4 '((TCU_K_STEPS / 2 ) - 1 ) : 4 '(TCU_K_STEPS - 1 );
93+ `else
94+ wire [3 : 0 ] k_steps_val = 4 '(TCU_K_STEPS - 1 );
95+ `endif
8496
8597`ifdef TCU_WGMMA_ENABLE
8698 wire is_wgmma = (execute_if.data.op_type == INST_TCU_WGMMA );
8799 wire wg_a_smem = execute_if.data.op_args.tcu.a_from_smem;
88- // A source: tile buffer (smem) or register file
100+
101+ // A/B operand mux: tile buffer (smem) or register file
89102 assign rs1_data = (is_wgmma && wg_a_smem) ? tbuf_rs1_data : execute_if.data.rs1_data;
90- // B source: always tile buffer (smem) for WGMMA
91- assign rs2_data = is_wgmma ? tbuf_rs2_data[TCU_BLOCK_CAP - 1 : 0 ] : execute_if.data.rs2_data;
103+ /* verilator lint_off WIDTHEXPAND */
104+ assign rs2_data = is_wgmma ? tbuf_rs2_data
105+ : TCU_WG_RS2_WIDTH ' (execute_if.data.rs2_data);
106+ /* verilator lint_on WIDTHEXPAND */
107+
108+ `ifdef TCU_SPARSE_ENABLE
109+ // Sparse metadata mux: tile-buffer vs register-file metadata
110+ wire [TCU_MAX_META_BLOCK_WIDTH - 1 : 0 ] vld_meta_block = is_wgmma ? tbuf_sp_meta : wmma_sp_meta;
111+ // K-step limit: WGMMA and WMMA have different tile-K sizes
112+ wire [3 : 0 ] wg_k_steps_val = execute_if.data.op_args.tcu.is_sparse
113+ ? 4 '((TCU_WG_K_STEPS / 2 ) - 1 ) : 4 '(TCU_WG_K_STEPS - 1 );
114+ `else
115+ wire [3 : 0 ] wg_k_steps_val = 4 '(TCU_WG_K_STEPS - 1 );
116+ `endif
117+
118+ assign last_k_steps = is_wgmma ? wg_k_steps_val : k_steps_val;
119+ assign exe_ready_extra = ~ is_wgmma || tbuf_ready;
92120`else
93121 assign rs1_data = execute_if.data.rs1_data;
94122 assign rs2_data = execute_if.data.rs2_data;
123+ `ifdef TCU_SPARSE_ENABLE
124+ wire [TCU_MAX_META_BLOCK_WIDTH - 1 : 0 ] vld_meta_block = wmma_sp_meta;
125+ `endif
126+ assign last_k_steps = k_steps_val;
127+ assign exe_ready_extra = 1'b1 ;
95128`endif
96129
97130 wire [3 : 0 ] step_m = execute_if.data.op_args.tcu.step_m;
@@ -138,25 +171,7 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #(
138171 // Accumulator is a multi-tile LUTRAM indexed by {wid, step_m, step_n}.
139172
140173 wire is_first_k = (step_k == '0 );
141-
142- `ifdef TCU_SPARSE_ENABLE
143- wire [3 : 0 ] k_steps_val = is_sparse ? 4 '((TCU_K_STEPS / 2 ) - 1 ) : 4 '(TCU_K_STEPS - 1 );
144- `else
145- wire [3 : 0 ] k_steps_val = 4 '(TCU_K_STEPS - 1 );
146- `endif
147-
148- `ifdef TCU_WGMMA_ENABLE
149- wire [3 : 0 ] wg_k_steps_val;
150- `ifdef TCU_SPARSE_ENABLE
151- assign wg_k_steps_val = execute_if.data.op_args.tcu.is_sparse
152- ? 4 '((TCU_WG_K_STEPS / 2 ) - 1 ) : 4 '(TCU_WG_K_STEPS - 1 );
153- `else
154- assign wg_k_steps_val = 4 '(TCU_WG_K_STEPS - 1 );
155- `endif
156- wire is_last_k = is_wgmma ? (step_k == wg_k_steps_val) : (step_k == k_steps_val);
157- `else
158- wire is_last_k = (step_k == k_steps_val);
159- `endif
174+ wire is_last_k = (step_k == last_k_steps);
160175
161176 // -----------------------------------------------------------------------
162177 // Multi-tile accumulator (LUTRAM, async read, read-first)
@@ -266,11 +281,7 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #(
266281
267282 assign result_if.valid = fedp_done;
268283 assign fedp_enable = ~ fedp_done || result_if.ready;
269- `ifdef TCU_WGMMA_ENABLE
270- assign execute_if.ready = ~ mdata_queue_full && fedp_enable && ! k_stall && (~ is_wgmma || tbuf_ready);
271- `else
272- assign execute_if.ready = ~ mdata_queue_full && fedp_enable && ! k_stall;
273- `endif
284+ assign execute_if.ready = ~ mdata_queue_full && fedp_enable && ! k_stall && exe_ready_extra;
274285
275286 // All uops push to the metadata queue; non-last-k have wb=0 in their
276287 // header so the writeback stage skips the RF write.
@@ -330,12 +341,7 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #(
330341 .vld_block (wmma_sp_meta)
331342 );
332343
333- wire [TCU_MAX_META_BLOCK_WIDTH - 1 : 0 ] vld_meta_block;
334- `ifdef TCU_WGMMA_ENABLE
335- assign vld_meta_block = is_wgmma ? tbuf_sp_meta : wmma_sp_meta;
336- `else
337- assign vld_meta_block = wmma_sp_meta;
338- `endif
344+ // vld_meta_block is muxed in the operand mux region above
339345`endif
340346
341347 // -----------------------------------------------------------------------
@@ -355,22 +361,9 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #(
355361 assign a_row[k_idx] = 32 '(rs1_data[a_off + i * TCU_TC_K + k_idx]);
356362 `ifdef TCU_SPARSE_ENABLE
357363 assign b_col_dense[k_idx] = 32 '(rs2_data[b_off + j * TCU_TC_K + k_idx]);
358- // WGMMA_SP: tbuf_rs2_data is wide (TCU_WG_RS2_WIDTH lanes);
359- // use j directly — gather already placed each column's pair at j*tcK*2.
360- // WMMA_SP: rs2_data comes from the register file (TCU_BLOCK_CAP lanes);
361- // SYM_SPARSE folds j to the packed column-pair layout.
362364 localparam J_SP = SYM_SPARSE ? (j % (TCU_TC_N / 2 )) : j;
363- `ifdef TCU_WGMMA_ENABLE
364- assign b_col_1[k_idx] = 32 '(is_wgmma
365- ? tbuf_rs2_data[j * TCU_TC_K * 2 + k_idx * 2 ]
366- : rs2_data[b_off + J_SP * TCU_TC_K * 2 + k_idx * 2 ]);
367- assign b_col_2[k_idx] = 32 '(is_wgmma
368- ? tbuf_rs2_data[j * TCU_TC_K * 2 + k_idx * 2 + 1 ]
369- : rs2_data[b_off + J_SP * TCU_TC_K * 2 + k_idx * 2 + 1 ]);
370- `else
371365 assign b_col_1[k_idx] = 32 '(rs2_data[b_off + J_SP * TCU_TC_K * 2 + k_idx * 2 ]);
372366 assign b_col_2[k_idx] = 32 '(rs2_data[b_off + J_SP * TCU_TC_K * 2 + k_idx * 2 + 1 ]);
373- `endif
374367 `else
375368 assign b_col[k_idx] = 32 '(rs2_data[b_off + j * TCU_TC_K + k_idx]);
376369 `endif
0 commit comments