From 1b1e0b68420ac95618eb2a63c7fbca3a29f5fcb5 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sat, 23 May 2026 16:31:50 +0000 Subject: [PATCH 01/19] Fix async scale scheduling --- kernels/gemm_fp8fp4_gfx1250.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index ee09dc7a4..13fcc44a2 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -95,6 +95,7 @@ def compile_mxscale_gemm( raise ValueError(f"scale_load_path must be one of {scale_load_paths}, got {scale_load_path!r}") use_scale_buffer_load = scale_load_path != "tdm" use_ab_split_scale_buffer_load = scale_load_path == "buffer_lds_stage_ab_split" + effective_expert_sched_mode = bool(expert_sched_mode) and not use_scale_buffer_load if num_buffers not in (2, 3, 4): raise ValueError(f"num_buffers must be 2, 3, or 4, got {num_buffers}") @@ -301,6 +302,17 @@ def _align_up(value: int, align: int) -> int: TDM_LOADS_PER_STEP = 1 else: TDM_LOADS_PER_STEP = 2 if use_scale_buffer_load else 4 + if use_scale_buffer_load: + _scale_async_batch_bytes = block_threads * _scale_dma_bytes + _scale_async_ops_per_stage = ( + (tile_m * scale_k_per_tile + _scale_async_batch_bytes - 1) // _scale_async_batch_bytes + + (tile_n * scale_k_per_tile + _scale_async_batch_bytes - 1) // _scale_async_batch_bytes + ) + # Wait only for the scale stage that is about to be consumed; keep + # later buffered scale async copies in flight. + _scale_async_future_wait_count = _scale_async_ops_per_stage * max(num_buffers - 2, 0) + else: + _scale_async_future_wait_count = 0 tail_plan = [(ls, cs, o * TDM_LOADS_PER_STEP // 2 if o > 0 else o) for ls, cs, o in _base_tail_plan] # Pre-compute epilogue sub-tile layout (unified for FP4 vec16 and FP8 vec8) @@ -1709,16 +1721,18 @@ def _issue_scale_buffer_loads(stage_idx, k_base): tile_n * scale_k_per_tile, ) - def _wait_scale_buffer_loads(): + def _wait_scale_all(): if const_expr(use_scale_buffer_load): rocdl.s_wait_asynccnt(0) + def _wait_scale_for_compute_stage(): + if const_expr(use_scale_buffer_load): + rocdl.s_wait_asynccnt(_scale_async_future_wait_count) + def _pipeline_fence(outstanding=0): - _wait_scale_buffer_loads() pipeline_fence(outstanding=outstanding, use_cluster=use_cluster) def _pipeline_fence_signal(outstanding=0): - _wait_scale_buffer_loads() pipeline_fence_signal(outstanding=outstanding, use_cluster=use_cluster) def _issue_ab_tdm(load_stage, addr_a, addr_b): @@ -1784,6 +1798,7 @@ def _issue_active_tdm(load_stage, addr_box, scale_k_box=None, k_prefetch=None): if const_expr(use_scale_buffer_load): scale_next_k_base = split_k_base + arith.index(pre_loaded * tile_k) + _wait_scale_for_compute_stage() _pipeline_fence(outstanding=TDM_LOADS_PER_STEP * (num_buffers - 2)) # Main loop — acc_mixed style: fence at top, TDM_load mid-compute. @@ -1844,6 +1859,7 @@ def _mid_tdm_ws( for buf_idx in range_constexpr(num_buffers): load_stage = (buf_idx + num_buffers - 1) % num_buffers + _wait_scale_for_compute_stage() _pipeline_fence_signal(outstanding=_fence_outstanding) pipeline_fence_wait(use_cluster=use_cluster) @@ -1892,6 +1908,7 @@ def _mid_tdm_split_scale_dma( for buf_idx in range_constexpr(num_buffers): load_stage = (buf_idx + num_buffers - 1) % num_buffers + _wait_scale_for_compute_stage() _pipeline_fence_signal(outstanding=_fence_outstanding) pipeline_fence_wait(use_cluster=use_cluster) @@ -2003,6 +2020,7 @@ def _mid_tdm_nws( addr_lo_bs = results[n_accs + 3] # Tail — same acc_mixed pattern: fence at top, TDM mid-compute. + _wait_scale_all() if const_expr(loop_iters > 0): _pipeline_fence(outstanding=0) elif const_expr(use_cluster): @@ -2012,6 +2030,7 @@ def _mid_tdm_nws( for _load_stage, _compute_stage, _outstanding in tail_plan: if const_expr(_outstanding == -1): if const_expr(_tail_had_load): + _wait_scale_all() _pipeline_fence(outstanding=0) if const_expr(use_tdm_store): accs = compute_tile_scheduled( @@ -2035,6 +2054,7 @@ def _emit_epi_addrs(): emit_filler=_emit_epi_addrs, ) else: + _wait_scale_all() _pipeline_fence_signal(outstanding=_outstanding) pipeline_fence_wait(use_cluster=use_cluster) @@ -2204,7 +2224,7 @@ def launch_mxscale_gemm( cluster=cluster_arg, ) - if expert_sched_mode: + if effective_expert_sched_mode: launch_mxscale_gemm.compile_hints["llvm_options"] = { "amdgpu-expert-scheduling-mode": True, } From efb7cf9e6cac9e8e9bb9f3b230774bfb95bcf411 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sun, 24 May 2026 04:28:06 +0000 Subject: [PATCH 02/19] Add FP8 deep pipeline schedule --- kernels/gemm_fp8fp4_gfx1250.py | 407 ++++++++++++++++++++++++++++++--- 1 file changed, 374 insertions(+), 33 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 13fcc44a2..ac4cd41a9 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -66,6 +66,7 @@ def compile_mxscale_gemm( atomic_barrier_enable: bool = False, b_streaming: bool = False, scale_load_path: str = "tdm", + fp8_schedule: str = "auto", ): """Compile an MXFP4 or MXFP8 GEMM kernel with TDM async copy. @@ -93,6 +94,13 @@ def compile_mxscale_gemm( scale_load_paths = ("tdm", "buffer_lds_stage", "buffer_lds_stage_ab_split") if scale_load_path not in scale_load_paths: raise ValueError(f"scale_load_path must be one of {scale_load_paths}, got {scale_load_path!r}") + fp8_schedule_modes = ("auto", "quadrant", "deep-pipeline") + if fp8_schedule not in fp8_schedule_modes: + raise ValueError(f"fp8_schedule must be one of {fp8_schedule_modes}, got {fp8_schedule!r}") + if fp8_schedule != "auto" and data_format != "fp8": + raise ValueError(f"fp8_schedule={fp8_schedule!r} is only valid for data_format='fp8'") + if fp8_schedule != "auto" and b_streaming: + raise ValueError("fp8_schedule cannot be combined with b_streaming=True") use_scale_buffer_load = scale_load_path != "tdm" use_ab_split_scale_buffer_load = scale_load_path == "buffer_lds_stage_ab_split" effective_expert_sched_mode = bool(expert_sched_mode) and not use_scale_buffer_load @@ -337,8 +345,29 @@ def _align_up(value: int, align: int) -> int: COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING = "row_major_streaming" COMPUTE_SCHEDULE_FP4_COL_BAND = "fp4_col_band" COMPUTE_SCHEDULE_FP8_QUADRANT = "fp8_quadrant" + COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE = "fp8_deep_pipeline" COMPUTE_SCHEDULE_B_STREAMING = "b_streaming" + fp8_deep_pipeline_eligible = ( + data_format == "fp8" + and tile_m == 256 + and tile_n == 256 + and tile_k == 128 + and m_warp == 2 + and n_warp == 2 + and num_buffers == 4 + and wave_specialized_tdm + and scale_load_path == "tdm" + and out_dtype == "bf16" + and not use_scale_opsel + ) + if fp8_schedule == "deep-pipeline" and not fp8_deep_pipeline_eligible: + raise ValueError( + "fp8_schedule='deep-pipeline' requires fp8 256x256x128, " + "m_warp=n_warp=2, num_buffers=4, wave_specialized_tdm=True, " + "scale_load_path='tdm', out_dtype='bf16', and use_scale_opsel=False" + ) + def _pick_compute_schedule_kind(): if b_streaming: return COMPUTE_SCHEDULE_B_STREAMING @@ -351,12 +380,29 @@ def _pick_compute_schedule_kind(): if is_fp4: return COMPUTE_SCHEDULE_FP4_COL_BAND if data_format == "fp8": + if fp8_schedule == "deep-pipeline" or (fp8_schedule == "auto" and fp8_deep_pipeline_eligible): + return COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE return COMPUTE_SCHEDULE_FP8_QUADRANT return COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING compute_schedule_kind = _pick_compute_schedule_kind() use_fp4_bank_friendly_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND use_fp8_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT + use_fp8_deep_pipeline_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE + use_b_streaming_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING + use_ws_tdm_split_signal_overlap = ( + wave_specialized_tdm + and not use_scale_buffer_load + and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) + and num_buffers == 4 + and use_cluster + ) + + if use_b_streaming_schedule: + print( + f"[b_streaming] {data_format} tile=({tile_m},{tile_n},{tile_k}) " f"M_r={wmma_m_rep} N_r={wmma_n_rep}", + flush=True, + ) if use_fp4_bank_friendly_schedule: _bank_half_wm = wmma_m_rep // 2 @@ -377,11 +423,19 @@ def _pick_compute_schedule_kind(): for _wn in range(_bank_half_wn, wmma_n_rep): _bank_group_to_row_major.append(_wm * wmma_n_rep + _wn) - if use_fp8_quadrant_schedule: + if use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule: _fp8_half_wm = wmma_m_rep // 2 _fp8_half_wn = wmma_n_rep // 2 _fp8_group_size = _fp8_half_wm * _fp8_half_wn _fp8_b_scale_loads = (b_scale_load_rep + 3) // 4 + if use_fp8_deep_pipeline_schedule: + _fp8_pair_wm = 2 + _fp8_pair_wn = 2 + _fp8_wm_pairs = wmma_m_rep // _fp8_pair_wm + _fp8_wn_pairs = wmma_n_rep // _fp8_pair_wn + _fp8_pair_a_loads = _fp8_pair_wm * DS_LOADS_PER_A_FRAG + _fp8_pair_b_loads = _fp8_pair_wn * _b_frag_loads_per_wn + _fp8_scale_loads = (wmma_m_rep + 3) // 4 + (b_scale_load_rep + 3) // 4 @flyc.kernel(known_block_size=[block_threads, 1, 1]) def kernel_mxscale_gemm( @@ -1096,6 +1150,7 @@ def compute_tile_fp8_quadrant( lds_bs, emit_filler=None, mid_compute_callback=None, + late_compute_callback=None, ): current_accs = list(accs_in) a_buf, a_bases = _precompute_a_lane_bases(lds_a) @@ -1130,11 +1185,22 @@ def _load_b_scales(ks): def _load_b_left_bundle(ks): return _load_b_half(0, ks), _load_b_scales(ks) - def _emit_group(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, emit_filler_now=False): + def _emit_group_rows( + wm_base, + wn_base, + a_frags, + b_frags, + a_scales, + b_scales, + row_start, + row_count, + emit_filler_now=False, + ): if const_expr(emit_filler_now and emit_filler is not None): rocdl.sched_barrier(0) emit_filler() - for wm_local in range_constexpr(_fp8_half_wm): + for row_offset in range_constexpr(row_count): + wm_local = row_start + row_offset global_wm = wm_base + wm_local for wn_local in range_constexpr(_fp8_half_wn): global_wn = wn_base + wn_local @@ -1148,24 +1214,67 @@ def _emit_group(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, emit_fil b_scales, ) + def _emit_group(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, emit_filler_now=False): + _emit_group_rows( + wm_base, + wn_base, + a_frags, + b_frags, + a_scales, + b_scales, + 0, + _fp8_half_wm, + emit_filler_now=emit_filler_now, + ) + + def _emit_group_col(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, wn_local): + global_wn = wn_base + wn_local + for wm_local in range_constexpr(_fp8_half_wm): + global_wm = wm_base + wm_local + _emit_wmma( + current_accs, + global_wm, + global_wn, + a_frags[wm_local], + b_frags[wn_local], + a_scales, + b_scales, + ) + b_left_frags, b_scales = _load_b_left_bundle(0) + _first_top_row_keep = max((_fp8_half_wm - 1) * DS_LOADS_PER_A_FRAG - _fp8_b_scale_loads, 0) + _bottom_left_keep = max(_b_half_loads - DS_LOADS_PER_A_FRAG, 0) for ks in range_constexpr(k_wmma_steps): is_last_ks = ks == k_wmma_steps - 1 a_scales = _load_a_scales(ks) a_top_frags = _load_a_group(0, _fp8_half_wm, ks) - a_bottom_frags = _load_a_group(_fp8_half_wm, _fp8_half_wm, ks) - # Keep bottom A outstanding while the first quadrant consumes top A. - rocdl.s_wait_dscnt(_fp8_half_wm * DS_LOADS_PER_A_FRAG) + # Consume the first top-left row before issuing bottom-A. + # The barriers only constrain LLVM scheduling; they are not + # hardware synchronization points. + rocdl.s_wait_dscnt(_first_top_row_keep) + rocdl.sched_barrier(0) + _emit_group_rows(0, 0, a_top_frags, b_left_frags, a_scales, b_scales, 0, 1) + rocdl.sched_barrier(0) - _emit_group(0, 0, a_top_frags, b_left_frags, a_scales, b_scales) + a_bottom_frags = _load_a_group(_fp8_half_wm, _fp8_half_wm, ks) + if const_expr(_fp8_half_wm > 1): + _emit_group_rows( + 0, + 0, + a_top_frags, + b_left_frags, + a_scales, + b_scales, + 1, + _fp8_half_wm - 1, + ) b_right_frags = _load_b_half(_fp8_half_wn, ks) - # Keep the newly issued right-half B loads outstanding while - # bottom A becomes ready for the second quadrant. - rocdl.s_wait_dscnt(_b_half_loads) + # Drain bottom-A while keeping most right-half B in flight. + rocdl.s_wait_dscnt(_bottom_left_keep) _emit_group(_fp8_half_wm, 0, a_bottom_frags, b_left_frags, a_scales, b_scales) @@ -1175,22 +1284,33 @@ def _emit_group(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, emit_fil if const_expr(not is_last_ks): next_left_frags, next_b_scales = _load_b_left_bundle(ks + 1) - # Current right-half B must be ready before Q2/Q3, while - # the next ks left-half bundle stays in flight. - rocdl.s_wait_dscnt(_b_left_bundle_loads) - else: - rocdl.s_wait_dscnt(0) - _emit_group(0, _fp8_half_wn, a_top_frags, b_right_frags, a_scales, b_scales) - _emit_group( - _fp8_half_wm, - _fp8_half_wn, - a_bottom_frags, - b_right_frags, - a_scales, - b_scales, - emit_filler_now=is_last_ks, - ) + for wn_local in range_constexpr(_fp8_half_wn): + if const_expr(not is_last_ks): + _right_keep = _b_left_bundle_loads + (_fp8_half_wn - wn_local - 1) * _b_frag_loads_per_wn + else: + _right_keep = (_fp8_half_wn - wn_local - 1) * _b_frag_loads_per_wn + rocdl.s_wait_dscnt(_right_keep) + _emit_group_col(0, _fp8_half_wn, a_top_frags, b_right_frags, a_scales, b_scales, wn_local) + + if const_expr(is_last_ks and late_compute_callback is not None): + rocdl.sched_barrier(0) + late_compute_callback() + + if const_expr(is_last_ks and emit_filler is not None): + rocdl.sched_barrier(0) + emit_filler() + + for wn_local in range_constexpr(_fp8_half_wn): + _emit_group_col( + _fp8_half_wm, + _fp8_half_wn, + a_bottom_frags, + b_right_frags, + a_scales, + b_scales, + wn_local, + ) if const_expr(not is_last_ks): b_left_frags = next_left_frags @@ -1198,6 +1318,149 @@ def _emit_group(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, emit_fil return current_accs + def compute_tile_fp8_deep_pipeline( + accs_in, + lds_a, + lds_b, + lds_as, + lds_bs, + emit_filler=None, + mid_compute_callback=None, + late_compute_callback=None, + ): + current_accs = list(accs_in) + a_buf, a_bases = _precompute_a_lane_bases(lds_a) + b_buf, b_bases = _precompute_b_lane_bases(lds_b) + as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) + bs_buf, bs_bases = _precompute_scale_lane_bases( + lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b + ) + + def load_a_pair(wm_pair, ks): + wm_base = wm_pair * _fp8_pair_wm + return [ + load_a_frag(a_buf, a_bases[wm_base + wm_local], ks) + for wm_local in range_constexpr(_fp8_pair_wm) + ] + + def load_b_pair(wn_pair, ks): + wn_base = wn_pair * _fp8_pair_wn + return [ + load_b_frag(b_buf, b_bases, wn_base + wn_local, ks) + for wn_local in range_constexpr(_fp8_pair_wn) + ] + + def _load_a_scales(ks): + return load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) + + def _load_b_scales(ks): + return load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) + + def emit_panel_2x2(wm_pair, wn_pair, a_pair, b_pair, scale_pair, prefetch_after_first_row=None): + a_scales, b_scales = scale_pair + wm_base = wm_pair * _fp8_pair_wm + wn_base = wn_pair * _fp8_pair_wn + for wn_local in range_constexpr(_fp8_pair_wn): + _emit_wmma( + current_accs, + wm_base, + wn_base + wn_local, + a_pair[0], + b_pair[wn_local], + a_scales, + b_scales, + ) + if const_expr(prefetch_after_first_row is not None): + prefetch_after_first_row() + for wn_local in range_constexpr(_fp8_pair_wn): + _emit_wmma( + current_accs, + wm_base + 1, + wn_base + wn_local, + a_pair[1], + b_pair[wn_local], + a_scales, + b_scales, + ) + + _pair_loads = _fp8_pair_a_loads + _two_pair_loads = _fp8_pair_a_loads + _fp8_pair_b_loads + _three_pair_loads = _fp8_pair_b_loads + _fp8_pair_a_loads * 2 + + for ks in range_constexpr(k_wmma_steps): + is_last_ks = ks == k_wmma_steps - 1 + a_scales = _load_a_scales(ks) + b_scales = _load_b_scales(ks) + scale_pair = (a_scales, b_scales) + + b0 = load_b_pair(0, ks) + a0 = load_a_pair(0, ks) + b1 = load_b_pair(1, ks) + + a1_box = [None] + b2_box = [None] + b3_box = [None] + a2_box = [None] + a3_box = [None] + + def _prefetch_a1(): + a1_box[0] = load_a_pair(1, ks) + + rocdl.s_wait_dscnt(_fp8_pair_b_loads) + emit_panel_2x2(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) + + if const_expr(ks == 0 and mid_compute_callback is not None): + rocdl.sched_barrier(0) + mid_compute_callback() + + def _prefetch_b2(): + b2_box[0] = load_b_pair(2, ks) + + rocdl.s_wait_dscnt(_pair_loads) + emit_panel_2x2(0, 1, a0, b1, scale_pair, prefetch_after_first_row=_prefetch_b2) + + def _prefetch_b3(): + b3_box[0] = load_b_pair(3, ks) + + rocdl.s_wait_dscnt(_fp8_pair_b_loads) + emit_panel_2x2(1, 0, a1_box[0], b0, scale_pair, prefetch_after_first_row=_prefetch_b3) + + def _prefetch_a2_a3(): + a2_box[0] = load_a_pair(2, ks) + a3_box[0] = load_a_pair(3, ks) + + emit_panel_2x2(1, 1, a1_box[0], b1, scale_pair, prefetch_after_first_row=_prefetch_a2_a3) + + rocdl.s_wait_dscnt(_three_pair_loads) + emit_panel_2x2(0, 2, a0, b2_box[0], scale_pair) + rocdl.s_wait_dscnt(_two_pair_loads) + emit_panel_2x2(0, 3, a0, b3_box[0], scale_pair) + emit_panel_2x2(1, 2, a1_box[0], b2_box[0], scale_pair) + emit_panel_2x2(1, 3, a1_box[0], b3_box[0], scale_pair) + + rocdl.s_wait_dscnt(_fp8_pair_a_loads) + emit_panel_2x2(2, 0, a2_box[0], b0, scale_pair) + emit_panel_2x2(2, 1, a2_box[0], b1, scale_pair) + + rocdl.s_wait_dscnt(0) + emit_panel_2x2(3, 0, a3_box[0], b0, scale_pair) + emit_panel_2x2(3, 1, a3_box[0], b1, scale_pair) + + if const_expr(is_last_ks and late_compute_callback is not None): + rocdl.sched_barrier(0) + late_compute_callback() + + if const_expr(is_last_ks and emit_filler is not None): + rocdl.sched_barrier(0) + emit_filler() + + emit_panel_2x2(2, 2, a2_box[0], b2_box[0], scale_pair) + emit_panel_2x2(2, 3, a2_box[0], b3_box[0], scale_pair) + emit_panel_2x2(3, 2, a3_box[0], b2_box[0], scale_pair) + emit_panel_2x2(3, 3, a3_box[0], b3_box[0], scale_pair) + + return current_accs + def compute_tile_b_streaming( accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=None, mid_compute_callback=None ): @@ -1289,27 +1552,75 @@ def hot_loop_scheduler_fp4_bank_friendly(): rocdl.sched_barrier(0) def hot_loop_scheduler_fp8_quadrant(): - _a_all_loads = wmma_m_rep * DS_LOADS_PER_A_FRAG _a_scale_loads = (wmma_m_rep + 3) // 4 + _a_top_loads = _fp8_half_wm * DS_LOADS_PER_A_FRAG + _a_bottom_loads = _a_top_loads _b_half_loads = _fp8_half_wn * _b_frag_loads_per_wn _b_left_bundle_loads = _b_half_loads + _fp8_b_scale_loads _group_wmma = _fp8_group_size + _first_row_wmma = _fp8_half_wn + _remaining_top_left_wmma = (_fp8_half_wm - 1) * _fp8_half_wn for _ks in range_constexpr(k_wmma_steps): if const_expr(_ks == 0): - rocdl.sched_dsrd(_b_left_bundle_loads + _a_scale_loads + _a_all_loads) + rocdl.sched_dsrd(_b_left_bundle_loads + _a_scale_loads + _a_top_loads) else: - rocdl.sched_dsrd(_a_scale_loads + _a_all_loads) - rocdl.sched_mfma(_group_wmma) + rocdl.sched_dsrd(_a_scale_loads + _a_top_loads) + rocdl.sched_mfma(_first_row_wmma) + rocdl.sched_dsrd(_a_bottom_loads) + if const_expr(_remaining_top_left_wmma > 0): + rocdl.sched_mfma(_remaining_top_left_wmma) rocdl.sched_dsrd(_b_half_loads) rocdl.sched_mfma(_group_wmma) if const_expr(_ks < k_wmma_steps - 1): rocdl.sched_dsrd(_b_left_bundle_loads) - rocdl.sched_mfma(_group_wmma) - rocdl.sched_mfma(_group_wmma) + for _wn_local in range_constexpr(_fp8_half_wn): + rocdl.sched_mfma(_fp8_half_wm) + for _wn_local in range_constexpr(_fp8_half_wn): + rocdl.sched_mfma(_fp8_half_wm) rocdl.sched_barrier(0) - def compute_tile_scheduled(accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=None, mid_compute_callback=None): + def hot_loop_scheduler_fp8_deep_pipeline(): + def _sched_panel_2x2(prefetch_loads=0): + if const_expr(prefetch_loads > 0): + rocdl.sched_mfma(_fp8_pair_wn) + rocdl.sched_dsrd(prefetch_loads) + rocdl.sched_mfma(_fp8_pair_wn) + else: + rocdl.sched_mfma(_fp8_pair_wm * _fp8_pair_wn) + + _initial_loads = _fp8_scale_loads + _fp8_pair_b_loads + _fp8_pair_a_loads + _fp8_pair_b_loads + + for _ks in range_constexpr(k_wmma_steps): + rocdl.sched_dsrd(_initial_loads) + _sched_panel_2x2(_fp8_pair_a_loads) + _sched_panel_2x2(_fp8_pair_b_loads) + _sched_panel_2x2(_fp8_pair_b_loads) + _sched_panel_2x2(_fp8_pair_a_loads * 2) + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + rocdl.sched_barrier(0) + + def compute_tile_scheduled( + accs_in, + lds_a, + lds_b, + lds_as, + lds_bs, + emit_filler=None, + mid_compute_callback=None, + late_compute_callback=None, + ): if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING): return compute_tile_b_streaming( accs_in, @@ -1339,6 +1650,18 @@ def compute_tile_scheduled(accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=No lds_bs, emit_filler=emit_filler, mid_compute_callback=mid_compute_callback, + late_compute_callback=late_compute_callback, + ) + if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE): + return compute_tile_fp8_deep_pipeline( + accs_in, + lds_a, + lds_b, + lds_as, + lds_bs, + emit_filler=emit_filler, + mid_compute_callback=mid_compute_callback, + late_compute_callback=late_compute_callback, ) return compute_tile( accs_in, @@ -1377,6 +1700,8 @@ def hot_loop_scheduler_scheduled(): hot_loop_scheduler_b_streaming() elif const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): hot_loop_scheduler_fp4_bank_friendly() + elif const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE): + hot_loop_scheduler_fp8_deep_pipeline() elif const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT): hot_loop_scheduler_fp8_quadrant() else: @@ -1805,6 +2130,9 @@ def _issue_active_tdm(load_stage, addr_box, scale_k_box=None, k_prefetch=None): # This overlaps TDM DMA with the remaining WMMA instructions, _fence_outstanding = TDM_LOADS_PER_STEP * (num_buffers - 2) + if const_expr(loop_iters > 0 and use_ws_tdm_split_signal_overlap): + _pipeline_fence_signal(outstanding=_fence_outstanding) + if const_expr(loop_iters > 0): if const_expr(wave_specialized_tdm and not use_scale_buffer_load): init_args = list(accs) + [active_addr_lo] @@ -1816,7 +2144,8 @@ def _issue_active_tdm(load_stage, addr_box, scale_k_box=None, k_prefetch=None): for buf_idx in range_constexpr(num_buffers): load_stage = (buf_idx + num_buffers - 1) % num_buffers - _pipeline_fence_signal(outstanding=_fence_outstanding) + if const_expr(not use_ws_tdm_split_signal_overlap): + _pipeline_fence_signal(outstanding=_fence_outstanding) pipeline_fence_wait(use_cluster=use_cluster) addr_box = [cur_addr_lo] @@ -1832,6 +2161,14 @@ def _mid_tdm_ws( ): _issue_active_tdm(_ls, _ab, k_prefetch=_k_off) + _late_tdm_ws_fence_signal = None + if const_expr(use_ws_tdm_split_signal_overlap): + + def _late_tdm_ws_split_signal(): + _pipeline_fence_signal(outstanding=_fence_outstanding) + + _late_tdm_ws_fence_signal = _late_tdm_ws_split_signal + rocdl.sched_barrier(0) accs_in = compute_tile_scheduled( accs_in, @@ -1840,6 +2177,7 @@ def _mid_tdm_ws( stages_as_idx[buf_idx], stages_bs_idx[buf_idx], mid_compute_callback=_mid_tdm_ws, + late_compute_callback=_late_tdm_ws_fence_signal, ) cur_addr_lo = addr_box[0] hot_loop_scheduler_scheduled() @@ -2020,6 +2358,8 @@ def _mid_tdm_nws( addr_lo_bs = results[n_accs + 3] # Tail — same acc_mixed pattern: fence at top, TDM mid-compute. + if const_expr(loop_iters > 0 and use_ws_tdm_split_signal_overlap): + pipeline_fence_wait(use_cluster=use_cluster) _wait_scale_all() if const_expr(loop_iters > 0): _pipeline_fence(outstanding=0) @@ -2181,6 +2521,7 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): atomic_barrier_enable, b_streaming, scale_load_path, + fp8_schedule, ) @flyc.jit From e20d06a3f789e6c7a054ff5f6d70b9875d7747f3 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sun, 24 May 2026 13:28:17 +0000 Subject: [PATCH 03/19] Optimize FP8 deep pipeline waits --- kernels/gemm_fp8fp4_gfx1250.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index ac4cd41a9..7a047ec22 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -1396,9 +1396,9 @@ def emit_panel_2x2(wm_pair, wn_pair, a_pair, b_pair, scale_pair, prefetch_after_ b0 = load_b_pair(0, ks) a0 = load_a_pair(0, ks) b1 = load_b_pair(1, ks) + b2 = load_b_pair(2, ks) a1_box = [None] - b2_box = [None] b3_box = [None] a2_box = [None] a3_box = [None] @@ -1406,23 +1406,20 @@ def emit_panel_2x2(wm_pair, wn_pair, a_pair, b_pair, scale_pair, prefetch_after_ def _prefetch_a1(): a1_box[0] = load_a_pair(1, ks) - rocdl.s_wait_dscnt(_fp8_pair_b_loads) + rocdl.s_wait_dscnt(_two_pair_loads + 3) emit_panel_2x2(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) if const_expr(ks == 0 and mid_compute_callback is not None): rocdl.sched_barrier(0) mid_compute_callback() - def _prefetch_b2(): - b2_box[0] = load_b_pair(2, ks) - - rocdl.s_wait_dscnt(_pair_loads) - emit_panel_2x2(0, 1, a0, b1, scale_pair, prefetch_after_first_row=_prefetch_b2) + rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) + emit_panel_2x2(0, 1, a0, b1, scale_pair) def _prefetch_b3(): b3_box[0] = load_b_pair(3, ks) - rocdl.s_wait_dscnt(_fp8_pair_b_loads) + rocdl.s_wait_dscnt(_fp8_pair_b_loads + 2) emit_panel_2x2(1, 0, a1_box[0], b0, scale_pair, prefetch_after_first_row=_prefetch_b3) def _prefetch_a2_a3(): @@ -1431,14 +1428,12 @@ def _prefetch_a2_a3(): emit_panel_2x2(1, 1, a1_box[0], b1, scale_pair, prefetch_after_first_row=_prefetch_a2_a3) - rocdl.s_wait_dscnt(_three_pair_loads) - emit_panel_2x2(0, 2, a0, b2_box[0], scale_pair) - rocdl.s_wait_dscnt(_two_pair_loads) + emit_panel_2x2(0, 2, a0, b2, scale_pair) + rocdl.s_wait_dscnt(_pair_loads) emit_panel_2x2(0, 3, a0, b3_box[0], scale_pair) - emit_panel_2x2(1, 2, a1_box[0], b2_box[0], scale_pair) + emit_panel_2x2(1, 2, a1_box[0], b2, scale_pair) emit_panel_2x2(1, 3, a1_box[0], b3_box[0], scale_pair) - rocdl.s_wait_dscnt(_fp8_pair_a_loads) emit_panel_2x2(2, 0, a2_box[0], b0, scale_pair) emit_panel_2x2(2, 1, a2_box[0], b1, scale_pair) @@ -1454,9 +1449,9 @@ def _prefetch_a2_a3(): rocdl.sched_barrier(0) emit_filler() - emit_panel_2x2(2, 2, a2_box[0], b2_box[0], scale_pair) + emit_panel_2x2(2, 2, a2_box[0], b2, scale_pair) emit_panel_2x2(2, 3, a2_box[0], b3_box[0], scale_pair) - emit_panel_2x2(3, 2, a3_box[0], b2_box[0], scale_pair) + emit_panel_2x2(3, 2, a3_box[0], b2, scale_pair) emit_panel_2x2(3, 3, a3_box[0], b3_box[0], scale_pair) return current_accs @@ -1589,12 +1584,12 @@ def _sched_panel_2x2(prefetch_loads=0): else: rocdl.sched_mfma(_fp8_pair_wm * _fp8_pair_wn) - _initial_loads = _fp8_scale_loads + _fp8_pair_b_loads + _fp8_pair_a_loads + _fp8_pair_b_loads + _initial_loads = _fp8_scale_loads + _fp8_pair_b_loads * 3 + _fp8_pair_a_loads for _ks in range_constexpr(k_wmma_steps): rocdl.sched_dsrd(_initial_loads) _sched_panel_2x2(_fp8_pair_a_loads) - _sched_panel_2x2(_fp8_pair_b_loads) + _sched_panel_2x2() _sched_panel_2x2(_fp8_pair_b_loads) _sched_panel_2x2(_fp8_pair_a_loads * 2) _sched_panel_2x2() From aa0469180a644a69e504aa3952f8e31ddd50aea8 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Tue, 26 May 2026 06:04:59 +0000 Subject: [PATCH 04/19] tuning pipeline barrier --- kernels/gemm_fp8fp4_gfx1250.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 7a047ec22..0319da6f1 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -1435,16 +1435,15 @@ def _prefetch_a2_a3(): emit_panel_2x2(1, 3, a1_box[0], b3_box[0], scale_pair) emit_panel_2x2(2, 0, a2_box[0], b0, scale_pair) + if const_expr(is_last_ks and late_compute_callback is not None): + rocdl.sched_barrier(0) + late_compute_callback() emit_panel_2x2(2, 1, a2_box[0], b1, scale_pair) rocdl.s_wait_dscnt(0) emit_panel_2x2(3, 0, a3_box[0], b0, scale_pair) emit_panel_2x2(3, 1, a3_box[0], b1, scale_pair) - if const_expr(is_last_ks and late_compute_callback is not None): - rocdl.sched_barrier(0) - late_compute_callback() - if const_expr(is_last_ks and emit_filler is not None): rocdl.sched_barrier(0) emit_filler() From 56d30c12df742245250af7b19e1e454cb114d11f Mon Sep 17 00:00:00 2001 From: aoli26 Date: Tue, 26 May 2026 12:55:44 +0000 Subject: [PATCH 05/19] fix inst prefetch bug --- kernels/gemm_fp8fp4_gfx1250.py | 2 +- python/flydsl/expr/rocdl/inline_asm.py | 18 +++++------------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 0319da6f1..7d5dd08f0 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -452,7 +452,7 @@ def kernel_mxscale_gemm( if const_expr(inst_prefetch): if rocdl.wave_id() == fx.Int32(0): - rocdl.s_prefetch_inst_burst(num_pages=10) + rocdl.s_prefetch_inst_burst(num_pages=4) tx = gpu.thread_id("x") bx = gpu.block_id("x") diff --git a/python/flydsl/expr/rocdl/inline_asm.py b/python/flydsl/expr/rocdl/inline_asm.py index 357af0d21..f16ffacb5 100644 --- a/python/flydsl/expr/rocdl/inline_asm.py +++ b/python/flydsl/expr/rocdl/inline_asm.py @@ -67,23 +67,15 @@ def cvt_pk_bf16_f32(src_a_f32, src_b_f32): ) -def s_prefetch_inst_burst(num_pages: int = 10, page_bytes: int = 4096): - """gfx1250: prefetch ``num_pages`` cache lines of instructions ahead of PC. +def s_prefetch_inst_burst(num_pages: int = 3, page_bytes: int = 4096): + """gfx1250: prefetch ``num_pages`` × 4 KB of instructions ahead of PC. - Sets HW_REG_WAVE_MODE bit 8 (allow s_prefetch) then issues - ``s_prefetch_inst_pc_rel offset, s0, 31`` for ``offset = 0, page_bytes, - 2*page_bytes, ...``. Used by GEMM kernels to warm the I-cache before the - main loop starts so the first few iterations don't stall on instruction - fetch. - - Wraps the inline-asm sequence so callers do not need to import the raw - ``llvm`` dialect. + Caller must keep ``num_pages * page_bytes`` within shader bounds; over-reach + page-faults. """ from ..._mlir.dialects import llvm as _llvm - lines = ["s_setreg_imm32_b32 hwreg(HW_REG_WAVE_MODE, 8, 1), 1"] - for pg in range(num_pages): - lines.append(f"s_prefetch_inst_pc_rel {pg * page_bytes}, s0, 31") + lines = [f"s_prefetch_inst_pc_rel {pg * page_bytes}, null, 31" for pg in range(num_pages)] _llvm.inline_asm( None, [], From 717510d49cc0070b5a14ae48b09ead2c552914c1 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Tue, 26 May 2026 14:12:45 +0000 Subject: [PATCH 06/19] Optimize FP8/FP4 GEMM panel scheduling --- kernels/gemm_fp8fp4_gfx1250.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 7d5dd08f0..feadc021c 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -1383,6 +1383,21 @@ def emit_panel_2x2(wm_pair, wn_pair, a_pair, b_pair, scale_pair, prefetch_after_ b_scales, ) + def emit_panel_2x2_row(wm_pair, wn_pair, row_local, a_pair, b_pair, scale_pair): + a_scales, b_scales = scale_pair + wm_base = wm_pair * _fp8_pair_wm + wn_base = wn_pair * _fp8_pair_wn + for wn_local in range_constexpr(_fp8_pair_wn): + _emit_wmma( + current_accs, + wm_base + row_local, + wn_base + wn_local, + a_pair[row_local], + b_pair[wn_local], + a_scales, + b_scales, + ) + _pair_loads = _fp8_pair_a_loads _two_pair_loads = _fp8_pair_a_loads + _fp8_pair_b_loads _three_pair_loads = _fp8_pair_b_loads + _fp8_pair_a_loads * 2 @@ -1429,9 +1444,10 @@ def _prefetch_a2_a3(): emit_panel_2x2(1, 1, a1_box[0], b1, scale_pair, prefetch_after_first_row=_prefetch_a2_a3) emit_panel_2x2(0, 2, a0, b2, scale_pair) + emit_panel_2x2_row(1, 2, 0, a1_box[0], b2, scale_pair) rocdl.s_wait_dscnt(_pair_loads) emit_panel_2x2(0, 3, a0, b3_box[0], scale_pair) - emit_panel_2x2(1, 2, a1_box[0], b2, scale_pair) + emit_panel_2x2_row(1, 2, 1, a1_box[0], b2, scale_pair) emit_panel_2x2(1, 3, a1_box[0], b3_box[0], scale_pair) emit_panel_2x2(2, 0, a2_box[0], b0, scale_pair) From 30a97e4b70e231e0ebea0f6b3058b35a2a18a777 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Thu, 28 May 2026 09:41:13 +0000 Subject: [PATCH 07/19] Improve gfx1250 GEMM graph benchmarking --- tests/kernels/benchmark_common.py | 11 + tests/kernels/test_gemm_fp8fp4_gfx1250.py | 428 +++++++++++++++++----- 2 files changed, 345 insertions(+), 94 deletions(-) diff --git a/tests/kernels/benchmark_common.py b/tests/kernels/benchmark_common.py index 2f5749361..06192c0f0 100644 --- a/tests/kernels/benchmark_common.py +++ b/tests/kernels/benchmark_common.py @@ -468,6 +468,17 @@ def bench_kernel_us(run_fn, warmup=10, iters=50, flush_l2=True, prep_fn=None): run_fn() torch.cuda.synchronize() + if flush_buf is None and prep_fn is None: + # Single event pair preserves back-to-back launch pipelining. + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + run_fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) * 1e3 / iters + start_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index acabe17e0..59de0224a 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -4,6 +4,7 @@ Kernel implementation: kernels/gemm_fp8fp4_gfx1250.py """ +import math import os import re import sys @@ -53,6 +54,101 @@ def random_fp8_data(rows: int, cols: int, *, device="cpu") -> torch.Tensor: return torch.randint(0, 126, (rows, cols), dtype=torch.uint8, device=device) +def _fp8_e4m3fn_byte(value: float) -> int: + """Return torch's FP8 E4M3FN byte encoding for a finite scalar.""" + t = torch.tensor([float(value)], dtype=torch.float8_e4m3fn) + byte = int(t.view(torch.uint8).item()) + if (byte & 0x7F) == 0x7F: + raise SystemExit(f"--fill-mode constant {value:g} is outside the finite FP8 E4M3FN range") + return byte + + +def _parse_fill_mode(arg: str): + """Parse --fill-mode as ('random',) or ('const', value).""" + if arg == "random": + return ("random",) + if arg == "zero": + return ("const", 0.0) + try: + value = float(arg) + except ValueError as e: + raise SystemExit(f"--fill-mode must be 'random' or a finite float constant, got {arg!r}") from e + if not math.isfinite(value): + raise SystemExit(f"--fill-mode constant must be finite, got {arg!r}") + return ("const", value) + + +def _fp4_e2m1_packed_fill(rows: int, cols: int, value: float) -> torch.Tensor: + dense = torch.full((rows, cols), float(value), dtype=torch.float32) + return fp4_utils.f32_to_mxfp4(dense).view(torch.uint8) + + +def _random_mxscale_inputs(M: int, N: int, K: int, data_format: str): + if data_format == "a8w4": + a = random_fp8_data(M, K) + b = fp4_utils.random_fp4_packed(N, K) + elif data_format == "fp4": + a = fp4_utils.random_fp4_packed(M, K) + b = fp4_utils.random_fp4_packed(N, K) + elif data_format == "fp8": + a = random_fp8_data(M, K) + b = random_fp8_data(N, K) + else: + raise ValueError(f"unsupported data_format={data_format!r}") + return a, b, fp4_utils.random_e8m0(M, K // SCALE_BLOCK), fp4_utils.random_e8m0(N, K // SCALE_BLOCK) + + +def _const_fill_inputs(M, N, K, data_format: str, value: float): + """Build constant A/B tensors with neutral E8M0 scales for CLI runs.""" + if data_format == "fp4": + a = _fp4_e2m1_packed_fill(M, K, value) + b = _fp4_e2m1_packed_fill(N, K, value) + elif data_format == "a8w4": + fp8_byte = _fp8_e4m3fn_byte(value) + a = torch.full((M, K), fp8_byte, dtype=torch.uint8) + b = _fp4_e2m1_packed_fill(N, K, value) + elif data_format == "fp8": + fp8_byte = _fp8_e4m3fn_byte(value) + a = torch.full((M, K), fp8_byte, dtype=torch.uint8) + b = torch.full((N, K), fp8_byte, dtype=torch.uint8) + else: + raise ValueError(f"unsupported data_format={data_format!r}") + a_scale = torch.full((M, K // SCALE_BLOCK), 127, dtype=torch.uint8) + b_scale = torch.full((N, K // SCALE_BLOCK), 127, dtype=torch.uint8) + return a, b, a_scale, b_scale + + +def _fill_mode_inputs(M: int, N: int, K: int, data_format: str, fill_mode: str): + fill_spec = _parse_fill_mode(fill_mode) + if fill_spec[0] == "const": + a, b, a_scale, b_scale = _const_fill_inputs(M, N, K, data_format, fill_spec[1]) + else: + a, b, a_scale, b_scale = _random_mxscale_inputs(M, N, K, data_format) + return a, b, a_scale, b_scale, fill_spec + + +def _fill_mode_label(fill_spec, data_format: str) -> str: + if fill_spec[0] == "random": + return "random (seed=0)" + label = f"const={fill_spec[1]:g}, E8M0 byte=127" + if data_format in ("fp8", "a8w4"): + label += f", FP8 byte=0x{_fp8_e4m3fn_byte(fill_spec[1]):02x}" + return label + + +def _has_nonzero_quantized_values(tensor: torch.Tensor, data_format: str) -> bool: + convert = fp4_utils.mxfp4_to_f32 if data_format == "fp4" else fp4_utils.fp8_e4m3_to_f32 + return bool(convert(tensor.view(torch.uint8)).abs().max().item() > 0) + + +def _expect_nonzero_graph_output(a: torch.Tensor, b: torch.Tensor, data_format: str, fill_spec) -> bool: + if fill_spec[0] == "random": + return True + a_format = "fp4" if data_format == "fp4" else "fp8" + b_format = "fp8" if data_format == "fp8" else "fp4" + return _has_nonzero_quantized_values(a, a_format) and _has_nonzero_quantized_values(b, b_format) + + def _reference_scaled_gemm(a, b, a_scale, b_scale, M, N, K, convert_fn, convert_fn_b=None): """Reference scaled GEMM: D = (A * A_scale) @ (B * B_scale)^T.""" a_f32 = convert_fn(a.view(torch.uint8))[:M, :K] @@ -327,18 +423,12 @@ def _run_mxscale_gemm_test( scale_load_path=scale_load_path, ) - # Pre-bind via flyc.compile so the launch goes through the CompiledFunction - # ctypes fast path (matches test_blockscale_preshuffle_gemm.py and any - # production caller that bench-times this kernel). The slow JitFunction - # path adds ~17us of inspect.Signature.bind + _resolve_and_make_cache_key per call, - # which would mask genuine kernel timing differences in the bench path. - # flyc.compile() launches the kernel once internally to trigger - # compilation, so no separate eager call is needed for correctness. - c_flat = c_gpu.contiguous().view(-1) - a_flat = a_gpu.contiguous().view(-1) - b_flat = b_gpu.contiguous().view(-1) - as_flat = as_gpu.contiguous().view(-1) - bs_flat = bs_gpu.contiguous().view(-1) + # Keep 2D — dynamic_layout=True packs shape as i32; flattening overflows for M*K >= 2^31. + c_flat = c_gpu.contiguous() + a_flat = a_gpu.contiguous() + b_flat = b_gpu.contiguous() + as_flat = as_gpu.contiguous() + bs_flat = bs_gpu.contiguous() flyc.compile( launch_fn, @@ -855,11 +945,11 @@ def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_ split_k=1, ) - c_flat = c_gpu.contiguous().view(-1) - a_flat = a_gpu.contiguous().view(-1) - b_flat = b_gpu.contiguous().view(-1) - as_flat = as_gpu.contiguous().view(-1) - bs_flat = bs_gpu.contiguous().view(-1) + c_flat = c_gpu.contiguous() + a_flat = a_gpu.contiguous() + b_flat = b_gpu.contiguous() + as_flat = as_gpu.contiguous() + bs_flat = bs_gpu.contiguous() compiled_exe = flyc.compile( launch_fn, c_flat, @@ -914,27 +1004,11 @@ def launch(): ) -def _bench_kernel_us_cudagraph(run_fn, warmup=10, iters=100, prep_fn=None): - """Per-launch timer that strips host launch overhead via hipGraph. - - How it works: - - Capture a single kernel launch into a hipGraph - - Replay it N times in one stream submission burst - - Per-launch time = total_burst_time / N - - NB: stream-ordered execution guarantees the N replays serialise — each - g.replay() is one submission to the stream and the next one cannot - start until the previous one finishes. - - NB: no L2 flush between replays. The whole point of the graph is to - measure the kernel in a hot, back-to-back launch scenario (which is - what production serving looks like). For cold-cache numbers use the - regular _bench_kernel_us with flush_l2=True. - """ +def _bench_kernel_us_cudagraph(run_fn, warmup=10, iters=100, prep_fn=None, n_per_graph=20): + """Per-launch timer via hipGraph: capture n_per_graph launches, replay iters times, single event pair around the whole replay loop.""" capture_stream = torch.cuda.Stream() capture_stream.wait_stream(torch.cuda.current_stream()) - # Warmup on the capture stream so the allocator / JIT cache is settled. with torch.cuda.stream(capture_stream): for _ in range(warmup): if prep_fn is not None: @@ -943,15 +1017,46 @@ def _bench_kernel_us_cudagraph(run_fn, warmup=10, iters=100, prep_fn=None): torch.cuda.current_stream().wait_stream(capture_stream) torch.cuda.synchronize() - # Capture exactly one kernel launch into the graph. g = torch.cuda.CUDAGraph() if prep_fn is not None: prep_fn() with torch.cuda.graph(g, stream=capture_stream): + for _ in range(n_per_graph): + run_fn() + torch.cuda.synchronize() + + # Sanity guard against empty graph capture. + ref_start = torch.cuda.Event(enable_timing=True) + ref_end = torch.cuda.Event(enable_timing=True) + ref_start.record() + for _ in range(n_per_graph): run_fn() + ref_end.record() + torch.cuda.synchronize() + ref_per_launch_us = ref_start.elapsed_time(ref_end) * 1e3 / n_per_graph + + rep_start = torch.cuda.Event(enable_timing=True) + rep_end = torch.cuda.Event(enable_timing=True) + rep_start.record() + g.replay() + rep_end.record() torch.cuda.synchronize() + first_replay_per_launch_us = rep_start.elapsed_time(rep_end) * 1e3 / n_per_graph + + print( + f"SANITY_GRAPH,n_per_graph={n_per_graph}," + f"ref_per_launch_us={ref_per_launch_us:.3f}," + f"first_replay_per_launch_us={first_replay_per_launch_us:.3f}", + file=sys.stderr, + flush=True, + ) + if first_replay_per_launch_us < 1.0 and ref_per_launch_us > 2.0: + raise RuntimeError( + f"hipGraph replay per-launch={first_replay_per_launch_us:.3f}us " + f"<< ref direct-launch={ref_per_launch_us:.3f}us. " + f"Graph capture likely empty (stream mismatch?)." + ) - # Time iters replays as one batch. start_ev = torch.cuda.Event(enable_timing=True) end_ev = torch.cuda.Event(enable_timing=True) start_ev.record() @@ -959,18 +1064,11 @@ def _bench_kernel_us_cudagraph(run_fn, warmup=10, iters=100, prep_fn=None): g.replay() end_ev.record() torch.cuda.synchronize() - total_us = start_ev.elapsed_time(end_ev) * 1e3 - return total_us / iters + return start_ev.elapsed_time(end_ev) * 1e3 / (iters * n_per_graph) def _bench_kernel_us(run_fn, warmup=10, iters=50, flush_l2=True, prep_fn=None): - """Per-iteration CUDA events timer with L2 flush, IQR outlier removal, median. - - Follows the golden practices from benchmarks/gemm_gfx1250_benchmark.py: - - L2 flush between iterations via zero-fill of 2× L2 buffer - - Per-iteration event pairs (no inter-iteration sync on kernel) - - IQR outlier removal (if n >= 8), median latency - """ + """Per-iter CUDA events with L2 flush + IQR-trimmed median; fast path uses a single event pair when no flush/prep is requested (preserves back-to-back launch pipelining).""" flush_buf = None if flush_l2: l2_bytes = getattr( @@ -987,6 +1085,17 @@ def _bench_kernel_us(run_fn, warmup=10, iters=50, flush_l2=True, prep_fn=None): run_fn() torch.cuda.synchronize() + if flush_buf is None and prep_fn is None: + # Single event pair preserves back-to-back launch pipelining. + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + run_fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) * 1e3 / iters + start_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -1057,24 +1166,14 @@ def _run_benchmark(args): ) if args.split_k > 1: print(f" Split-K={args.split_k} (atomic accumulate, buffer-store epilogue)") - print(f" Warmup={args.warmup}, Iters={args.iters}, " f"L2 flush={'ON' if not args.no_flush_l2 else 'OFF'}") - print(" Zero fill: ON (outside timing)") + l2_flush_label = "OFF (graph)" if getattr(args, "use_graph", False) else ("OFF" if args.no_flush_l2 else "ON") + print(f" Warmup={args.warmup}, Iters={args.iters}, L2 flush={l2_flush_label}") + print(" Output init: zero before warmup") print("=" * 72) torch.manual_seed(0) - - if is_a8w4: - a = random_fp8_data(M, K) - b = fp4_utils.random_fp4_packed(N, K) - elif is_fp4: - a = fp4_utils.random_fp4_packed(M, K) - b = fp4_utils.random_fp4_packed(N, K) - else: - a = random_fp8_data(M, K) - b = random_fp8_data(N, K) - - a_scale = fp4_utils.random_e8m0(M, K // SCALE_BLOCK) - b_scale = fp4_utils.random_e8m0(N, K // SCALE_BLOCK) + a, b, a_scale, b_scale, fill_spec = _fill_mode_inputs(M, N, K, data_format, getattr(args, "fill_mode", "random")) + print(f" Fill mode: {_fill_mode_label(fill_spec, data_format)}") a, b, a_scale, b_scale = _pad_mxscale_inputs(a, b, a_scale, b_scale, padded_shape) @@ -1126,23 +1225,13 @@ def _run_benchmark(args): scale_load_path=args.scale_load_path, ) - c_flat = c_gpu.view(-1) - a_flat = a_gpu.view(-1) - b_flat = b_gpu.view(-1) - as_flat = as_gpu.view(-1) - bs_flat = bs_gpu.view(-1) - - # Pre-bind via flyc.compile so the bench loop calls go through the - # CompiledFunction ctypes fast path. The slow JitFunction path adds - # ~17us of inspect.Signature.bind + _resolve_and_make_cache_key per call, which - # would dominate per-launch latency for short kernels. compiled_exe = flyc.compile( launch_fn, - c_flat, - a_flat, - b_flat, - as_flat, - bs_flat, + c_gpu, + a_gpu, + b_gpu, + as_gpu, + bs_gpu, padded_m, padded_n, torch.cuda.current_stream(), @@ -1151,16 +1240,13 @@ def _run_benchmark(args): def prep_kernel(): c_gpu.zero_() - # Resolve the stream lazily inside the closure so the graph-bench path - # captures on the active capture stream rather than the stream bound - # before capture. Same value on the eager path. def run_kernel(): compiled_exe( - c_flat, - a_flat, - b_flat, - as_flat, - bs_flat, + c_gpu, + a_gpu, + b_gpu, + as_gpu, + bs_gpu, padded_m, padded_n, torch.cuda.current_stream(), @@ -1174,12 +1260,7 @@ def run_kernel(): use_graph = getattr(args, "use_graph", False) if use_graph: - print(f"[2/3] Warming up ({args.warmup} iters) + bench via hipGraph ({args.iters} replays)...") - # Graph mode prep: don't zero c_gpu inside the captured kernel - # (zero would be baked into the graph and runs every replay, but - # that's also fine — it would add a trivial memset per replay). - # We omit prep_fn here because the c_gpu state across replays - # doesn't matter for timing. + print(f"[2/3] Warming up ({args.warmup} iters) + bench via hipGraph " f"({args.iters} replays)...") us = _bench_kernel_us_cudagraph(run_kernel, warmup=args.warmup, iters=args.iters) else: print(f"[2/3] Warming up ({args.warmup} iters) + benchmarking ({args.iters} iters)...") @@ -1249,17 +1330,158 @@ def run_kernel(): return us, reported_tflops, bw_gbs +def _run_graph_verify(args): + """Compare eager launch and hipGraph replay for the CLI-selected shape.""" + arch = str(get_rocm_arch()) + if arch != "gfx1250": + raise SystemExit(f"WMMA_SCALE requires gfx1250, got {arch}") + + data_format = args.data_format + M, N, K = args.M, args.N, args.K + tile_m, tile_n, tile_k = args.tile_m, args.tile_n, args.tile_k + if K % SCALE_BLOCK != 0: + raise SystemExit(f"K={K} must be divisible by SCALE_BLOCK={SCALE_BLOCK}") + + padded_shape = _get_padded_problem_shape(data_format, M, N, K, tile_m, tile_n, tile_k, args.split_k) + padded_m = padded_shape["M"] + padded_n = padded_shape["N"] + padded_k = padded_shape["K"] + + print("=" * 72) + print(f" Graph functional verification ({data_format}) on gfx1250") + print(f" Shape: M={M}, N={N}, K={K} (padded {padded_m}x{padded_n}x{padded_k})") + print( + f" Tile: ({tile_m},{tile_n},{tile_k}) warps=({args.m_warp}x{args.n_warp}) " + f"nb={args.num_buffers} sk={args.split_k} " + f"cluster=({args.cluster_m},{args.cluster_n})" + ) + print("=" * 72) + + torch.manual_seed(0) + a, b, a_scale, b_scale, fill_spec = _fill_mode_inputs(M, N, K, data_format, getattr(args, "fill_mode", "random")) + expect_nonzero_output = _expect_nonzero_graph_output(a, b, data_format, fill_spec) + print(f" Fill: {_fill_mode_label(fill_spec, data_format)}") + + a, b, a_scale, b_scale = _pad_mxscale_inputs(a, b, a_scale, b_scale, padded_shape) + + skt = tile_k // SCALE_BLOCK + warp_tile_m = tile_m // args.m_warp + warp_tile_n = tile_n // args.n_warp + a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt) + b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt) + K_packed = padded_k // padded_shape["pack_b"] + b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) + + a_gpu = a.cuda() + b_gpu = b.cuda() + as_gpu = a_scale.cuda() + bs_gpu = b_scale.cuda() + torch_out_dtype = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16}[args.out_dtype] + c_gpu = torch.zeros(padded_m, padded_n, dtype=torch_out_dtype, device="cuda") + + use_tdm_store = not args.no_tdm_store and args.split_k == 1 + launch_fn = compile_mxscale_gemm( + data_format=data_format, + M=padded_m, + N=padded_n, + K=padded_k, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=args.m_warp, + n_warp=args.n_warp, + num_buffers=args.num_buffers, + waves_per_eu=args.waves_per_eu, + l2_prefetch_distance=args.l2_prefetch_distance, + cluster_m=args.cluster_m, + cluster_n=args.cluster_n, + use_tdm_store=use_tdm_store, + out_dtype=args.out_dtype, + inst_prefetch=args.inst_prefetch, + wave_specialized_tdm=args.wave_spec_tdm, + split_k=args.split_k, + use_scale_opsel=args.use_scale_opsel, + expert_sched_mode=args.expert_sched_mode, + atomic_barrier_enable=args.atomic_barrier_enable, + b_streaming=args.b_streaming, + scale_load_path=args.scale_load_path, + ) + + c_flat = c_gpu.contiguous() + a_flat = a_gpu.contiguous() + b_flat = b_gpu.contiguous() + as_flat = as_gpu.contiguous() + bs_flat = bs_gpu.contiguous() + compiled_exe = flyc.compile( + launch_fn, + c_flat, + a_flat, + b_flat, + as_flat, + bs_flat, + padded_m, + padded_n, + torch.cuda.current_stream(), + ) + + def launch(): + compiled_exe(c_flat, a_flat, b_flat, as_flat, bs_flat, padded_m, padded_n, torch.cuda.current_stream()) + + c_gpu.zero_() + launch() + torch.cuda.synchronize() + eager_result = c_gpu.clone() + + g = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + launch() + torch.cuda.current_stream().wait_stream(s) + torch.cuda.synchronize() + + c_gpu.zero_() + with torch.cuda.graph(g, stream=s): + launch() + torch.cuda.synchronize() + + c_gpu.zero_() + g.replay() + torch.cuda.synchronize() + graph_result = c_gpu.clone() + + if expect_nonzero_output: + if eager_result.abs().max().item() == 0: + raise SystemExit( + "FAIL: eager run produced all zeros -- kernel did not execute (unexpected for non-zero fill)." + ) + if graph_result.abs().max().item() == 0: + raise SystemExit( + "FAIL: hipGraph replay produced all zeros -- kernel was NOT captured (stream mismatch suspected)." + ) + if not torch.equal(eager_result, graph_result): + diff = (eager_result.float() - graph_result.float()).abs().max().item() + raise SystemExit(f"FAIL: eager vs hipGraph result mismatch, max abs diff = {diff:.6f}") + + sample_max = eager_result.abs().max().item() + print( + f" Eager output |max| = {sample_max:.6g}" + + ("" if expect_nonzero_output else " (zero is expected for this fill)") + ) + print(" PASS: eager == hipGraph replay (bit-exact)") + + if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() - parser.add_argument("--data-format", type=str, default="fp4", choices=["fp4", "fp8", "a8w4"]) + parser.add_argument("--data-format", type=str, default="fp8", choices=["fp4", "fp8", "a8w4"]) parser.add_argument("-M", type=int, default=1024) parser.add_argument("-N", type=int, default=1024) parser.add_argument("-K", type=int, default=2048) - parser.add_argument("--tile-m", type=int, default=128) + parser.add_argument("--tile-m", type=int, default=256) parser.add_argument("--tile-n", type=int, default=256) - parser.add_argument("--tile-k", type=int, default=256) + parser.add_argument("--tile-k", type=int, default=128) parser.add_argument("--m-warp", type=int, default=2) parser.add_argument("--n-warp", type=int, default=2) parser.add_argument("--num-buffers", type=int, default=4, choices=[2, 3, 4]) @@ -1270,7 +1492,7 @@ def run_kernel(): parser.add_argument("--no-tdm-store", action="store_true", default=False) parser.add_argument("--out-dtype", type=str, default="bf16", choices=["f32", "bf16", "f16"]) parser.add_argument("--inst-prefetch", action="store_true", default=False) - parser.add_argument("--wave-spec-tdm", action="store_true", default=False) + parser.add_argument("--no-wave-spec-tdm", dest="wave_spec_tdm", action="store_false", default=True) parser.add_argument("--waves-per-eu", type=int, default=None) parser.add_argument("--use-scale-opsel", action="store_true", default=False) parser.add_argument( @@ -1303,8 +1525,26 @@ def run_kernel(): "Implicitly disables L2 flush (graph replays " "are back-to-back, hot-cache).", ) + parser.add_argument( + "--verify-graph", + action="store_true", + default=False, + help="Functional verification: capture the kernel in a hipGraph, " + "replay once, assert bit-exact match against an eager launch. " + ) + parser.add_argument( + "--fill-mode", + type=str, + default="random", + help="Input fill mode: 'random', 'zero', or a finite float. Constant " + "mode uses FP8/FP4 encodings for A/B and neutral E8M0 scales.", + ) args = parser.parse_args() + if args.verify_graph: + _run_graph_verify(args) + if not args.benchmark: + sys.exit(0) if args.benchmark: _run_benchmark(args) else: From f65745534c176ebf9fb475a2f5d0855296ea6afe Mon Sep 17 00:00:00 2001 From: aoli26 Date: Fri, 29 May 2026 03:37:28 +0000 Subject: [PATCH 08/19] Optimize gfx1250 FP8 GEMM schedule --- kernels/gemm_fp8fp4_gfx1250.py | 71 +++++++++++++++++++++++++++------- 1 file changed, 56 insertions(+), 15 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index feadc021c..34653472f 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -1327,6 +1327,7 @@ def compute_tile_fp8_deep_pipeline( emit_filler=None, mid_compute_callback=None, late_compute_callback=None, + a0_prefetch=None, ): current_accs = list(accs_in) a_buf, a_bases = _precompute_a_lane_bases(lds_a) @@ -1400,7 +1401,6 @@ def emit_panel_2x2_row(wm_pair, wn_pair, row_local, a_pair, b_pair, scale_pair): _pair_loads = _fp8_pair_a_loads _two_pair_loads = _fp8_pair_a_loads + _fp8_pair_b_loads - _three_pair_loads = _fp8_pair_b_loads + _fp8_pair_a_loads * 2 for ks in range_constexpr(k_wmma_steps): is_last_ks = ks == k_wmma_steps - 1 @@ -1409,7 +1409,10 @@ def emit_panel_2x2_row(wm_pair, wn_pair, row_local, a_pair, b_pair, scale_pair): scale_pair = (a_scales, b_scales) b0 = load_b_pair(0, ks) - a0 = load_a_pair(0, ks) + if const_expr(ks == 0 and a0_prefetch is not None): + a0 = [a0_prefetch[0], load_a_frag(a_buf, a_bases[1], ks)] + else: + a0 = load_a_pair(0, ks) b1 = load_b_pair(1, ks) b2 = load_b_pair(2, ks) @@ -1421,29 +1424,34 @@ def emit_panel_2x2_row(wm_pair, wn_pair, row_local, a_pair, b_pair, scale_pair): def _prefetch_a1(): a1_box[0] = load_a_pair(1, ks) - rocdl.s_wait_dscnt(_two_pair_loads + 3) + first_wait_keep = _two_pair_loads + 3 + if const_expr(ks == 0 and a0_prefetch is not None): + first_wait_keep += DS_LOADS_PER_A_FRAG + rocdl.s_wait_dscnt(first_wait_keep) emit_panel_2x2(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) if const_expr(ks == 0 and mid_compute_callback is not None): rocdl.sched_barrier(0) mid_compute_callback() - rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) - emit_panel_2x2(0, 1, a0, b1, scale_pair) - def _prefetch_b3(): b3_box[0] = load_b_pair(3, ks) + def _prefetch_a3(): + a3_box[0] = load_a_pair(3, ks) + + rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) + emit_panel_2x2(0, 1, a0, b1, scale_pair, prefetch_after_first_row=_prefetch_b3) + rocdl.s_wait_dscnt(_fp8_pair_b_loads + 2) - emit_panel_2x2(1, 0, a1_box[0], b0, scale_pair, prefetch_after_first_row=_prefetch_b3) + emit_panel_2x2(1, 0, a1_box[0], b0, scale_pair, prefetch_after_first_row=_prefetch_a3) - def _prefetch_a2_a3(): + def _prefetch_a2(): a2_box[0] = load_a_pair(2, ks) - a3_box[0] = load_a_pair(3, ks) - emit_panel_2x2(1, 1, a1_box[0], b1, scale_pair, prefetch_after_first_row=_prefetch_a2_a3) + emit_panel_2x2(1, 1, a1_box[0], b1, scale_pair) - emit_panel_2x2(0, 2, a0, b2, scale_pair) + emit_panel_2x2(0, 2, a0, b2, scale_pair, prefetch_after_first_row=_prefetch_a2) emit_panel_2x2_row(1, 2, 0, a1_box[0], b2, scale_pair) rocdl.s_wait_dscnt(_pair_loads) emit_panel_2x2(0, 3, a0, b3_box[0], scale_pair) @@ -1599,17 +1607,24 @@ def _sched_panel_2x2(prefetch_loads=0): else: rocdl.sched_mfma(_fp8_pair_wm * _fp8_pair_wn) + def _sched_panel_row(): + rocdl.sched_mfma(_fp8_pair_wn) + _initial_loads = _fp8_scale_loads + _fp8_pair_b_loads * 3 + _fp8_pair_a_loads for _ks in range_constexpr(k_wmma_steps): - rocdl.sched_dsrd(_initial_loads) + _ks_initial_loads = _initial_loads + if const_expr(_ks == 0): + _ks_initial_loads -= DS_LOADS_PER_A_FRAG + rocdl.sched_dsrd(_ks_initial_loads) _sched_panel_2x2(_fp8_pair_a_loads) - _sched_panel_2x2() _sched_panel_2x2(_fp8_pair_b_loads) - _sched_panel_2x2(_fp8_pair_a_loads * 2) - _sched_panel_2x2() + _sched_panel_2x2(_fp8_pair_a_loads) _sched_panel_2x2() + _sched_panel_2x2(_fp8_pair_a_loads) + _sched_panel_row() _sched_panel_2x2() + _sched_panel_row() _sched_panel_2x2() _sched_panel_2x2() _sched_panel_2x2() @@ -1630,6 +1645,7 @@ def compute_tile_scheduled( emit_filler=None, mid_compute_callback=None, late_compute_callback=None, + a0_prefetch=None, ): if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING): return compute_tile_b_streaming( @@ -1672,6 +1688,7 @@ def compute_tile_scheduled( emit_filler=emit_filler, mid_compute_callback=mid_compute_callback, late_compute_callback=late_compute_callback, + a0_prefetch=a0_prefetch, ) return compute_tile( accs_in, @@ -1717,6 +1734,16 @@ def hot_loop_scheduler_scheduled(): else: hot_loop_scheduler() + def prefetch_fp8_deep_a0_frags(lds_a): + a_buf, a_bases = _precompute_a_lane_bases(lds_a) + return [load_a_frag(a_buf, a_bases[0], 0)] + + def maybe_prefetch_fp8_deep_a0(lds_a): + # Call only after the TDM fence for this stage; pre-fence LDS reads can race multicast delivery. + if const_expr(use_fp8_deep_pipeline_schedule): + return prefetch_fp8_deep_a0_frags(lds_a) + return None + # ── Epilogue (unified via _sub_tiles) ── def _get_acc_sub8(accs, acc_idx, vec_base): """Extract 8-element sub-vector from accumulator.""" @@ -2179,6 +2206,7 @@ def _late_tdm_ws_split_signal(): _late_tdm_ws_fence_signal = _late_tdm_ws_split_signal + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) rocdl.sched_barrier(0) accs_in = compute_tile_scheduled( accs_in, @@ -2188,6 +2216,7 @@ def _late_tdm_ws_split_signal(): stages_bs_idx[buf_idx], mid_compute_callback=_mid_tdm_ws, late_compute_callback=_late_tdm_ws_fence_signal, + a0_prefetch=a0_prefetch, ) cur_addr_lo = addr_box[0] hot_loop_scheduler_scheduled() @@ -2226,6 +2255,7 @@ def _mid_tdm_split_scale_dma( ): _issue_active_tdm(_ls, _ab, scale_k_box=_scale_k, k_prefetch=_k_off) + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) rocdl.sched_barrier(0) accs_in = compute_tile_scheduled( accs_in, @@ -2234,6 +2264,7 @@ def _mid_tdm_split_scale_dma( stages_as_idx[buf_idx], stages_bs_idx[buf_idx], mid_compute_callback=_mid_tdm_split_scale_dma, + a0_prefetch=a0_prefetch, ) cur_addr_lo = addr_box[0] cur_scale_k = scale_k_box[0] @@ -2280,6 +2311,7 @@ def _mid_tdm_scale_dma( _scale_k[0] = _scale_k[0] + arith.index(tile_k) _l2_prefetch(_k_off) + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) rocdl.sched_barrier(0) accs_in = compute_tile_scheduled( accs_in, @@ -2288,6 +2320,7 @@ def _mid_tdm_scale_dma( stages_as_idx[buf_idx], stages_bs_idx[buf_idx], mid_compute_callback=_mid_tdm_scale_dma, + a0_prefetch=a0_prefetch, ) cur_lo_a = addr_boxes[0][0] cur_lo_b = addr_boxes[1][0] @@ -2344,6 +2377,7 @@ def _mid_tdm_nws( _ab[3][0] = _ab[3][0] + adv_bs_i32 _l2_prefetch(_k_off) + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) rocdl.sched_barrier(0) accs_in = compute_tile_scheduled( accs_in, @@ -2352,6 +2386,7 @@ def _mid_tdm_nws( stages_as_idx[buf_idx], stages_bs_idx[buf_idx], mid_compute_callback=_mid_tdm_nws, + a0_prefetch=a0_prefetch, ) cur_lo_a = addr_boxes[0][0] cur_lo_b = addr_boxes[1][0] @@ -2383,18 +2418,21 @@ def _mid_tdm_nws( _wait_scale_all() _pipeline_fence(outstanding=0) if const_expr(use_tdm_store): + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) accs = compute_tile_scheduled( accs, stages_a_idx[_compute_stage], stages_b_idx[_compute_stage], stages_as_idx[_compute_stage], stages_bs_idx[_compute_stage], + a0_prefetch=a0_prefetch, ) else: def _emit_epi_addrs(): epi_addrs_box[0] = epilogue_prepare_addrs() + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) accs = compute_tile_scheduled( accs, stages_a_idx[_compute_stage], @@ -2402,6 +2440,7 @@ def _emit_epi_addrs(): stages_as_idx[_compute_stage], stages_bs_idx[_compute_stage], emit_filler=_emit_epi_addrs, + a0_prefetch=a0_prefetch, ) else: _wait_scale_all() @@ -2460,6 +2499,7 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): _tail_mid_cb = _tail_mid_nws + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) rocdl.sched_barrier(0) accs = compute_tile_scheduled( accs, @@ -2468,6 +2508,7 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): stages_as_idx[_compute_stage], stages_bs_idx[_compute_stage], mid_compute_callback=_tail_mid_cb, + a0_prefetch=a0_prefetch, ) if const_expr(_load_stage is not None): From 3a7684d947ccf2128d28c169ffacb16c4fa27675 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sun, 31 May 2026 03:33:17 +0000 Subject: [PATCH 09/19] Tune FP8 deep pipeline scheduling --- kernels/gemm_fp8fp4_gfx1250.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 34653472f..6f995eff1 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -470,7 +470,12 @@ def kernel_mxscale_gemm( a_mcast_mask = 0 b_mcast_mask = 0 - layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) + # The FP8 deep pipeline runs cleaner when adjacent wave ids advance M + # first; keep the default mapping for the other schedules. + if const_expr(use_fp8_deep_pipeline_schedule): + layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (WAVE_SIZE, m_warp * WAVE_SIZE, 16, 1)) + else: + layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) thr_coord = idx2crd(tx, layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), @@ -1414,7 +1419,6 @@ def emit_panel_2x2_row(wm_pair, wn_pair, row_local, a_pair, b_pair, scale_pair): else: a0 = load_a_pair(0, ks) b1 = load_b_pair(1, ks) - b2 = load_b_pair(2, ks) a1_box = [None] b3_box = [None] @@ -1429,6 +1433,7 @@ def _prefetch_a1(): first_wait_keep += DS_LOADS_PER_A_FRAG rocdl.s_wait_dscnt(first_wait_keep) emit_panel_2x2(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) + b2 = load_b_pair(2, ks) if const_expr(ks == 0 and mid_compute_callback is not None): rocdl.sched_barrier(0) @@ -1610,7 +1615,7 @@ def _sched_panel_2x2(prefetch_loads=0): def _sched_panel_row(): rocdl.sched_mfma(_fp8_pair_wn) - _initial_loads = _fp8_scale_loads + _fp8_pair_b_loads * 3 + _fp8_pair_a_loads + _initial_loads = _fp8_scale_loads + _fp8_pair_b_loads * 2 + _fp8_pair_a_loads for _ks in range_constexpr(k_wmma_steps): _ks_initial_loads = _initial_loads @@ -1618,6 +1623,7 @@ def _sched_panel_row(): _ks_initial_loads -= DS_LOADS_PER_A_FRAG rocdl.sched_dsrd(_ks_initial_loads) _sched_panel_2x2(_fp8_pair_a_loads) + rocdl.sched_dsrd(_fp8_pair_b_loads) _sched_panel_2x2(_fp8_pair_b_loads) _sched_panel_2x2(_fp8_pair_a_loads) _sched_panel_2x2() @@ -1904,9 +1910,17 @@ def _l2_prefetch(k_base): warp_lds_off + lane16 * arith.index(_lds_d_stride_elems) + lane_kgrp * arith.index(4 * elem_bytes_d) ) wave_id_idx = arith.index_cast(T.index, rocdl.wave_id()) - d_warp_off_sgpr = wave_id_idx * arith.index(warp_d_bytes) + arith.index(d_output_off) - warp_m_off_sgpr = (wave_id_idx / arith.index(n_warp)) * arith.index(warp_tile_m) - warp_n_off_sgpr = (wave_id_idx % arith.index(n_warp)) * arith.index(warp_tile_n) + # Match the TDM-store descriptor offsets to the compute wave mapping. + if const_expr(use_fp8_deep_pipeline_schedule): + wave_m_sgpr = wave_id_idx % arith.index(m_warp) + wave_n_sgpr = wave_id_idx / arith.index(m_warp) + else: + wave_m_sgpr = wave_id_idx / arith.index(n_warp) + wave_n_sgpr = wave_id_idx % arith.index(n_warp) + d_warp_linear_sgpr = wave_m_sgpr * arith.index(n_warp) + wave_n_sgpr + d_warp_off_sgpr = d_warp_linear_sgpr * arith.index(warp_d_bytes) + arith.index(d_output_off) + warp_m_off_sgpr = wave_m_sgpr * arith.index(warp_tile_m) + warp_n_off_sgpr = wave_n_sgpr * arith.index(warp_tile_n) d_desc = tdm_ops.make_tensor_descriptor_2d( global_ptr=arg_c, lds_memref=d_lds_base_ptr, From 4cf9dced1f25efb29882d90d6eab4e9d5215ca47 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sun, 31 May 2026 15:34:05 +0000 Subject: [PATCH 10/19] gfx1250 FP8 GEMM LDS segment layout --- kernels/gemm_fp8fp4_gfx1250.py | 112 ++++++++++++++++++++++++--------- 1 file changed, 83 insertions(+), 29 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 6f995eff1..83cfda72b 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -37,6 +37,8 @@ LDS_PAD_A_BYTES = 16 LDS_PAD_D_BYTES = 16 +LDS_SEGMENT_BYTES = 64 * 1024 +LDS_GFX1250_MAX_BYTES = 5 * LDS_SEGMENT_BYTES @functools.lru_cache(maxsize=256) @@ -240,9 +242,9 @@ def _align_up(value: int, align: int) -> int: # active loader wave must issue a full-tile descriptor by itself. tdm_desc_num_warps = 1 if wave_specialized_tdm else num_warps - # All pipeline stages share the same intra-stage layout. Keep that layout - # unchanged and only remap each logical stage to a physical base inside one - # LDS arena so TDM epilogue can alias the dead prefix of the arena. + # All pipeline stages share the same intra-stage layout in the generic + # arena path. The active gfx1250 FP8 TDM tile uses a separate reference + # pool layout below. stage_layout = SmemAllocator(None, arch=gpu_arch, global_sym_name=f"mxscale_{data_format}_layout") stage_a_data_rel_off = stage_layout._align(stage_layout.ptr, 16) stage_layout.ptr = stage_a_data_rel_off + lds_a_data_bytes @@ -271,24 +273,73 @@ def _align_up(value: int, align: int) -> int: ), ) - stage_phys_order = [i for i in range(num_buffers) if i != _last_compute_stage] - stage_phys_order.append(_last_compute_stage) - stage_base_off = [0] * num_buffers - for phys_i, logical_i in enumerate(stage_phys_order): - stage_base_off[logical_i] = phys_i * stage_pitch_bytes - arena_alloc.ptr = stage_pitch_bytes * num_buffers - arena_total_bytes = arena_alloc.ptr - epilogue_fence_threshold_bytes = tdm_epilogue_fence_threshold_bytes( - stage_base_off=stage_base_off, - tail_plan=_base_tail_plan, - loop_iters=loop_iters, - extra=extra, + use_ref_segmented_lds_layout = ( + data_format == "fp8" + and tile_m == 256 + and tile_n == 256 + and tile_k == 128 + and m_warp == 2 + and n_warp == 2 + and num_buffers == 4 + and split_k == 1 + and wave_specialized_tdm + and not use_scale_buffer_load + and not use_scale_opsel ) - stage_a_data_off = [stage_base_off[i] + stage_a_data_rel_off for i in range(num_buffers)] - stage_b_data_off = [stage_base_off[i] + stage_b_data_rel_off for i in range(num_buffers)] - stage_a_scale_off = [stage_base_off[i] + stage_a_scale_rel_off for i in range(num_buffers)] - stage_b_scale_off = [stage_base_off[i] + stage_b_scale_rel_off for i in range(num_buffers)] + if use_ref_segmented_lds_layout: + # The A/B data pools are no longer packed into the same per-stage + # 64KiB segment window. Scale pools keep the reference 0x800 stride so + # every TDM LDS target remains 2KiB-aligned. + ref_a_stage_stride = 0x9000 + ref_b_stage_stride = 0x8000 + ref_scale_stage_stride = 0x800 + if lds_a_data_bytes > ref_a_stage_stride: + raise RuntimeError( + "reference segmented LDS layout requires A stage <= 0x9000 bytes, " + f"got {lds_a_data_bytes}" + ) + if lds_b_data_bytes > ref_b_stage_stride: + raise RuntimeError( + "reference segmented LDS layout requires B stage <= 0x8000 bytes, " + f"got {lds_b_data_bytes}" + ) + if lds_a_scale_bytes > ref_scale_stage_stride or lds_b_scale_bytes > ref_scale_stage_stride: + raise RuntimeError( + "reference segmented LDS layout requires scale stage <= 0x800 bytes, " + f"got A={lds_a_scale_bytes} B={lds_b_scale_bytes}" + ) + + stage_a_data_off = [0x00000, 0x09000, 0x16000, 0x1F000] + stage_a_scale_off = [0x12000 + i * ref_scale_stage_stride for i in range(num_buffers)] + stage_b_scale_off = [0x28000 + i * ref_scale_stage_stride for i in range(num_buffers)] + stage_b_data_off = [0x30000 + i * ref_b_stage_stride for i in range(num_buffers)] + arena_alloc.ptr = LDS_GFX1250_MAX_BYTES + arena_total_bytes = arena_alloc.ptr + + # The epilogue may reuse the prefix only after all main/tail TDM traffic + # is fully fenced. This is outside the hot loop and avoids assuming a + # single monotonic per-stage base for the segmented pool layout. + epilogue_fence_threshold_bytes = 0 + else: + stage_phys_order = [i for i in range(num_buffers) if i != _last_compute_stage] + stage_phys_order.append(_last_compute_stage) + stage_base_off = [0] * num_buffers + for phys_i, logical_i in enumerate(stage_phys_order): + stage_base_off[logical_i] = phys_i * stage_pitch_bytes + arena_alloc.ptr = stage_pitch_bytes * num_buffers + arena_total_bytes = arena_alloc.ptr + epilogue_fence_threshold_bytes = tdm_epilogue_fence_threshold_bytes( + stage_base_off=stage_base_off, + tail_plan=_base_tail_plan, + loop_iters=loop_iters, + extra=extra, + ) + + stage_a_data_off = [stage_base_off[i] + stage_a_data_rel_off for i in range(num_buffers)] + stage_b_data_off = [stage_base_off[i] + stage_b_data_rel_off for i in range(num_buffers)] + stage_a_scale_off = [stage_base_off[i] + stage_a_scale_rel_off for i in range(num_buffers)] + stage_b_scale_off = [stage_base_off[i] + stage_b_scale_rel_off for i in range(num_buffers)] if use_tdm_store: lds_d_row_stride = warp_tile_n * elem_bytes_d + LDS_PAD_D_BYTES @@ -397,7 +448,6 @@ def _pick_compute_schedule_kind(): and num_buffers == 4 and use_cluster ) - if use_b_streaming_schedule: print( f"[b_streaming] {data_format} tile=({tile_m},{tile_n},{tile_k}) " f"M_r={wmma_m_rep} N_r={wmma_n_rep}", @@ -1414,11 +1464,14 @@ def emit_panel_2x2_row(wm_pair, wn_pair, row_local, a_pair, b_pair, scale_pair): scale_pair = (a_scales, b_scales) b0 = load_b_pair(0, ks) - if const_expr(ks == 0 and a0_prefetch is not None): + if const_expr(ks == 0 and a0_prefetch is not None and len(a0_prefetch) == _fp8_pair_wm): + a0 = list(a0_prefetch) + elif const_expr(ks == 0 and a0_prefetch is not None): a0 = [a0_prefetch[0], load_a_frag(a_buf, a_bases[1], ks)] else: a0 = load_a_pair(0, ks) b1 = load_b_pair(1, ks) + b2 = load_b_pair(2, ks) a1_box = [None] b3_box = [None] @@ -1430,10 +1483,9 @@ def _prefetch_a1(): first_wait_keep = _two_pair_loads + 3 if const_expr(ks == 0 and a0_prefetch is not None): - first_wait_keep += DS_LOADS_PER_A_FRAG + first_wait_keep += DS_LOADS_PER_A_FRAG * len(a0_prefetch) rocdl.s_wait_dscnt(first_wait_keep) emit_panel_2x2(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) - b2 = load_b_pair(2, ks) if const_expr(ks == 0 and mid_compute_callback is not None): rocdl.sched_barrier(0) @@ -1458,9 +1510,9 @@ def _prefetch_a2(): emit_panel_2x2(0, 2, a0, b2, scale_pair, prefetch_after_first_row=_prefetch_a2) emit_panel_2x2_row(1, 2, 0, a1_box[0], b2, scale_pair) + emit_panel_2x2_row(1, 2, 1, a1_box[0], b2, scale_pair) rocdl.s_wait_dscnt(_pair_loads) emit_panel_2x2(0, 3, a0, b3_box[0], scale_pair) - emit_panel_2x2_row(1, 2, 1, a1_box[0], b2, scale_pair) emit_panel_2x2(1, 3, a1_box[0], b3_box[0], scale_pair) emit_panel_2x2(2, 0, a2_box[0], b0, scale_pair) @@ -1615,21 +1667,19 @@ def _sched_panel_2x2(prefetch_loads=0): def _sched_panel_row(): rocdl.sched_mfma(_fp8_pair_wn) - _initial_loads = _fp8_scale_loads + _fp8_pair_b_loads * 2 + _fp8_pair_a_loads + _initial_loads = _fp8_scale_loads + _fp8_pair_b_loads * 3 + _fp8_pair_a_loads for _ks in range_constexpr(k_wmma_steps): _ks_initial_loads = _initial_loads if const_expr(_ks == 0): - _ks_initial_loads -= DS_LOADS_PER_A_FRAG + _ks_initial_loads -= _fp8_pair_a_loads rocdl.sched_dsrd(_ks_initial_loads) _sched_panel_2x2(_fp8_pair_a_loads) - rocdl.sched_dsrd(_fp8_pair_b_loads) _sched_panel_2x2(_fp8_pair_b_loads) _sched_panel_2x2(_fp8_pair_a_loads) _sched_panel_2x2() _sched_panel_2x2(_fp8_pair_a_loads) _sched_panel_row() - _sched_panel_2x2() _sched_panel_row() _sched_panel_2x2() _sched_panel_2x2() @@ -1640,6 +1690,7 @@ def _sched_panel_row(): _sched_panel_2x2() _sched_panel_2x2() _sched_panel_2x2() + _sched_panel_2x2() rocdl.sched_barrier(0) def compute_tile_scheduled( @@ -1742,7 +1793,10 @@ def hot_loop_scheduler_scheduled(): def prefetch_fp8_deep_a0_frags(lds_a): a_buf, a_bases = _precompute_a_lane_bases(lds_a) - return [load_a_frag(a_buf, a_bases[0], 0)] + return [ + load_a_frag(a_buf, a_bases[wm_local], 0) + for wm_local in range_constexpr(_fp8_pair_wm) + ] def maybe_prefetch_fp8_deep_a0(lds_a): # Call only after the TDM fence for this stage; pre-fence LDS reads can race multicast delivery. From a7e2a075ed99a23100e97e340e8d29f882533175 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Mon, 1 Jun 2026 05:39:48 +0000 Subject: [PATCH 11/19] Add buffer_load->VGPR scale path for gfx1250 FP8 GEMM Replace the buffer_lds_stage scale paths with vgpr/vgpr_ab_split, loading scale global->VGPR via buffer_load (off the LDS/TDM/barrier path) with a coalesced lane-major host layout. Add opt-in overlay-chunks and ab-half-fence schedules, and trim verbose comments. --- kernels/gemm_common_gfx1250.py | 27 + kernels/gemm_fp8fp4_gfx1250.py | 898 +++++++++++++--------- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 74 +- 3 files changed, 596 insertions(+), 403 deletions(-) diff --git a/kernels/gemm_common_gfx1250.py b/kernels/gemm_common_gfx1250.py index b269192d3..3d4227828 100644 --- a/kernels/gemm_common_gfx1250.py +++ b/kernels/gemm_common_gfx1250.py @@ -154,6 +154,33 @@ def pipeline_fence_wait(use_cluster=False): cluster.cluster_wait() +def pipeline_fence_partial(outstanding=0, first_half=True, n_loader_waves=4, use_cluster=False): + """Half fence for the A/B half-split schedule. + + Only the loader waves owning the targeted half tensor_wait (wave0->A0, + wave1->B0, wave2->A1, wave3->B1), then *all* waves run a full barrier: + first_half=True fences {A0,B0}, first_half=False fences {A1,B1}. + + SAFETY: the cluster barrier is never weakened -- each call is a full + ``cluster_barrier``. It fences only the halves tensor_waited before its + signal; the un-waited half stays in flight until the matching second call. + """ + half = n_loader_waves // 2 + wid = rocdl.wave_id() + if first_half: + cond = arith.cmpi(arith.CmpIPredicate.ult, wid, arith.constant(half, type=T.i32)) + else: + cond = arith.cmpi(arith.CmpIPredicate.uge, wid, arith.constant(half, type=T.i32)) + if_op = scf.IfOp(cond) + with ir.InsertionPoint(if_op.then_block): + tdm_ops.tensor_wait(outstanding) + scf.YieldOp([]) + if use_cluster: + cluster.cluster_barrier() + else: + gpu.barrier() + + def issue_tdm_loads(*descs, wave_specialized=False, wave_id=None): """Emit one or more TDM loads, optionally one descriptor per loader wave.""" if wave_specialized: diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 83cfda72b..933f523da 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -6,6 +6,7 @@ """ import functools +import os import flydsl.compiler as flyc import flydsl.expr as fx @@ -22,6 +23,7 @@ issue_tdm_loads, lds_load_b128_raw, pipeline_fence, + pipeline_fence_partial, pipeline_fence_signal, pipeline_fence_wait, store_acc_vec8_to_buffer, @@ -93,7 +95,10 @@ def compile_mxscale_gemm( if out_dtype not in ("f32", "bf16", "f16"): raise ValueError(f"out_dtype must be 'f32', 'bf16', or 'f16', got {out_dtype!r}") elem_bytes_d = 2 if out_dtype in ("bf16", "f16") else 4 - scale_load_paths = ("tdm", "buffer_lds_stage", "buffer_lds_stage_ab_split") + # scale_load_path: "tdm" = TDM->LDS (default); "vgpr" = buffer_load->VGPR, + # off the LDS/TDM/barrier path; "vgpr_ab_split" = "vgpr" plus repurposing the + # idle scale waves 2,3 to load the second A/B halves. + scale_load_paths = ("tdm", "vgpr", "vgpr_ab_split") if scale_load_path not in scale_load_paths: raise ValueError(f"scale_load_path must be one of {scale_load_paths}, got {scale_load_path!r}") fp8_schedule_modes = ("auto", "quadrant", "deep-pipeline") @@ -103,9 +108,7 @@ def compile_mxscale_gemm( raise ValueError(f"fp8_schedule={fp8_schedule!r} is only valid for data_format='fp8'") if fp8_schedule != "auto" and b_streaming: raise ValueError("fp8_schedule cannot be combined with b_streaming=True") - use_scale_buffer_load = scale_load_path != "tdm" - use_ab_split_scale_buffer_load = scale_load_path == "buffer_lds_stage_ab_split" - effective_expert_sched_mode = bool(expert_sched_mode) and not use_scale_buffer_load + effective_expert_sched_mode = bool(expert_sched_mode) if num_buffers not in (2, 3, 4): raise ValueError(f"num_buffers must be 2, 3, or 4, got {num_buffers}") @@ -125,8 +128,6 @@ def compile_mxscale_gemm( if wave_specialized_tdm and num_warps < 4: raise ValueError(f"wave_specialized_tdm requires at least 4 waves, got {num_warps}") - if use_ab_split_scale_buffer_load and not wave_specialized_tdm: - raise ValueError("scale_load_path='buffer_lds_stage_ab_split' requires wave_specialized_tdm=True") # ── Format-dependent compile-time constants ── # A8W4: activation is FP8 (PACK_FACTOR_A=1), weight is FP4 (PACK_FACTOR_B=2) @@ -203,11 +204,11 @@ def compile_mxscale_gemm( _as_ds_loads = wmma_m_rep * _a_frag_loads_per_wm + _scale_ds_loads lds_a_stride_bytes = packed_tile_k_a + LDS_PAD_A_BYTES - if use_ab_split_scale_buffer_load: + if scale_load_path == "vgpr_ab_split": if tile_m % 2 != 0: - raise ValueError(f"buffer_lds_stage_ab_split requires even tile_m, got {tile_m}") + raise ValueError(f"scale_load_path='vgpr_ab_split' requires even tile_m, got {tile_m}") if tile_n % 32 != 0: - raise ValueError(f"buffer_lds_stage_ab_split requires tile_n divisible by 32, got {tile_n}") + raise ValueError(f"scale_load_path='vgpr_ab_split' requires tile_n divisible by 32, got {tile_n}") lds_a_data_bytes = tile_m * lds_a_stride_bytes lds_b_data_bytes = tile_n * packed_tile_k_b @@ -218,18 +219,6 @@ def compile_mxscale_gemm( lds_b_scale_bytes = tile_n * scale_k_per_tile + _scale_guard_bytes interleaved_scale_cols_a = wmma_m_rep * scale_k_per_tile interleaved_scale_cols_b = b_scale_load_rep * scale_k_per_tile - _scale_dma_bytes = 16 - if use_scale_buffer_load: - if interleaved_scale_cols_a % _scale_dma_bytes != 0: - raise ValueError( - "buffer_lds_stage scale loads require A scale rows to be 16-byte aligned, " - f"got interleaved_scale_cols_a={interleaved_scale_cols_a}" - ) - if interleaved_scale_cols_b % _scale_dma_bytes != 0: - raise ValueError( - "buffer_lds_stage scale loads require B scale rows to be 16-byte aligned, " - f"got interleaved_scale_cols_b={interleaved_scale_cols_b}" - ) def _align_up(value: int, align: int) -> int: if value % align == 0: @@ -283,10 +272,32 @@ def _align_up(value: int, align: int) -> int: and num_buffers == 4 and split_k == 1 and wave_specialized_tdm - and not use_scale_buffer_load and not use_scale_opsel ) + # "vgpr"/"vgpr_ab_split": load scale global->VGPR via buffer_load, bypassing + # TDM+LDS entirely. Requires the reference segmented LDS layout. + use_buffer_vgpr_scale = scale_load_path in ("vgpr", "vgpr_ab_split") + if use_buffer_vgpr_scale and not use_ref_segmented_lds_layout: + raise ValueError( + f"scale_load_path={scale_load_path!r} requires the reference segmented " + "LDS layout (not active for this tile/format configuration)" + ) + # Scale prefetch depth (K-tiles ahead) for the buffer->VGPR path. D=1 is the + # sweet spot; D=2 doubles scale VGPRs -> spill + ~18% regression. + _bvs_D = int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_DEPTH", "1")) + # overlay-chunks: issue one fragment load after each WMMA (vs batch-after-row) + # for finer LDS-load pipelining. OPT-IN: FLYDSL_OVERLAY_CHUNKS=1. + use_overlay_chunks = use_ref_segmented_lds_layout and (os.environ.get("FLYDSL_OVERLAY_CHUNKS", "0") == "1") + # ab_half_split: repurpose the (under "vgpr") idle scale waves 2,3 as the + # second halves of A/B, so all 4 waves share the A/B TDM (wave0=A0, wave1=B0, + # wave2=A1, wave3=B1). Measured wall-neutral. + use_ab_half_split = scale_load_path == "vgpr_ab_split" + # ab_half_fence: fence only {A0,B0}, compute the (m0-3,n0-3) quadrant while + # {A1,B1} TDM is still in flight, then fence {A1,B1} and compute the rest. + # Requires the half-split. OPT-IN: FLYDSL_AB_HALF_FENCE=1. + use_ab_half_fence = use_ab_half_split and (os.environ.get("FLYDSL_AB_HALF_FENCE", "0") == "1") + if use_ref_segmented_lds_layout: # The A/B data pools are no longer packed into the same per-stage # 64KiB segment window. Scale pools keep the reference 0x800 stride so @@ -296,13 +307,11 @@ def _align_up(value: int, align: int) -> int: ref_scale_stage_stride = 0x800 if lds_a_data_bytes > ref_a_stage_stride: raise RuntimeError( - "reference segmented LDS layout requires A stage <= 0x9000 bytes, " - f"got {lds_a_data_bytes}" + "reference segmented LDS layout requires A stage <= 0x9000 bytes, " f"got {lds_a_data_bytes}" ) if lds_b_data_bytes > ref_b_stage_stride: raise RuntimeError( - "reference segmented LDS layout requires B stage <= 0x8000 bytes, " - f"got {lds_b_data_bytes}" + "reference segmented LDS layout requires B stage <= 0x8000 bytes, " f"got {lds_b_data_bytes}" ) if lds_a_scale_bytes > ref_scale_stage_stride or lds_b_scale_bytes > ref_scale_stage_stride: raise RuntimeError( @@ -355,23 +364,12 @@ def _align_up(value: int, align: int) -> int: arena_alloc.ptr = total_d_bytes check_smem_capacity(arena_total_bytes, gpu_arch) - # TENSORcnt is tracked per-wave in hardware. When scale is loaded through - # buffer_load_lds, TDM only carries A/B data. + # TENSORcnt is tracked per-wave in hardware. Wave-specialized TDM issues one + # tensor_load per wave per step; otherwise all 4 (A/B/A_scale/B_scale). if wave_specialized_tdm: TDM_LOADS_PER_STEP = 1 else: - TDM_LOADS_PER_STEP = 2 if use_scale_buffer_load else 4 - if use_scale_buffer_load: - _scale_async_batch_bytes = block_threads * _scale_dma_bytes - _scale_async_ops_per_stage = ( - (tile_m * scale_k_per_tile + _scale_async_batch_bytes - 1) // _scale_async_batch_bytes - + (tile_n * scale_k_per_tile + _scale_async_batch_bytes - 1) // _scale_async_batch_bytes - ) - # Wait only for the scale stage that is about to be consumed; keep - # later buffered scale async copies in flight. - _scale_async_future_wait_count = _scale_async_ops_per_stage * max(num_buffers - 2, 0) - else: - _scale_async_future_wait_count = 0 + TDM_LOADS_PER_STEP = 4 tail_plan = [(ls, cs, o * TDM_LOADS_PER_STEP // 2 if o > 0 else o) for ls, cs, o in _base_tail_plan] # Pre-compute epilogue sub-tile layout (unified for FP4 vec16 and FP8 vec8) @@ -408,7 +406,6 @@ def _align_up(value: int, align: int) -> int: and n_warp == 2 and num_buffers == 4 and wave_specialized_tdm - and scale_load_path == "tdm" and out_dtype == "bf16" and not use_scale_opsel ) @@ -416,7 +413,7 @@ def _align_up(value: int, align: int) -> int: raise ValueError( "fp8_schedule='deep-pipeline' requires fp8 256x256x128, " "m_warp=n_warp=2, num_buffers=4, wave_specialized_tdm=True, " - "scale_load_path='tdm', out_dtype='bf16', and use_scale_opsel=False" + "out_dtype='bf16', and use_scale_opsel=False" ) def _pick_compute_schedule_kind(): @@ -443,10 +440,10 @@ def _pick_compute_schedule_kind(): use_b_streaming_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING use_ws_tdm_split_signal_overlap = ( wave_specialized_tdm - and not use_scale_buffer_load and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) and num_buffers == 4 and use_cluster + and not use_ab_half_fence ) if use_b_streaming_schedule: print( @@ -537,6 +534,35 @@ def kernel_mxscale_gemm( warp_m_base = wave_m_idx * arith.index(warp_tile_m) warp_n_base = wave_n_idx * arith.index(warp_tile_n) + if const_expr(use_buffer_vgpr_scale): + # Direct global->VGPR scale load (no TDM/LDS). Coalesced lane-major + # host layout [M_block(128), K_tile, group(2), lane16(16), 4 i32], so + # each buffer_load_b128's 16 lanes read 256 contiguous bytes: + # i32_off(group) = (mb*Kt + kt)*128 + group*64 + lane16*4 + _bvs_a_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False) + _bvs_b_rsrc = buffer_ops.create_buffer_resource(arg_b_scale, max_size=False) + _bvs_Kt = K // tile_k # total K-tiles + _bvs_mb_a = blk_m / arith.index(128) + wave_m_idx + _bvs_mb_b = blk_n / arith.index(128) + wave_n_idx + _bvs_lane4 = lane16 * arith.index(4) + + def _bvs_load_scales(rsrc, mb, rep, k_base): + kt = k_base / arith.index(tile_k) + tile_i32 = (mb * arith.index(_bvs_Kt) + kt) * arith.index(128) + vals = [] + for ld in range_constexpr(rep // 4): # rep=8 -> 2 groups of 4 i32 + off = arith.index_cast(T.i32, tile_i32 + arith.index(ld * 64) + _bvs_lane4) + v = fx.Vector(buffer_ops.buffer_load(rsrc, off, vec_width=4, dtype=T.i32)) + for j in range_constexpr(4): + vals.append(v[j]) + return vals + + def _bvs_prefetch(k_base): + # Issue scale buffer_load for one K-tile; returns (a[8], b[8]) VGPR. + a = _bvs_load_scales(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, k_base) + b = _bvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, k_base) + return a, b + m_idx = fx.Index(i32_m) n_stride = arith.index(N) c_nrec = m_idx * n_stride * arith.index(elem_bytes_d) @@ -1383,6 +1409,9 @@ def compute_tile_fp8_deep_pipeline( mid_compute_callback=None, late_compute_callback=None, a0_prefetch=None, + scale_k_base=None, + pf_a_scales=None, + pf_b_scales=None, ): current_accs = list(accs_in) a_buf, a_bases = _precompute_a_lane_bases(lds_a) @@ -1395,24 +1424,39 @@ def compute_tile_fp8_deep_pipeline( def load_a_pair(wm_pair, ks): wm_base = wm_pair * _fp8_pair_wm return [ - load_a_frag(a_buf, a_bases[wm_base + wm_local], ks) - for wm_local in range_constexpr(_fp8_pair_wm) + load_a_frag(a_buf, a_bases[wm_base + wm_local], ks) for wm_local in range_constexpr(_fp8_pair_wm) ] def load_b_pair(wn_pair, ks): wn_base = wn_pair * _fp8_pair_wn return [ - load_b_frag(b_buf, b_bases, wn_base + wn_local, ks) - for wn_local in range_constexpr(_fp8_pair_wn) + load_b_frag(b_buf, b_bases, wn_base + wn_local, ks) for wn_local in range_constexpr(_fp8_pair_wn) ] def _load_a_scales(ks): + if const_expr(use_buffer_vgpr_scale): + if const_expr(pf_a_scales is not None): + return pf_a_scales # prefetched (issued in the prior compute tile) + return _bvs_load_scales(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, scale_k_base) return load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) def _load_b_scales(ks): + if const_expr(use_buffer_vgpr_scale): + if const_expr(pf_b_scales is not None): + return pf_b_scales + return _bvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, scale_k_base) return load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) - def emit_panel_2x2(wm_pair, wn_pair, a_pair, b_pair, scale_pair, prefetch_after_first_row=None): + def emit_panel_2x2( + wm_pair, + wn_pair, + a_pair, + b_pair, + scale_pair, + prefetch_after_first_row=None, + prefetch_after_first_row_wmma=None, + prefetch_after_second_row_wmma=None, + ): a_scales, b_scales = scale_pair wm_base = wm_pair * _fp8_pair_wm wn_base = wn_pair * _fp8_pair_wn @@ -1426,6 +1470,9 @@ def emit_panel_2x2(wm_pair, wn_pair, a_pair, b_pair, scale_pair, prefetch_after_ a_scales, b_scales, ) + # [overlay-chunks] one fragment load interleaved after each WMMA + if const_expr(prefetch_after_first_row_wmma is not None): + prefetch_after_first_row_wmma(wn_local) if const_expr(prefetch_after_first_row is not None): prefetch_after_first_row() for wn_local in range_constexpr(_fp8_pair_wn): @@ -1438,6 +1485,8 @@ def emit_panel_2x2(wm_pair, wn_pair, a_pair, b_pair, scale_pair, prefetch_after_ a_scales, b_scales, ) + if const_expr(prefetch_after_second_row_wmma is not None): + prefetch_after_second_row_wmma(wn_local) def emit_panel_2x2_row(wm_pair, wn_pair, row_local, a_pair, b_pair, scale_pair): a_scales, b_scales = scale_pair @@ -1471,68 +1520,286 @@ def emit_panel_2x2_row(wm_pair, wn_pair, row_local, a_pair, b_pair, scale_pair): else: a0 = load_a_pair(0, ks) b1 = load_b_pair(1, ks) - b2 = load_b_pair(2, ks) + if const_expr(use_overlay_chunks): + # overlay-chunks: defer b2/b3/a1/a2/a3 and issue one fragment + # load after each WMMA for finer LDS-load pipelining. + a1 = [None for _ in range_constexpr(_fp8_pair_wm)] + b2 = [None for _ in range_constexpr(_fp8_pair_wn)] + b3 = [None for _ in range_constexpr(_fp8_pair_wn)] + a2 = [None for _ in range_constexpr(_fp8_pair_wm)] + a3 = [None for _ in range_constexpr(_fp8_pair_wm)] + + def _pf_a1_chunk(ci): + a1[ci] = load_a_frag(a_buf, a_bases[_fp8_pair_wm + ci], ks) + + def _pf_b2_chunk(ci): + b2[ci] = load_b_frag(b_buf, b_bases, 2 * _fp8_pair_wn + ci, ks) + + first_wait_keep = _two_pair_loads + 3 + if const_expr(ks == 0 and a0_prefetch is not None): + first_wait_keep += DS_LOADS_PER_A_FRAG * len(a0_prefetch) + rocdl.s_wait_dscnt(first_wait_keep) + emit_panel_2x2( + 0, + 0, + a0, + b0, + scale_pair, + prefetch_after_first_row_wmma=_pf_a1_chunk, + prefetch_after_second_row_wmma=_pf_b2_chunk, + ) - a1_box = [None] - b3_box = [None] - a2_box = [None] - a3_box = [None] + if const_expr(ks == 0 and mid_compute_callback is not None): + rocdl.sched_barrier(0) + mid_compute_callback() - def _prefetch_a1(): - a1_box[0] = load_a_pair(1, ks) + def _pf_b3_chunk(ci): + b3[ci] = load_b_frag(b_buf, b_bases, 3 * _fp8_pair_wn + ci, ks) - first_wait_keep = _two_pair_loads + 3 - if const_expr(ks == 0 and a0_prefetch is not None): - first_wait_keep += DS_LOADS_PER_A_FRAG * len(a0_prefetch) - rocdl.s_wait_dscnt(first_wait_keep) - emit_panel_2x2(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) + def _pf_a3_chunk(ci): + a3[ci] = load_a_frag(a_buf, a_bases[3 * _fp8_pair_wm + ci], ks) - if const_expr(ks == 0 and mid_compute_callback is not None): - rocdl.sched_barrier(0) - mid_compute_callback() + rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) + emit_panel_2x2(0, 1, a0, b1, scale_pair, prefetch_after_first_row_wmma=_pf_b3_chunk) - def _prefetch_b3(): - b3_box[0] = load_b_pair(3, ks) + rocdl.s_wait_dscnt(_fp8_pair_b_loads + 2) + emit_panel_2x2(1, 0, a1, b0, scale_pair, prefetch_after_first_row_wmma=_pf_a3_chunk) - def _prefetch_a3(): - a3_box[0] = load_a_pair(3, ks) + def _pf_a2_chunk(ci): + a2[ci] = load_a_frag(a_buf, a_bases[2 * _fp8_pair_wm + ci], ks) - rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) - emit_panel_2x2(0, 1, a0, b1, scale_pair, prefetch_after_first_row=_prefetch_b3) + emit_panel_2x2(1, 1, a1, b1, scale_pair) + emit_panel_2x2(0, 2, a0, b2, scale_pair, prefetch_after_first_row_wmma=_pf_a2_chunk) + emit_panel_2x2_row(1, 2, 0, a1, b2, scale_pair) + rocdl.s_wait_dscnt(_pair_loads) + emit_panel_2x2(0, 3, a0, b3, scale_pair) + emit_panel_2x2_row(1, 2, 1, a1, b2, scale_pair) + emit_panel_2x2(1, 3, a1, b3, scale_pair) + emit_panel_2x2(2, 0, a2, b0, scale_pair) + if const_expr(is_last_ks and late_compute_callback is not None): + rocdl.sched_barrier(0) + late_compute_callback() + emit_panel_2x2(2, 1, a2, b1, scale_pair) + rocdl.s_wait_dscnt(0) + emit_panel_2x2(3, 0, a3, b0, scale_pair) + emit_panel_2x2(3, 1, a3, b1, scale_pair) + if const_expr(is_last_ks and emit_filler is not None): + rocdl.sched_barrier(0) + emit_filler() + emit_panel_2x2(2, 2, a2, b2, scale_pair) + emit_panel_2x2(2, 3, a2, b3, scale_pair) + emit_panel_2x2(3, 2, a3, b2, scale_pair) + emit_panel_2x2(3, 3, a3, b3, scale_pair) + else: + b2 = load_b_pair(2, ks) - rocdl.s_wait_dscnt(_fp8_pair_b_loads + 2) - emit_panel_2x2(1, 0, a1_box[0], b0, scale_pair, prefetch_after_first_row=_prefetch_a3) + a1_box = [None] + b3_box = [None] + a2_box = [None] + a3_box = [None] - def _prefetch_a2(): - a2_box[0] = load_a_pair(2, ks) + def _prefetch_a1(): + a1_box[0] = load_a_pair(1, ks) - emit_panel_2x2(1, 1, a1_box[0], b1, scale_pair) + first_wait_keep = _two_pair_loads + 3 + if const_expr(ks == 0 and a0_prefetch is not None): + first_wait_keep += DS_LOADS_PER_A_FRAG * len(a0_prefetch) + rocdl.s_wait_dscnt(first_wait_keep) + emit_panel_2x2(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) - emit_panel_2x2(0, 2, a0, b2, scale_pair, prefetch_after_first_row=_prefetch_a2) - emit_panel_2x2_row(1, 2, 0, a1_box[0], b2, scale_pair) - emit_panel_2x2_row(1, 2, 1, a1_box[0], b2, scale_pair) - rocdl.s_wait_dscnt(_pair_loads) - emit_panel_2x2(0, 3, a0, b3_box[0], scale_pair) - emit_panel_2x2(1, 3, a1_box[0], b3_box[0], scale_pair) + if const_expr(ks == 0 and mid_compute_callback is not None): + rocdl.sched_barrier(0) + mid_compute_callback() - emit_panel_2x2(2, 0, a2_box[0], b0, scale_pair) - if const_expr(is_last_ks and late_compute_callback is not None): - rocdl.sched_barrier(0) - late_compute_callback() - emit_panel_2x2(2, 1, a2_box[0], b1, scale_pair) + def _prefetch_b3(): + b3_box[0] = load_b_pair(3, ks) + + def _prefetch_a3(): + a3_box[0] = load_a_pair(3, ks) + + rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) + emit_panel_2x2(0, 1, a0, b1, scale_pair, prefetch_after_first_row=_prefetch_b3) + + rocdl.s_wait_dscnt(_fp8_pair_b_loads + 2) + emit_panel_2x2(1, 0, a1_box[0], b0, scale_pair, prefetch_after_first_row=_prefetch_a3) + + def _prefetch_a2(): + a2_box[0] = load_a_pair(2, ks) + + emit_panel_2x2(1, 1, a1_box[0], b1, scale_pair) + + emit_panel_2x2(0, 2, a0, b2, scale_pair, prefetch_after_first_row=_prefetch_a2) + emit_panel_2x2_row(1, 2, 0, a1_box[0], b2, scale_pair) + emit_panel_2x2_row(1, 2, 1, a1_box[0], b2, scale_pair) + rocdl.s_wait_dscnt(_pair_loads) + emit_panel_2x2(0, 3, a0, b3_box[0], scale_pair) + emit_panel_2x2(1, 3, a1_box[0], b3_box[0], scale_pair) + + emit_panel_2x2(2, 0, a2_box[0], b0, scale_pair) + if const_expr(is_last_ks and late_compute_callback is not None): + rocdl.sched_barrier(0) + late_compute_callback() + emit_panel_2x2(2, 1, a2_box[0], b1, scale_pair) + + rocdl.s_wait_dscnt(0) + emit_panel_2x2(3, 0, a3_box[0], b0, scale_pair) + emit_panel_2x2(3, 1, a3_box[0], b1, scale_pair) + + if const_expr(is_last_ks and emit_filler is not None): + rocdl.sched_barrier(0) + emit_filler() + + emit_panel_2x2(2, 2, a2_box[0], b2, scale_pair) + emit_panel_2x2(2, 3, a2_box[0], b3_box[0], scale_pair) + emit_panel_2x2(3, 2, a3_box[0], b2, scale_pair) + emit_panel_2x2(3, 3, a3_box[0], b3_box[0], scale_pair) + + return current_accs + + def compute_tile_fp8_half_fence( + accs_in, + lds_a, + lds_b, + lds_as, + lds_bs, + mid_compute_callback=None, + quadrant_fence_callback=None, + scale_k_base=None, + pf_a_scales=None, + pf_b_scales=None, + ): + """Two-phase compute for the A/B half-fence schedule. + + Phase 1 reads the first A/B half (A0/B0 = frags 0..3) and computes the + (0..3, 0..3) quadrant; phase 2 reads A1/B1 and computes the remaining + 12 panels. ``quadrant_fence_callback`` (the {A1,B1} fence) runs between + the phases, at ks==0 only. No second-half ds_load is issued before that + fence -- the invariant the deep-pipeline schedule can't honour. + """ + current_accs = list(accs_in) + a_buf, a_bases = _precompute_a_lane_bases(lds_a) + b_buf, b_bases = _precompute_b_lane_bases(lds_b) + as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) + bs_buf, bs_bases = _precompute_scale_lane_bases( + lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b + ) + + def load_a_pair(wm_pair, ks): + wm_base = wm_pair * _fp8_pair_wm + return [ + load_a_frag(a_buf, a_bases[wm_base + wm_local], ks) for wm_local in range_constexpr(_fp8_pair_wm) + ] + + def load_b_pair(wn_pair, ks): + wn_base = wn_pair * _fp8_pair_wn + return [ + load_b_frag(b_buf, b_bases, wn_base + wn_local, ks) for wn_local in range_constexpr(_fp8_pair_wn) + ] + + def _load_a_scales(ks): + if const_expr(pf_a_scales is not None): + return pf_a_scales + if const_expr(use_buffer_vgpr_scale): + return _bvs_load_scales(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, scale_k_base) + return load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) + + def _load_b_scales(ks): + if const_expr(pf_b_scales is not None): + return pf_b_scales + if const_expr(use_buffer_vgpr_scale): + return _bvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, scale_k_base) + return load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) + + def emit_panel(wm_pair, wn_pair, a_pair, b_pair, scale_pair, prefetch_after_first_row=None): + a_scales, b_scales = scale_pair + wm_base = wm_pair * _fp8_pair_wm + wn_base = wn_pair * _fp8_pair_wn + for wn_local in range_constexpr(_fp8_pair_wn): + _emit_wmma( + current_accs, + wm_base, + wn_base + wn_local, + a_pair[0], + b_pair[wn_local], + a_scales, + b_scales, + ) + if const_expr(prefetch_after_first_row is not None): + prefetch_after_first_row() + for wn_local in range_constexpr(_fp8_pair_wn): + _emit_wmma( + current_accs, + wm_base + 1, + wn_base + wn_local, + a_pair[1], + b_pair[wn_local], + a_scales, + b_scales, + ) + + _pair_loads = _fp8_pair_a_loads + _two_pair_loads = _fp8_pair_a_loads + _fp8_pair_b_loads + + for ks in range_constexpr(k_wmma_steps): + scale_pair = (_load_a_scales(ks), _load_b_scales(ks)) + + # ---- Phase 1: first half {A0,B0} -> quadrant (m0-3, n0-3) ---- + # Issue a0,b0,b1 up front; prefetch a1 mid-panel(0,0). NO + # second-half (a2/a3/b2/b3) load may be issued before the fence. + b0 = load_b_pair(0, ks) + a0 = load_a_pair(0, ks) + b1 = load_b_pair(1, ks) + a1_box = [None] + + def _prefetch_a1(): + a1_box[0] = load_a_pair(1, ks) + rocdl.s_wait_dscnt(_two_pair_loads + 3) + emit_panel(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) + if const_expr(ks == 0 and mid_compute_callback is not None): + rocdl.sched_barrier(0) + mid_compute_callback() + rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) + emit_panel(0, 1, a0, b1, scale_pair) + rocdl.s_wait_dscnt(_fp8_pair_b_loads) + emit_panel(1, 0, a1_box[0], b0, scale_pair) + emit_panel(1, 1, a1_box[0], b1, scale_pair) rocdl.s_wait_dscnt(0) - emit_panel_2x2(3, 0, a3_box[0], b0, scale_pair) - emit_panel_2x2(3, 1, a3_box[0], b1, scale_pair) - if const_expr(is_last_ks and emit_filler is not None): + # ---- {A1,B1} half-fence (once, before any second-half read) ---- + if const_expr(ks == 0 and quadrant_fence_callback is not None): rocdl.sched_barrier(0) - emit_filler() + quadrant_fence_callback() - emit_panel_2x2(2, 2, a2_box[0], b2, scale_pair) - emit_panel_2x2(2, 3, a2_box[0], b3_box[0], scale_pair) - emit_panel_2x2(3, 2, a3_box[0], b2, scale_pair) - emit_panel_2x2(3, 3, a3_box[0], b3_box[0], scale_pair) + # ---- Phase 2: second half {A1,B1} -> remaining 12 panels ---- + a1 = a1_box[0] + b2 = load_b_pair(2, ks) + b3 = load_b_pair(3, ks) + a2_box = [None] + a3_box = [None] + + def _prefetch_a2(): + a2_box[0] = load_a_pair(2, ks) + + def _prefetch_a3(): + a3_box[0] = load_a_pair(3, ks) + + # panels reading only a0/a1 + b2/b3 first, prefetch a2,a3 meanwhile + rocdl.s_wait_dscnt(_two_pair_loads) + emit_panel(0, 2, a0, b2, scale_pair, prefetch_after_first_row=_prefetch_a2) + emit_panel(0, 3, a0, b3, scale_pair, prefetch_after_first_row=_prefetch_a3) + emit_panel(1, 2, a1, b2, scale_pair) + emit_panel(1, 3, a1, b3, scale_pair) + rocdl.s_wait_dscnt(_fp8_pair_a_loads) + emit_panel(2, 0, a2_box[0], b0, scale_pair) + emit_panel(2, 1, a2_box[0], b1, scale_pair) + rocdl.s_wait_dscnt(0) + emit_panel(3, 0, a3_box[0], b0, scale_pair) + emit_panel(3, 1, a3_box[0], b1, scale_pair) + emit_panel(2, 2, a2_box[0], b2, scale_pair) + emit_panel(2, 3, a2_box[0], b3, scale_pair) + emit_panel(3, 2, a3_box[0], b2, scale_pair) + emit_panel(3, 3, a3_box[0], b3, scale_pair) return current_accs @@ -1703,6 +1970,9 @@ def compute_tile_scheduled( mid_compute_callback=None, late_compute_callback=None, a0_prefetch=None, + scale_k_base=None, + pf_a_scales=None, + pf_b_scales=None, ): if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING): return compute_tile_b_streaming( @@ -1746,6 +2016,9 @@ def compute_tile_scheduled( mid_compute_callback=mid_compute_callback, late_compute_callback=late_compute_callback, a0_prefetch=a0_prefetch, + scale_k_base=scale_k_base, + pf_a_scales=pf_a_scales, + pf_b_scales=pf_b_scales, ) return compute_tile( accs_in, @@ -1793,10 +2066,7 @@ def hot_loop_scheduler_scheduled(): def prefetch_fp8_deep_a0_frags(lds_a): a_buf, a_bases = _precompute_a_lane_bases(lds_a) - return [ - load_a_frag(a_buf, a_bases[wm_local], 0) - for wm_local in range_constexpr(_fp8_pair_wm) - ] + return [load_a_frag(a_buf, a_bases[wm_local], 0) for wm_local in range_constexpr(_fp8_pair_wm)] def maybe_prefetch_fp8_deep_a0(lds_a): # Call only after the TDM fence for this stage; pre-fence LDS reads can race multicast delivery. @@ -2012,7 +2282,7 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): desc_b_init = make_desc_b(stages_b_mem[0], split_k_base) desc_as_init = make_desc_as(stages_as_mem[0], split_k_base) desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) - if const_expr(use_ab_split_scale_buffer_load): + if const_expr(use_ab_half_split): stages_a0_lds_addr = [] stages_b0_lds_addr = [] stages_a1_lds_addr = [] @@ -2035,7 +2305,10 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): pred_const = fx.Int32(1) if const_expr(wave_specialized_tdm): - active_pred_const = arith.select(tdm_wave_id < fx.Int32(4), fx.Int32(1), fx.Int32(0)) + # With scale on the VGPR path, drop scale waves 2,3 from the active TDM + # path -- unless ab-half-split repurposes them as the second A/B halves. + _active_wave_limit = 2 if (use_buffer_vgpr_scale and not use_ab_half_split) else 4 + active_pred_const = arith.select(tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0)) def _select4(values): return _select_wave_tdm_value(values[0], values[1], values[2], values[3]) @@ -2064,18 +2337,20 @@ def _select_active_tdm(stage_lds_addrs, descs, advs): else: active_pred_const = pred_const - if const_expr(wave_specialized_tdm and not use_scale_buffer_load): - active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( - (stages_a_lds_addr, stages_b_lds_addr, stages_as_lds_addr, stages_bs_lds_addr), - (desc_a_init, desc_b_init, desc_as_init, desc_bs_init), - (adv_a_i32, adv_b_i32, adv_as_i32, adv_bs_i32), - ) - elif const_expr(use_ab_split_scale_buffer_load): + if const_expr(use_ab_half_split): + # All 4 waves load A/B halves: wave0=A0, wave1=B0, wave2=A1, wave3=B1. + # Both halves of A share adv_a (same K-step); both halves of B share adv_b. active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( (stages_a0_lds_addr, stages_b0_lds_addr, stages_a1_lds_addr, stages_b1_lds_addr), (desc_a0_init, desc_b0_init, desc_a1_init, desc_b1_init), (adv_a_i32, adv_b_i32, adv_a_i32, adv_b_i32), ) + elif const_expr(wave_specialized_tdm): + active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( + (stages_a_lds_addr, stages_b_lds_addr, stages_as_lds_addr, stages_bs_lds_addr), + (desc_a_init, desc_b_init, desc_as_init, desc_bs_init), + (adv_a_i32, adv_b_i32, adv_as_i32, adv_bs_i32), + ) else: addr_lo_a = _dg0_lane(desc_a_init, 2) addr_hi_a = _dg0_lane(desc_a_init, 3) @@ -2091,144 +2366,61 @@ def _select_active_tdm(stage_lds_addrs, descs, advs): dgroup1_as = desc_as_init.dgroup1 dgroup1_bs = desc_bs_init.dgroup1 - if const_expr(use_scale_buffer_load): - scale_a_base = buffer_ops.extract_base_index(arg_a_scale) - scale_b_base = buffer_ops.extract_base_index(arg_b_scale) - scale_async_offset = fx.Int32(0) - scale_async_aux = fx.Int32(0) - - def _dma_scale_tile_to_lds( - global_base, - lds_mem, - global_row_base, - global_col_base, - row_stride, - row_bytes: int, - total_bytes: int, - ): - from flydsl._mlir.dialects import memref as memref_dialect - from flydsl._mlir.dialects import rocdl as rocdl_dialect - - for batch in range_constexpr( - (total_bytes + block_threads * _scale_dma_bytes - 1) // (block_threads * _scale_dma_bytes) - ): - batch_byte = batch * block_threads * _scale_dma_bytes - copy_byte = arith.index(batch_byte) + tx * arith.index(_scale_dma_bytes) - if copy_byte < arith.index(total_bytes): - row = copy_byte / arith.index(row_bytes) - col = copy_byte % arith.index(row_bytes) - global_byte = (global_row_base + row) * arith.index(row_stride) + global_col_base + col - global_ptr = buffer_ops.create_llvm_ptr(global_base + global_byte, address_space=1) - lds_ptr = buffer_ops.create_llvm_ptr( - memref_dialect.extract_aligned_pointer_as_index(lds_mem) + copy_byte, - address_space=3, - ) - rocdl_dialect.global_load_async_to_lds_b128( - global_ptr, - lds_ptr, - scale_async_offset, - scale_async_aux, - ) - - def _issue_scale_buffer_loads(stage_idx, k_base): - k_scale_off = k_base / arith.index(SCALE_BLOCK) - _dma_scale_tile_to_lds( - scale_a_base, - stages_as_mem[stage_idx], - blk_m / arith.index(wmma_m_rep), - k_scale_off * arith.index(wmma_m_rep), - wmma_m_rep * K_scale, - interleaved_scale_cols_a, - tile_m * scale_k_per_tile, - ) - _dma_scale_tile_to_lds( - scale_b_base, - stages_bs_mem[stage_idx], - blk_n / arith.index(b_scale_load_rep), - k_scale_off * arith.index(b_scale_load_rep), - b_scale_load_rep * K_scale, - interleaved_scale_cols_b, - tile_n * scale_k_per_tile, - ) - - def _wait_scale_all(): - if const_expr(use_scale_buffer_load): - rocdl.s_wait_asynccnt(0) - - def _wait_scale_for_compute_stage(): - if const_expr(use_scale_buffer_load): - rocdl.s_wait_asynccnt(_scale_async_future_wait_count) - def _pipeline_fence(outstanding=0): pipeline_fence(outstanding=outstanding, use_cluster=use_cluster) def _pipeline_fence_signal(outstanding=0): pipeline_fence_signal(outstanding=outstanding, use_cluster=use_cluster) - def _issue_ab_tdm(load_stage, addr_a, addr_b): - dg0_a = _pack_dg0(pred_const, stages_a_lds_addr[load_stage], addr_a, addr_hi_a) - dg0_b = _pack_dg0(pred_const, stages_b_lds_addr[load_stage], addr_b, addr_hi_b) - issue_tdm_loads( - tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), - tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), - wave_specialized=wave_specialized_tdm, + def _pipeline_fence_partial(outstanding=0, first_half=True): + pipeline_fence_partial( + outstanding=outstanding, + first_half=first_half, + n_loader_waves=4, + use_cluster=use_cluster, ) - if const_expr(wave_specialized_tdm and (not use_scale_buffer_load or use_ab_split_scale_buffer_load)): + if const_expr(wave_specialized_tdm): - def _issue_active_tdm(load_stage, addr_box, scale_k_box=None, k_prefetch=None): + def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): dg0 = _pack_dg0(active_pred_const, active_stage_lds_addr[load_stage], addr_box[0], active_addr_hi) tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0, active_dgroup1)) addr_box[0] = addr_box[0] + active_adv_i32 - if scale_k_box is not None: - _issue_scale_buffer_loads(load_stage, scale_k_box[0]) - scale_k_box[0] = scale_k_box[0] + arith.index(tile_k) if k_prefetch is not None: _l2_prefetch(k_prefetch) # Prologue - if const_expr(wave_specialized_tdm and not use_scale_buffer_load): + if const_expr(wave_specialized_tdm): for i in range_constexpr(pre_loaded): addr_box = [active_addr_lo] _issue_active_tdm(i, addr_box) active_addr_lo = addr_box[0] - elif const_expr(use_ab_split_scale_buffer_load): - for i in range_constexpr(pre_loaded): - addr_box = [active_addr_lo] - scale_k_box = [split_k_base + arith.index(i * tile_k)] - _issue_active_tdm(i, addr_box, scale_k_box=scale_k_box) - active_addr_lo = addr_box[0] else: for i in range_constexpr(pre_loaded): dg0_a = _pack_dg0(pred_const, stages_a_lds_addr[i], addr_lo_a, addr_hi_a) dg0_b = _pack_dg0(pred_const, stages_b_lds_addr[i], addr_lo_b, addr_hi_b) - if const_expr(use_scale_buffer_load): - issue_tdm_loads( - tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), - tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), - wave_specialized=wave_specialized_tdm, - ) - _issue_scale_buffer_loads(i, split_k_base + arith.index(i * tile_k)) - else: - dg0_as = _pack_dg0(pred_const, stages_as_lds_addr[i], addr_lo_as, addr_hi_as) - dg0_bs = _pack_dg0(pred_const, stages_bs_lds_addr[i], addr_lo_bs, addr_hi_bs) - issue_tdm_loads( - tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), - tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), - tdm_ops.TDMDescriptor2D(dg0_as, dgroup1_as), - tdm_ops.TDMDescriptor2D(dg0_bs, dgroup1_bs), - wave_specialized=wave_specialized_tdm, - ) + dg0_as = _pack_dg0(pred_const, stages_as_lds_addr[i], addr_lo_as, addr_hi_as) + dg0_bs = _pack_dg0(pred_const, stages_bs_lds_addr[i], addr_lo_bs, addr_hi_bs) + issue_tdm_loads( + tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), + tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), + tdm_ops.TDMDescriptor2D(dg0_as, dgroup1_as), + tdm_ops.TDMDescriptor2D(dg0_bs, dgroup1_bs), + wave_specialized=wave_specialized_tdm, + ) addr_lo_a = addr_lo_a + adv_a_i32 addr_lo_b = addr_lo_b + adv_b_i32 - if const_expr(not use_scale_buffer_load): - addr_lo_as = addr_lo_as + adv_as_i32 - addr_lo_bs = addr_lo_bs + adv_bs_i32 - if const_expr(use_scale_buffer_load): - scale_next_k_base = split_k_base + arith.index(pre_loaded * tile_k) + addr_lo_as = addr_lo_as + adv_as_i32 + addr_lo_bs = addr_lo_bs + adv_bs_i32 + + if const_expr(use_buffer_vgpr_scale): + # Prologue: prefetch the first _bvs_D K-tiles (global->VGPR). Carried as + # FLAT lists of i32 (list-of-tuples can't be loop-carried). + _bvs_pf = [_bvs_prefetch(split_k_base + arith.index(_d * tile_k)) for _d in range(_bvs_D)] + _bvs_ra = [_v for (_a, _b) in _bvs_pf for _v in _a] + _bvs_rb = [_v for (_a, _b) in _bvs_pf for _v in _b] - _wait_scale_for_compute_stage() _pipeline_fence(outstanding=TDM_LOADS_PER_STEP * (num_buffers - 2)) # Main loop — acc_mixed style: fence at top, TDM_load mid-compute. @@ -2239,20 +2431,23 @@ def _issue_active_tdm(load_stage, addr_box, scale_k_box=None, k_prefetch=None): _pipeline_fence_signal(outstanding=_fence_outstanding) if const_expr(loop_iters > 0): - if const_expr(wave_specialized_tdm and not use_scale_buffer_load): + if const_expr(wave_specialized_tdm): init_args = list(accs) + [active_addr_lo] + if const_expr(use_buffer_vgpr_scale): + init_args = init_args + _bvs_ra + _bvs_rb for loop_iter, state in range(0, loop_iters, 1, init=init_args): accs_in = list(state[:n_accs]) cur_addr_lo = state[n_accs] + if const_expr(use_buffer_vgpr_scale): + _ra0 = n_accs + 1 + _ring_a = list(state[_ra0 : _ra0 + _bvs_D * wmma_m_rep]) + _rb0 = _ra0 + _bvs_D * wmma_m_rep + _ring_b = list(state[_rb0 : _rb0 + _bvs_D * b_scale_load_rep]) for buf_idx in range_constexpr(num_buffers): load_stage = (buf_idx + num_buffers - 1) % num_buffers - if const_expr(not use_ws_tdm_split_signal_overlap): - _pipeline_fence_signal(outstanding=_fence_outstanding) - pipeline_fence_wait(use_cluster=use_cluster) - addr_box = [cur_addr_lo] def _mid_tdm_ws( @@ -2266,141 +2461,98 @@ def _mid_tdm_ws( ): _issue_active_tdm(_ls, _ab, k_prefetch=_k_off) - _late_tdm_ws_fence_signal = None - if const_expr(use_ws_tdm_split_signal_overlap): - - def _late_tdm_ws_split_signal(): + if const_expr(use_ab_half_fence): + # Fence {A0,B0} only; the {A1,B1} fence + # fires mid-compute so its TDM overlaps the quadrant WMMAs. + if const_expr(use_buffer_vgpr_scale): + _cur_a = _ring_a[:wmma_m_rep] + _cur_b = _ring_b[:b_scale_load_rep] + _next_kb = ( + split_k_base + + loop_iter * arith.index(num_buffers * tile_k) + + arith.index((buf_idx + _bvs_D) * tile_k) + ) + _na, _nb2 = _bvs_prefetch(_next_kb) + _ring_a = _ring_a[wmma_m_rep:] + list(_na) + _ring_b = _ring_b[b_scale_load_rep:] + list(_nb2) + else: + _cur_a = None + _cur_b = None + + _pipeline_fence_partial(outstanding=_fence_outstanding, first_half=True) + + def _quad_fence(): + _pipeline_fence_partial(outstanding=_fence_outstanding, first_half=False) + + accs_in = compute_tile_fp8_half_fence( + accs_in, + stages_a_idx[buf_idx], + stages_b_idx[buf_idx], + stages_as_idx[buf_idx], + stages_bs_idx[buf_idx], + mid_compute_callback=_mid_tdm_ws, + quadrant_fence_callback=_quad_fence, + pf_a_scales=_cur_a, + pf_b_scales=_cur_b, + ) + else: + if const_expr(not use_ws_tdm_split_signal_overlap): _pipeline_fence_signal(outstanding=_fence_outstanding) - - _late_tdm_ws_fence_signal = _late_tdm_ws_split_signal - - a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) - rocdl.sched_barrier(0) - accs_in = compute_tile_scheduled( - accs_in, - stages_a_idx[buf_idx], - stages_b_idx[buf_idx], - stages_as_idx[buf_idx], - stages_bs_idx[buf_idx], - mid_compute_callback=_mid_tdm_ws, - late_compute_callback=_late_tdm_ws_fence_signal, - a0_prefetch=a0_prefetch, - ) - cur_addr_lo = addr_box[0] - hot_loop_scheduler_scheduled() - - results = yield list(accs_in) + [cur_addr_lo] - - accs = list(results[:n_accs]) - active_addr_lo = results[n_accs] - elif const_expr(use_ab_split_scale_buffer_load): - init_args = list(accs) + [active_addr_lo, scale_next_k_base] - - for loop_iter, state in range(0, loop_iters, 1, init=init_args): - accs_in = list(state[:n_accs]) - cur_addr_lo = state[n_accs] - cur_scale_k = state[n_accs + 1] - - for buf_idx in range_constexpr(num_buffers): - load_stage = (buf_idx + num_buffers - 1) % num_buffers - - _wait_scale_for_compute_stage() - _pipeline_fence_signal(outstanding=_fence_outstanding) - pipeline_fence_wait(use_cluster=use_cluster) - - addr_box = [cur_addr_lo] - scale_k_box = [cur_scale_k] - - def _mid_tdm_split_scale_dma( - _ls=load_stage, - _ab=addr_box, - _scale_k=scale_k_box, - _k_off=( - split_k_base - + loop_iter * arith.index(num_buffers * tile_k) - + arith.index(buf_idx * tile_k) - ), - ): - _issue_active_tdm(_ls, _ab, scale_k_box=_scale_k, k_prefetch=_k_off) - - a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) - rocdl.sched_barrier(0) - accs_in = compute_tile_scheduled( - accs_in, - stages_a_idx[buf_idx], - stages_b_idx[buf_idx], - stages_as_idx[buf_idx], - stages_bs_idx[buf_idx], - mid_compute_callback=_mid_tdm_split_scale_dma, - a0_prefetch=a0_prefetch, - ) + pipeline_fence_wait(use_cluster=use_cluster) + + _late_tdm_ws_fence_signal = None + if const_expr(use_ws_tdm_split_signal_overlap): + + def _late_tdm_ws_split_signal(): + _pipeline_fence_signal(outstanding=_fence_outstanding) + + _late_tdm_ws_fence_signal = _late_tdm_ws_split_signal + + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) + rocdl.sched_barrier(0) + # Consume scale prefetched _bvs_D K-tiles ago; issue the + # K-tile +_bvs_D prefetch now (overlaps this tile's WMMAs). + # NOTE: must stay AFTER the fence (matches the original + # ordering); issuing the scale buffer_loads before the + # cluster barrier hangs the vgpr path. + if const_expr(use_buffer_vgpr_scale): + _cur_a = _ring_a[:wmma_m_rep] + _cur_b = _ring_b[:b_scale_load_rep] + _next_kb = ( + split_k_base + + loop_iter * arith.index(num_buffers * tile_k) + + arith.index((buf_idx + _bvs_D) * tile_k) + ) + _na, _nb2 = _bvs_prefetch(_next_kb) + _ring_a = _ring_a[wmma_m_rep:] + list(_na) + _ring_b = _ring_b[b_scale_load_rep:] + list(_nb2) + else: + _cur_a = None + _cur_b = None + + accs_in = compute_tile_scheduled( + accs_in, + stages_a_idx[buf_idx], + stages_b_idx[buf_idx], + stages_as_idx[buf_idx], + stages_bs_idx[buf_idx], + mid_compute_callback=_mid_tdm_ws, + late_compute_callback=_late_tdm_ws_fence_signal, + a0_prefetch=a0_prefetch, + pf_a_scales=_cur_a, + pf_b_scales=_cur_b, + ) cur_addr_lo = addr_box[0] - cur_scale_k = scale_k_box[0] hot_loop_scheduler_scheduled() - results = yield list(accs_in) + [cur_addr_lo, cur_scale_k] + if const_expr(use_buffer_vgpr_scale): + _bvs_yield = _ring_a + _ring_b + else: + _bvs_yield = [] + results = yield list(accs_in) + [cur_addr_lo] + _bvs_yield accs = list(results[:n_accs]) active_addr_lo = results[n_accs] - scale_next_k_base = results[n_accs + 1] - elif const_expr(use_scale_buffer_load): - init_args = list(accs) + [addr_lo_a, addr_lo_b, scale_next_k_base] - - for loop_iter, state in range(0, loop_iters, 1, init=init_args): - accs_in = list(state[:n_accs]) - cur_lo_a = state[n_accs] - cur_lo_b = state[n_accs + 1] - cur_scale_k = state[n_accs + 2] - - for buf_idx in range_constexpr(num_buffers): - load_stage = (buf_idx + num_buffers - 1) % num_buffers - - _wait_scale_for_compute_stage() - _pipeline_fence_signal(outstanding=_fence_outstanding) - pipeline_fence_wait(use_cluster=use_cluster) - - addr_boxes = [[cur_lo_a], [cur_lo_b]] - scale_k_box = [cur_scale_k] - - def _mid_tdm_scale_dma( - _ls=load_stage, - _ab=addr_boxes, - _scale_k=scale_k_box, - _k_off=( - split_k_base - + loop_iter * arith.index(num_buffers * tile_k) - + arith.index(buf_idx * tile_k) - ), - ): - _issue_ab_tdm(_ls, _ab[0][0], _ab[1][0]) - _ab[0][0] = _ab[0][0] + adv_a_i32 - _ab[1][0] = _ab[1][0] + adv_b_i32 - _issue_scale_buffer_loads(_ls, _scale_k[0]) - _scale_k[0] = _scale_k[0] + arith.index(tile_k) - _l2_prefetch(_k_off) - - a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) - rocdl.sched_barrier(0) - accs_in = compute_tile_scheduled( - accs_in, - stages_a_idx[buf_idx], - stages_b_idx[buf_idx], - stages_as_idx[buf_idx], - stages_bs_idx[buf_idx], - mid_compute_callback=_mid_tdm_scale_dma, - a0_prefetch=a0_prefetch, - ) - cur_lo_a = addr_boxes[0][0] - cur_lo_b = addr_boxes[1][0] - cur_scale_k = scale_k_box[0] - hot_loop_scheduler_scheduled() - - results = yield list(accs_in) + [cur_lo_a, cur_lo_b, cur_scale_k] - - accs = list(results[:n_accs]) - addr_lo_a = results[n_accs] - addr_lo_b = results[n_accs + 1] - scale_next_k_base = results[n_accs + 2] else: init_args = list(accs) + [addr_lo_a, addr_lo_b, addr_lo_as, addr_lo_bs] @@ -2473,17 +2625,26 @@ def _mid_tdm_nws( # Tail — same acc_mixed pattern: fence at top, TDM mid-compute. if const_expr(loop_iters > 0 and use_ws_tdm_split_signal_overlap): pipeline_fence_wait(use_cluster=use_cluster) - _wait_scale_all() if const_expr(loop_iters > 0): _pipeline_fence(outstanding=0) elif const_expr(use_cluster): cluster.cluster_barrier() epi_addrs_box = [None] _tail_had_load = False + # Tail K-tile index, so the VGPR-path scale buffer_load uses the right k_base. + _bvs_tail_kt = [loop_iters * num_buffers] + + def _bvs_tail_kb(): + if const_expr(not use_buffer_vgpr_scale): + return None + kb = split_k_base + arith.index(_bvs_tail_kt[0] * tile_k) + _bvs_tail_kt[0] += 1 + return kb + for _load_stage, _compute_stage, _outstanding in tail_plan: + _entry_kb = _bvs_tail_kb() if const_expr(_outstanding == -1): if const_expr(_tail_had_load): - _wait_scale_all() _pipeline_fence(outstanding=0) if const_expr(use_tdm_store): a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) @@ -2494,6 +2655,7 @@ def _mid_tdm_nws( stages_as_idx[_compute_stage], stages_bs_idx[_compute_stage], a0_prefetch=a0_prefetch, + scale_k_base=_entry_kb, ) else: @@ -2509,36 +2671,16 @@ def _emit_epi_addrs(): stages_bs_idx[_compute_stage], emit_filler=_emit_epi_addrs, a0_prefetch=a0_prefetch, + scale_k_base=_entry_kb, ) else: - _wait_scale_all() _pipeline_fence_signal(outstanding=_outstanding) pipeline_fence_wait(use_cluster=use_cluster) _tail_mid_cb = None if const_expr(_load_stage is not None): _tail_had_load = True - if const_expr(use_ab_split_scale_buffer_load): - _tail_addr_box = [active_addr_lo] - _tail_scale_k = [scale_next_k_base] - - def _tail_mid_split_scale_dma(_ls=_load_stage, _ab=_tail_addr_box, _scale_k=_tail_scale_k): - _issue_active_tdm(_ls, _ab, scale_k_box=_scale_k) - - _tail_mid_cb = _tail_mid_split_scale_dma - elif const_expr(use_scale_buffer_load): - _tail_ab = [[addr_lo_a], [addr_lo_b]] - _tail_scale_k = [scale_next_k_base] - - def _tail_mid_scale_dma(_ls=_load_stage, _ab=_tail_ab, _scale_k=_tail_scale_k): - _issue_ab_tdm(_ls, _ab[0][0], _ab[1][0]) - _ab[0][0] = _ab[0][0] + adv_a_i32 - _ab[1][0] = _ab[1][0] + adv_b_i32 - _issue_scale_buffer_loads(_ls, _scale_k[0]) - _scale_k[0] = _scale_k[0] + arith.index(tile_k) - - _tail_mid_cb = _tail_mid_scale_dma - elif const_expr(wave_specialized_tdm): + if const_expr(wave_specialized_tdm): _tail_addr_box = [active_addr_lo] def _tail_mid_ws(_ls=_load_stage, _ab=_tail_addr_box): @@ -2577,17 +2719,11 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): stages_bs_idx[_compute_stage], mid_compute_callback=_tail_mid_cb, a0_prefetch=a0_prefetch, + scale_k_base=_entry_kb, ) if const_expr(_load_stage is not None): - if const_expr(use_ab_split_scale_buffer_load): - active_addr_lo = _tail_addr_box[0] - scale_next_k_base = _tail_scale_k[0] - elif const_expr(use_scale_buffer_load): - addr_lo_a = _tail_ab[0][0] - addr_lo_b = _tail_ab[1][0] - scale_next_k_base = _tail_scale_k[0] - elif const_expr(wave_specialized_tdm): + if const_expr(wave_specialized_tdm): active_addr_lo = _tail_addr_box[0] else: addr_lo_a = _tail_ab[0][0] diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 59de0224a..74539ff10 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -34,10 +34,35 @@ SCALE_BLOCK = 32 +def preshuffle_e8m0_scale_coalesced(scale: torch.Tensor, block: int = 128) -> torch.Tensor: + """Lane-major scale layout for direct buffer_load->VGPR. + + Per (M_block=128, K_tile): [group(2), lane16(16), 4 i32], so a buffer_load_b128's + 16 lanes read 256 contiguous bytes. M = mb*128 + (group*4 + j)*16 + lane16. + """ + M, Ks = scale.shape + assert M % block == 0 and Ks % 4 == 0, f"M={M} Ks={Ks} block={block}" + assert block == 128, "coalesced scale layout assumes warp_tile=128 (8 subtiles)" + Kt = Ks // 4 + g = scale.view(M // block, 2, 4, 16, Kt, 4) # [mb, group, j, lane16, kt, spw] + g = g.permute(0, 4, 1, 3, 2, 5).contiguous() # [mb, kt, group, lane16, j, spw] + return g.view(M, Ks) + + def preshuffle_e8m0_scale( - scale: torch.Tensor, warp_tile: int, scale_k_per_tile: int = 4, WMMA_DIM: int = 16 + scale: torch.Tensor, + warp_tile: int, + scale_k_per_tile: int = 4, + WMMA_DIM: int = 16, + coalesced: bool = False, ) -> torch.Tensor: - """Preshuffle E8M0 scale: optional byte swap + interleave for WMMA access.""" + """Preshuffle E8M0 scale: optional byte swap + interleave for WMMA access. + + ``coalesced=True`` produces the lane-major layout the scale_load_path + "vgpr"/"vgpr_ab_split" buffer_load->VGPR path expects. + """ + if coalesced: + return preshuffle_e8m0_scale_coalesced(scale, block=warp_tile) _, K_scale = scale.shape assert K_scale % 4 == 0, f"K_scale must be divisible by 4, got {K_scale}" SCALES_PER_WMMA = 4 @@ -383,8 +408,9 @@ def _run_mxscale_gemm_test( skt = tile_k // SCALE_BLOCK warp_tile_m = tile_m // m_warp warp_tile_n = tile_n // n_warp - a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt) - b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt) + _coalesced_scale = scale_load_path in ("vgpr", "vgpr_ab_split") + a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt, coalesced=_coalesced_scale) + b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt, coalesced=_coalesced_scale) # Preshuffle B data K_packed = padded_k // padded_shape["pack_b"] @@ -600,7 +626,7 @@ def test_mxfp4_metadata_and_spill_regression(out_dtype): @pytest.mark.parametrize("use_tdm_store", [True, False]) @pytest.mark.parametrize("use_scale_opsel", [True, False]) @pytest.mark.parametrize("out_dtype", ["f32", "bf16"]) -@pytest.mark.parametrize("scale_load_path", ["tdm", "buffer_lds_stage"]) +@pytest.mark.parametrize("scale_load_path", ["tdm"]) def test_mxfp8_gemm( M, N, @@ -750,11 +776,10 @@ def test_b_streaming_with_wave_spec_tdm(data_format, M, N, K, tile_m, tile_n, ti ) -@pytest.mark.parametrize("scale_load_path", ["tdm", "buffer_lds_stage", "buffer_lds_stage_ab_split"]) @pytest.mark.parametrize("num_buffers", [2, 3]) @pytest.mark.parametrize("use_tdm_store", [True, False]) @pytest.mark.parametrize("use_scale_opsel", [False, True]) -def test_mxfp8_wave_spec_scale_load_paths(scale_load_path, num_buffers, use_tdm_store, use_scale_opsel): +def test_mxfp8_wave_spec_scale_load_paths(num_buffers, use_tdm_store, use_scale_opsel): _run_mxscale_gemm_test( "fp8", 128, @@ -771,28 +796,31 @@ def test_mxfp8_wave_spec_scale_load_paths(scale_load_path, num_buffers, use_tdm_ l2_prefetch_distance=2, wave_specialized_tdm=True, use_scale_opsel=use_scale_opsel, - scale_load_path=scale_load_path, + scale_load_path="tdm", ) -def test_mxfp8_ab_split_scale_load_allows_extra_waves(): +@pytest.mark.parametrize("scale_load_path", ["vgpr", "vgpr_ab_split"]) +@pytest.mark.parametrize("cluster_m, cluster_n", [(1, 1), (2, 2)]) +def test_mxfp8_vgpr_scale_load(scale_load_path, cluster_m, cluster_n): _run_mxscale_gemm_test( "fp8", - 128, + 256 * cluster_m, + 256 * cluster_n, + 512, 256, - 384, - 128, 256, 128, 2, - 4, - num_buffers=3, + 2, + num_buffers=4, use_tdm_store=True, out_dtype="bf16", l2_prefetch_distance=2, wave_specialized_tdm=True, - use_scale_opsel=True, - scale_load_path="buffer_lds_stage_ab_split", + cluster_m=cluster_m, + cluster_n=cluster_n, + scale_load_path=scale_load_path, ) @@ -1180,8 +1208,9 @@ def _run_benchmark(args): skt = tile_k // SCALE_BLOCK warp_tile_m = tile_m // args.m_warp warp_tile_n = tile_n // args.n_warp - a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt) - b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt) + _coalesced_scale = args.scale_load_path in ("vgpr", "vgpr_ab_split") + a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt, coalesced=_coalesced_scale) + b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt, coalesced=_coalesced_scale) K_packed = padded_k // PACK_B b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) @@ -1367,8 +1396,9 @@ def _run_graph_verify(args): skt = tile_k // SCALE_BLOCK warp_tile_m = tile_m // args.m_warp warp_tile_n = tile_n // args.n_warp - a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt) - b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt) + _coalesced_scale = args.scale_load_path in ("vgpr", "vgpr_ab_split") + a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt, coalesced=_coalesced_scale) + b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt, coalesced=_coalesced_scale) K_packed = padded_k // padded_shape["pack_b"] b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) @@ -1499,7 +1529,7 @@ def launch(): "--scale-load-path", type=str, default="tdm", - choices=["tdm", "buffer_lds_stage", "buffer_lds_stage_ab_split"], + choices=["tdm", "vgpr", "vgpr_ab_split"], ) parser.add_argument("--disable-expert-sched-mode", dest="expert_sched_mode", action="store_false", default=True) parser.add_argument("--b-streaming", action="store_true", default=False) @@ -1530,7 +1560,7 @@ def launch(): action="store_true", default=False, help="Functional verification: capture the kernel in a hipGraph, " - "replay once, assert bit-exact match against an eager launch. " + "replay once, assert bit-exact match against an eager launch. ", ) parser.add_argument( "--fill-mode", From 064a11aab2276f4173d68c922e1cf7dd3eed9de8 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Mon, 1 Jun 2026 10:39:21 +0000 Subject: [PATCH 12/19] Clean up gfx1250 FP8 GEMM experiments --- kernels/gemm_common_gfx1250.py | 27 -- kernels/gemm_fp8fp4_gfx1250.py | 476 ++++++++------------------------- 2 files changed, 105 insertions(+), 398 deletions(-) diff --git a/kernels/gemm_common_gfx1250.py b/kernels/gemm_common_gfx1250.py index 3d4227828..b269192d3 100644 --- a/kernels/gemm_common_gfx1250.py +++ b/kernels/gemm_common_gfx1250.py @@ -154,33 +154,6 @@ def pipeline_fence_wait(use_cluster=False): cluster.cluster_wait() -def pipeline_fence_partial(outstanding=0, first_half=True, n_loader_waves=4, use_cluster=False): - """Half fence for the A/B half-split schedule. - - Only the loader waves owning the targeted half tensor_wait (wave0->A0, - wave1->B0, wave2->A1, wave3->B1), then *all* waves run a full barrier: - first_half=True fences {A0,B0}, first_half=False fences {A1,B1}. - - SAFETY: the cluster barrier is never weakened -- each call is a full - ``cluster_barrier``. It fences only the halves tensor_waited before its - signal; the un-waited half stays in flight until the matching second call. - """ - half = n_loader_waves // 2 - wid = rocdl.wave_id() - if first_half: - cond = arith.cmpi(arith.CmpIPredicate.ult, wid, arith.constant(half, type=T.i32)) - else: - cond = arith.cmpi(arith.CmpIPredicate.uge, wid, arith.constant(half, type=T.i32)) - if_op = scf.IfOp(cond) - with ir.InsertionPoint(if_op.then_block): - tdm_ops.tensor_wait(outstanding) - scf.YieldOp([]) - if use_cluster: - cluster.cluster_barrier() - else: - gpu.barrier() - - def issue_tdm_loads(*descs, wave_specialized=False, wave_id=None): """Emit one or more TDM loads, optionally one descriptor per loader wave.""" if wave_specialized: diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 933f523da..e04a1dc25 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -23,7 +23,6 @@ issue_tdm_loads, lds_load_b128_raw, pipeline_fence, - pipeline_fence_partial, pipeline_fence_signal, pipeline_fence_wait, store_acc_vec8_to_buffer, @@ -286,17 +285,18 @@ def _align_up(value: int, align: int) -> int: # Scale prefetch depth (K-tiles ahead) for the buffer->VGPR path. D=1 is the # sweet spot; D=2 doubles scale VGPRs -> spill + ~18% regression. _bvs_D = int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_DEPTH", "1")) - # overlay-chunks: issue one fragment load after each WMMA (vs batch-after-row) - # for finer LDS-load pipelining. OPT-IN: FLYDSL_OVERLAY_CHUNKS=1. - use_overlay_chunks = use_ref_segmented_lds_layout and (os.environ.get("FLYDSL_OVERLAY_CHUNKS", "0") == "1") # ab_half_split: repurpose the (under "vgpr") idle scale waves 2,3 as the # second halves of A/B, so all 4 waves share the A/B TDM (wave0=A0, wave1=B0, # wave2=A1, wave3=B1). Measured wall-neutral. use_ab_half_split = scale_load_path == "vgpr_ab_split" - # ab_half_fence: fence only {A0,B0}, compute the (m0-3,n0-3) quadrant while - # {A1,B1} TDM is still in flight, then fence {A1,B1} and compute the rest. - # Requires the half-split. OPT-IN: FLYDSL_AB_HALF_FENCE=1. - use_ab_half_fence = use_ab_half_split and (os.environ.get("FLYDSL_AB_HALF_FENCE", "0") == "1") + # scale-disabled: perf-ceiling probe. Feed a constant E8M0=1.0 to the scaled + # WMMA instead of loading scale -> removes the scale buffer_loads (issue slots) + # AND the scale VGPR ring (register pressure), while keeping the scaled-WMMA op + # identical. Numerically correct only when the real scale is all 1.0 (fill 0.1). + # OPT-IN: FLYDSL_SCALE_DISABLED=1. + use_scale_disabled = os.environ.get("FLYDSL_SCALE_DISABLED", "0") == "1" + # The buffer_load->VGPR scale ring is built only when scale is actually loaded. + _bvs_active = use_buffer_vgpr_scale and not use_scale_disabled if use_ref_segmented_lds_layout: # The A/B data pools are no longer packed into the same per-stage @@ -443,7 +443,6 @@ def _pick_compute_schedule_kind(): and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) and num_buffers == 4 and use_cluster - and not use_ab_half_fence ) if use_b_streaming_schedule: print( @@ -1434,6 +1433,8 @@ def load_b_pair(wn_pair, ks): ] def _load_a_scales(ks): + if const_expr(use_scale_disabled): + return [fx.Int32(0x7F7F7F7F) for _ in range_constexpr(wmma_m_rep)] if const_expr(use_buffer_vgpr_scale): if const_expr(pf_a_scales is not None): return pf_a_scales # prefetched (issued in the prior compute tile) @@ -1441,6 +1442,8 @@ def _load_a_scales(ks): return load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) def _load_b_scales(ks): + if const_expr(use_scale_disabled): + return [fx.Int32(0x7F7F7F7F) for _ in range_constexpr(b_scale_load_rep)] if const_expr(use_buffer_vgpr_scale): if const_expr(pf_b_scales is not None): return pf_b_scales @@ -1454,8 +1457,6 @@ def emit_panel_2x2( b_pair, scale_pair, prefetch_after_first_row=None, - prefetch_after_first_row_wmma=None, - prefetch_after_second_row_wmma=None, ): a_scales, b_scales = scale_pair wm_base = wm_pair * _fp8_pair_wm @@ -1470,9 +1471,6 @@ def emit_panel_2x2( a_scales, b_scales, ) - # [overlay-chunks] one fragment load interleaved after each WMMA - if const_expr(prefetch_after_first_row_wmma is not None): - prefetch_after_first_row_wmma(wn_local) if const_expr(prefetch_after_first_row is not None): prefetch_after_first_row() for wn_local in range_constexpr(_fp8_pair_wn): @@ -1485,8 +1483,6 @@ def emit_panel_2x2( a_scales, b_scales, ) - if const_expr(prefetch_after_second_row_wmma is not None): - prefetch_after_second_row_wmma(wn_local) def emit_panel_2x2_row(wm_pair, wn_pair, row_local, a_pair, b_pair, scale_pair): a_scales, b_scales = scale_pair @@ -1520,286 +1516,68 @@ def emit_panel_2x2_row(wm_pair, wn_pair, row_local, a_pair, b_pair, scale_pair): else: a0 = load_a_pair(0, ks) b1 = load_b_pair(1, ks) - if const_expr(use_overlay_chunks): - # overlay-chunks: defer b2/b3/a1/a2/a3 and issue one fragment - # load after each WMMA for finer LDS-load pipelining. - a1 = [None for _ in range_constexpr(_fp8_pair_wm)] - b2 = [None for _ in range_constexpr(_fp8_pair_wn)] - b3 = [None for _ in range_constexpr(_fp8_pair_wn)] - a2 = [None for _ in range_constexpr(_fp8_pair_wm)] - a3 = [None for _ in range_constexpr(_fp8_pair_wm)] - - def _pf_a1_chunk(ci): - a1[ci] = load_a_frag(a_buf, a_bases[_fp8_pair_wm + ci], ks) - - def _pf_b2_chunk(ci): - b2[ci] = load_b_frag(b_buf, b_bases, 2 * _fp8_pair_wn + ci, ks) - - first_wait_keep = _two_pair_loads + 3 - if const_expr(ks == 0 and a0_prefetch is not None): - first_wait_keep += DS_LOADS_PER_A_FRAG * len(a0_prefetch) - rocdl.s_wait_dscnt(first_wait_keep) - emit_panel_2x2( - 0, - 0, - a0, - b0, - scale_pair, - prefetch_after_first_row_wmma=_pf_a1_chunk, - prefetch_after_second_row_wmma=_pf_b2_chunk, - ) - - if const_expr(ks == 0 and mid_compute_callback is not None): - rocdl.sched_barrier(0) - mid_compute_callback() - - def _pf_b3_chunk(ci): - b3[ci] = load_b_frag(b_buf, b_bases, 3 * _fp8_pair_wn + ci, ks) - - def _pf_a3_chunk(ci): - a3[ci] = load_a_frag(a_buf, a_bases[3 * _fp8_pair_wm + ci], ks) - - rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) - emit_panel_2x2(0, 1, a0, b1, scale_pair, prefetch_after_first_row_wmma=_pf_b3_chunk) - - rocdl.s_wait_dscnt(_fp8_pair_b_loads + 2) - emit_panel_2x2(1, 0, a1, b0, scale_pair, prefetch_after_first_row_wmma=_pf_a3_chunk) - - def _pf_a2_chunk(ci): - a2[ci] = load_a_frag(a_buf, a_bases[2 * _fp8_pair_wm + ci], ks) - - emit_panel_2x2(1, 1, a1, b1, scale_pair) - emit_panel_2x2(0, 2, a0, b2, scale_pair, prefetch_after_first_row_wmma=_pf_a2_chunk) - emit_panel_2x2_row(1, 2, 0, a1, b2, scale_pair) - rocdl.s_wait_dscnt(_pair_loads) - emit_panel_2x2(0, 3, a0, b3, scale_pair) - emit_panel_2x2_row(1, 2, 1, a1, b2, scale_pair) - emit_panel_2x2(1, 3, a1, b3, scale_pair) - emit_panel_2x2(2, 0, a2, b0, scale_pair) - if const_expr(is_last_ks and late_compute_callback is not None): - rocdl.sched_barrier(0) - late_compute_callback() - emit_panel_2x2(2, 1, a2, b1, scale_pair) - rocdl.s_wait_dscnt(0) - emit_panel_2x2(3, 0, a3, b0, scale_pair) - emit_panel_2x2(3, 1, a3, b1, scale_pair) - if const_expr(is_last_ks and emit_filler is not None): - rocdl.sched_barrier(0) - emit_filler() - emit_panel_2x2(2, 2, a2, b2, scale_pair) - emit_panel_2x2(2, 3, a2, b3, scale_pair) - emit_panel_2x2(3, 2, a3, b2, scale_pair) - emit_panel_2x2(3, 3, a3, b3, scale_pair) - else: - b2 = load_b_pair(2, ks) - - a1_box = [None] - b3_box = [None] - a2_box = [None] - a3_box = [None] - - def _prefetch_a1(): - a1_box[0] = load_a_pair(1, ks) - - first_wait_keep = _two_pair_loads + 3 - if const_expr(ks == 0 and a0_prefetch is not None): - first_wait_keep += DS_LOADS_PER_A_FRAG * len(a0_prefetch) - rocdl.s_wait_dscnt(first_wait_keep) - emit_panel_2x2(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) - - if const_expr(ks == 0 and mid_compute_callback is not None): - rocdl.sched_barrier(0) - mid_compute_callback() - - def _prefetch_b3(): - b3_box[0] = load_b_pair(3, ks) - - def _prefetch_a3(): - a3_box[0] = load_a_pair(3, ks) - - rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) - emit_panel_2x2(0, 1, a0, b1, scale_pair, prefetch_after_first_row=_prefetch_b3) - - rocdl.s_wait_dscnt(_fp8_pair_b_loads + 2) - emit_panel_2x2(1, 0, a1_box[0], b0, scale_pair, prefetch_after_first_row=_prefetch_a3) - - def _prefetch_a2(): - a2_box[0] = load_a_pair(2, ks) - - emit_panel_2x2(1, 1, a1_box[0], b1, scale_pair) - - emit_panel_2x2(0, 2, a0, b2, scale_pair, prefetch_after_first_row=_prefetch_a2) - emit_panel_2x2_row(1, 2, 0, a1_box[0], b2, scale_pair) - emit_panel_2x2_row(1, 2, 1, a1_box[0], b2, scale_pair) - rocdl.s_wait_dscnt(_pair_loads) - emit_panel_2x2(0, 3, a0, b3_box[0], scale_pair) - emit_panel_2x2(1, 3, a1_box[0], b3_box[0], scale_pair) - - emit_panel_2x2(2, 0, a2_box[0], b0, scale_pair) - if const_expr(is_last_ks and late_compute_callback is not None): - rocdl.sched_barrier(0) - late_compute_callback() - emit_panel_2x2(2, 1, a2_box[0], b1, scale_pair) - - rocdl.s_wait_dscnt(0) - emit_panel_2x2(3, 0, a3_box[0], b0, scale_pair) - emit_panel_2x2(3, 1, a3_box[0], b1, scale_pair) - - if const_expr(is_last_ks and emit_filler is not None): - rocdl.sched_barrier(0) - emit_filler() - - emit_panel_2x2(2, 2, a2_box[0], b2, scale_pair) - emit_panel_2x2(2, 3, a2_box[0], b3_box[0], scale_pair) - emit_panel_2x2(3, 2, a3_box[0], b2, scale_pair) - emit_panel_2x2(3, 3, a3_box[0], b3_box[0], scale_pair) - - return current_accs - - def compute_tile_fp8_half_fence( - accs_in, - lds_a, - lds_b, - lds_as, - lds_bs, - mid_compute_callback=None, - quadrant_fence_callback=None, - scale_k_base=None, - pf_a_scales=None, - pf_b_scales=None, - ): - """Two-phase compute for the A/B half-fence schedule. - - Phase 1 reads the first A/B half (A0/B0 = frags 0..3) and computes the - (0..3, 0..3) quadrant; phase 2 reads A1/B1 and computes the remaining - 12 panels. ``quadrant_fence_callback`` (the {A1,B1} fence) runs between - the phases, at ks==0 only. No second-half ds_load is issued before that - fence -- the invariant the deep-pipeline schedule can't honour. - """ - current_accs = list(accs_in) - a_buf, a_bases = _precompute_a_lane_bases(lds_a) - b_buf, b_bases = _precompute_b_lane_bases(lds_b) - as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) - - def load_a_pair(wm_pair, ks): - wm_base = wm_pair * _fp8_pair_wm - return [ - load_a_frag(a_buf, a_bases[wm_base + wm_local], ks) for wm_local in range_constexpr(_fp8_pair_wm) - ] - - def load_b_pair(wn_pair, ks): - wn_base = wn_pair * _fp8_pair_wn - return [ - load_b_frag(b_buf, b_bases, wn_base + wn_local, ks) for wn_local in range_constexpr(_fp8_pair_wn) - ] - - def _load_a_scales(ks): - if const_expr(pf_a_scales is not None): - return pf_a_scales - if const_expr(use_buffer_vgpr_scale): - return _bvs_load_scales(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, scale_k_base) - return load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) - - def _load_b_scales(ks): - if const_expr(pf_b_scales is not None): - return pf_b_scales - if const_expr(use_buffer_vgpr_scale): - return _bvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, scale_k_base) - return load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) - - def emit_panel(wm_pair, wn_pair, a_pair, b_pair, scale_pair, prefetch_after_first_row=None): - a_scales, b_scales = scale_pair - wm_base = wm_pair * _fp8_pair_wm - wn_base = wn_pair * _fp8_pair_wn - for wn_local in range_constexpr(_fp8_pair_wn): - _emit_wmma( - current_accs, - wm_base, - wn_base + wn_local, - a_pair[0], - b_pair[wn_local], - a_scales, - b_scales, - ) - if const_expr(prefetch_after_first_row is not None): - prefetch_after_first_row() - for wn_local in range_constexpr(_fp8_pair_wn): - _emit_wmma( - current_accs, - wm_base + 1, - wn_base + wn_local, - a_pair[1], - b_pair[wn_local], - a_scales, - b_scales, - ) - - _pair_loads = _fp8_pair_a_loads - _two_pair_loads = _fp8_pair_a_loads + _fp8_pair_b_loads - - for ks in range_constexpr(k_wmma_steps): - scale_pair = (_load_a_scales(ks), _load_b_scales(ks)) + b2 = load_b_pair(2, ks) - # ---- Phase 1: first half {A0,B0} -> quadrant (m0-3, n0-3) ---- - # Issue a0,b0,b1 up front; prefetch a1 mid-panel(0,0). NO - # second-half (a2/a3/b2/b3) load may be issued before the fence. - b0 = load_b_pair(0, ks) - a0 = load_a_pair(0, ks) - b1 = load_b_pair(1, ks) a1_box = [None] + b3_box = [None] + a2_box = [None] + a3_box = [None] def _prefetch_a1(): a1_box[0] = load_a_pair(1, ks) - rocdl.s_wait_dscnt(_two_pair_loads + 3) - emit_panel(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) + first_wait_keep = _two_pair_loads + 3 + if const_expr(ks == 0 and a0_prefetch is not None): + first_wait_keep += DS_LOADS_PER_A_FRAG * len(a0_prefetch) + rocdl.s_wait_dscnt(first_wait_keep) + emit_panel_2x2(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) + if const_expr(ks == 0 and mid_compute_callback is not None): rocdl.sched_barrier(0) mid_compute_callback() - rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) - emit_panel(0, 1, a0, b1, scale_pair) - rocdl.s_wait_dscnt(_fp8_pair_b_loads) - emit_panel(1, 0, a1_box[0], b0, scale_pair) - emit_panel(1, 1, a1_box[0], b1, scale_pair) - rocdl.s_wait_dscnt(0) - # ---- {A1,B1} half-fence (once, before any second-half read) ---- - if const_expr(ks == 0 and quadrant_fence_callback is not None): - rocdl.sched_barrier(0) - quadrant_fence_callback() + def _prefetch_b3(): + b3_box[0] = load_b_pair(3, ks) - # ---- Phase 2: second half {A1,B1} -> remaining 12 panels ---- - a1 = a1_box[0] - b2 = load_b_pair(2, ks) - b3 = load_b_pair(3, ks) - a2_box = [None] - a3_box = [None] + def _prefetch_a3(): + a3_box[0] = load_a_pair(3, ks) + + rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) + emit_panel_2x2(0, 1, a0, b1, scale_pair, prefetch_after_first_row=_prefetch_b3) + + rocdl.s_wait_dscnt(_fp8_pair_b_loads + 2) + emit_panel_2x2(1, 0, a1_box[0], b0, scale_pair, prefetch_after_first_row=_prefetch_a3) def _prefetch_a2(): a2_box[0] = load_a_pair(2, ks) - def _prefetch_a3(): - a3_box[0] = load_a_pair(3, ks) + emit_panel_2x2(1, 1, a1_box[0], b1, scale_pair) + + emit_panel_2x2(0, 2, a0, b2, scale_pair, prefetch_after_first_row=_prefetch_a2) + emit_panel_2x2_row(1, 2, 0, a1_box[0], b2, scale_pair) + emit_panel_2x2_row(1, 2, 1, a1_box[0], b2, scale_pair) + rocdl.s_wait_dscnt(_pair_loads) + emit_panel_2x2(0, 3, a0, b3_box[0], scale_pair) + emit_panel_2x2(1, 3, a1_box[0], b3_box[0], scale_pair) + + emit_panel_2x2(2, 0, a2_box[0], b0, scale_pair) + if const_expr(is_last_ks and late_compute_callback is not None): + rocdl.sched_barrier(0) + late_compute_callback() + emit_panel_2x2(2, 1, a2_box[0], b1, scale_pair) - # panels reading only a0/a1 + b2/b3 first, prefetch a2,a3 meanwhile - rocdl.s_wait_dscnt(_two_pair_loads) - emit_panel(0, 2, a0, b2, scale_pair, prefetch_after_first_row=_prefetch_a2) - emit_panel(0, 3, a0, b3, scale_pair, prefetch_after_first_row=_prefetch_a3) - emit_panel(1, 2, a1, b2, scale_pair) - emit_panel(1, 3, a1, b3, scale_pair) - rocdl.s_wait_dscnt(_fp8_pair_a_loads) - emit_panel(2, 0, a2_box[0], b0, scale_pair) - emit_panel(2, 1, a2_box[0], b1, scale_pair) rocdl.s_wait_dscnt(0) - emit_panel(3, 0, a3_box[0], b0, scale_pair) - emit_panel(3, 1, a3_box[0], b1, scale_pair) - emit_panel(2, 2, a2_box[0], b2, scale_pair) - emit_panel(2, 3, a2_box[0], b3, scale_pair) - emit_panel(3, 2, a3_box[0], b2, scale_pair) - emit_panel(3, 3, a3_box[0], b3, scale_pair) + emit_panel_2x2(3, 0, a3_box[0], b0, scale_pair) + emit_panel_2x2(3, 1, a3_box[0], b1, scale_pair) + + if const_expr(is_last_ks and emit_filler is not None): + rocdl.sched_barrier(0) + emit_filler() + + emit_panel_2x2(2, 2, a2_box[0], b2, scale_pair) + emit_panel_2x2(2, 3, a2_box[0], b3_box[0], scale_pair) + emit_panel_2x2(3, 2, a3_box[0], b2, scale_pair) + emit_panel_2x2(3, 3, a3_box[0], b3_box[0], scale_pair) return current_accs @@ -2372,14 +2150,6 @@ def _pipeline_fence(outstanding=0): def _pipeline_fence_signal(outstanding=0): pipeline_fence_signal(outstanding=outstanding, use_cluster=use_cluster) - def _pipeline_fence_partial(outstanding=0, first_half=True): - pipeline_fence_partial( - outstanding=outstanding, - first_half=first_half, - n_loader_waves=4, - use_cluster=use_cluster, - ) - if const_expr(wave_specialized_tdm): def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): @@ -2414,7 +2184,7 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): addr_lo_as = addr_lo_as + adv_as_i32 addr_lo_bs = addr_lo_bs + adv_bs_i32 - if const_expr(use_buffer_vgpr_scale): + if const_expr(_bvs_active): # Prologue: prefetch the first _bvs_D K-tiles (global->VGPR). Carried as # FLAT lists of i32 (list-of-tuples can't be loop-carried). _bvs_pf = [_bvs_prefetch(split_k_base + arith.index(_d * tile_k)) for _d in range(_bvs_D)] @@ -2433,13 +2203,13 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): if const_expr(loop_iters > 0): if const_expr(wave_specialized_tdm): init_args = list(accs) + [active_addr_lo] - if const_expr(use_buffer_vgpr_scale): + if const_expr(_bvs_active): init_args = init_args + _bvs_ra + _bvs_rb for loop_iter, state in range(0, loop_iters, 1, init=init_args): accs_in = list(state[:n_accs]) cur_addr_lo = state[n_accs] - if const_expr(use_buffer_vgpr_scale): + if const_expr(_bvs_active): _ra0 = n_accs + 1 _ring_a = list(state[_ra0 : _ra0 + _bvs_D * wmma_m_rep]) _rb0 = _ra0 + _bvs_D * wmma_m_rep @@ -2461,91 +2231,55 @@ def _mid_tdm_ws( ): _issue_active_tdm(_ls, _ab, k_prefetch=_k_off) - if const_expr(use_ab_half_fence): - # Fence {A0,B0} only; the {A1,B1} fence - # fires mid-compute so its TDM overlaps the quadrant WMMAs. - if const_expr(use_buffer_vgpr_scale): - _cur_a = _ring_a[:wmma_m_rep] - _cur_b = _ring_b[:b_scale_load_rep] - _next_kb = ( - split_k_base - + loop_iter * arith.index(num_buffers * tile_k) - + arith.index((buf_idx + _bvs_D) * tile_k) - ) - _na, _nb2 = _bvs_prefetch(_next_kb) - _ring_a = _ring_a[wmma_m_rep:] + list(_na) - _ring_b = _ring_b[b_scale_load_rep:] + list(_nb2) - else: - _cur_a = None - _cur_b = None - - _pipeline_fence_partial(outstanding=_fence_outstanding, first_half=True) - - def _quad_fence(): - _pipeline_fence_partial(outstanding=_fence_outstanding, first_half=False) - - accs_in = compute_tile_fp8_half_fence( - accs_in, - stages_a_idx[buf_idx], - stages_b_idx[buf_idx], - stages_as_idx[buf_idx], - stages_bs_idx[buf_idx], - mid_compute_callback=_mid_tdm_ws, - quadrant_fence_callback=_quad_fence, - pf_a_scales=_cur_a, - pf_b_scales=_cur_b, - ) - else: - if const_expr(not use_ws_tdm_split_signal_overlap): + if const_expr(not use_ws_tdm_split_signal_overlap): + _pipeline_fence_signal(outstanding=_fence_outstanding) + pipeline_fence_wait(use_cluster=use_cluster) + + _late_tdm_ws_fence_signal = None + if const_expr(use_ws_tdm_split_signal_overlap): + + def _late_tdm_ws_split_signal(): _pipeline_fence_signal(outstanding=_fence_outstanding) - pipeline_fence_wait(use_cluster=use_cluster) - - _late_tdm_ws_fence_signal = None - if const_expr(use_ws_tdm_split_signal_overlap): - - def _late_tdm_ws_split_signal(): - _pipeline_fence_signal(outstanding=_fence_outstanding) - - _late_tdm_ws_fence_signal = _late_tdm_ws_split_signal - - a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) - rocdl.sched_barrier(0) - # Consume scale prefetched _bvs_D K-tiles ago; issue the - # K-tile +_bvs_D prefetch now (overlaps this tile's WMMAs). - # NOTE: must stay AFTER the fence (matches the original - # ordering); issuing the scale buffer_loads before the - # cluster barrier hangs the vgpr path. - if const_expr(use_buffer_vgpr_scale): - _cur_a = _ring_a[:wmma_m_rep] - _cur_b = _ring_b[:b_scale_load_rep] - _next_kb = ( - split_k_base - + loop_iter * arith.index(num_buffers * tile_k) - + arith.index((buf_idx + _bvs_D) * tile_k) - ) - _na, _nb2 = _bvs_prefetch(_next_kb) - _ring_a = _ring_a[wmma_m_rep:] + list(_na) - _ring_b = _ring_b[b_scale_load_rep:] + list(_nb2) - else: - _cur_a = None - _cur_b = None - - accs_in = compute_tile_scheduled( - accs_in, - stages_a_idx[buf_idx], - stages_b_idx[buf_idx], - stages_as_idx[buf_idx], - stages_bs_idx[buf_idx], - mid_compute_callback=_mid_tdm_ws, - late_compute_callback=_late_tdm_ws_fence_signal, - a0_prefetch=a0_prefetch, - pf_a_scales=_cur_a, - pf_b_scales=_cur_b, + + _late_tdm_ws_fence_signal = _late_tdm_ws_split_signal + + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) + rocdl.sched_barrier(0) + # Consume scale prefetched _bvs_D K-tiles ago; issue the + # K-tile +_bvs_D prefetch now (overlaps this tile's WMMAs). + # NOTE: must stay AFTER the fence; issuing the scale + # buffer_loads before the cluster barrier hangs the vgpr path. + if const_expr(_bvs_active): + _cur_a = _ring_a[:wmma_m_rep] + _cur_b = _ring_b[:b_scale_load_rep] + _next_kb = ( + split_k_base + + loop_iter * arith.index(num_buffers * tile_k) + + arith.index((buf_idx + _bvs_D) * tile_k) ) + _na, _nb2 = _bvs_prefetch(_next_kb) + _ring_a = _ring_a[wmma_m_rep:] + list(_na) + _ring_b = _ring_b[b_scale_load_rep:] + list(_nb2) + else: + _cur_a = None + _cur_b = None + + accs_in = compute_tile_scheduled( + accs_in, + stages_a_idx[buf_idx], + stages_b_idx[buf_idx], + stages_as_idx[buf_idx], + stages_bs_idx[buf_idx], + mid_compute_callback=_mid_tdm_ws, + late_compute_callback=_late_tdm_ws_fence_signal, + a0_prefetch=a0_prefetch, + pf_a_scales=_cur_a, + pf_b_scales=_cur_b, + ) cur_addr_lo = addr_box[0] hot_loop_scheduler_scheduled() - if const_expr(use_buffer_vgpr_scale): + if const_expr(_bvs_active): _bvs_yield = _ring_a + _ring_b else: _bvs_yield = [] @@ -2635,7 +2369,7 @@ def _mid_tdm_nws( _bvs_tail_kt = [loop_iters * num_buffers] def _bvs_tail_kb(): - if const_expr(not use_buffer_vgpr_scale): + if const_expr(not _bvs_active): return None kb = split_k_base + arith.index(_bvs_tail_kt[0] * tile_k) _bvs_tail_kt[0] += 1 From 42b8e9f7ccd25759d391e141aa6c02f65687043c Mon Sep 17 00:00:00 2001 From: aoli26 Date: Mon, 1 Jun 2026 12:01:38 +0000 Subject: [PATCH 13/19] [experiment] FP8 deep-pipeline hot_loop_sched_mode (iglp / manual sched_group_barrier) Opt-in hot-loop scheduling for the FP8 deep-pipeline (default unchanged): - hot_loop_sched_mode = default | iglp | manual - iglp: clean region (internal sched_barrier(0) suppressed) + iglp_opt(0) (LLVM MFMASmallGemmOpt) -> SP3-style DS/MFMA interleave. - manual: same clean region + hand-emitted sched_group_barrier template, granularity hot_loop_manual_ds:hot_loop_manual_mfma (default 8:8). default mode byte-identical to baseline. All modes cosine=1.0; fp8 wave_spec tests pass. Net win over default unconfirmed (AM underestimates ds latency) - to be ranked on silicon ATT. --- kernels/gemm_fp8fp4_gfx1250.py | 77 ++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 9 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index e04a1dc25..ac727adcc 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -70,6 +70,9 @@ def compile_mxscale_gemm( b_streaming: bool = False, scale_load_path: str = "tdm", fp8_schedule: str = "auto", + hot_loop_sched_mode: str = "default", + hot_loop_manual_ds: int = 8, + hot_loop_manual_mfma: int = 8, ): """Compile an MXFP4 or MXFP8 GEMM kernel with TDM async copy. @@ -438,6 +441,28 @@ def _pick_compute_schedule_kind(): use_fp8_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT use_fp8_deep_pipeline_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE use_b_streaming_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING + # hot_loop_sched_mode: deep-pipeline hot-loop scheduling. + # "default" — original sched_barrier(0) fences + coarse hot_loop_scheduler. + # "iglp" — clean region (internal sched_barrier(0) suppressed), LLVM + # MFMASmallGemmOpt via iglp_opt(0). + # "manual" — clean region + hand-emitted sched_group_barrier template, + # granularity hot_loop_manual_ds:hot_loop_manual_mfma (default 8:8). + # Latency-hiding: a small DS group degenerates to "issue few ds -> wait(0) + # full-drain -> 1-2 WMMA" (exposes ds latency); larger DS+MFMA keeps ds in + # flight with partial drains + long WMMA bursts that hide it. Net win over + # "default" unconfirmed (AM underestimates ds latency) -> rank on silicon ATT. + _hl_modes = ("default", "iglp", "manual") + if hot_loop_sched_mode not in _hl_modes: + raise ValueError(f"hot_loop_sched_mode must be one of {_hl_modes}, got {hot_loop_sched_mode!r}") + _hl_iglp = hot_loop_sched_mode == "iglp" + _hl_manual = hot_loop_sched_mode == "manual" + _hl_clean_region = _hl_iglp or _hl_manual + if _hl_clean_region and not use_fp8_deep_pipeline_schedule: + raise ValueError( + f"hot_loop_sched_mode={hot_loop_sched_mode!r} currently requires the FP8 deep-pipeline schedule" + ) + if _hl_manual and (hot_loop_manual_ds < 1 or hot_loop_manual_mfma < 1): + raise ValueError("hot_loop_manual_ds / hot_loop_manual_mfma must be >= 1") use_ws_tdm_split_signal_overlap = ( wave_specialized_tdm and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) @@ -496,6 +521,20 @@ def kernel_mxscale_gemm( # Enable back-to-back WMMA issue (SCHED_MODE bit[4] = DISABLE_VALU_STALL) rocdl.disable_xdl_arb_stall() + # One-shot per clean scheduling region (iglp/manual): drive the LLVM + # PipelineSolver. iglp_opt(0) uses MFMASmallGemmOpt; manual emits the + # equivalent SchedGroup template by hand with repo-standard + # sched_group_barrier (tunable DS:MFMA). Over-provisioning is harmless. + _manual_pipeline_reps = num_buffers * n_accs + + def _emit_pipeline_region_hint(): + if const_expr(_hl_iglp): + rocdl.iglp_opt(0) + elif const_expr(_hl_manual): + for _grp in range_constexpr(_manual_pipeline_reps): + rocdl.sched_group_barrier(rocdl.mask_dsrd, hot_loop_manual_ds, 0) + rocdl.sched_group_barrier(rocdl.mask_mfma, hot_loop_manual_mfma, 0) + if const_expr(inst_prefetch): if rocdl.wave_id() == fx.Int32(0): rocdl.s_prefetch_inst_burst(num_pages=4) @@ -1533,7 +1572,8 @@ def _prefetch_a1(): emit_panel_2x2(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) if const_expr(ks == 0 and mid_compute_callback is not None): - rocdl.sched_barrier(0) + if const_expr(not _hl_clean_region): + rocdl.sched_barrier(0) mid_compute_callback() def _prefetch_b3(): @@ -1562,7 +1602,8 @@ def _prefetch_a2(): emit_panel_2x2(2, 0, a2_box[0], b0, scale_pair) if const_expr(is_last_ks and late_compute_callback is not None): - rocdl.sched_barrier(0) + if const_expr(not _hl_clean_region): + rocdl.sched_barrier(0) late_compute_callback() emit_panel_2x2(2, 1, a2_box[0], b1, scale_pair) @@ -1571,7 +1612,8 @@ def _prefetch_a2(): emit_panel_2x2(3, 1, a3_box[0], b1, scale_pair) if const_expr(is_last_ks and emit_filler is not None): - rocdl.sched_barrier(0) + if const_expr(not _hl_clean_region): + rocdl.sched_barrier(0) emit_filler() emit_panel_2x2(2, 2, a2_box[0], b2, scale_pair) @@ -2244,7 +2286,11 @@ def _late_tdm_ws_split_signal(): _late_tdm_ws_fence_signal = _late_tdm_ws_split_signal a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) - rocdl.sched_barrier(0) + if const_expr(_hl_clean_region): + if const_expr(buf_idx == 0): + _emit_pipeline_region_hint() + else: + rocdl.sched_barrier(0) # Consume scale prefetched _bvs_D K-tiles ago; issue the # K-tile +_bvs_D prefetch now (overlaps this tile's WMMAs). # NOTE: must stay AFTER the fence; issuing the scale @@ -2277,7 +2323,8 @@ def _late_tdm_ws_split_signal(): pf_b_scales=_cur_b, ) cur_addr_lo = addr_box[0] - hot_loop_scheduler_scheduled() + if const_expr(not _hl_clean_region): + hot_loop_scheduler_scheduled() if const_expr(_bvs_active): _bvs_yield = _ring_a + _ring_b @@ -2332,7 +2379,11 @@ def _mid_tdm_nws( _l2_prefetch(_k_off) a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) - rocdl.sched_barrier(0) + if const_expr(_hl_clean_region): + if const_expr(buf_idx == 0): + _emit_pipeline_region_hint() + else: + rocdl.sched_barrier(0) accs_in = compute_tile_scheduled( accs_in, stages_a_idx[buf_idx], @@ -2346,7 +2397,8 @@ def _mid_tdm_nws( cur_lo_b = addr_boxes[1][0] cur_lo_as = addr_boxes[2][0] cur_lo_bs = addr_boxes[3][0] - hot_loop_scheduler_scheduled() + if const_expr(not _hl_clean_region): + hot_loop_scheduler_scheduled() results = yield list(accs_in) + [cur_lo_a, cur_lo_b, cur_lo_as, cur_lo_bs] @@ -2444,7 +2496,10 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): _tail_mid_cb = _tail_mid_nws a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) - rocdl.sched_barrier(0) + if const_expr(_hl_clean_region): + _emit_pipeline_region_hint() + else: + rocdl.sched_barrier(0) accs = compute_tile_scheduled( accs, stages_a_idx[_compute_stage], @@ -2465,7 +2520,8 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): addr_lo_as = _tail_ab[2][0] addr_lo_bs = _tail_ab[3][0] - hot_loop_scheduler_scheduled() + if const_expr(not _hl_clean_region): + hot_loop_scheduler_scheduled() accs = finalize_acc_layout(accs) @@ -2511,6 +2567,9 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): b_streaming, scale_load_path, fp8_schedule, + hot_loop_sched_mode, + hot_loop_manual_ds, + hot_loop_manual_mfma, ) @flyc.jit From 4973cd9c46316135e524c26f33ac44b4436655e5 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Tue, 2 Jun 2026 02:47:50 +0000 Subject: [PATCH 14/19] Add FLYDSL_SCALE_SKIP_TDM scale-TDM-isolation control probe Skip only the scale TDM vmem->LDS load path: drop the scale loader waves (2,3) from the active TDM the same way the vgpr path does, and pre-fill the scale LDS stages once with a constant E8M0=1.0 (0x7F) byte. The downstream scale LDS read path and the scaled-WMMA op are left byte-for-byte unchanged, so the variant isolates the cost of scale TDM delivery alone (unlike FLYDSL_SCALE_DISABLED, which also removes the LDS reads). Gated on scale_load_path='tdm' + wave_specialized_tdm; mutually exclusive with FLYDSL_SCALE_DISABLED. Also adds test_mxfp8_hot_loop_sched_modes. --- kernels/gemm_fp8fp4_gfx1250.py | 94 +++++++++++++++++++++-- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 81 +++++++++++++++++++ 2 files changed, 168 insertions(+), 7 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index ac727adcc..91d914ba8 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -11,8 +11,9 @@ import flydsl.compiler as flyc import flydsl.expr as fx from flydsl._mlir import ir +from flydsl._mlir.dialects import scf from flydsl.compiler.kernel_function import CompilationContext -from flydsl.expr import arith, buffer_ops, const_expr, gpu, idx2crd, range_constexpr, rocdl, tdm_ops +from flydsl.expr import arith, buffer_ops, const_expr, gpu, idx2crd, range_constexpr, rocdl, tdm_ops, vector from flydsl.expr.rocdl import cluster from flydsl.expr.typing import T from flydsl.runtime.device import get_rocm_arch as get_hip_arch @@ -22,6 +23,7 @@ get_lds_memref, issue_tdm_loads, lds_load_b128_raw, + lds_store_b128, pipeline_fence, pipeline_fence_signal, pipeline_fence_wait, @@ -298,6 +300,30 @@ def _align_up(value: int, align: int) -> int: # identical. Numerically correct only when the real scale is all 1.0 (fill 0.1). # OPT-IN: FLYDSL_SCALE_DISABLED=1. use_scale_disabled = os.environ.get("FLYDSL_SCALE_DISABLED", "0") == "1" + # skip-scale-TDM: scale/TDM-isolation control probe. Keep the ENTIRE downstream + # scale path identical (LDS layout + ds_load scale reads + scaled-WMMA op), but + # never deliver scale via TDM. Instead the scale LDS stages are pre-filled once + # with a fixed E8M0 byte before the hot loop. This isolates the cost of the + # scale TDM vmem->LDS traffic alone (vs FLYDSL_SCALE_DISABLED, which also removes + # the LDS reads). Numerically correct only when the real scale is all 1.0 + # (fill 0.1). OPT-IN: FLYDSL_SCALE_SKIP_TDM=1. Only valid on the TDM scale path + # with wave-specialized TDM (scale lives on loader waves 2,3, which are dropped). + use_skip_scale_tdm = os.environ.get("FLYDSL_SCALE_SKIP_TDM", "0") == "1" + if use_skip_scale_tdm: + if scale_load_path != "tdm": + raise ValueError( + "FLYDSL_SCALE_SKIP_TDM=1 only applies to scale_load_path='tdm', " + f"got {scale_load_path!r}" + ) + if not wave_specialized_tdm: + raise ValueError( + "FLYDSL_SCALE_SKIP_TDM=1 requires wave_specialized_tdm=True " + "(scale is carried on dedicated loader waves 2,3)" + ) + if use_scale_disabled: + raise ValueError( + "FLYDSL_SCALE_SKIP_TDM and FLYDSL_SCALE_DISABLED are mutually exclusive probes" + ) # The buffer_load->VGPR scale ring is built only when scale is actually loaded. _bvs_active = use_buffer_vgpr_scale and not use_scale_disabled @@ -463,6 +489,26 @@ def _pick_compute_schedule_kind(): ) if _hl_manual and (hot_loop_manual_ds < 1 or hot_loop_manual_mfma < 1): raise ValueError("hot_loop_manual_ds / hot_loop_manual_mfma must be >= 1") + # Deep-pipeline explicit s_wait_dscnt control (clean-region iglp/manual only). + # The SP3 prototype (poc_kl mxfp8fp4gemm) drains dscnt ~once per K-iteration; + # FlyDSL's compute_tile_fp8_deep_pipeline hardcodes 5 explicit s_wait_dscnt per + # ks, which double as scheduling barriers that pin each ds_load right before its + # consuming WMMA. Those explicit drains prevent the MI scheduler from hoisting + # the next ks's ds_load prefetch to overlap the current ks's WMMA (the + # prototype's "dW dW" future-buffer interleave). In a clean region the LLVM + # auto-waitcnt pass re-derives correct dscnt waits after reorder, so dropping + # some/all explicit drains is correctness-safe and lets IGLP+auto-waitcnt + # approach the prototype's single-drain structure. + # keep - emit all explicit dscnt drains (default; byte-identical to today) + # single - one full drain (s_wait_dscnt 0) at the top of each ks, drop the + # other 4 (mirrors the prototype's 1x s_wait_dscnt 0x0 per K-iter) + # auto - drop every explicit dscnt drain; auto-waitcnt owns them entirely + _hl_dscnt_mode = os.environ.get("FLYDSL_HL_DSCNT_MODE", "keep") + if _hl_dscnt_mode not in ("keep", "single", "auto"): + raise ValueError(f"FLYDSL_HL_DSCNT_MODE must be keep/single/auto, got {_hl_dscnt_mode!r}") + if _hl_dscnt_mode != "keep" and not _hl_clean_region: + raise ValueError("FLYDSL_HL_DSCNT_MODE!=keep requires hot_loop_sched_mode in {iglp,manual}") + _hl_keep_mid_dscnt = _hl_dscnt_mode == "keep" use_ws_tdm_split_signal_overlap = ( wave_specialized_tdm and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) @@ -1568,7 +1614,11 @@ def _prefetch_a1(): first_wait_keep = _two_pair_loads + 3 if const_expr(ks == 0 and a0_prefetch is not None): first_wait_keep += DS_LOADS_PER_A_FRAG * len(a0_prefetch) - rocdl.s_wait_dscnt(first_wait_keep) + if const_expr(_hl_dscnt_mode == "keep"): + rocdl.s_wait_dscnt(first_wait_keep) + elif const_expr(_hl_dscnt_mode == "single"): + rocdl.s_wait_dscnt(0) # prototype-style: one full drain per ks + # auto: emit nothing here; auto-waitcnt re-derives all dscnt waits emit_panel_2x2(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) if const_expr(ks == 0 and mid_compute_callback is not None): @@ -1582,10 +1632,12 @@ def _prefetch_b3(): def _prefetch_a3(): a3_box[0] = load_a_pair(3, ks) - rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) + if const_expr(_hl_keep_mid_dscnt): + rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) emit_panel_2x2(0, 1, a0, b1, scale_pair, prefetch_after_first_row=_prefetch_b3) - rocdl.s_wait_dscnt(_fp8_pair_b_loads + 2) + if const_expr(_hl_keep_mid_dscnt): + rocdl.s_wait_dscnt(_fp8_pair_b_loads + 2) emit_panel_2x2(1, 0, a1_box[0], b0, scale_pair, prefetch_after_first_row=_prefetch_a3) def _prefetch_a2(): @@ -1596,7 +1648,8 @@ def _prefetch_a2(): emit_panel_2x2(0, 2, a0, b2, scale_pair, prefetch_after_first_row=_prefetch_a2) emit_panel_2x2_row(1, 2, 0, a1_box[0], b2, scale_pair) emit_panel_2x2_row(1, 2, 1, a1_box[0], b2, scale_pair) - rocdl.s_wait_dscnt(_pair_loads) + if const_expr(_hl_keep_mid_dscnt): + rocdl.s_wait_dscnt(_pair_loads) emit_panel_2x2(0, 3, a0, b3_box[0], scale_pair) emit_panel_2x2(1, 3, a1_box[0], b3_box[0], scale_pair) @@ -1607,7 +1660,8 @@ def _prefetch_a2(): late_compute_callback() emit_panel_2x2(2, 1, a2_box[0], b1, scale_pair) - rocdl.s_wait_dscnt(0) + if const_expr(_hl_keep_mid_dscnt): + rocdl.s_wait_dscnt(0) emit_panel_2x2(3, 0, a3_box[0], b0, scale_pair) emit_panel_2x2(3, 1, a3_box[0], b1, scale_pair) @@ -2127,7 +2181,10 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): if const_expr(wave_specialized_tdm): # With scale on the VGPR path, drop scale waves 2,3 from the active TDM # path -- unless ab-half-split repurposes them as the second A/B halves. - _active_wave_limit = 2 if (use_buffer_vgpr_scale and not use_ab_half_split) else 4 + # skip-scale-TDM drops the scale loader waves (2,3) from the active TDM + # path -- identical mechanism to the vgpr path -- so scale never DMAs. + _drop_scale_waves = (use_buffer_vgpr_scale and not use_ab_half_split) or use_skip_scale_tdm + _active_wave_limit = 2 if _drop_scale_waves else 4 active_pred_const = arith.select(tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0)) def _select4(values): @@ -2233,6 +2290,29 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): _bvs_ra = [_v for (_a, _b) in _bvs_pf for _v in _a] _bvs_rb = [_v for (_a, _b) in _bvs_pf for _v in _b] + if const_expr(use_skip_scale_tdm): + # Scale TDM is skipped (loader waves 2,3 dropped above). Pre-fill every + # scale LDS stage with a fixed E8M0 byte (0x7F = 1.0) so the unchanged + # downstream ds_load scale reads see a defined, constant value. The + # _pipeline_fence below issues the workgroup barrier that makes these + # writes visible before the first scale read. Threads cooperate: each + # lane writes a strided set of 16B chunks (one ds_store_b128 each). + _scale_fill = vector.full((4,), 0x7F7F7F7F, fx.Int32) + for i in range_constexpr(num_buffers): + for _smem, _sbytes in ( + (stages_as_mem[i], lds_a_scale_bytes), + (stages_bs_mem[i], lds_b_scale_bytes), + ): + _nchunks = _sbytes // 16 + _rounds = (_nchunks + block_threads - 1) // block_threads + for _r in range_constexpr(_rounds): + _chunk = tx + arith.index(_r * block_threads) + _pred = arith.cmpi(arith.CmpIPredicate.ult, _chunk, arith.index(_nchunks)) + _if = scf.IfOp(_pred) + with ir.InsertionPoint(_if.then_block): + lds_store_b128(_smem, _chunk * arith.index(8), _scale_fill) + scf.YieldOp([]) + _pipeline_fence(outstanding=TDM_LOADS_PER_STEP * (num_buffers - 2)) # Main loop — acc_mixed style: fence at top, TDM_load mid-compute. diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 74539ff10..50f6da69b 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -331,6 +331,9 @@ def _run_mxscale_gemm_test( split_k=1, b_streaming=False, scale_load_path="tdm", + hot_loop_sched_mode="default", + hot_loop_manual_ds=8, + hot_loop_manual_mfma=8, return_launch_fn=False, ): """Unified test body for FP4 and FP8.""" @@ -447,6 +450,9 @@ def _run_mxscale_gemm_test( expert_sched_mode=expert_sched_mode, b_streaming=b_streaming, scale_load_path=scale_load_path, + hot_loop_sched_mode=hot_loop_sched_mode, + hot_loop_manual_ds=hot_loop_manual_ds, + hot_loop_manual_mfma=hot_loop_manual_mfma, ) # Keep 2D — dynamic_layout=True packs shape as i32; flattening overflows for M*K >= 2^31. @@ -824,6 +830,51 @@ def test_mxfp8_vgpr_scale_load(scale_load_path, cluster_m, cluster_n): ) +# IGLP hot-loop scheduling modes (FP8 deep-pipeline, scale=tdm). Covers the +# baseline 'default' plus the two clean-region drivers ('iglp' = iglp_opt(0), +# 'manual' = sched_group_barrier DS:MFMA template) at a few granularities in one +# test. All modes must stay numerically correct; perf ranking is a separate +# silicon-ATT exercise. Scope mirrors the deep-pipeline target: +# 256x256x128, warps 2x2, num_buffers=4, wave_specialized_tdm, cluster (2,2). +@pytest.mark.parametrize( + "hot_loop_sched_mode, hot_loop_manual_ds, hot_loop_manual_mfma", + [ + ("default", 8, 8), + ("iglp", 8, 8), + ("manual", 8, 8), + ("manual", 4, 4), + ("manual", 2, 1), + ], +) +def test_mxfp8_hot_loop_sched_modes(hot_loop_sched_mode, hot_loop_manual_ds, hot_loop_manual_mfma): + if str(get_rocm_arch()) != "gfx1250": + pytest.skip("requires gfx1250") + if "FFMLITE_TOPOLOGY" in os.environ or "AM_TOPOLOGY" in os.environ: + pytest.skip("cluster multicast not supported on simulator") + _run_mxscale_gemm_test( + "fp8", + 512, + 512, + 512, + 256, + 256, + 128, + 2, + 2, + num_buffers=4, + use_tdm_store=True, + out_dtype="bf16", + l2_prefetch_distance=2, + wave_specialized_tdm=True, + cluster_m=2, + cluster_n=2, + scale_load_path="tdm", + hot_loop_sched_mode=hot_loop_sched_mode, + hot_loop_manual_ds=hot_loop_manual_ds, + hot_loop_manual_mfma=hot_loop_manual_mfma, + ) + + @pytest.mark.parametrize( "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n", [ @@ -1252,6 +1303,9 @@ def _run_benchmark(args): atomic_barrier_enable=args.atomic_barrier_enable, b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, + hot_loop_sched_mode=args.hot_loop_sched_mode, + hot_loop_manual_ds=args.hot_loop_manual_ds, + hot_loop_manual_mfma=args.hot_loop_manual_mfma, ) compiled_exe = flyc.compile( @@ -1435,6 +1489,9 @@ def _run_graph_verify(args): atomic_barrier_enable=args.atomic_barrier_enable, b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, + hot_loop_sched_mode=args.hot_loop_sched_mode, + hot_loop_manual_ds=args.hot_loop_manual_ds, + hot_loop_manual_mfma=args.hot_loop_manual_mfma, ) c_flat = c_gpu.contiguous() @@ -1531,6 +1588,27 @@ def launch(): default="tdm", choices=["tdm", "vgpr", "vgpr_ab_split"], ) + parser.add_argument( + "--hot-loop-sched-mode", + type=str, + default="default", + choices=["default", "iglp", "manual"], + help="FP8 deep-pipeline hot-loop scheduling: 'default' (baseline, " + "byte-identical), 'iglp' (iglp_opt(0) MFMASmallGemmOpt), or 'manual' " + "(hand-emitted sched_group_barrier DS:MFMA template).", + ) + parser.add_argument( + "--hot-loop-manual-ds", + type=int, + default=8, + help="DS group size for --hot-loop-sched-mode manual.", + ) + parser.add_argument( + "--hot-loop-manual-mfma", + type=int, + default=8, + help="MFMA group size for --hot-loop-sched-mode manual.", + ) parser.add_argument("--disable-expert-sched-mode", dest="expert_sched_mode", action="store_false", default=True) parser.add_argument("--b-streaming", action="store_true", default=False) parser.add_argument( @@ -1603,4 +1681,7 @@ def launch(): expert_sched_mode=args.expert_sched_mode, b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, + hot_loop_sched_mode=args.hot_loop_sched_mode, + hot_loop_manual_ds=args.hot_loop_manual_ds, + hot_loop_manual_mfma=args.hot_loop_manual_mfma, ) From 48acb876286c9df86e5e69f8dd5458919b2f337b Mon Sep 17 00:00:00 2001 From: aoli26 Date: Tue, 2 Jun 2026 07:08:08 +0000 Subject: [PATCH 15/19] add cluster early timeout and rm experimental codes --- kernels/gemm_fp8fp4_gfx1250.py | 182 ++-------------------- python/flydsl/expr/rocdl/tdm_ops.py | 6 +- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 89 +---------- 3 files changed, 29 insertions(+), 248 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 91d914ba8..d71d85a76 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -11,9 +11,8 @@ import flydsl.compiler as flyc import flydsl.expr as fx from flydsl._mlir import ir -from flydsl._mlir.dialects import scf from flydsl.compiler.kernel_function import CompilationContext -from flydsl.expr import arith, buffer_ops, const_expr, gpu, idx2crd, range_constexpr, rocdl, tdm_ops, vector +from flydsl.expr import arith, buffer_ops, const_expr, gpu, idx2crd, range_constexpr, rocdl, tdm_ops from flydsl.expr.rocdl import cluster from flydsl.expr.typing import T from flydsl.runtime.device import get_rocm_arch as get_hip_arch @@ -23,7 +22,6 @@ get_lds_memref, issue_tdm_loads, lds_load_b128_raw, - lds_store_b128, pipeline_fence, pipeline_fence_signal, pipeline_fence_wait, @@ -72,9 +70,6 @@ def compile_mxscale_gemm( b_streaming: bool = False, scale_load_path: str = "tdm", fp8_schedule: str = "auto", - hot_loop_sched_mode: str = "default", - hot_loop_manual_ds: int = 8, - hot_loop_manual_mfma: int = 8, ): """Compile an MXFP4 or MXFP8 GEMM kernel with TDM async copy. @@ -294,38 +289,8 @@ def _align_up(value: int, align: int) -> int: # second halves of A/B, so all 4 waves share the A/B TDM (wave0=A0, wave1=B0, # wave2=A1, wave3=B1). Measured wall-neutral. use_ab_half_split = scale_load_path == "vgpr_ab_split" - # scale-disabled: perf-ceiling probe. Feed a constant E8M0=1.0 to the scaled - # WMMA instead of loading scale -> removes the scale buffer_loads (issue slots) - # AND the scale VGPR ring (register pressure), while keeping the scaled-WMMA op - # identical. Numerically correct only when the real scale is all 1.0 (fill 0.1). - # OPT-IN: FLYDSL_SCALE_DISABLED=1. - use_scale_disabled = os.environ.get("FLYDSL_SCALE_DISABLED", "0") == "1" - # skip-scale-TDM: scale/TDM-isolation control probe. Keep the ENTIRE downstream - # scale path identical (LDS layout + ds_load scale reads + scaled-WMMA op), but - # never deliver scale via TDM. Instead the scale LDS stages are pre-filled once - # with a fixed E8M0 byte before the hot loop. This isolates the cost of the - # scale TDM vmem->LDS traffic alone (vs FLYDSL_SCALE_DISABLED, which also removes - # the LDS reads). Numerically correct only when the real scale is all 1.0 - # (fill 0.1). OPT-IN: FLYDSL_SCALE_SKIP_TDM=1. Only valid on the TDM scale path - # with wave-specialized TDM (scale lives on loader waves 2,3, which are dropped). - use_skip_scale_tdm = os.environ.get("FLYDSL_SCALE_SKIP_TDM", "0") == "1" - if use_skip_scale_tdm: - if scale_load_path != "tdm": - raise ValueError( - "FLYDSL_SCALE_SKIP_TDM=1 only applies to scale_load_path='tdm', " - f"got {scale_load_path!r}" - ) - if not wave_specialized_tdm: - raise ValueError( - "FLYDSL_SCALE_SKIP_TDM=1 requires wave_specialized_tdm=True " - "(scale is carried on dedicated loader waves 2,3)" - ) - if use_scale_disabled: - raise ValueError( - "FLYDSL_SCALE_SKIP_TDM and FLYDSL_SCALE_DISABLED are mutually exclusive probes" - ) # The buffer_load->VGPR scale ring is built only when scale is actually loaded. - _bvs_active = use_buffer_vgpr_scale and not use_scale_disabled + _bvs_active = use_buffer_vgpr_scale if use_ref_segmented_lds_layout: # The A/B data pools are no longer packed into the same per-stage @@ -467,48 +432,6 @@ def _pick_compute_schedule_kind(): use_fp8_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT use_fp8_deep_pipeline_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE use_b_streaming_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING - # hot_loop_sched_mode: deep-pipeline hot-loop scheduling. - # "default" — original sched_barrier(0) fences + coarse hot_loop_scheduler. - # "iglp" — clean region (internal sched_barrier(0) suppressed), LLVM - # MFMASmallGemmOpt via iglp_opt(0). - # "manual" — clean region + hand-emitted sched_group_barrier template, - # granularity hot_loop_manual_ds:hot_loop_manual_mfma (default 8:8). - # Latency-hiding: a small DS group degenerates to "issue few ds -> wait(0) - # full-drain -> 1-2 WMMA" (exposes ds latency); larger DS+MFMA keeps ds in - # flight with partial drains + long WMMA bursts that hide it. Net win over - # "default" unconfirmed (AM underestimates ds latency) -> rank on silicon ATT. - _hl_modes = ("default", "iglp", "manual") - if hot_loop_sched_mode not in _hl_modes: - raise ValueError(f"hot_loop_sched_mode must be one of {_hl_modes}, got {hot_loop_sched_mode!r}") - _hl_iglp = hot_loop_sched_mode == "iglp" - _hl_manual = hot_loop_sched_mode == "manual" - _hl_clean_region = _hl_iglp or _hl_manual - if _hl_clean_region and not use_fp8_deep_pipeline_schedule: - raise ValueError( - f"hot_loop_sched_mode={hot_loop_sched_mode!r} currently requires the FP8 deep-pipeline schedule" - ) - if _hl_manual and (hot_loop_manual_ds < 1 or hot_loop_manual_mfma < 1): - raise ValueError("hot_loop_manual_ds / hot_loop_manual_mfma must be >= 1") - # Deep-pipeline explicit s_wait_dscnt control (clean-region iglp/manual only). - # The SP3 prototype (poc_kl mxfp8fp4gemm) drains dscnt ~once per K-iteration; - # FlyDSL's compute_tile_fp8_deep_pipeline hardcodes 5 explicit s_wait_dscnt per - # ks, which double as scheduling barriers that pin each ds_load right before its - # consuming WMMA. Those explicit drains prevent the MI scheduler from hoisting - # the next ks's ds_load prefetch to overlap the current ks's WMMA (the - # prototype's "dW dW" future-buffer interleave). In a clean region the LLVM - # auto-waitcnt pass re-derives correct dscnt waits after reorder, so dropping - # some/all explicit drains is correctness-safe and lets IGLP+auto-waitcnt - # approach the prototype's single-drain structure. - # keep - emit all explicit dscnt drains (default; byte-identical to today) - # single - one full drain (s_wait_dscnt 0) at the top of each ks, drop the - # other 4 (mirrors the prototype's 1x s_wait_dscnt 0x0 per K-iter) - # auto - drop every explicit dscnt drain; auto-waitcnt owns them entirely - _hl_dscnt_mode = os.environ.get("FLYDSL_HL_DSCNT_MODE", "keep") - if _hl_dscnt_mode not in ("keep", "single", "auto"): - raise ValueError(f"FLYDSL_HL_DSCNT_MODE must be keep/single/auto, got {_hl_dscnt_mode!r}") - if _hl_dscnt_mode != "keep" and not _hl_clean_region: - raise ValueError("FLYDSL_HL_DSCNT_MODE!=keep requires hot_loop_sched_mode in {iglp,manual}") - _hl_keep_mid_dscnt = _hl_dscnt_mode == "keep" use_ws_tdm_split_signal_overlap = ( wave_specialized_tdm and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) @@ -567,20 +490,6 @@ def kernel_mxscale_gemm( # Enable back-to-back WMMA issue (SCHED_MODE bit[4] = DISABLE_VALU_STALL) rocdl.disable_xdl_arb_stall() - # One-shot per clean scheduling region (iglp/manual): drive the LLVM - # PipelineSolver. iglp_opt(0) uses MFMASmallGemmOpt; manual emits the - # equivalent SchedGroup template by hand with repo-standard - # sched_group_barrier (tunable DS:MFMA). Over-provisioning is harmless. - _manual_pipeline_reps = num_buffers * n_accs - - def _emit_pipeline_region_hint(): - if const_expr(_hl_iglp): - rocdl.iglp_opt(0) - elif const_expr(_hl_manual): - for _grp in range_constexpr(_manual_pipeline_reps): - rocdl.sched_group_barrier(rocdl.mask_dsrd, hot_loop_manual_ds, 0) - rocdl.sched_group_barrier(rocdl.mask_mfma, hot_loop_manual_mfma, 0) - if const_expr(inst_prefetch): if rocdl.wave_id() == fx.Int32(0): rocdl.s_prefetch_inst_burst(num_pages=4) @@ -1518,8 +1427,6 @@ def load_b_pair(wn_pair, ks): ] def _load_a_scales(ks): - if const_expr(use_scale_disabled): - return [fx.Int32(0x7F7F7F7F) for _ in range_constexpr(wmma_m_rep)] if const_expr(use_buffer_vgpr_scale): if const_expr(pf_a_scales is not None): return pf_a_scales # prefetched (issued in the prior compute tile) @@ -1527,8 +1434,6 @@ def _load_a_scales(ks): return load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) def _load_b_scales(ks): - if const_expr(use_scale_disabled): - return [fx.Int32(0x7F7F7F7F) for _ in range_constexpr(b_scale_load_rep)] if const_expr(use_buffer_vgpr_scale): if const_expr(pf_b_scales is not None): return pf_b_scales @@ -1614,16 +1519,11 @@ def _prefetch_a1(): first_wait_keep = _two_pair_loads + 3 if const_expr(ks == 0 and a0_prefetch is not None): first_wait_keep += DS_LOADS_PER_A_FRAG * len(a0_prefetch) - if const_expr(_hl_dscnt_mode == "keep"): - rocdl.s_wait_dscnt(first_wait_keep) - elif const_expr(_hl_dscnt_mode == "single"): - rocdl.s_wait_dscnt(0) # prototype-style: one full drain per ks - # auto: emit nothing here; auto-waitcnt re-derives all dscnt waits + rocdl.s_wait_dscnt(first_wait_keep) emit_panel_2x2(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) if const_expr(ks == 0 and mid_compute_callback is not None): - if const_expr(not _hl_clean_region): - rocdl.sched_barrier(0) + rocdl.sched_barrier(0) mid_compute_callback() def _prefetch_b3(): @@ -1632,12 +1532,10 @@ def _prefetch_b3(): def _prefetch_a3(): a3_box[0] = load_a_pair(3, ks) - if const_expr(_hl_keep_mid_dscnt): - rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) + rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) emit_panel_2x2(0, 1, a0, b1, scale_pair, prefetch_after_first_row=_prefetch_b3) - if const_expr(_hl_keep_mid_dscnt): - rocdl.s_wait_dscnt(_fp8_pair_b_loads + 2) + rocdl.s_wait_dscnt(_fp8_pair_b_loads + 2) emit_panel_2x2(1, 0, a1_box[0], b0, scale_pair, prefetch_after_first_row=_prefetch_a3) def _prefetch_a2(): @@ -1648,26 +1546,22 @@ def _prefetch_a2(): emit_panel_2x2(0, 2, a0, b2, scale_pair, prefetch_after_first_row=_prefetch_a2) emit_panel_2x2_row(1, 2, 0, a1_box[0], b2, scale_pair) emit_panel_2x2_row(1, 2, 1, a1_box[0], b2, scale_pair) - if const_expr(_hl_keep_mid_dscnt): - rocdl.s_wait_dscnt(_pair_loads) + rocdl.s_wait_dscnt(_pair_loads) emit_panel_2x2(0, 3, a0, b3_box[0], scale_pair) emit_panel_2x2(1, 3, a1_box[0], b3_box[0], scale_pair) emit_panel_2x2(2, 0, a2_box[0], b0, scale_pair) if const_expr(is_last_ks and late_compute_callback is not None): - if const_expr(not _hl_clean_region): - rocdl.sched_barrier(0) + rocdl.sched_barrier(0) late_compute_callback() emit_panel_2x2(2, 1, a2_box[0], b1, scale_pair) - if const_expr(_hl_keep_mid_dscnt): - rocdl.s_wait_dscnt(0) + rocdl.s_wait_dscnt(0) emit_panel_2x2(3, 0, a3_box[0], b0, scale_pair) emit_panel_2x2(3, 1, a3_box[0], b1, scale_pair) if const_expr(is_last_ks and emit_filler is not None): - if const_expr(not _hl_clean_region): - rocdl.sched_barrier(0) + rocdl.sched_barrier(0) emit_filler() emit_panel_2x2(2, 2, a2_box[0], b2, scale_pair) @@ -2181,9 +2075,7 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): if const_expr(wave_specialized_tdm): # With scale on the VGPR path, drop scale waves 2,3 from the active TDM # path -- unless ab-half-split repurposes them as the second A/B halves. - # skip-scale-TDM drops the scale loader waves (2,3) from the active TDM - # path -- identical mechanism to the vgpr path -- so scale never DMAs. - _drop_scale_waves = (use_buffer_vgpr_scale and not use_ab_half_split) or use_skip_scale_tdm + _drop_scale_waves = use_buffer_vgpr_scale and not use_ab_half_split _active_wave_limit = 2 if _drop_scale_waves else 4 active_pred_const = arith.select(tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0)) @@ -2290,29 +2182,6 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): _bvs_ra = [_v for (_a, _b) in _bvs_pf for _v in _a] _bvs_rb = [_v for (_a, _b) in _bvs_pf for _v in _b] - if const_expr(use_skip_scale_tdm): - # Scale TDM is skipped (loader waves 2,3 dropped above). Pre-fill every - # scale LDS stage with a fixed E8M0 byte (0x7F = 1.0) so the unchanged - # downstream ds_load scale reads see a defined, constant value. The - # _pipeline_fence below issues the workgroup barrier that makes these - # writes visible before the first scale read. Threads cooperate: each - # lane writes a strided set of 16B chunks (one ds_store_b128 each). - _scale_fill = vector.full((4,), 0x7F7F7F7F, fx.Int32) - for i in range_constexpr(num_buffers): - for _smem, _sbytes in ( - (stages_as_mem[i], lds_a_scale_bytes), - (stages_bs_mem[i], lds_b_scale_bytes), - ): - _nchunks = _sbytes // 16 - _rounds = (_nchunks + block_threads - 1) // block_threads - for _r in range_constexpr(_rounds): - _chunk = tx + arith.index(_r * block_threads) - _pred = arith.cmpi(arith.CmpIPredicate.ult, _chunk, arith.index(_nchunks)) - _if = scf.IfOp(_pred) - with ir.InsertionPoint(_if.then_block): - lds_store_b128(_smem, _chunk * arith.index(8), _scale_fill) - scf.YieldOp([]) - _pipeline_fence(outstanding=TDM_LOADS_PER_STEP * (num_buffers - 2)) # Main loop — acc_mixed style: fence at top, TDM_load mid-compute. @@ -2366,11 +2235,7 @@ def _late_tdm_ws_split_signal(): _late_tdm_ws_fence_signal = _late_tdm_ws_split_signal a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) - if const_expr(_hl_clean_region): - if const_expr(buf_idx == 0): - _emit_pipeline_region_hint() - else: - rocdl.sched_barrier(0) + rocdl.sched_barrier(0) # Consume scale prefetched _bvs_D K-tiles ago; issue the # K-tile +_bvs_D prefetch now (overlaps this tile's WMMAs). # NOTE: must stay AFTER the fence; issuing the scale @@ -2403,8 +2268,7 @@ def _late_tdm_ws_split_signal(): pf_b_scales=_cur_b, ) cur_addr_lo = addr_box[0] - if const_expr(not _hl_clean_region): - hot_loop_scheduler_scheduled() + hot_loop_scheduler_scheduled() if const_expr(_bvs_active): _bvs_yield = _ring_a + _ring_b @@ -2459,11 +2323,7 @@ def _mid_tdm_nws( _l2_prefetch(_k_off) a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) - if const_expr(_hl_clean_region): - if const_expr(buf_idx == 0): - _emit_pipeline_region_hint() - else: - rocdl.sched_barrier(0) + rocdl.sched_barrier(0) accs_in = compute_tile_scheduled( accs_in, stages_a_idx[buf_idx], @@ -2477,8 +2337,7 @@ def _mid_tdm_nws( cur_lo_b = addr_boxes[1][0] cur_lo_as = addr_boxes[2][0] cur_lo_bs = addr_boxes[3][0] - if const_expr(not _hl_clean_region): - hot_loop_scheduler_scheduled() + hot_loop_scheduler_scheduled() results = yield list(accs_in) + [cur_lo_a, cur_lo_b, cur_lo_as, cur_lo_bs] @@ -2576,10 +2435,7 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): _tail_mid_cb = _tail_mid_nws a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) - if const_expr(_hl_clean_region): - _emit_pipeline_region_hint() - else: - rocdl.sched_barrier(0) + rocdl.sched_barrier(0) accs = compute_tile_scheduled( accs, stages_a_idx[_compute_stage], @@ -2600,8 +2456,7 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): addr_lo_as = _tail_ab[2][0] addr_lo_bs = _tail_ab[3][0] - if const_expr(not _hl_clean_region): - hot_loop_scheduler_scheduled() + hot_loop_scheduler_scheduled() accs = finalize_acc_layout(accs) @@ -2647,9 +2502,6 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): b_streaming, scale_load_path, fp8_schedule, - hot_loop_sched_mode, - hot_loop_manual_ds, - hot_loop_manual_mfma, ) @flyc.jit diff --git a/python/flydsl/expr/rocdl/tdm_ops.py b/python/flydsl/expr/rocdl/tdm_ops.py index 8c727218a..5b2bcae63 100644 --- a/python/flydsl/expr/rocdl/tdm_ops.py +++ b/python/flydsl/expr/rocdl/tdm_ops.py @@ -368,12 +368,16 @@ def make_tensor_descriptor_2d( # sgpr0: config bitfields _abe = 1 if atomic_barrier_enable else 0 + # early_timeout (bit 21) is a multicast-load knob: 1 = GL1 returns to the + # requesters present when GL2 data arrives (latecomers re-broadcast); 0 = + # standard (wider merge) timeout. keep stores at 0 (store cannot multicast). + _early_timeout = 0 if for_store else 1 g1_s0_upper = ( (data_size_code << 16) # data_size [17:16] | (_abe << 18) # atomic_barrier_enable | (0 << 19) # iterate_enable | (pad_enable << 20) # pad_enable - | (0 << 21) # early_timeout + | (_early_timeout << 21) # early_timeout | (enc_interval << 22) # pad_interval [24:22] | (enc_amount << 25) # pad_amount [31:25] ) diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 50f6da69b..d1dfa9ade 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -331,9 +331,6 @@ def _run_mxscale_gemm_test( split_k=1, b_streaming=False, scale_load_path="tdm", - hot_loop_sched_mode="default", - hot_loop_manual_ds=8, - hot_loop_manual_mfma=8, return_launch_fn=False, ): """Unified test body for FP4 and FP8.""" @@ -450,9 +447,6 @@ def _run_mxscale_gemm_test( expert_sched_mode=expert_sched_mode, b_streaming=b_streaming, scale_load_path=scale_load_path, - hot_loop_sched_mode=hot_loop_sched_mode, - hot_loop_manual_ds=hot_loop_manual_ds, - hot_loop_manual_mfma=hot_loop_manual_mfma, ) # Keep 2D — dynamic_layout=True packs shape as i32; flattening overflows for M*K >= 2^31. @@ -496,7 +490,13 @@ def _run_mxscale_gemm_test( diff = (c_out_f - ref_f).abs() print(f"Abs diff: max={diff.max():.4f}, mean={diff.mean():.4f}") - cos_sim = torch.nn.functional.cosine_similarity(c_out_f.flatten().unsqueeze(0), ref_f.flatten().unsqueeze(0)).item() + # Compute cosine in float64: for large M/N/K with large E8M0 scales the values + # (and their squares) overflow float32's accurate-summation range, so an fp32 + # cosine reduction saturates and can print values outside [-1, 1]. fp64 keeps + # the diagnostic meaningful. (Pass/fail is gated by assert_close below, not this.) + cos_sim = torch.nn.functional.cosine_similarity( + c_out_f.flatten().unsqueeze(0).double(), ref_f.flatten().unsqueeze(0).double() + ).item() print(f"Cosine similarity: {cos_sim:.6f}") # Tolerances: FP4 is exact; FP8/A8W4 have FP accumulation error @@ -830,51 +830,6 @@ def test_mxfp8_vgpr_scale_load(scale_load_path, cluster_m, cluster_n): ) -# IGLP hot-loop scheduling modes (FP8 deep-pipeline, scale=tdm). Covers the -# baseline 'default' plus the two clean-region drivers ('iglp' = iglp_opt(0), -# 'manual' = sched_group_barrier DS:MFMA template) at a few granularities in one -# test. All modes must stay numerically correct; perf ranking is a separate -# silicon-ATT exercise. Scope mirrors the deep-pipeline target: -# 256x256x128, warps 2x2, num_buffers=4, wave_specialized_tdm, cluster (2,2). -@pytest.mark.parametrize( - "hot_loop_sched_mode, hot_loop_manual_ds, hot_loop_manual_mfma", - [ - ("default", 8, 8), - ("iglp", 8, 8), - ("manual", 8, 8), - ("manual", 4, 4), - ("manual", 2, 1), - ], -) -def test_mxfp8_hot_loop_sched_modes(hot_loop_sched_mode, hot_loop_manual_ds, hot_loop_manual_mfma): - if str(get_rocm_arch()) != "gfx1250": - pytest.skip("requires gfx1250") - if "FFMLITE_TOPOLOGY" in os.environ or "AM_TOPOLOGY" in os.environ: - pytest.skip("cluster multicast not supported on simulator") - _run_mxscale_gemm_test( - "fp8", - 512, - 512, - 512, - 256, - 256, - 128, - 2, - 2, - num_buffers=4, - use_tdm_store=True, - out_dtype="bf16", - l2_prefetch_distance=2, - wave_specialized_tdm=True, - cluster_m=2, - cluster_n=2, - scale_load_path="tdm", - hot_loop_sched_mode=hot_loop_sched_mode, - hot_loop_manual_ds=hot_loop_manual_ds, - hot_loop_manual_mfma=hot_loop_manual_mfma, - ) - - @pytest.mark.parametrize( "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n", [ @@ -1303,9 +1258,6 @@ def _run_benchmark(args): atomic_barrier_enable=args.atomic_barrier_enable, b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, - hot_loop_sched_mode=args.hot_loop_sched_mode, - hot_loop_manual_ds=args.hot_loop_manual_ds, - hot_loop_manual_mfma=args.hot_loop_manual_mfma, ) compiled_exe = flyc.compile( @@ -1489,9 +1441,6 @@ def _run_graph_verify(args): atomic_barrier_enable=args.atomic_barrier_enable, b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, - hot_loop_sched_mode=args.hot_loop_sched_mode, - hot_loop_manual_ds=args.hot_loop_manual_ds, - hot_loop_manual_mfma=args.hot_loop_manual_mfma, ) c_flat = c_gpu.contiguous() @@ -1588,27 +1537,6 @@ def launch(): default="tdm", choices=["tdm", "vgpr", "vgpr_ab_split"], ) - parser.add_argument( - "--hot-loop-sched-mode", - type=str, - default="default", - choices=["default", "iglp", "manual"], - help="FP8 deep-pipeline hot-loop scheduling: 'default' (baseline, " - "byte-identical), 'iglp' (iglp_opt(0) MFMASmallGemmOpt), or 'manual' " - "(hand-emitted sched_group_barrier DS:MFMA template).", - ) - parser.add_argument( - "--hot-loop-manual-ds", - type=int, - default=8, - help="DS group size for --hot-loop-sched-mode manual.", - ) - parser.add_argument( - "--hot-loop-manual-mfma", - type=int, - default=8, - help="MFMA group size for --hot-loop-sched-mode manual.", - ) parser.add_argument("--disable-expert-sched-mode", dest="expert_sched_mode", action="store_false", default=True) parser.add_argument("--b-streaming", action="store_true", default=False) parser.add_argument( @@ -1681,7 +1609,4 @@ def launch(): expert_sched_mode=args.expert_sched_mode, b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, - hot_loop_sched_mode=args.hot_loop_sched_mode, - hot_loop_manual_ds=args.hot_loop_manual_ds, - hot_loop_manual_mfma=args.hot_loop_manual_mfma, ) From ed4af4ba5c6bfa0b4bba1be588d06f38708f9942 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Tue, 2 Jun 2026 07:56:44 +0000 Subject: [PATCH 16/19] Reject VGPR scale loads outside FP8 deep pipeline --- kernels/gemm_fp8fp4_gfx1250.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index d71d85a76..0c617441d 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -432,6 +432,10 @@ def _pick_compute_schedule_kind(): use_fp8_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT use_fp8_deep_pipeline_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE use_b_streaming_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING + if use_buffer_vgpr_scale and not use_fp8_deep_pipeline_schedule: + raise ValueError( + f"scale_load_path={scale_load_path!r} is only supported with the FP8 deep-pipeline schedule" + ) use_ws_tdm_split_signal_overlap = ( wave_specialized_tdm and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) From 2032b724d29669a46842cc530935ef06d7bcb189 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Tue, 2 Jun 2026 08:11:18 +0000 Subject: [PATCH 17/19] Fix gfx1250 GEMM Python formatting --- kernels/gemm_fp8fp4_gfx1250.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 0c617441d..1119e78e4 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -433,9 +433,7 @@ def _pick_compute_schedule_kind(): use_fp8_deep_pipeline_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE use_b_streaming_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING if use_buffer_vgpr_scale and not use_fp8_deep_pipeline_schedule: - raise ValueError( - f"scale_load_path={scale_load_path!r} is only supported with the FP8 deep-pipeline schedule" - ) + raise ValueError(f"scale_load_path={scale_load_path!r} is only supported with the FP8 deep-pipeline schedule") use_ws_tdm_split_signal_overlap = ( wave_specialized_tdm and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) From c91972b3cc481ffc569e5e261740901b1947473f Mon Sep 17 00:00:00 2001 From: aoli26 Date: Tue, 2 Jun 2026 09:07:45 +0000 Subject: [PATCH 18/19] Address review feedback --- kernels/gemm_fp8fp4_gfx1250.py | 49 ++++++++++++++++++----- python/flydsl/expr/rocdl/inline_asm.py | 18 --------- python/flydsl/expr/rocdl/tdm_ops.py | 10 +++-- tests/kernels/benchmark_common.py | 2 +- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 4 +- 5 files changed, 49 insertions(+), 34 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 1119e78e4..6f20afcb3 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -6,6 +6,7 @@ """ import functools +import inspect import os import flydsl.compiler as flyc @@ -30,6 +31,30 @@ ) from kernels.pipeline_utils import make_tail_plan, tdm_epilogue_fence_threshold_bytes + +def _s_prefetch_inst_burst(num_pages: int, page_bytes: int = 4096): + """gfx1250: prefetch ``num_pages`` × 4 KB of instructions ahead of PC. + + Caller must keep ``num_pages * page_bytes`` within shader bounds; over-reach + page-faults. + """ + from flydsl._mlir.dialects import llvm as _llvm + + lines = [f"s_prefetch_inst_pc_rel {pg * page_bytes}, null, 31" for pg in range(num_pages)] + _llvm.inline_asm(None, [], "\n".join(lines), "", has_side_effects=True) + + +# compatible with no early_timeout descriptor +_TDM_HAS_EARLY_TIMEOUT = "early_timeout" in inspect.signature(tdm_ops.make_tensor_descriptor_2d).parameters + + +def _make_tdm_desc(*, early_timeout=False, **kwargs): + """Build a TDM descriptor, applying early_timeout only when supports it.""" + if _TDM_HAS_EARLY_TIMEOUT: + kwargs["early_timeout"] = early_timeout + return tdm_ops.make_tensor_descriptor_2d(**kwargs) + + # Common constants WMMA_M, WMMA_N, WMMA_K = 16, 16, 128 WAVE_SIZE = 32 @@ -284,7 +309,7 @@ def _align_up(value: int, align: int) -> int: ) # Scale prefetch depth (K-tiles ahead) for the buffer->VGPR path. D=1 is the # sweet spot; D=2 doubles scale VGPRs -> spill + ~18% regression. - _bvs_D = int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_DEPTH", "1")) + _bvs_D = max(1, int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_DEPTH", "1"))) # ab_half_split: repurpose the (under "vgpr") idle scale waves 2,3 as the # second halves of A/B, so all 4 waves share the A/B TDM (wave0=A0, wave1=B0, # wave2=A1, wave3=B1). Measured wall-neutral. @@ -494,7 +519,7 @@ def kernel_mxscale_gemm( if const_expr(inst_prefetch): if rocdl.wave_id() == fx.Int32(0): - rocdl.s_prefetch_inst_burst(num_pages=4) + _s_prefetch_inst_burst(num_pages=4) tx = gpu.thread_id("x") bx = gpu.block_id("x") @@ -566,7 +591,7 @@ def _bvs_prefetch(k_base): def make_desc_a(memref, k_base): k_packed_off = k_base / arith.index(PACK_FACTOR_A) - return tdm_ops.make_tensor_descriptor_2d( + return _make_tdm_desc( global_ptr=arg_a, lds_memref=memref, global_offset=(blk_m, k_packed_off), @@ -579,11 +604,12 @@ def make_desc_a(memref, k_base): num_warps=tdm_desc_num_warps, workgroup_mask=a_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, + early_timeout=True, ) def make_desc_b(memref, k_base): k_packed_off = k_base / arith.index(PACK_FACTOR_B) - return tdm_ops.make_tensor_descriptor_2d( + return _make_tdm_desc( global_ptr=arg_b, lds_memref=memref, global_offset=(blk_n / arith.index(16), k_packed_off * arith.index(16)), @@ -596,12 +622,13 @@ def make_desc_b(memref, k_base): num_warps=tdm_desc_num_warps, workgroup_mask=b_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, + early_timeout=True, ) def make_desc_a_half(memref, k_base, m_half: int): row_start = m_half * ab_split_a_rows k_packed_off = k_base / arith.index(PACK_FACTOR_A) - return tdm_ops.make_tensor_descriptor_2d( + return _make_tdm_desc( global_ptr=arg_a, lds_memref=memref, global_offset=(blk_m + arith.index(row_start), k_packed_off), @@ -615,12 +642,13 @@ def make_desc_a_half(memref, k_base, m_half: int): workgroup_mask=a_mcast_mask, lds_byte_offset=arith.index(row_start * lds_a_stride_bytes), atomic_barrier_enable=atomic_barrier_enable, + early_timeout=True, ) def make_desc_b_half(memref, k_base, n_half: int): group_start = n_half * ab_split_b_groups k_packed_off = k_base / arith.index(PACK_FACTOR_B) - return tdm_ops.make_tensor_descriptor_2d( + return _make_tdm_desc( global_ptr=arg_b, lds_memref=memref, global_offset=(blk_n / arith.index(16) + arith.index(group_start), k_packed_off * arith.index(16)), @@ -634,13 +662,14 @@ def make_desc_b_half(memref, k_base, n_half: int): workgroup_mask=b_mcast_mask, lds_byte_offset=arith.index(group_start * packed_tile_k_b * 16), atomic_barrier_enable=atomic_barrier_enable, + early_timeout=True, ) def make_desc_as(memref, k_base): k_scale_off = k_base / arith.index(SCALE_BLOCK) outer_off = blk_m / arith.index(wmma_m_rep) inner_off = k_scale_off * arith.index(wmma_m_rep) - return tdm_ops.make_tensor_descriptor_2d( + return _make_tdm_desc( global_ptr=arg_a_scale, lds_memref=memref, global_offset=(outer_off, inner_off), @@ -653,13 +682,14 @@ def make_desc_as(memref, k_base): num_warps=tdm_desc_num_warps, workgroup_mask=a_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, + early_timeout=True, ) def make_desc_bs(memref, k_base): k_scale_off = k_base / arith.index(SCALE_BLOCK) outer_off = blk_n / arith.index(b_scale_load_rep) inner_off = k_scale_off * arith.index(b_scale_load_rep) - return tdm_ops.make_tensor_descriptor_2d( + return _make_tdm_desc( global_ptr=arg_b_scale, lds_memref=memref, global_offset=(outer_off, inner_off), @@ -672,6 +702,7 @@ def make_desc_bs(memref, k_base): num_warps=tdm_desc_num_warps, workgroup_mask=b_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, + early_timeout=True, ) if const_expr(wave_specialized_tdm): @@ -2015,7 +2046,7 @@ def _l2_prefetch(k_base): d_warp_off_sgpr = d_warp_linear_sgpr * arith.index(warp_d_bytes) + arith.index(d_output_off) warp_m_off_sgpr = wave_m_sgpr * arith.index(warp_tile_m) warp_n_off_sgpr = wave_n_sgpr * arith.index(warp_tile_n) - d_desc = tdm_ops.make_tensor_descriptor_2d( + d_desc = _make_tdm_desc( global_ptr=arg_c, lds_memref=d_lds_base_ptr, global_offset=(blk_m + warp_m_off_sgpr, blk_n + warp_n_off_sgpr), diff --git a/python/flydsl/expr/rocdl/inline_asm.py b/python/flydsl/expr/rocdl/inline_asm.py index f16ffacb5..5cb7fb1a1 100644 --- a/python/flydsl/expr/rocdl/inline_asm.py +++ b/python/flydsl/expr/rocdl/inline_asm.py @@ -65,21 +65,3 @@ def cvt_pk_bf16_f32(src_a_f32, src_b_f32): "=v,v,v", has_side_effects=False, ) - - -def s_prefetch_inst_burst(num_pages: int = 3, page_bytes: int = 4096): - """gfx1250: prefetch ``num_pages`` × 4 KB of instructions ahead of PC. - - Caller must keep ``num_pages * page_bytes`` within shader bounds; over-reach - page-faults. - """ - from ..._mlir.dialects import llvm as _llvm - - lines = [f"s_prefetch_inst_pc_rel {pg * page_bytes}, null, 31" for pg in range(num_pages)] - _llvm.inline_asm( - None, - [], - "\n".join(lines), - "", - has_side_effects=True, - ) diff --git a/python/flydsl/expr/rocdl/tdm_ops.py b/python/flydsl/expr/rocdl/tdm_ops.py index 5b2bcae63..f2644d315 100644 --- a/python/flydsl/expr/rocdl/tdm_ops.py +++ b/python/flydsl/expr/rocdl/tdm_ops.py @@ -215,6 +215,7 @@ def make_tensor_descriptor_2d( lds_byte_offset=None, for_store: bool = False, atomic_barrier_enable: bool = False, + early_timeout: bool = False, ) -> TDMDescriptor2D: """Build a 2D TDM descriptor for tensor_load_to_lds_d2. @@ -260,6 +261,10 @@ def make_tensor_descriptor_2d( relying on TDM atomic-barrier semantics; this helper keeps the encoded atomic-barrier address at zero, so all participating waves must agree on that protocol. + early_timeout: Set the descriptor's early-timeout bit [21]. This is a + multicast-load knob (1 = GL1 returns to the requesters + present when GL2 data arrives, latecomers re-broadcast; + default 0 = standard wider-merge timeout). Returns: TDMDescriptor2D with dgroup0 and dgroup1 ready for tensor_load_2d. @@ -368,10 +373,7 @@ def make_tensor_descriptor_2d( # sgpr0: config bitfields _abe = 1 if atomic_barrier_enable else 0 - # early_timeout (bit 21) is a multicast-load knob: 1 = GL1 returns to the - # requesters present when GL2 data arrives (latecomers re-broadcast); 0 = - # standard (wider merge) timeout. keep stores at 0 (store cannot multicast). - _early_timeout = 0 if for_store else 1 + _early_timeout = 1 if early_timeout else 0 g1_s0_upper = ( (data_size_code << 16) # data_size [17:16] | (_abe << 18) # atomic_barrier_enable diff --git a/tests/kernels/benchmark_common.py b/tests/kernels/benchmark_common.py index 06192c0f0..7cd238b98 100644 --- a/tests/kernels/benchmark_common.py +++ b/tests/kernels/benchmark_common.py @@ -469,7 +469,7 @@ def bench_kernel_us(run_fn, warmup=10, iters=50, flush_l2=True, prep_fn=None): torch.cuda.synchronize() if flush_buf is None and prep_fn is None: - # Single event pair preserves back-to-back launch pipelining. + # Single event pair preserves back-to-back launch pipelining (returns mean latency). start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index d1dfa9ade..fb202ebd2 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -785,7 +785,7 @@ def test_b_streaming_with_wave_spec_tdm(data_format, M, N, K, tile_m, tile_n, ti @pytest.mark.parametrize("num_buffers", [2, 3]) @pytest.mark.parametrize("use_tdm_store", [True, False]) @pytest.mark.parametrize("use_scale_opsel", [False, True]) -def test_mxfp8_wave_spec_scale_load_paths(num_buffers, use_tdm_store, use_scale_opsel): +def test_mxfp8_wave_spec_scale_load_tdm(num_buffers, use_tdm_store, use_scale_opsel): _run_mxscale_gemm_test( "fp8", 128, @@ -1120,7 +1120,7 @@ def _bench_kernel_us(run_fn, warmup=10, iters=50, flush_l2=True, prep_fn=None): torch.cuda.synchronize() if flush_buf is None and prep_fn is None: - # Single event pair preserves back-to-back launch pipelining. + # Single event pair preserves back-to-back launch pipelining (returns mean latency). start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() From 1a90fd8cbc2f5f1fdac3660d6d561778553f5a4a Mon Sep 17 00:00:00 2001 From: aoli26 Date: Tue, 2 Jun 2026 16:16:33 +0000 Subject: [PATCH 19/19] [gemm_fp8fp4_gfx1250] Fix split-K precision and use global atomicrmw Split-K accumulates partial K-sums across workgroups via atomic add into C. Two fixes: - Correctness: route the atomic accumulation through llvm.atomicrmw fadd on a global (addrspace 1) pointer with syncscope("agent") instead of buffer atomics. - Precision: tests run split-K with f32 accumulation and convert to the requested bf16/f16 on the host, avoiding compounded rounding from half-precision atomics. Adds test_mxfp8_gemm_splitk over split_k in {2,4,6,8}. --- kernels/gemm_fp8fp4_gfx1250.py | 29 ++++++++++-- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 57 +++++++++++++++++++---- 2 files changed, 71 insertions(+), 15 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 6f20afcb3..0270893ce 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -12,6 +12,7 @@ import flydsl.compiler as flyc import flydsl.expr as fx from flydsl._mlir import ir +from flydsl._mlir.dialects import fly, llvm from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import arith, buffer_ops, const_expr, gpu, idx2crd, range_constexpr, rocdl, tdm_ops from flydsl.expr.rocdl import cluster @@ -587,7 +588,10 @@ def _bvs_prefetch(k_base): n_stride = arith.index(N) c_nrec = m_idx * n_stride * arith.index(elem_bytes_d) c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) - zero_i32 = fx.Int32(0) + c_global_ptr_type = ir.Type.parse("!llvm.ptr<1>") + c_global_base_i64 = llvm.PtrToIntOp( + T.i64, fly.extract_aligned_pointer_as_index(c_global_ptr_type, arg_c.__extract_to_ir_values__()[0]) + ).result def make_desc_a(memref, k_base): k_packed_off = k_base / arith.index(PACK_FACTOR_A) @@ -1920,13 +1924,28 @@ def epilogue_lds_stores(final_accs, d_buf, d_base): imm = m_off * _lds_d_stride_elems + wn * _n_col_d_elems store_acc_vec8_to_lds(d_buf, d_base, imm, sub8, out_elem=_out_elem_local) + def _atomic_fadd_global(val, byte_off): + # Device-scoped, relaxed atomic add into C at c_global_base_i64 + byte_off. + addr_i64 = llvm.AddOp( + c_global_base_i64, arith.index_cast(T.i64, byte_off), llvm.IntegerOverflowFlags(0) + ).result + ptr = llvm.IntToPtrOp(c_global_ptr_type, addr_i64).result + llvm.AtomicRMWOp( + llvm.AtomicBinOp.fadd, + ptr, + val.ir_value(), + llvm.AtomicOrdering.monotonic, + syncscope="agent", + alignment=4, + ) + def _atomic_add_acc_vec8_to_buffer(acc_vec8, addr): if const_expr(_bf16_out): h_vec = fx.Vector(arith.trunc_f(T.vec(8, _out_elem_local), acc_vec8)) for pair in range_constexpr(4): pair_vec = fx.Vector.from_elements([h_vec[pair * 2], h_vec[pair * 2 + 1]]) - byte_off = arith.index_cast(T.i32, addr + arith.index(pair * 4)) - rocdl.raw_ptr_buffer_atomic_fadd(pair_vec, c_rsrc, byte_off, zero_i32, zero_i32) + byte_off = addr + arith.index(pair * 4) + _atomic_fadd_global(pair_vec, byte_off) return 1 acc_vec = fx.Vector(acc_vec8) @@ -1934,8 +1953,8 @@ def _atomic_add_acc_vec8_to_buffer(acc_vec8, addr): base_addr = addr[half] if isinstance(addr, (list, tuple)) else addr for vi in range_constexpr(4): val = acc_vec[half * 4 + vi] - byte_off = arith.index_cast(T.i32, (base_addr + arith.index(vi)) * arith.index(4)) - rocdl.raw_ptr_buffer_atomic_fadd(val, c_rsrc, byte_off, zero_i32, zero_i32) + byte_off = (base_addr + arith.index(vi)) * arith.index(4) + _atomic_fadd_global(val, byte_off) return 2 def epilogue_atomic_adds(final_accs, addrs): diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index fb202ebd2..382a73f0b 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -364,6 +364,10 @@ def _run_mxscale_gemm_test( _dtype_map = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16} torch_out_dtype = _dtype_map[out_dtype] + # Split-K accumulates across workgroups in fp32; half outputs are converted after. + kernel_out_dtype = "f32" if (split_k > 1 and out_dtype in ("bf16", "f16")) else out_dtype + torch_kernel_dtype = _dtype_map[kernel_out_dtype] + torch.manual_seed(0) fmt_name = "A8W4" if is_a8w4 else ("MXFP4" if is_fp4 else "MXFP8") @@ -421,7 +425,7 @@ def _run_mxscale_gemm_test( b_gpu = b.cuda() as_gpu = a_scale.cuda() bs_gpu = b_scale.cuda() - c_gpu = torch.zeros(padded_m, padded_n, dtype=torch_out_dtype, device="cuda") + c_gpu = torch.zeros(padded_m, padded_n, dtype=torch_kernel_dtype, device="cuda") launch_fn = compile_mxscale_gemm( data_format=data_format, @@ -439,7 +443,7 @@ def _run_mxscale_gemm_test( cluster_m=cluster_m, cluster_n=cluster_n, use_tdm_store=use_tdm_store, - out_dtype=out_dtype, + out_dtype=kernel_out_dtype, inst_prefetch=inst_prefetch, wave_specialized_tdm=wave_specialized_tdm, split_k=split_k, @@ -469,7 +473,8 @@ def _run_mxscale_gemm_test( ) torch.cuda.synchronize() - c_out = c_gpu[:M, :N].cpu() + # Convert the fp32 split-K accumulation back to the requested half dtype. + c_out = c_gpu[:M, :N].to(torch_out_dtype).cpu() print( f"Out stats: min={c_out.float().min():.2f}, max={c_out.float().max():.2f}, " @@ -667,6 +672,32 @@ def test_mxfp8_gemm( ) +@pytest.mark.parametrize("split_k", [2, 4, 6, 8]) +@pytest.mark.parametrize("out_dtype", ["f32", "bf16"]) +def test_mxfp8_gemm_splitk(split_k, out_dtype): + """FP8 split-K: split_k workgroups accumulate partial K-sums into C via atomic add. + + Exercises the atomic epilogue path (use_tdm_store=False). K=2048/tile_k=128 gives + every split_k value >= 2 local K-tiles (needed for double buffering). + """ + _run_mxscale_gemm_test( + "fp8", + 128, + 256, + 2048, + 128, + 256, + 128, + 2, + 4, + num_buffers=2, + use_tdm_store=False, + out_dtype=out_dtype, + l2_prefetch_distance=2, + split_k=split_k, + ) + + @pytest.mark.parametrize( "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", [ @@ -1181,8 +1212,11 @@ def _run_benchmark(args): is_fp4 = data_format == "fp4" is_a8w4 = data_format == "a8w4" _dtype_map = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16} - torch_out_dtype = _dtype_map[args.out_dtype] - elem_bytes_d = 2 if args.out_dtype in ("bf16", "f16") else 4 + # split_k>1 accumulates partial K-sums in fp32 for precision; bf16/f16 atomics are + # supported but compound rounding error, so we run f32 and convert back on the host. + kernel_out_dtype = "f32" if (args.split_k > 1 and args.out_dtype in ("bf16", "f16")) else args.out_dtype + torch_kernel_dtype = _dtype_map[kernel_out_dtype] + elem_bytes_d = 2 if kernel_out_dtype in ("bf16", "f16") else 4 fmt_name = "A8W4" if is_a8w4 else ("MXFP4" if is_fp4 else "MXFP8") print("=" * 72) @@ -1225,7 +1259,7 @@ def _run_benchmark(args): b_gpu = b.cuda() as_gpu = a_scale.cuda() bs_gpu = b_scale.cuda() - c_gpu = torch.zeros(padded_m, padded_n, dtype=torch_out_dtype, device="cuda") + c_gpu = torch.zeros(padded_m, padded_n, dtype=torch_kernel_dtype, device="cuda") print("\n[1/3] Compiling kernel...") t0 = time.perf_counter() @@ -1249,7 +1283,7 @@ def _run_benchmark(args): cluster_m=args.cluster_m, cluster_n=args.cluster_n, use_tdm_store=use_tdm_store, - out_dtype=args.out_dtype, + out_dtype=kernel_out_dtype, inst_prefetch=args.inst_prefetch, wave_specialized_tdm=args.wave_spec_tdm, split_k=args.split_k, @@ -1412,8 +1446,11 @@ def _run_graph_verify(args): b_gpu = b.cuda() as_gpu = a_scale.cuda() bs_gpu = b_scale.cuda() - torch_out_dtype = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16}[args.out_dtype] - c_gpu = torch.zeros(padded_m, padded_n, dtype=torch_out_dtype, device="cuda") + _dtype_map = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16} + # split_k>1 accumulates partial K-sums in fp32 for precision; bf16/f16 atomics are + # supported but compound rounding error, so we run f32 and convert back on the host. + kernel_out_dtype = "f32" if (args.split_k > 1 and args.out_dtype in ("bf16", "f16")) else args.out_dtype + c_gpu = torch.zeros(padded_m, padded_n, dtype=_dtype_map[kernel_out_dtype], device="cuda") use_tdm_store = not args.no_tdm_store and args.split_k == 1 launch_fn = compile_mxscale_gemm( @@ -1432,7 +1469,7 @@ def _run_graph_verify(args): cluster_m=args.cluster_m, cluster_n=args.cluster_n, use_tdm_store=use_tdm_store, - out_dtype=args.out_dtype, + out_dtype=kernel_out_dtype, inst_prefetch=args.inst_prefetch, wave_specialized_tdm=args.wave_spec_tdm, split_k=args.split_k,