diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index ee09dc7a4..0270893ce 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -6,10 +6,13 @@ """ import functools +import inspect +import os import flydsl.compiler as flyc import flydsl.expr as fx from flydsl._mlir import ir +from flydsl._mlir.dialects import fly, llvm from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import arith, buffer_ops, const_expr, gpu, idx2crd, range_constexpr, rocdl, tdm_ops from flydsl.expr.rocdl import cluster @@ -29,6 +32,30 @@ ) from kernels.pipeline_utils import make_tail_plan, tdm_epilogue_fence_threshold_bytes + +def _s_prefetch_inst_burst(num_pages: int, page_bytes: int = 4096): + """gfx1250: prefetch ``num_pages`` × 4 KB of instructions ahead of PC. + + Caller must keep ``num_pages * page_bytes`` within shader bounds; over-reach + page-faults. + """ + from flydsl._mlir.dialects import llvm as _llvm + + lines = [f"s_prefetch_inst_pc_rel {pg * page_bytes}, null, 31" for pg in range(num_pages)] + _llvm.inline_asm(None, [], "\n".join(lines), "", has_side_effects=True) + + +# compatible with no early_timeout descriptor +_TDM_HAS_EARLY_TIMEOUT = "early_timeout" in inspect.signature(tdm_ops.make_tensor_descriptor_2d).parameters + + +def _make_tdm_desc(*, early_timeout=False, **kwargs): + """Build a TDM descriptor, applying early_timeout only when supports it.""" + if _TDM_HAS_EARLY_TIMEOUT: + kwargs["early_timeout"] = early_timeout + return tdm_ops.make_tensor_descriptor_2d(**kwargs) + + # Common constants WMMA_M, WMMA_N, WMMA_K = 16, 16, 128 WAVE_SIZE = 32 @@ -37,6 +64,8 @@ LDS_PAD_A_BYTES = 16 LDS_PAD_D_BYTES = 16 +LDS_SEGMENT_BYTES = 64 * 1024 +LDS_GFX1250_MAX_BYTES = 5 * LDS_SEGMENT_BYTES @functools.lru_cache(maxsize=256) @@ -66,6 +95,7 @@ def compile_mxscale_gemm( atomic_barrier_enable: bool = False, b_streaming: bool = False, scale_load_path: str = "tdm", + fp8_schedule: str = "auto", ): """Compile an MXFP4 or MXFP8 GEMM kernel with TDM async copy. @@ -90,11 +120,20 @@ def compile_mxscale_gemm( if out_dtype not in ("f32", "bf16", "f16"): raise ValueError(f"out_dtype must be 'f32', 'bf16', or 'f16', got {out_dtype!r}") elem_bytes_d = 2 if out_dtype in ("bf16", "f16") else 4 - scale_load_paths = ("tdm", "buffer_lds_stage", "buffer_lds_stage_ab_split") + # scale_load_path: "tdm" = TDM->LDS (default); "vgpr" = buffer_load->VGPR, + # off the LDS/TDM/barrier path; "vgpr_ab_split" = "vgpr" plus repurposing the + # idle scale waves 2,3 to load the second A/B halves. + scale_load_paths = ("tdm", "vgpr", "vgpr_ab_split") if scale_load_path not in scale_load_paths: raise ValueError(f"scale_load_path must be one of {scale_load_paths}, got {scale_load_path!r}") - use_scale_buffer_load = scale_load_path != "tdm" - use_ab_split_scale_buffer_load = scale_load_path == "buffer_lds_stage_ab_split" + fp8_schedule_modes = ("auto", "quadrant", "deep-pipeline") + if fp8_schedule not in fp8_schedule_modes: + raise ValueError(f"fp8_schedule must be one of {fp8_schedule_modes}, got {fp8_schedule!r}") + if fp8_schedule != "auto" and data_format != "fp8": + raise ValueError(f"fp8_schedule={fp8_schedule!r} is only valid for data_format='fp8'") + if fp8_schedule != "auto" and b_streaming: + raise ValueError("fp8_schedule cannot be combined with b_streaming=True") + effective_expert_sched_mode = bool(expert_sched_mode) if num_buffers not in (2, 3, 4): raise ValueError(f"num_buffers must be 2, 3, or 4, got {num_buffers}") @@ -114,8 +153,6 @@ def compile_mxscale_gemm( if wave_specialized_tdm and num_warps < 4: raise ValueError(f"wave_specialized_tdm requires at least 4 waves, got {num_warps}") - if use_ab_split_scale_buffer_load and not wave_specialized_tdm: - raise ValueError("scale_load_path='buffer_lds_stage_ab_split' requires wave_specialized_tdm=True") # ── Format-dependent compile-time constants ── # A8W4: activation is FP8 (PACK_FACTOR_A=1), weight is FP4 (PACK_FACTOR_B=2) @@ -192,11 +229,11 @@ def compile_mxscale_gemm( _as_ds_loads = wmma_m_rep * _a_frag_loads_per_wm + _scale_ds_loads lds_a_stride_bytes = packed_tile_k_a + LDS_PAD_A_BYTES - if use_ab_split_scale_buffer_load: + if scale_load_path == "vgpr_ab_split": if tile_m % 2 != 0: - raise ValueError(f"buffer_lds_stage_ab_split requires even tile_m, got {tile_m}") + raise ValueError(f"scale_load_path='vgpr_ab_split' requires even tile_m, got {tile_m}") if tile_n % 32 != 0: - raise ValueError(f"buffer_lds_stage_ab_split requires tile_n divisible by 32, got {tile_n}") + raise ValueError(f"scale_load_path='vgpr_ab_split' requires tile_n divisible by 32, got {tile_n}") lds_a_data_bytes = tile_m * lds_a_stride_bytes lds_b_data_bytes = tile_n * packed_tile_k_b @@ -207,18 +244,6 @@ def compile_mxscale_gemm( lds_b_scale_bytes = tile_n * scale_k_per_tile + _scale_guard_bytes interleaved_scale_cols_a = wmma_m_rep * scale_k_per_tile interleaved_scale_cols_b = b_scale_load_rep * scale_k_per_tile - _scale_dma_bytes = 16 - if use_scale_buffer_load: - if interleaved_scale_cols_a % _scale_dma_bytes != 0: - raise ValueError( - "buffer_lds_stage scale loads require A scale rows to be 16-byte aligned, " - f"got interleaved_scale_cols_a={interleaved_scale_cols_a}" - ) - if interleaved_scale_cols_b % _scale_dma_bytes != 0: - raise ValueError( - "buffer_lds_stage scale loads require B scale rows to be 16-byte aligned, " - f"got interleaved_scale_cols_b={interleaved_scale_cols_b}" - ) def _align_up(value: int, align: int) -> int: if value % align == 0: @@ -231,9 +256,9 @@ def _align_up(value: int, align: int) -> int: # active loader wave must issue a full-tile descriptor by itself. tdm_desc_num_warps = 1 if wave_specialized_tdm else num_warps - # All pipeline stages share the same intra-stage layout. Keep that layout - # unchanged and only remap each logical stage to a physical base inside one - # LDS arena so TDM epilogue can alias the dead prefix of the arena. + # All pipeline stages share the same intra-stage layout in the generic + # arena path. The active gfx1250 FP8 TDM tile uses a separate reference + # pool layout below. stage_layout = SmemAllocator(None, arch=gpu_arch, global_sym_name=f"mxscale_{data_format}_layout") stage_a_data_rel_off = stage_layout._align(stage_layout.ptr, 16) stage_layout.ptr = stage_a_data_rel_off + lds_a_data_bytes @@ -262,24 +287,88 @@ def _align_up(value: int, align: int) -> int: ), ) - stage_phys_order = [i for i in range(num_buffers) if i != _last_compute_stage] - stage_phys_order.append(_last_compute_stage) - stage_base_off = [0] * num_buffers - for phys_i, logical_i in enumerate(stage_phys_order): - stage_base_off[logical_i] = phys_i * stage_pitch_bytes - arena_alloc.ptr = stage_pitch_bytes * num_buffers - arena_total_bytes = arena_alloc.ptr - epilogue_fence_threshold_bytes = tdm_epilogue_fence_threshold_bytes( - stage_base_off=stage_base_off, - tail_plan=_base_tail_plan, - loop_iters=loop_iters, - extra=extra, + use_ref_segmented_lds_layout = ( + data_format == "fp8" + and tile_m == 256 + and tile_n == 256 + and tile_k == 128 + and m_warp == 2 + and n_warp == 2 + and num_buffers == 4 + and split_k == 1 + and wave_specialized_tdm + and not use_scale_opsel ) - stage_a_data_off = [stage_base_off[i] + stage_a_data_rel_off for i in range(num_buffers)] - stage_b_data_off = [stage_base_off[i] + stage_b_data_rel_off for i in range(num_buffers)] - stage_a_scale_off = [stage_base_off[i] + stage_a_scale_rel_off for i in range(num_buffers)] - stage_b_scale_off = [stage_base_off[i] + stage_b_scale_rel_off for i in range(num_buffers)] + # "vgpr"/"vgpr_ab_split": load scale global->VGPR via buffer_load, bypassing + # TDM+LDS entirely. Requires the reference segmented LDS layout. + use_buffer_vgpr_scale = scale_load_path in ("vgpr", "vgpr_ab_split") + if use_buffer_vgpr_scale and not use_ref_segmented_lds_layout: + raise ValueError( + f"scale_load_path={scale_load_path!r} requires the reference segmented " + "LDS layout (not active for this tile/format configuration)" + ) + # Scale prefetch depth (K-tiles ahead) for the buffer->VGPR path. D=1 is the + # sweet spot; D=2 doubles scale VGPRs -> spill + ~18% regression. + _bvs_D = max(1, int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_DEPTH", "1"))) + # ab_half_split: repurpose the (under "vgpr") idle scale waves 2,3 as the + # second halves of A/B, so all 4 waves share the A/B TDM (wave0=A0, wave1=B0, + # wave2=A1, wave3=B1). Measured wall-neutral. + use_ab_half_split = scale_load_path == "vgpr_ab_split" + # The buffer_load->VGPR scale ring is built only when scale is actually loaded. + _bvs_active = use_buffer_vgpr_scale + + if use_ref_segmented_lds_layout: + # The A/B data pools are no longer packed into the same per-stage + # 64KiB segment window. Scale pools keep the reference 0x800 stride so + # every TDM LDS target remains 2KiB-aligned. + ref_a_stage_stride = 0x9000 + ref_b_stage_stride = 0x8000 + ref_scale_stage_stride = 0x800 + if lds_a_data_bytes > ref_a_stage_stride: + raise RuntimeError( + "reference segmented LDS layout requires A stage <= 0x9000 bytes, " f"got {lds_a_data_bytes}" + ) + if lds_b_data_bytes > ref_b_stage_stride: + raise RuntimeError( + "reference segmented LDS layout requires B stage <= 0x8000 bytes, " f"got {lds_b_data_bytes}" + ) + if lds_a_scale_bytes > ref_scale_stage_stride or lds_b_scale_bytes > ref_scale_stage_stride: + raise RuntimeError( + "reference segmented LDS layout requires scale stage <= 0x800 bytes, " + f"got A={lds_a_scale_bytes} B={lds_b_scale_bytes}" + ) + + stage_a_data_off = [0x00000, 0x09000, 0x16000, 0x1F000] + stage_a_scale_off = [0x12000 + i * ref_scale_stage_stride for i in range(num_buffers)] + stage_b_scale_off = [0x28000 + i * ref_scale_stage_stride for i in range(num_buffers)] + stage_b_data_off = [0x30000 + i * ref_b_stage_stride for i in range(num_buffers)] + arena_alloc.ptr = LDS_GFX1250_MAX_BYTES + arena_total_bytes = arena_alloc.ptr + + # The epilogue may reuse the prefix only after all main/tail TDM traffic + # is fully fenced. This is outside the hot loop and avoids assuming a + # single monotonic per-stage base for the segmented pool layout. + epilogue_fence_threshold_bytes = 0 + else: + stage_phys_order = [i for i in range(num_buffers) if i != _last_compute_stage] + stage_phys_order.append(_last_compute_stage) + stage_base_off = [0] * num_buffers + for phys_i, logical_i in enumerate(stage_phys_order): + stage_base_off[logical_i] = phys_i * stage_pitch_bytes + arena_alloc.ptr = stage_pitch_bytes * num_buffers + arena_total_bytes = arena_alloc.ptr + epilogue_fence_threshold_bytes = tdm_epilogue_fence_threshold_bytes( + stage_base_off=stage_base_off, + tail_plan=_base_tail_plan, + loop_iters=loop_iters, + extra=extra, + ) + + stage_a_data_off = [stage_base_off[i] + stage_a_data_rel_off for i in range(num_buffers)] + stage_b_data_off = [stage_base_off[i] + stage_b_data_rel_off for i in range(num_buffers)] + stage_a_scale_off = [stage_base_off[i] + stage_a_scale_rel_off for i in range(num_buffers)] + stage_b_scale_off = [stage_base_off[i] + stage_b_scale_rel_off for i in range(num_buffers)] if use_tdm_store: lds_d_row_stride = warp_tile_n * elem_bytes_d + LDS_PAD_D_BYTES @@ -295,12 +384,12 @@ def _align_up(value: int, align: int) -> int: arena_alloc.ptr = total_d_bytes check_smem_capacity(arena_total_bytes, gpu_arch) - # TENSORcnt is tracked per-wave in hardware. When scale is loaded through - # buffer_load_lds, TDM only carries A/B data. + # TENSORcnt is tracked per-wave in hardware. Wave-specialized TDM issues one + # tensor_load per wave per step; otherwise all 4 (A/B/A_scale/B_scale). if wave_specialized_tdm: TDM_LOADS_PER_STEP = 1 else: - TDM_LOADS_PER_STEP = 2 if use_scale_buffer_load else 4 + TDM_LOADS_PER_STEP = 4 tail_plan = [(ls, cs, o * TDM_LOADS_PER_STEP // 2 if o > 0 else o) for ls, cs, o in _base_tail_plan] # Pre-compute epilogue sub-tile layout (unified for FP4 vec16 and FP8 vec8) @@ -325,8 +414,28 @@ def _align_up(value: int, align: int) -> int: COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING = "row_major_streaming" COMPUTE_SCHEDULE_FP4_COL_BAND = "fp4_col_band" COMPUTE_SCHEDULE_FP8_QUADRANT = "fp8_quadrant" + COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE = "fp8_deep_pipeline" COMPUTE_SCHEDULE_B_STREAMING = "b_streaming" + fp8_deep_pipeline_eligible = ( + data_format == "fp8" + and tile_m == 256 + and tile_n == 256 + and tile_k == 128 + and m_warp == 2 + and n_warp == 2 + and num_buffers == 4 + and wave_specialized_tdm + and out_dtype == "bf16" + and not use_scale_opsel + ) + if fp8_schedule == "deep-pipeline" and not fp8_deep_pipeline_eligible: + raise ValueError( + "fp8_schedule='deep-pipeline' requires fp8 256x256x128, " + "m_warp=n_warp=2, num_buffers=4, wave_specialized_tdm=True, " + "out_dtype='bf16', and use_scale_opsel=False" + ) + def _pick_compute_schedule_kind(): if b_streaming: return COMPUTE_SCHEDULE_B_STREAMING @@ -339,12 +448,29 @@ def _pick_compute_schedule_kind(): if is_fp4: return COMPUTE_SCHEDULE_FP4_COL_BAND if data_format == "fp8": + if fp8_schedule == "deep-pipeline" or (fp8_schedule == "auto" and fp8_deep_pipeline_eligible): + return COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE return COMPUTE_SCHEDULE_FP8_QUADRANT return COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING compute_schedule_kind = _pick_compute_schedule_kind() use_fp4_bank_friendly_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND use_fp8_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT + use_fp8_deep_pipeline_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE + use_b_streaming_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING + if use_buffer_vgpr_scale and not use_fp8_deep_pipeline_schedule: + raise ValueError(f"scale_load_path={scale_load_path!r} is only supported with the FP8 deep-pipeline schedule") + use_ws_tdm_split_signal_overlap = ( + wave_specialized_tdm + and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) + and num_buffers == 4 + and use_cluster + ) + if use_b_streaming_schedule: + print( + f"[b_streaming] {data_format} tile=({tile_m},{tile_n},{tile_k}) " f"M_r={wmma_m_rep} N_r={wmma_n_rep}", + flush=True, + ) if use_fp4_bank_friendly_schedule: _bank_half_wm = wmma_m_rep // 2 @@ -365,11 +491,19 @@ def _pick_compute_schedule_kind(): for _wn in range(_bank_half_wn, wmma_n_rep): _bank_group_to_row_major.append(_wm * wmma_n_rep + _wn) - if use_fp8_quadrant_schedule: + if use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule: _fp8_half_wm = wmma_m_rep // 2 _fp8_half_wn = wmma_n_rep // 2 _fp8_group_size = _fp8_half_wm * _fp8_half_wn _fp8_b_scale_loads = (b_scale_load_rep + 3) // 4 + if use_fp8_deep_pipeline_schedule: + _fp8_pair_wm = 2 + _fp8_pair_wn = 2 + _fp8_wm_pairs = wmma_m_rep // _fp8_pair_wm + _fp8_wn_pairs = wmma_n_rep // _fp8_pair_wn + _fp8_pair_a_loads = _fp8_pair_wm * DS_LOADS_PER_A_FRAG + _fp8_pair_b_loads = _fp8_pair_wn * _b_frag_loads_per_wn + _fp8_scale_loads = (wmma_m_rep + 3) // 4 + (b_scale_load_rep + 3) // 4 @flyc.kernel(known_block_size=[block_threads, 1, 1]) def kernel_mxscale_gemm( @@ -386,7 +520,7 @@ def kernel_mxscale_gemm( if const_expr(inst_prefetch): if rocdl.wave_id() == fx.Int32(0): - rocdl.s_prefetch_inst_burst(num_pages=10) + _s_prefetch_inst_burst(num_pages=4) tx = gpu.thread_id("x") bx = gpu.block_id("x") @@ -404,7 +538,12 @@ def kernel_mxscale_gemm( a_mcast_mask = 0 b_mcast_mask = 0 - layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) + # The FP8 deep pipeline runs cleaner when adjacent wave ids advance M + # first; keep the default mapping for the other schedules. + if const_expr(use_fp8_deep_pipeline_schedule): + layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (WAVE_SIZE, m_warp * WAVE_SIZE, 16, 1)) + else: + layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) thr_coord = idx2crd(tx, layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), @@ -416,15 +555,47 @@ def kernel_mxscale_gemm( warp_m_base = wave_m_idx * arith.index(warp_tile_m) warp_n_base = wave_n_idx * arith.index(warp_tile_n) + if const_expr(use_buffer_vgpr_scale): + # Direct global->VGPR scale load (no TDM/LDS). Coalesced lane-major + # host layout [M_block(128), K_tile, group(2), lane16(16), 4 i32], so + # each buffer_load_b128's 16 lanes read 256 contiguous bytes: + # i32_off(group) = (mb*Kt + kt)*128 + group*64 + lane16*4 + _bvs_a_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False) + _bvs_b_rsrc = buffer_ops.create_buffer_resource(arg_b_scale, max_size=False) + _bvs_Kt = K // tile_k # total K-tiles + _bvs_mb_a = blk_m / arith.index(128) + wave_m_idx + _bvs_mb_b = blk_n / arith.index(128) + wave_n_idx + _bvs_lane4 = lane16 * arith.index(4) + + def _bvs_load_scales(rsrc, mb, rep, k_base): + kt = k_base / arith.index(tile_k) + tile_i32 = (mb * arith.index(_bvs_Kt) + kt) * arith.index(128) + vals = [] + for ld in range_constexpr(rep // 4): # rep=8 -> 2 groups of 4 i32 + off = arith.index_cast(T.i32, tile_i32 + arith.index(ld * 64) + _bvs_lane4) + v = fx.Vector(buffer_ops.buffer_load(rsrc, off, vec_width=4, dtype=T.i32)) + for j in range_constexpr(4): + vals.append(v[j]) + return vals + + def _bvs_prefetch(k_base): + # Issue scale buffer_load for one K-tile; returns (a[8], b[8]) VGPR. + a = _bvs_load_scales(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, k_base) + b = _bvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, k_base) + return a, b + m_idx = fx.Index(i32_m) n_stride = arith.index(N) c_nrec = m_idx * n_stride * arith.index(elem_bytes_d) c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) - zero_i32 = fx.Int32(0) + c_global_ptr_type = ir.Type.parse("!llvm.ptr<1>") + c_global_base_i64 = llvm.PtrToIntOp( + T.i64, fly.extract_aligned_pointer_as_index(c_global_ptr_type, arg_c.__extract_to_ir_values__()[0]) + ).result def make_desc_a(memref, k_base): k_packed_off = k_base / arith.index(PACK_FACTOR_A) - return tdm_ops.make_tensor_descriptor_2d( + return _make_tdm_desc( global_ptr=arg_a, lds_memref=memref, global_offset=(blk_m, k_packed_off), @@ -437,11 +608,12 @@ def make_desc_a(memref, k_base): num_warps=tdm_desc_num_warps, workgroup_mask=a_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, + early_timeout=True, ) def make_desc_b(memref, k_base): k_packed_off = k_base / arith.index(PACK_FACTOR_B) - return tdm_ops.make_tensor_descriptor_2d( + return _make_tdm_desc( global_ptr=arg_b, lds_memref=memref, global_offset=(blk_n / arith.index(16), k_packed_off * arith.index(16)), @@ -454,12 +626,13 @@ def make_desc_b(memref, k_base): num_warps=tdm_desc_num_warps, workgroup_mask=b_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, + early_timeout=True, ) def make_desc_a_half(memref, k_base, m_half: int): row_start = m_half * ab_split_a_rows k_packed_off = k_base / arith.index(PACK_FACTOR_A) - return tdm_ops.make_tensor_descriptor_2d( + return _make_tdm_desc( global_ptr=arg_a, lds_memref=memref, global_offset=(blk_m + arith.index(row_start), k_packed_off), @@ -473,12 +646,13 @@ def make_desc_a_half(memref, k_base, m_half: int): workgroup_mask=a_mcast_mask, lds_byte_offset=arith.index(row_start * lds_a_stride_bytes), atomic_barrier_enable=atomic_barrier_enable, + early_timeout=True, ) def make_desc_b_half(memref, k_base, n_half: int): group_start = n_half * ab_split_b_groups k_packed_off = k_base / arith.index(PACK_FACTOR_B) - return tdm_ops.make_tensor_descriptor_2d( + return _make_tdm_desc( global_ptr=arg_b, lds_memref=memref, global_offset=(blk_n / arith.index(16) + arith.index(group_start), k_packed_off * arith.index(16)), @@ -492,13 +666,14 @@ def make_desc_b_half(memref, k_base, n_half: int): workgroup_mask=b_mcast_mask, lds_byte_offset=arith.index(group_start * packed_tile_k_b * 16), atomic_barrier_enable=atomic_barrier_enable, + early_timeout=True, ) def make_desc_as(memref, k_base): k_scale_off = k_base / arith.index(SCALE_BLOCK) outer_off = blk_m / arith.index(wmma_m_rep) inner_off = k_scale_off * arith.index(wmma_m_rep) - return tdm_ops.make_tensor_descriptor_2d( + return _make_tdm_desc( global_ptr=arg_a_scale, lds_memref=memref, global_offset=(outer_off, inner_off), @@ -511,13 +686,14 @@ def make_desc_as(memref, k_base): num_warps=tdm_desc_num_warps, workgroup_mask=a_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, + early_timeout=True, ) def make_desc_bs(memref, k_base): k_scale_off = k_base / arith.index(SCALE_BLOCK) outer_off = blk_n / arith.index(b_scale_load_rep) inner_off = k_scale_off * arith.index(b_scale_load_rep) - return tdm_ops.make_tensor_descriptor_2d( + return _make_tdm_desc( global_ptr=arg_b_scale, lds_memref=memref, global_offset=(outer_off, inner_off), @@ -530,6 +706,7 @@ def make_desc_bs(memref, k_base): num_warps=tdm_desc_num_warps, workgroup_mask=b_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, + early_timeout=True, ) if const_expr(wave_specialized_tdm): @@ -1084,6 +1261,7 @@ def compute_tile_fp8_quadrant( lds_bs, emit_filler=None, mid_compute_callback=None, + late_compute_callback=None, ): current_accs = list(accs_in) a_buf, a_bases = _precompute_a_lane_bases(lds_a) @@ -1118,11 +1296,22 @@ def _load_b_scales(ks): def _load_b_left_bundle(ks): return _load_b_half(0, ks), _load_b_scales(ks) - def _emit_group(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, emit_filler_now=False): + def _emit_group_rows( + wm_base, + wn_base, + a_frags, + b_frags, + a_scales, + b_scales, + row_start, + row_count, + emit_filler_now=False, + ): if const_expr(emit_filler_now and emit_filler is not None): rocdl.sched_barrier(0) emit_filler() - for wm_local in range_constexpr(_fp8_half_wm): + for row_offset in range_constexpr(row_count): + wm_local = row_start + row_offset global_wm = wm_base + wm_local for wn_local in range_constexpr(_fp8_half_wn): global_wn = wn_base + wn_local @@ -1136,24 +1325,67 @@ def _emit_group(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, emit_fil b_scales, ) + def _emit_group(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, emit_filler_now=False): + _emit_group_rows( + wm_base, + wn_base, + a_frags, + b_frags, + a_scales, + b_scales, + 0, + _fp8_half_wm, + emit_filler_now=emit_filler_now, + ) + + def _emit_group_col(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, wn_local): + global_wn = wn_base + wn_local + for wm_local in range_constexpr(_fp8_half_wm): + global_wm = wm_base + wm_local + _emit_wmma( + current_accs, + global_wm, + global_wn, + a_frags[wm_local], + b_frags[wn_local], + a_scales, + b_scales, + ) + b_left_frags, b_scales = _load_b_left_bundle(0) + _first_top_row_keep = max((_fp8_half_wm - 1) * DS_LOADS_PER_A_FRAG - _fp8_b_scale_loads, 0) + _bottom_left_keep = max(_b_half_loads - DS_LOADS_PER_A_FRAG, 0) for ks in range_constexpr(k_wmma_steps): is_last_ks = ks == k_wmma_steps - 1 a_scales = _load_a_scales(ks) a_top_frags = _load_a_group(0, _fp8_half_wm, ks) - a_bottom_frags = _load_a_group(_fp8_half_wm, _fp8_half_wm, ks) - # Keep bottom A outstanding while the first quadrant consumes top A. - rocdl.s_wait_dscnt(_fp8_half_wm * DS_LOADS_PER_A_FRAG) + # Consume the first top-left row before issuing bottom-A. + # The barriers only constrain LLVM scheduling; they are not + # hardware synchronization points. + rocdl.s_wait_dscnt(_first_top_row_keep) + rocdl.sched_barrier(0) + _emit_group_rows(0, 0, a_top_frags, b_left_frags, a_scales, b_scales, 0, 1) + rocdl.sched_barrier(0) - _emit_group(0, 0, a_top_frags, b_left_frags, a_scales, b_scales) + a_bottom_frags = _load_a_group(_fp8_half_wm, _fp8_half_wm, ks) + if const_expr(_fp8_half_wm > 1): + _emit_group_rows( + 0, + 0, + a_top_frags, + b_left_frags, + a_scales, + b_scales, + 1, + _fp8_half_wm - 1, + ) b_right_frags = _load_b_half(_fp8_half_wn, ks) - # Keep the newly issued right-half B loads outstanding while - # bottom A becomes ready for the second quadrant. - rocdl.s_wait_dscnt(_b_half_loads) + # Drain bottom-A while keeping most right-half B in flight. + rocdl.s_wait_dscnt(_bottom_left_keep) _emit_group(_fp8_half_wm, 0, a_bottom_frags, b_left_frags, a_scales, b_scales) @@ -1163,22 +1395,33 @@ def _emit_group(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, emit_fil if const_expr(not is_last_ks): next_left_frags, next_b_scales = _load_b_left_bundle(ks + 1) - # Current right-half B must be ready before Q2/Q3, while - # the next ks left-half bundle stays in flight. - rocdl.s_wait_dscnt(_b_left_bundle_loads) - else: - rocdl.s_wait_dscnt(0) - _emit_group(0, _fp8_half_wn, a_top_frags, b_right_frags, a_scales, b_scales) - _emit_group( - _fp8_half_wm, - _fp8_half_wn, - a_bottom_frags, - b_right_frags, - a_scales, - b_scales, - emit_filler_now=is_last_ks, - ) + for wn_local in range_constexpr(_fp8_half_wn): + if const_expr(not is_last_ks): + _right_keep = _b_left_bundle_loads + (_fp8_half_wn - wn_local - 1) * _b_frag_loads_per_wn + else: + _right_keep = (_fp8_half_wn - wn_local - 1) * _b_frag_loads_per_wn + rocdl.s_wait_dscnt(_right_keep) + _emit_group_col(0, _fp8_half_wn, a_top_frags, b_right_frags, a_scales, b_scales, wn_local) + + if const_expr(is_last_ks and late_compute_callback is not None): + rocdl.sched_barrier(0) + late_compute_callback() + + if const_expr(is_last_ks and emit_filler is not None): + rocdl.sched_barrier(0) + emit_filler() + + for wn_local in range_constexpr(_fp8_half_wn): + _emit_group_col( + _fp8_half_wm, + _fp8_half_wn, + a_bottom_frags, + b_right_frags, + a_scales, + b_scales, + wn_local, + ) if const_expr(not is_last_ks): b_left_frags = next_left_frags @@ -1186,6 +1429,185 @@ def _emit_group(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, emit_fil return current_accs + def compute_tile_fp8_deep_pipeline( + accs_in, + lds_a, + lds_b, + lds_as, + lds_bs, + emit_filler=None, + mid_compute_callback=None, + late_compute_callback=None, + a0_prefetch=None, + scale_k_base=None, + pf_a_scales=None, + pf_b_scales=None, + ): + current_accs = list(accs_in) + a_buf, a_bases = _precompute_a_lane_bases(lds_a) + b_buf, b_bases = _precompute_b_lane_bases(lds_b) + as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) + bs_buf, bs_bases = _precompute_scale_lane_bases( + lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b + ) + + def load_a_pair(wm_pair, ks): + wm_base = wm_pair * _fp8_pair_wm + return [ + load_a_frag(a_buf, a_bases[wm_base + wm_local], ks) for wm_local in range_constexpr(_fp8_pair_wm) + ] + + def load_b_pair(wn_pair, ks): + wn_base = wn_pair * _fp8_pair_wn + return [ + load_b_frag(b_buf, b_bases, wn_base + wn_local, ks) for wn_local in range_constexpr(_fp8_pair_wn) + ] + + def _load_a_scales(ks): + if const_expr(use_buffer_vgpr_scale): + if const_expr(pf_a_scales is not None): + return pf_a_scales # prefetched (issued in the prior compute tile) + return _bvs_load_scales(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, scale_k_base) + return load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) + + def _load_b_scales(ks): + if const_expr(use_buffer_vgpr_scale): + if const_expr(pf_b_scales is not None): + return pf_b_scales + return _bvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, scale_k_base) + return load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) + + def emit_panel_2x2( + wm_pair, + wn_pair, + a_pair, + b_pair, + scale_pair, + prefetch_after_first_row=None, + ): + a_scales, b_scales = scale_pair + wm_base = wm_pair * _fp8_pair_wm + wn_base = wn_pair * _fp8_pair_wn + for wn_local in range_constexpr(_fp8_pair_wn): + _emit_wmma( + current_accs, + wm_base, + wn_base + wn_local, + a_pair[0], + b_pair[wn_local], + a_scales, + b_scales, + ) + if const_expr(prefetch_after_first_row is not None): + prefetch_after_first_row() + for wn_local in range_constexpr(_fp8_pair_wn): + _emit_wmma( + current_accs, + wm_base + 1, + wn_base + wn_local, + a_pair[1], + b_pair[wn_local], + a_scales, + b_scales, + ) + + def emit_panel_2x2_row(wm_pair, wn_pair, row_local, a_pair, b_pair, scale_pair): + a_scales, b_scales = scale_pair + wm_base = wm_pair * _fp8_pair_wm + wn_base = wn_pair * _fp8_pair_wn + for wn_local in range_constexpr(_fp8_pair_wn): + _emit_wmma( + current_accs, + wm_base + row_local, + wn_base + wn_local, + a_pair[row_local], + b_pair[wn_local], + a_scales, + b_scales, + ) + + _pair_loads = _fp8_pair_a_loads + _two_pair_loads = _fp8_pair_a_loads + _fp8_pair_b_loads + + for ks in range_constexpr(k_wmma_steps): + is_last_ks = ks == k_wmma_steps - 1 + a_scales = _load_a_scales(ks) + b_scales = _load_b_scales(ks) + scale_pair = (a_scales, b_scales) + + b0 = load_b_pair(0, ks) + if const_expr(ks == 0 and a0_prefetch is not None and len(a0_prefetch) == _fp8_pair_wm): + a0 = list(a0_prefetch) + elif const_expr(ks == 0 and a0_prefetch is not None): + a0 = [a0_prefetch[0], load_a_frag(a_buf, a_bases[1], ks)] + else: + a0 = load_a_pair(0, ks) + b1 = load_b_pair(1, ks) + b2 = load_b_pair(2, ks) + + a1_box = [None] + b3_box = [None] + a2_box = [None] + a3_box = [None] + + def _prefetch_a1(): + a1_box[0] = load_a_pair(1, ks) + + first_wait_keep = _two_pair_loads + 3 + if const_expr(ks == 0 and a0_prefetch is not None): + first_wait_keep += DS_LOADS_PER_A_FRAG * len(a0_prefetch) + rocdl.s_wait_dscnt(first_wait_keep) + emit_panel_2x2(0, 0, a0, b0, scale_pair, prefetch_after_first_row=_prefetch_a1) + + if const_expr(ks == 0 and mid_compute_callback is not None): + rocdl.sched_barrier(0) + mid_compute_callback() + + def _prefetch_b3(): + b3_box[0] = load_b_pair(3, ks) + + def _prefetch_a3(): + a3_box[0] = load_a_pair(3, ks) + + rocdl.s_wait_dscnt(_pair_loads + _fp8_pair_b_loads) + emit_panel_2x2(0, 1, a0, b1, scale_pair, prefetch_after_first_row=_prefetch_b3) + + rocdl.s_wait_dscnt(_fp8_pair_b_loads + 2) + emit_panel_2x2(1, 0, a1_box[0], b0, scale_pair, prefetch_after_first_row=_prefetch_a3) + + def _prefetch_a2(): + a2_box[0] = load_a_pair(2, ks) + + emit_panel_2x2(1, 1, a1_box[0], b1, scale_pair) + + emit_panel_2x2(0, 2, a0, b2, scale_pair, prefetch_after_first_row=_prefetch_a2) + emit_panel_2x2_row(1, 2, 0, a1_box[0], b2, scale_pair) + emit_panel_2x2_row(1, 2, 1, a1_box[0], b2, scale_pair) + rocdl.s_wait_dscnt(_pair_loads) + emit_panel_2x2(0, 3, a0, b3_box[0], scale_pair) + emit_panel_2x2(1, 3, a1_box[0], b3_box[0], scale_pair) + + emit_panel_2x2(2, 0, a2_box[0], b0, scale_pair) + if const_expr(is_last_ks and late_compute_callback is not None): + rocdl.sched_barrier(0) + late_compute_callback() + emit_panel_2x2(2, 1, a2_box[0], b1, scale_pair) + + rocdl.s_wait_dscnt(0) + emit_panel_2x2(3, 0, a3_box[0], b0, scale_pair) + emit_panel_2x2(3, 1, a3_box[0], b1, scale_pair) + + if const_expr(is_last_ks and emit_filler is not None): + rocdl.sched_barrier(0) + emit_filler() + + emit_panel_2x2(2, 2, a2_box[0], b2, scale_pair) + emit_panel_2x2(2, 3, a2_box[0], b3_box[0], scale_pair) + emit_panel_2x2(3, 2, a3_box[0], b2, scale_pair) + emit_panel_2x2(3, 3, a3_box[0], b3_box[0], scale_pair) + + return current_accs + def compute_tile_b_streaming( accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=None, mid_compute_callback=None ): @@ -1277,27 +1699,86 @@ def hot_loop_scheduler_fp4_bank_friendly(): rocdl.sched_barrier(0) def hot_loop_scheduler_fp8_quadrant(): - _a_all_loads = wmma_m_rep * DS_LOADS_PER_A_FRAG _a_scale_loads = (wmma_m_rep + 3) // 4 + _a_top_loads = _fp8_half_wm * DS_LOADS_PER_A_FRAG + _a_bottom_loads = _a_top_loads _b_half_loads = _fp8_half_wn * _b_frag_loads_per_wn _b_left_bundle_loads = _b_half_loads + _fp8_b_scale_loads _group_wmma = _fp8_group_size + _first_row_wmma = _fp8_half_wn + _remaining_top_left_wmma = (_fp8_half_wm - 1) * _fp8_half_wn for _ks in range_constexpr(k_wmma_steps): if const_expr(_ks == 0): - rocdl.sched_dsrd(_b_left_bundle_loads + _a_scale_loads + _a_all_loads) + rocdl.sched_dsrd(_b_left_bundle_loads + _a_scale_loads + _a_top_loads) else: - rocdl.sched_dsrd(_a_scale_loads + _a_all_loads) - rocdl.sched_mfma(_group_wmma) + rocdl.sched_dsrd(_a_scale_loads + _a_top_loads) + rocdl.sched_mfma(_first_row_wmma) + rocdl.sched_dsrd(_a_bottom_loads) + if const_expr(_remaining_top_left_wmma > 0): + rocdl.sched_mfma(_remaining_top_left_wmma) rocdl.sched_dsrd(_b_half_loads) rocdl.sched_mfma(_group_wmma) if const_expr(_ks < k_wmma_steps - 1): rocdl.sched_dsrd(_b_left_bundle_loads) - rocdl.sched_mfma(_group_wmma) - rocdl.sched_mfma(_group_wmma) + for _wn_local in range_constexpr(_fp8_half_wn): + rocdl.sched_mfma(_fp8_half_wm) + for _wn_local in range_constexpr(_fp8_half_wn): + rocdl.sched_mfma(_fp8_half_wm) + rocdl.sched_barrier(0) + + def hot_loop_scheduler_fp8_deep_pipeline(): + def _sched_panel_2x2(prefetch_loads=0): + if const_expr(prefetch_loads > 0): + rocdl.sched_mfma(_fp8_pair_wn) + rocdl.sched_dsrd(prefetch_loads) + rocdl.sched_mfma(_fp8_pair_wn) + else: + rocdl.sched_mfma(_fp8_pair_wm * _fp8_pair_wn) + + def _sched_panel_row(): + rocdl.sched_mfma(_fp8_pair_wn) + + _initial_loads = _fp8_scale_loads + _fp8_pair_b_loads * 3 + _fp8_pair_a_loads + + for _ks in range_constexpr(k_wmma_steps): + _ks_initial_loads = _initial_loads + if const_expr(_ks == 0): + _ks_initial_loads -= _fp8_pair_a_loads + rocdl.sched_dsrd(_ks_initial_loads) + _sched_panel_2x2(_fp8_pair_a_loads) + _sched_panel_2x2(_fp8_pair_b_loads) + _sched_panel_2x2(_fp8_pair_a_loads) + _sched_panel_2x2() + _sched_panel_2x2(_fp8_pair_a_loads) + _sched_panel_row() + _sched_panel_row() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() + _sched_panel_2x2() rocdl.sched_barrier(0) - def compute_tile_scheduled(accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=None, mid_compute_callback=None): + def compute_tile_scheduled( + accs_in, + lds_a, + lds_b, + lds_as, + lds_bs, + emit_filler=None, + mid_compute_callback=None, + late_compute_callback=None, + a0_prefetch=None, + scale_k_base=None, + pf_a_scales=None, + pf_b_scales=None, + ): if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING): return compute_tile_b_streaming( accs_in, @@ -1327,6 +1808,22 @@ def compute_tile_scheduled(accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=No lds_bs, emit_filler=emit_filler, mid_compute_callback=mid_compute_callback, + late_compute_callback=late_compute_callback, + ) + if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE): + return compute_tile_fp8_deep_pipeline( + accs_in, + lds_a, + lds_b, + lds_as, + lds_bs, + emit_filler=emit_filler, + mid_compute_callback=mid_compute_callback, + late_compute_callback=late_compute_callback, + a0_prefetch=a0_prefetch, + scale_k_base=scale_k_base, + pf_a_scales=pf_a_scales, + pf_b_scales=pf_b_scales, ) return compute_tile( accs_in, @@ -1365,11 +1862,23 @@ def hot_loop_scheduler_scheduled(): hot_loop_scheduler_b_streaming() elif const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): hot_loop_scheduler_fp4_bank_friendly() + elif const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE): + hot_loop_scheduler_fp8_deep_pipeline() elif const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT): hot_loop_scheduler_fp8_quadrant() else: hot_loop_scheduler() + def prefetch_fp8_deep_a0_frags(lds_a): + a_buf, a_bases = _precompute_a_lane_bases(lds_a) + return [load_a_frag(a_buf, a_bases[wm_local], 0) for wm_local in range_constexpr(_fp8_pair_wm)] + + def maybe_prefetch_fp8_deep_a0(lds_a): + # Call only after the TDM fence for this stage; pre-fence LDS reads can race multicast delivery. + if const_expr(use_fp8_deep_pipeline_schedule): + return prefetch_fp8_deep_a0_frags(lds_a) + return None + # ── Epilogue (unified via _sub_tiles) ── def _get_acc_sub8(accs, acc_idx, vec_base): """Extract 8-element sub-vector from accumulator.""" @@ -1415,13 +1924,28 @@ def epilogue_lds_stores(final_accs, d_buf, d_base): imm = m_off * _lds_d_stride_elems + wn * _n_col_d_elems store_acc_vec8_to_lds(d_buf, d_base, imm, sub8, out_elem=_out_elem_local) + def _atomic_fadd_global(val, byte_off): + # Device-scoped, relaxed atomic add into C at c_global_base_i64 + byte_off. + addr_i64 = llvm.AddOp( + c_global_base_i64, arith.index_cast(T.i64, byte_off), llvm.IntegerOverflowFlags(0) + ).result + ptr = llvm.IntToPtrOp(c_global_ptr_type, addr_i64).result + llvm.AtomicRMWOp( + llvm.AtomicBinOp.fadd, + ptr, + val.ir_value(), + llvm.AtomicOrdering.monotonic, + syncscope="agent", + alignment=4, + ) + def _atomic_add_acc_vec8_to_buffer(acc_vec8, addr): if const_expr(_bf16_out): h_vec = fx.Vector(arith.trunc_f(T.vec(8, _out_elem_local), acc_vec8)) for pair in range_constexpr(4): pair_vec = fx.Vector.from_elements([h_vec[pair * 2], h_vec[pair * 2 + 1]]) - byte_off = arith.index_cast(T.i32, addr + arith.index(pair * 4)) - rocdl.raw_ptr_buffer_atomic_fadd(pair_vec, c_rsrc, byte_off, zero_i32, zero_i32) + byte_off = addr + arith.index(pair * 4) + _atomic_fadd_global(pair_vec, byte_off) return 1 acc_vec = fx.Vector(acc_vec8) @@ -1429,8 +1953,8 @@ def _atomic_add_acc_vec8_to_buffer(acc_vec8, addr): base_addr = addr[half] if isinstance(addr, (list, tuple)) else addr for vi in range_constexpr(4): val = acc_vec[half * 4 + vi] - byte_off = arith.index_cast(T.i32, (base_addr + arith.index(vi)) * arith.index(4)) - rocdl.raw_ptr_buffer_atomic_fadd(val, c_rsrc, byte_off, zero_i32, zero_i32) + byte_off = (base_addr + arith.index(vi)) * arith.index(4) + _atomic_fadd_global(val, byte_off) return 2 def epilogue_atomic_adds(final_accs, addrs): @@ -1530,10 +2054,18 @@ def _l2_prefetch(k_base): warp_lds_off + lane16 * arith.index(_lds_d_stride_elems) + lane_kgrp * arith.index(4 * elem_bytes_d) ) wave_id_idx = arith.index_cast(T.index, rocdl.wave_id()) - d_warp_off_sgpr = wave_id_idx * arith.index(warp_d_bytes) + arith.index(d_output_off) - warp_m_off_sgpr = (wave_id_idx / arith.index(n_warp)) * arith.index(warp_tile_m) - warp_n_off_sgpr = (wave_id_idx % arith.index(n_warp)) * arith.index(warp_tile_n) - d_desc = tdm_ops.make_tensor_descriptor_2d( + # Match the TDM-store descriptor offsets to the compute wave mapping. + if const_expr(use_fp8_deep_pipeline_schedule): + wave_m_sgpr = wave_id_idx % arith.index(m_warp) + wave_n_sgpr = wave_id_idx / arith.index(m_warp) + else: + wave_m_sgpr = wave_id_idx / arith.index(n_warp) + wave_n_sgpr = wave_id_idx % arith.index(n_warp) + d_warp_linear_sgpr = wave_m_sgpr * arith.index(n_warp) + wave_n_sgpr + d_warp_off_sgpr = d_warp_linear_sgpr * arith.index(warp_d_bytes) + arith.index(d_output_off) + warp_m_off_sgpr = wave_m_sgpr * arith.index(warp_tile_m) + warp_n_off_sgpr = wave_n_sgpr * arith.index(warp_tile_n) + d_desc = _make_tdm_desc( global_ptr=arg_c, lds_memref=d_lds_base_ptr, global_offset=(blk_m + warp_m_off_sgpr, blk_n + warp_n_off_sgpr), @@ -1570,7 +2102,7 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): desc_b_init = make_desc_b(stages_b_mem[0], split_k_base) desc_as_init = make_desc_as(stages_as_mem[0], split_k_base) desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) - if const_expr(use_ab_split_scale_buffer_load): + if const_expr(use_ab_half_split): stages_a0_lds_addr = [] stages_b0_lds_addr = [] stages_a1_lds_addr = [] @@ -1593,7 +2125,11 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): pred_const = fx.Int32(1) if const_expr(wave_specialized_tdm): - active_pred_const = arith.select(tdm_wave_id < fx.Int32(4), fx.Int32(1), fx.Int32(0)) + # With scale on the VGPR path, drop scale waves 2,3 from the active TDM + # path -- unless ab-half-split repurposes them as the second A/B halves. + _drop_scale_waves = use_buffer_vgpr_scale and not use_ab_half_split + _active_wave_limit = 2 if _drop_scale_waves else 4 + active_pred_const = arith.select(tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0)) def _select4(values): return _select_wave_tdm_value(values[0], values[1], values[2], values[3]) @@ -1622,18 +2158,20 @@ def _select_active_tdm(stage_lds_addrs, descs, advs): else: active_pred_const = pred_const - if const_expr(wave_specialized_tdm and not use_scale_buffer_load): - active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( - (stages_a_lds_addr, stages_b_lds_addr, stages_as_lds_addr, stages_bs_lds_addr), - (desc_a_init, desc_b_init, desc_as_init, desc_bs_init), - (adv_a_i32, adv_b_i32, adv_as_i32, adv_bs_i32), - ) - elif const_expr(use_ab_split_scale_buffer_load): + if const_expr(use_ab_half_split): + # All 4 waves load A/B halves: wave0=A0, wave1=B0, wave2=A1, wave3=B1. + # Both halves of A share adv_a (same K-step); both halves of B share adv_b. active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( (stages_a0_lds_addr, stages_b0_lds_addr, stages_a1_lds_addr, stages_b1_lds_addr), (desc_a0_init, desc_b0_init, desc_a1_init, desc_b1_init), (adv_a_i32, adv_b_i32, adv_a_i32, adv_b_i32), ) + elif const_expr(wave_specialized_tdm): + active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( + (stages_a_lds_addr, stages_b_lds_addr, stages_as_lds_addr, stages_bs_lds_addr), + (desc_a_init, desc_b_init, desc_as_init, desc_bs_init), + (adv_a_i32, adv_b_i32, adv_as_i32, adv_bs_i32), + ) else: addr_lo_a = _dg0_lane(desc_a_init, 2) addr_hi_a = _dg0_lane(desc_a_init, 3) @@ -1649,140 +2187,52 @@ def _select_active_tdm(stage_lds_addrs, descs, advs): dgroup1_as = desc_as_init.dgroup1 dgroup1_bs = desc_bs_init.dgroup1 - if const_expr(use_scale_buffer_load): - scale_a_base = buffer_ops.extract_base_index(arg_a_scale) - scale_b_base = buffer_ops.extract_base_index(arg_b_scale) - scale_async_offset = fx.Int32(0) - scale_async_aux = fx.Int32(0) - - def _dma_scale_tile_to_lds( - global_base, - lds_mem, - global_row_base, - global_col_base, - row_stride, - row_bytes: int, - total_bytes: int, - ): - from flydsl._mlir.dialects import memref as memref_dialect - from flydsl._mlir.dialects import rocdl as rocdl_dialect - - for batch in range_constexpr( - (total_bytes + block_threads * _scale_dma_bytes - 1) // (block_threads * _scale_dma_bytes) - ): - batch_byte = batch * block_threads * _scale_dma_bytes - copy_byte = arith.index(batch_byte) + tx * arith.index(_scale_dma_bytes) - if copy_byte < arith.index(total_bytes): - row = copy_byte / arith.index(row_bytes) - col = copy_byte % arith.index(row_bytes) - global_byte = (global_row_base + row) * arith.index(row_stride) + global_col_base + col - global_ptr = buffer_ops.create_llvm_ptr(global_base + global_byte, address_space=1) - lds_ptr = buffer_ops.create_llvm_ptr( - memref_dialect.extract_aligned_pointer_as_index(lds_mem) + copy_byte, - address_space=3, - ) - rocdl_dialect.global_load_async_to_lds_b128( - global_ptr, - lds_ptr, - scale_async_offset, - scale_async_aux, - ) - - def _issue_scale_buffer_loads(stage_idx, k_base): - k_scale_off = k_base / arith.index(SCALE_BLOCK) - _dma_scale_tile_to_lds( - scale_a_base, - stages_as_mem[stage_idx], - blk_m / arith.index(wmma_m_rep), - k_scale_off * arith.index(wmma_m_rep), - wmma_m_rep * K_scale, - interleaved_scale_cols_a, - tile_m * scale_k_per_tile, - ) - _dma_scale_tile_to_lds( - scale_b_base, - stages_bs_mem[stage_idx], - blk_n / arith.index(b_scale_load_rep), - k_scale_off * arith.index(b_scale_load_rep), - b_scale_load_rep * K_scale, - interleaved_scale_cols_b, - tile_n * scale_k_per_tile, - ) - - def _wait_scale_buffer_loads(): - if const_expr(use_scale_buffer_load): - rocdl.s_wait_asynccnt(0) - def _pipeline_fence(outstanding=0): - _wait_scale_buffer_loads() pipeline_fence(outstanding=outstanding, use_cluster=use_cluster) def _pipeline_fence_signal(outstanding=0): - _wait_scale_buffer_loads() pipeline_fence_signal(outstanding=outstanding, use_cluster=use_cluster) - def _issue_ab_tdm(load_stage, addr_a, addr_b): - dg0_a = _pack_dg0(pred_const, stages_a_lds_addr[load_stage], addr_a, addr_hi_a) - dg0_b = _pack_dg0(pred_const, stages_b_lds_addr[load_stage], addr_b, addr_hi_b) - issue_tdm_loads( - tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), - tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), - wave_specialized=wave_specialized_tdm, - ) - - if const_expr(wave_specialized_tdm and (not use_scale_buffer_load or use_ab_split_scale_buffer_load)): + if const_expr(wave_specialized_tdm): - def _issue_active_tdm(load_stage, addr_box, scale_k_box=None, k_prefetch=None): + def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): dg0 = _pack_dg0(active_pred_const, active_stage_lds_addr[load_stage], addr_box[0], active_addr_hi) tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0, active_dgroup1)) addr_box[0] = addr_box[0] + active_adv_i32 - if scale_k_box is not None: - _issue_scale_buffer_loads(load_stage, scale_k_box[0]) - scale_k_box[0] = scale_k_box[0] + arith.index(tile_k) if k_prefetch is not None: _l2_prefetch(k_prefetch) # Prologue - if const_expr(wave_specialized_tdm and not use_scale_buffer_load): + if const_expr(wave_specialized_tdm): for i in range_constexpr(pre_loaded): addr_box = [active_addr_lo] _issue_active_tdm(i, addr_box) active_addr_lo = addr_box[0] - elif const_expr(use_ab_split_scale_buffer_load): - for i in range_constexpr(pre_loaded): - addr_box = [active_addr_lo] - scale_k_box = [split_k_base + arith.index(i * tile_k)] - _issue_active_tdm(i, addr_box, scale_k_box=scale_k_box) - active_addr_lo = addr_box[0] else: for i in range_constexpr(pre_loaded): dg0_a = _pack_dg0(pred_const, stages_a_lds_addr[i], addr_lo_a, addr_hi_a) dg0_b = _pack_dg0(pred_const, stages_b_lds_addr[i], addr_lo_b, addr_hi_b) - if const_expr(use_scale_buffer_load): - issue_tdm_loads( - tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), - tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), - wave_specialized=wave_specialized_tdm, - ) - _issue_scale_buffer_loads(i, split_k_base + arith.index(i * tile_k)) - else: - dg0_as = _pack_dg0(pred_const, stages_as_lds_addr[i], addr_lo_as, addr_hi_as) - dg0_bs = _pack_dg0(pred_const, stages_bs_lds_addr[i], addr_lo_bs, addr_hi_bs) - issue_tdm_loads( - tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), - tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), - tdm_ops.TDMDescriptor2D(dg0_as, dgroup1_as), - tdm_ops.TDMDescriptor2D(dg0_bs, dgroup1_bs), - wave_specialized=wave_specialized_tdm, - ) + dg0_as = _pack_dg0(pred_const, stages_as_lds_addr[i], addr_lo_as, addr_hi_as) + dg0_bs = _pack_dg0(pred_const, stages_bs_lds_addr[i], addr_lo_bs, addr_hi_bs) + issue_tdm_loads( + tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), + tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), + tdm_ops.TDMDescriptor2D(dg0_as, dgroup1_as), + tdm_ops.TDMDescriptor2D(dg0_bs, dgroup1_bs), + wave_specialized=wave_specialized_tdm, + ) addr_lo_a = addr_lo_a + adv_a_i32 addr_lo_b = addr_lo_b + adv_b_i32 - if const_expr(not use_scale_buffer_load): - addr_lo_as = addr_lo_as + adv_as_i32 - addr_lo_bs = addr_lo_bs + adv_bs_i32 - if const_expr(use_scale_buffer_load): - scale_next_k_base = split_k_base + arith.index(pre_loaded * tile_k) + addr_lo_as = addr_lo_as + adv_as_i32 + addr_lo_bs = addr_lo_bs + adv_bs_i32 + + if const_expr(_bvs_active): + # Prologue: prefetch the first _bvs_D K-tiles (global->VGPR). Carried as + # FLAT lists of i32 (list-of-tuples can't be loop-carried). + _bvs_pf = [_bvs_prefetch(split_k_base + arith.index(_d * tile_k)) for _d in range(_bvs_D)] + _bvs_ra = [_v for (_a, _b) in _bvs_pf for _v in _a] + _bvs_rb = [_v for (_a, _b) in _bvs_pf for _v in _b] _pipeline_fence(outstanding=TDM_LOADS_PER_STEP * (num_buffers - 2)) @@ -1790,20 +2240,27 @@ def _issue_active_tdm(load_stage, addr_box, scale_k_box=None, k_prefetch=None): # This overlaps TDM DMA with the remaining WMMA instructions, _fence_outstanding = TDM_LOADS_PER_STEP * (num_buffers - 2) + if const_expr(loop_iters > 0 and use_ws_tdm_split_signal_overlap): + _pipeline_fence_signal(outstanding=_fence_outstanding) + if const_expr(loop_iters > 0): - if const_expr(wave_specialized_tdm and not use_scale_buffer_load): + if const_expr(wave_specialized_tdm): init_args = list(accs) + [active_addr_lo] + if const_expr(_bvs_active): + init_args = init_args + _bvs_ra + _bvs_rb for loop_iter, state in range(0, loop_iters, 1, init=init_args): accs_in = list(state[:n_accs]) cur_addr_lo = state[n_accs] + if const_expr(_bvs_active): + _ra0 = n_accs + 1 + _ring_a = list(state[_ra0 : _ra0 + _bvs_D * wmma_m_rep]) + _rb0 = _ra0 + _bvs_D * wmma_m_rep + _ring_b = list(state[_rb0 : _rb0 + _bvs_D * b_scale_load_rep]) for buf_idx in range_constexpr(num_buffers): load_stage = (buf_idx + num_buffers - 1) % num_buffers - _pipeline_fence_signal(outstanding=_fence_outstanding) - pipeline_fence_wait(use_cluster=use_cluster) - addr_box = [cur_addr_lo] def _mid_tdm_ws( @@ -1817,124 +2274,62 @@ def _mid_tdm_ws( ): _issue_active_tdm(_ls, _ab, k_prefetch=_k_off) - rocdl.sched_barrier(0) - accs_in = compute_tile_scheduled( - accs_in, - stages_a_idx[buf_idx], - stages_b_idx[buf_idx], - stages_as_idx[buf_idx], - stages_bs_idx[buf_idx], - mid_compute_callback=_mid_tdm_ws, - ) - cur_addr_lo = addr_box[0] - hot_loop_scheduler_scheduled() - - results = yield list(accs_in) + [cur_addr_lo] - - accs = list(results[:n_accs]) - active_addr_lo = results[n_accs] - elif const_expr(use_ab_split_scale_buffer_load): - init_args = list(accs) + [active_addr_lo, scale_next_k_base] - - for loop_iter, state in range(0, loop_iters, 1, init=init_args): - accs_in = list(state[:n_accs]) - cur_addr_lo = state[n_accs] - cur_scale_k = state[n_accs + 1] + if const_expr(not use_ws_tdm_split_signal_overlap): + _pipeline_fence_signal(outstanding=_fence_outstanding) + pipeline_fence_wait(use_cluster=use_cluster) - for buf_idx in range_constexpr(num_buffers): - load_stage = (buf_idx + num_buffers - 1) % num_buffers + _late_tdm_ws_fence_signal = None + if const_expr(use_ws_tdm_split_signal_overlap): - _pipeline_fence_signal(outstanding=_fence_outstanding) - pipeline_fence_wait(use_cluster=use_cluster) + def _late_tdm_ws_split_signal(): + _pipeline_fence_signal(outstanding=_fence_outstanding) - addr_box = [cur_addr_lo] - scale_k_box = [cur_scale_k] + _late_tdm_ws_fence_signal = _late_tdm_ws_split_signal - def _mid_tdm_split_scale_dma( - _ls=load_stage, - _ab=addr_box, - _scale_k=scale_k_box, - _k_off=( + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) + rocdl.sched_barrier(0) + # Consume scale prefetched _bvs_D K-tiles ago; issue the + # K-tile +_bvs_D prefetch now (overlaps this tile's WMMAs). + # NOTE: must stay AFTER the fence; issuing the scale + # buffer_loads before the cluster barrier hangs the vgpr path. + if const_expr(_bvs_active): + _cur_a = _ring_a[:wmma_m_rep] + _cur_b = _ring_b[:b_scale_load_rep] + _next_kb = ( split_k_base + loop_iter * arith.index(num_buffers * tile_k) - + arith.index(buf_idx * tile_k) - ), - ): - _issue_active_tdm(_ls, _ab, scale_k_box=_scale_k, k_prefetch=_k_off) + + arith.index((buf_idx + _bvs_D) * tile_k) + ) + _na, _nb2 = _bvs_prefetch(_next_kb) + _ring_a = _ring_a[wmma_m_rep:] + list(_na) + _ring_b = _ring_b[b_scale_load_rep:] + list(_nb2) + else: + _cur_a = None + _cur_b = None - rocdl.sched_barrier(0) accs_in = compute_tile_scheduled( accs_in, stages_a_idx[buf_idx], stages_b_idx[buf_idx], stages_as_idx[buf_idx], stages_bs_idx[buf_idx], - mid_compute_callback=_mid_tdm_split_scale_dma, + mid_compute_callback=_mid_tdm_ws, + late_compute_callback=_late_tdm_ws_fence_signal, + a0_prefetch=a0_prefetch, + pf_a_scales=_cur_a, + pf_b_scales=_cur_b, ) cur_addr_lo = addr_box[0] - cur_scale_k = scale_k_box[0] hot_loop_scheduler_scheduled() - results = yield list(accs_in) + [cur_addr_lo, cur_scale_k] + if const_expr(_bvs_active): + _bvs_yield = _ring_a + _ring_b + else: + _bvs_yield = [] + results = yield list(accs_in) + [cur_addr_lo] + _bvs_yield accs = list(results[:n_accs]) active_addr_lo = results[n_accs] - scale_next_k_base = results[n_accs + 1] - elif const_expr(use_scale_buffer_load): - init_args = list(accs) + [addr_lo_a, addr_lo_b, scale_next_k_base] - - for loop_iter, state in range(0, loop_iters, 1, init=init_args): - accs_in = list(state[:n_accs]) - cur_lo_a = state[n_accs] - cur_lo_b = state[n_accs + 1] - cur_scale_k = state[n_accs + 2] - - for buf_idx in range_constexpr(num_buffers): - load_stage = (buf_idx + num_buffers - 1) % num_buffers - - _pipeline_fence_signal(outstanding=_fence_outstanding) - pipeline_fence_wait(use_cluster=use_cluster) - - addr_boxes = [[cur_lo_a], [cur_lo_b]] - scale_k_box = [cur_scale_k] - - def _mid_tdm_scale_dma( - _ls=load_stage, - _ab=addr_boxes, - _scale_k=scale_k_box, - _k_off=( - split_k_base - + loop_iter * arith.index(num_buffers * tile_k) - + arith.index(buf_idx * tile_k) - ), - ): - _issue_ab_tdm(_ls, _ab[0][0], _ab[1][0]) - _ab[0][0] = _ab[0][0] + adv_a_i32 - _ab[1][0] = _ab[1][0] + adv_b_i32 - _issue_scale_buffer_loads(_ls, _scale_k[0]) - _scale_k[0] = _scale_k[0] + arith.index(tile_k) - _l2_prefetch(_k_off) - - rocdl.sched_barrier(0) - accs_in = compute_tile_scheduled( - accs_in, - stages_a_idx[buf_idx], - stages_b_idx[buf_idx], - stages_as_idx[buf_idx], - stages_bs_idx[buf_idx], - mid_compute_callback=_mid_tdm_scale_dma, - ) - cur_lo_a = addr_boxes[0][0] - cur_lo_b = addr_boxes[1][0] - cur_scale_k = scale_k_box[0] - hot_loop_scheduler_scheduled() - - results = yield list(accs_in) + [cur_lo_a, cur_lo_b, cur_scale_k] - - accs = list(results[:n_accs]) - addr_lo_a = results[n_accs] - addr_lo_b = results[n_accs + 1] - scale_next_k_base = results[n_accs + 2] else: init_args = list(accs) + [addr_lo_a, addr_lo_b, addr_lo_as, addr_lo_bs] @@ -1979,6 +2374,7 @@ def _mid_tdm_nws( _ab[3][0] = _ab[3][0] + adv_bs_i32 _l2_prefetch(_k_off) + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) rocdl.sched_barrier(0) accs_in = compute_tile_scheduled( accs_in, @@ -1987,6 +2383,7 @@ def _mid_tdm_nws( stages_as_idx[buf_idx], stages_bs_idx[buf_idx], mid_compute_callback=_mid_tdm_nws, + a0_prefetch=a0_prefetch, ) cur_lo_a = addr_boxes[0][0] cur_lo_b = addr_boxes[1][0] @@ -2003,29 +2400,46 @@ def _mid_tdm_nws( addr_lo_bs = results[n_accs + 3] # Tail — same acc_mixed pattern: fence at top, TDM mid-compute. + if const_expr(loop_iters > 0 and use_ws_tdm_split_signal_overlap): + pipeline_fence_wait(use_cluster=use_cluster) if const_expr(loop_iters > 0): _pipeline_fence(outstanding=0) elif const_expr(use_cluster): cluster.cluster_barrier() epi_addrs_box = [None] _tail_had_load = False + # Tail K-tile index, so the VGPR-path scale buffer_load uses the right k_base. + _bvs_tail_kt = [loop_iters * num_buffers] + + def _bvs_tail_kb(): + if const_expr(not _bvs_active): + return None + kb = split_k_base + arith.index(_bvs_tail_kt[0] * tile_k) + _bvs_tail_kt[0] += 1 + return kb + for _load_stage, _compute_stage, _outstanding in tail_plan: + _entry_kb = _bvs_tail_kb() if const_expr(_outstanding == -1): if const_expr(_tail_had_load): _pipeline_fence(outstanding=0) if const_expr(use_tdm_store): + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) accs = compute_tile_scheduled( accs, stages_a_idx[_compute_stage], stages_b_idx[_compute_stage], stages_as_idx[_compute_stage], stages_bs_idx[_compute_stage], + a0_prefetch=a0_prefetch, + scale_k_base=_entry_kb, ) else: def _emit_epi_addrs(): epi_addrs_box[0] = epilogue_prepare_addrs() + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) accs = compute_tile_scheduled( accs, stages_a_idx[_compute_stage], @@ -2033,6 +2447,8 @@ def _emit_epi_addrs(): stages_as_idx[_compute_stage], stages_bs_idx[_compute_stage], emit_filler=_emit_epi_addrs, + a0_prefetch=a0_prefetch, + scale_k_base=_entry_kb, ) else: _pipeline_fence_signal(outstanding=_outstanding) @@ -2041,27 +2457,7 @@ def _emit_epi_addrs(): _tail_mid_cb = None if const_expr(_load_stage is not None): _tail_had_load = True - if const_expr(use_ab_split_scale_buffer_load): - _tail_addr_box = [active_addr_lo] - _tail_scale_k = [scale_next_k_base] - - def _tail_mid_split_scale_dma(_ls=_load_stage, _ab=_tail_addr_box, _scale_k=_tail_scale_k): - _issue_active_tdm(_ls, _ab, scale_k_box=_scale_k) - - _tail_mid_cb = _tail_mid_split_scale_dma - elif const_expr(use_scale_buffer_load): - _tail_ab = [[addr_lo_a], [addr_lo_b]] - _tail_scale_k = [scale_next_k_base] - - def _tail_mid_scale_dma(_ls=_load_stage, _ab=_tail_ab, _scale_k=_tail_scale_k): - _issue_ab_tdm(_ls, _ab[0][0], _ab[1][0]) - _ab[0][0] = _ab[0][0] + adv_a_i32 - _ab[1][0] = _ab[1][0] + adv_b_i32 - _issue_scale_buffer_loads(_ls, _scale_k[0]) - _scale_k[0] = _scale_k[0] + arith.index(tile_k) - - _tail_mid_cb = _tail_mid_scale_dma - elif const_expr(wave_specialized_tdm): + if const_expr(wave_specialized_tdm): _tail_addr_box = [active_addr_lo] def _tail_mid_ws(_ls=_load_stage, _ab=_tail_addr_box): @@ -2090,6 +2486,7 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): _tail_mid_cb = _tail_mid_nws + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) rocdl.sched_barrier(0) accs = compute_tile_scheduled( accs, @@ -2098,17 +2495,12 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): stages_as_idx[_compute_stage], stages_bs_idx[_compute_stage], mid_compute_callback=_tail_mid_cb, + a0_prefetch=a0_prefetch, + scale_k_base=_entry_kb, ) if const_expr(_load_stage is not None): - if const_expr(use_ab_split_scale_buffer_load): - active_addr_lo = _tail_addr_box[0] - scale_next_k_base = _tail_scale_k[0] - elif const_expr(use_scale_buffer_load): - addr_lo_a = _tail_ab[0][0] - addr_lo_b = _tail_ab[1][0] - scale_next_k_base = _tail_scale_k[0] - elif const_expr(wave_specialized_tdm): + if const_expr(wave_specialized_tdm): active_addr_lo = _tail_addr_box[0] else: addr_lo_a = _tail_ab[0][0] @@ -2161,6 +2553,7 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): atomic_barrier_enable, b_streaming, scale_load_path, + fp8_schedule, ) @flyc.jit @@ -2204,7 +2597,7 @@ def launch_mxscale_gemm( cluster=cluster_arg, ) - if expert_sched_mode: + if effective_expert_sched_mode: launch_mxscale_gemm.compile_hints["llvm_options"] = { "amdgpu-expert-scheduling-mode": True, } diff --git a/python/flydsl/expr/rocdl/inline_asm.py b/python/flydsl/expr/rocdl/inline_asm.py index 357af0d21..5cb7fb1a1 100644 --- a/python/flydsl/expr/rocdl/inline_asm.py +++ b/python/flydsl/expr/rocdl/inline_asm.py @@ -65,29 +65,3 @@ def cvt_pk_bf16_f32(src_a_f32, src_b_f32): "=v,v,v", has_side_effects=False, ) - - -def s_prefetch_inst_burst(num_pages: int = 10, page_bytes: int = 4096): - """gfx1250: prefetch ``num_pages`` cache lines of instructions ahead of PC. - - Sets HW_REG_WAVE_MODE bit 8 (allow s_prefetch) then issues - ``s_prefetch_inst_pc_rel offset, s0, 31`` for ``offset = 0, page_bytes, - 2*page_bytes, ...``. Used by GEMM kernels to warm the I-cache before the - main loop starts so the first few iterations don't stall on instruction - fetch. - - Wraps the inline-asm sequence so callers do not need to import the raw - ``llvm`` dialect. - """ - from ..._mlir.dialects import llvm as _llvm - - lines = ["s_setreg_imm32_b32 hwreg(HW_REG_WAVE_MODE, 8, 1), 1"] - for pg in range(num_pages): - lines.append(f"s_prefetch_inst_pc_rel {pg * page_bytes}, s0, 31") - _llvm.inline_asm( - None, - [], - "\n".join(lines), - "", - has_side_effects=True, - ) diff --git a/python/flydsl/expr/rocdl/tdm_ops.py b/python/flydsl/expr/rocdl/tdm_ops.py index 8c727218a..f2644d315 100644 --- a/python/flydsl/expr/rocdl/tdm_ops.py +++ b/python/flydsl/expr/rocdl/tdm_ops.py @@ -215,6 +215,7 @@ def make_tensor_descriptor_2d( lds_byte_offset=None, for_store: bool = False, atomic_barrier_enable: bool = False, + early_timeout: bool = False, ) -> TDMDescriptor2D: """Build a 2D TDM descriptor for tensor_load_to_lds_d2. @@ -260,6 +261,10 @@ def make_tensor_descriptor_2d( relying on TDM atomic-barrier semantics; this helper keeps the encoded atomic-barrier address at zero, so all participating waves must agree on that protocol. + early_timeout: Set the descriptor's early-timeout bit [21]. This is a + multicast-load knob (1 = GL1 returns to the requesters + present when GL2 data arrives, latecomers re-broadcast; + default 0 = standard wider-merge timeout). Returns: TDMDescriptor2D with dgroup0 and dgroup1 ready for tensor_load_2d. @@ -368,12 +373,13 @@ def make_tensor_descriptor_2d( # sgpr0: config bitfields _abe = 1 if atomic_barrier_enable else 0 + _early_timeout = 1 if early_timeout else 0 g1_s0_upper = ( (data_size_code << 16) # data_size [17:16] | (_abe << 18) # atomic_barrier_enable | (0 << 19) # iterate_enable | (pad_enable << 20) # pad_enable - | (0 << 21) # early_timeout + | (_early_timeout << 21) # early_timeout | (enc_interval << 22) # pad_interval [24:22] | (enc_amount << 25) # pad_amount [31:25] ) diff --git a/tests/kernels/benchmark_common.py b/tests/kernels/benchmark_common.py index 2f5749361..7cd238b98 100644 --- a/tests/kernels/benchmark_common.py +++ b/tests/kernels/benchmark_common.py @@ -468,6 +468,17 @@ def bench_kernel_us(run_fn, warmup=10, iters=50, flush_l2=True, prep_fn=None): run_fn() torch.cuda.synchronize() + if flush_buf is None and prep_fn is None: + # Single event pair preserves back-to-back launch pipelining (returns mean latency). + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + run_fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) * 1e3 / iters + start_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index acabe17e0..382a73f0b 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -4,6 +4,7 @@ Kernel implementation: kernels/gemm_fp8fp4_gfx1250.py """ +import math import os import re import sys @@ -33,10 +34,35 @@ SCALE_BLOCK = 32 +def preshuffle_e8m0_scale_coalesced(scale: torch.Tensor, block: int = 128) -> torch.Tensor: + """Lane-major scale layout for direct buffer_load->VGPR. + + Per (M_block=128, K_tile): [group(2), lane16(16), 4 i32], so a buffer_load_b128's + 16 lanes read 256 contiguous bytes. M = mb*128 + (group*4 + j)*16 + lane16. + """ + M, Ks = scale.shape + assert M % block == 0 and Ks % 4 == 0, f"M={M} Ks={Ks} block={block}" + assert block == 128, "coalesced scale layout assumes warp_tile=128 (8 subtiles)" + Kt = Ks // 4 + g = scale.view(M // block, 2, 4, 16, Kt, 4) # [mb, group, j, lane16, kt, spw] + g = g.permute(0, 4, 1, 3, 2, 5).contiguous() # [mb, kt, group, lane16, j, spw] + return g.view(M, Ks) + + def preshuffle_e8m0_scale( - scale: torch.Tensor, warp_tile: int, scale_k_per_tile: int = 4, WMMA_DIM: int = 16 + scale: torch.Tensor, + warp_tile: int, + scale_k_per_tile: int = 4, + WMMA_DIM: int = 16, + coalesced: bool = False, ) -> torch.Tensor: - """Preshuffle E8M0 scale: optional byte swap + interleave for WMMA access.""" + """Preshuffle E8M0 scale: optional byte swap + interleave for WMMA access. + + ``coalesced=True`` produces the lane-major layout the scale_load_path + "vgpr"/"vgpr_ab_split" buffer_load->VGPR path expects. + """ + if coalesced: + return preshuffle_e8m0_scale_coalesced(scale, block=warp_tile) _, K_scale = scale.shape assert K_scale % 4 == 0, f"K_scale must be divisible by 4, got {K_scale}" SCALES_PER_WMMA = 4 @@ -53,6 +79,101 @@ def random_fp8_data(rows: int, cols: int, *, device="cpu") -> torch.Tensor: return torch.randint(0, 126, (rows, cols), dtype=torch.uint8, device=device) +def _fp8_e4m3fn_byte(value: float) -> int: + """Return torch's FP8 E4M3FN byte encoding for a finite scalar.""" + t = torch.tensor([float(value)], dtype=torch.float8_e4m3fn) + byte = int(t.view(torch.uint8).item()) + if (byte & 0x7F) == 0x7F: + raise SystemExit(f"--fill-mode constant {value:g} is outside the finite FP8 E4M3FN range") + return byte + + +def _parse_fill_mode(arg: str): + """Parse --fill-mode as ('random',) or ('const', value).""" + if arg == "random": + return ("random",) + if arg == "zero": + return ("const", 0.0) + try: + value = float(arg) + except ValueError as e: + raise SystemExit(f"--fill-mode must be 'random' or a finite float constant, got {arg!r}") from e + if not math.isfinite(value): + raise SystemExit(f"--fill-mode constant must be finite, got {arg!r}") + return ("const", value) + + +def _fp4_e2m1_packed_fill(rows: int, cols: int, value: float) -> torch.Tensor: + dense = torch.full((rows, cols), float(value), dtype=torch.float32) + return fp4_utils.f32_to_mxfp4(dense).view(torch.uint8) + + +def _random_mxscale_inputs(M: int, N: int, K: int, data_format: str): + if data_format == "a8w4": + a = random_fp8_data(M, K) + b = fp4_utils.random_fp4_packed(N, K) + elif data_format == "fp4": + a = fp4_utils.random_fp4_packed(M, K) + b = fp4_utils.random_fp4_packed(N, K) + elif data_format == "fp8": + a = random_fp8_data(M, K) + b = random_fp8_data(N, K) + else: + raise ValueError(f"unsupported data_format={data_format!r}") + return a, b, fp4_utils.random_e8m0(M, K // SCALE_BLOCK), fp4_utils.random_e8m0(N, K // SCALE_BLOCK) + + +def _const_fill_inputs(M, N, K, data_format: str, value: float): + """Build constant A/B tensors with neutral E8M0 scales for CLI runs.""" + if data_format == "fp4": + a = _fp4_e2m1_packed_fill(M, K, value) + b = _fp4_e2m1_packed_fill(N, K, value) + elif data_format == "a8w4": + fp8_byte = _fp8_e4m3fn_byte(value) + a = torch.full((M, K), fp8_byte, dtype=torch.uint8) + b = _fp4_e2m1_packed_fill(N, K, value) + elif data_format == "fp8": + fp8_byte = _fp8_e4m3fn_byte(value) + a = torch.full((M, K), fp8_byte, dtype=torch.uint8) + b = torch.full((N, K), fp8_byte, dtype=torch.uint8) + else: + raise ValueError(f"unsupported data_format={data_format!r}") + a_scale = torch.full((M, K // SCALE_BLOCK), 127, dtype=torch.uint8) + b_scale = torch.full((N, K // SCALE_BLOCK), 127, dtype=torch.uint8) + return a, b, a_scale, b_scale + + +def _fill_mode_inputs(M: int, N: int, K: int, data_format: str, fill_mode: str): + fill_spec = _parse_fill_mode(fill_mode) + if fill_spec[0] == "const": + a, b, a_scale, b_scale = _const_fill_inputs(M, N, K, data_format, fill_spec[1]) + else: + a, b, a_scale, b_scale = _random_mxscale_inputs(M, N, K, data_format) + return a, b, a_scale, b_scale, fill_spec + + +def _fill_mode_label(fill_spec, data_format: str) -> str: + if fill_spec[0] == "random": + return "random (seed=0)" + label = f"const={fill_spec[1]:g}, E8M0 byte=127" + if data_format in ("fp8", "a8w4"): + label += f", FP8 byte=0x{_fp8_e4m3fn_byte(fill_spec[1]):02x}" + return label + + +def _has_nonzero_quantized_values(tensor: torch.Tensor, data_format: str) -> bool: + convert = fp4_utils.mxfp4_to_f32 if data_format == "fp4" else fp4_utils.fp8_e4m3_to_f32 + return bool(convert(tensor.view(torch.uint8)).abs().max().item() > 0) + + +def _expect_nonzero_graph_output(a: torch.Tensor, b: torch.Tensor, data_format: str, fill_spec) -> bool: + if fill_spec[0] == "random": + return True + a_format = "fp4" if data_format == "fp4" else "fp8" + b_format = "fp8" if data_format == "fp8" else "fp4" + return _has_nonzero_quantized_values(a, a_format) and _has_nonzero_quantized_values(b, b_format) + + def _reference_scaled_gemm(a, b, a_scale, b_scale, M, N, K, convert_fn, convert_fn_b=None): """Reference scaled GEMM: D = (A * A_scale) @ (B * B_scale)^T.""" a_f32 = convert_fn(a.view(torch.uint8))[:M, :K] @@ -243,6 +364,10 @@ def _run_mxscale_gemm_test( _dtype_map = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16} torch_out_dtype = _dtype_map[out_dtype] + # Split-K accumulates across workgroups in fp32; half outputs are converted after. + kernel_out_dtype = "f32" if (split_k > 1 and out_dtype in ("bf16", "f16")) else out_dtype + torch_kernel_dtype = _dtype_map[kernel_out_dtype] + torch.manual_seed(0) fmt_name = "A8W4" if is_a8w4 else ("MXFP4" if is_fp4 else "MXFP8") @@ -287,8 +412,9 @@ def _run_mxscale_gemm_test( skt = tile_k // SCALE_BLOCK warp_tile_m = tile_m // m_warp warp_tile_n = tile_n // n_warp - a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt) - b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt) + _coalesced_scale = scale_load_path in ("vgpr", "vgpr_ab_split") + a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt, coalesced=_coalesced_scale) + b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt, coalesced=_coalesced_scale) # Preshuffle B data K_packed = padded_k // padded_shape["pack_b"] @@ -299,7 +425,7 @@ def _run_mxscale_gemm_test( b_gpu = b.cuda() as_gpu = a_scale.cuda() bs_gpu = b_scale.cuda() - c_gpu = torch.zeros(padded_m, padded_n, dtype=torch_out_dtype, device="cuda") + c_gpu = torch.zeros(padded_m, padded_n, dtype=torch_kernel_dtype, device="cuda") launch_fn = compile_mxscale_gemm( data_format=data_format, @@ -317,7 +443,7 @@ def _run_mxscale_gemm_test( cluster_m=cluster_m, cluster_n=cluster_n, use_tdm_store=use_tdm_store, - out_dtype=out_dtype, + out_dtype=kernel_out_dtype, inst_prefetch=inst_prefetch, wave_specialized_tdm=wave_specialized_tdm, split_k=split_k, @@ -327,18 +453,12 @@ def _run_mxscale_gemm_test( scale_load_path=scale_load_path, ) - # Pre-bind via flyc.compile so the launch goes through the CompiledFunction - # ctypes fast path (matches test_blockscale_preshuffle_gemm.py and any - # production caller that bench-times this kernel). The slow JitFunction - # path adds ~17us of inspect.Signature.bind + _resolve_and_make_cache_key per call, - # which would mask genuine kernel timing differences in the bench path. - # flyc.compile() launches the kernel once internally to trigger - # compilation, so no separate eager call is needed for correctness. - c_flat = c_gpu.contiguous().view(-1) - a_flat = a_gpu.contiguous().view(-1) - b_flat = b_gpu.contiguous().view(-1) - as_flat = as_gpu.contiguous().view(-1) - bs_flat = bs_gpu.contiguous().view(-1) + # Keep 2D — dynamic_layout=True packs shape as i32; flattening overflows for M*K >= 2^31. + c_flat = c_gpu.contiguous() + a_flat = a_gpu.contiguous() + b_flat = b_gpu.contiguous() + as_flat = as_gpu.contiguous() + bs_flat = bs_gpu.contiguous() flyc.compile( launch_fn, @@ -353,7 +473,8 @@ def _run_mxscale_gemm_test( ) torch.cuda.synchronize() - c_out = c_gpu[:M, :N].cpu() + # Convert the fp32 split-K accumulation back to the requested half dtype. + c_out = c_gpu[:M, :N].to(torch_out_dtype).cpu() print( f"Out stats: min={c_out.float().min():.2f}, max={c_out.float().max():.2f}, " @@ -374,7 +495,13 @@ def _run_mxscale_gemm_test( diff = (c_out_f - ref_f).abs() print(f"Abs diff: max={diff.max():.4f}, mean={diff.mean():.4f}") - cos_sim = torch.nn.functional.cosine_similarity(c_out_f.flatten().unsqueeze(0), ref_f.flatten().unsqueeze(0)).item() + # Compute cosine in float64: for large M/N/K with large E8M0 scales the values + # (and their squares) overflow float32's accurate-summation range, so an fp32 + # cosine reduction saturates and can print values outside [-1, 1]. fp64 keeps + # the diagnostic meaningful. (Pass/fail is gated by assert_close below, not this.) + cos_sim = torch.nn.functional.cosine_similarity( + c_out_f.flatten().unsqueeze(0).double(), ref_f.flatten().unsqueeze(0).double() + ).item() print(f"Cosine similarity: {cos_sim:.6f}") # Tolerances: FP4 is exact; FP8/A8W4 have FP accumulation error @@ -510,7 +637,7 @@ def test_mxfp4_metadata_and_spill_regression(out_dtype): @pytest.mark.parametrize("use_tdm_store", [True, False]) @pytest.mark.parametrize("use_scale_opsel", [True, False]) @pytest.mark.parametrize("out_dtype", ["f32", "bf16"]) -@pytest.mark.parametrize("scale_load_path", ["tdm", "buffer_lds_stage"]) +@pytest.mark.parametrize("scale_load_path", ["tdm"]) def test_mxfp8_gemm( M, N, @@ -545,6 +672,32 @@ def test_mxfp8_gemm( ) +@pytest.mark.parametrize("split_k", [2, 4, 6, 8]) +@pytest.mark.parametrize("out_dtype", ["f32", "bf16"]) +def test_mxfp8_gemm_splitk(split_k, out_dtype): + """FP8 split-K: split_k workgroups accumulate partial K-sums into C via atomic add. + + Exercises the atomic epilogue path (use_tdm_store=False). K=2048/tile_k=128 gives + every split_k value >= 2 local K-tiles (needed for double buffering). + """ + _run_mxscale_gemm_test( + "fp8", + 128, + 256, + 2048, + 128, + 256, + 128, + 2, + 4, + num_buffers=2, + use_tdm_store=False, + out_dtype=out_dtype, + l2_prefetch_distance=2, + split_k=split_k, + ) + + @pytest.mark.parametrize( "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", [ @@ -660,11 +813,10 @@ def test_b_streaming_with_wave_spec_tdm(data_format, M, N, K, tile_m, tile_n, ti ) -@pytest.mark.parametrize("scale_load_path", ["tdm", "buffer_lds_stage", "buffer_lds_stage_ab_split"]) @pytest.mark.parametrize("num_buffers", [2, 3]) @pytest.mark.parametrize("use_tdm_store", [True, False]) @pytest.mark.parametrize("use_scale_opsel", [False, True]) -def test_mxfp8_wave_spec_scale_load_paths(scale_load_path, num_buffers, use_tdm_store, use_scale_opsel): +def test_mxfp8_wave_spec_scale_load_tdm(num_buffers, use_tdm_store, use_scale_opsel): _run_mxscale_gemm_test( "fp8", 128, @@ -681,28 +833,31 @@ def test_mxfp8_wave_spec_scale_load_paths(scale_load_path, num_buffers, use_tdm_ l2_prefetch_distance=2, wave_specialized_tdm=True, use_scale_opsel=use_scale_opsel, - scale_load_path=scale_load_path, + scale_load_path="tdm", ) -def test_mxfp8_ab_split_scale_load_allows_extra_waves(): +@pytest.mark.parametrize("scale_load_path", ["vgpr", "vgpr_ab_split"]) +@pytest.mark.parametrize("cluster_m, cluster_n", [(1, 1), (2, 2)]) +def test_mxfp8_vgpr_scale_load(scale_load_path, cluster_m, cluster_n): _run_mxscale_gemm_test( "fp8", - 128, + 256 * cluster_m, + 256 * cluster_n, + 512, 256, - 384, - 128, 256, 128, 2, - 4, - num_buffers=3, + 2, + num_buffers=4, use_tdm_store=True, out_dtype="bf16", l2_prefetch_distance=2, wave_specialized_tdm=True, - use_scale_opsel=True, - scale_load_path="buffer_lds_stage_ab_split", + cluster_m=cluster_m, + cluster_n=cluster_n, + scale_load_path=scale_load_path, ) @@ -855,11 +1010,11 @@ def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_ split_k=1, ) - c_flat = c_gpu.contiguous().view(-1) - a_flat = a_gpu.contiguous().view(-1) - b_flat = b_gpu.contiguous().view(-1) - as_flat = as_gpu.contiguous().view(-1) - bs_flat = bs_gpu.contiguous().view(-1) + c_flat = c_gpu.contiguous() + a_flat = a_gpu.contiguous() + b_flat = b_gpu.contiguous() + as_flat = as_gpu.contiguous() + bs_flat = bs_gpu.contiguous() compiled_exe = flyc.compile( launch_fn, c_flat, @@ -914,27 +1069,11 @@ def launch(): ) -def _bench_kernel_us_cudagraph(run_fn, warmup=10, iters=100, prep_fn=None): - """Per-launch timer that strips host launch overhead via hipGraph. - - How it works: - - Capture a single kernel launch into a hipGraph - - Replay it N times in one stream submission burst - - Per-launch time = total_burst_time / N - - NB: stream-ordered execution guarantees the N replays serialise — each - g.replay() is one submission to the stream and the next one cannot - start until the previous one finishes. - - NB: no L2 flush between replays. The whole point of the graph is to - measure the kernel in a hot, back-to-back launch scenario (which is - what production serving looks like). For cold-cache numbers use the - regular _bench_kernel_us with flush_l2=True. - """ +def _bench_kernel_us_cudagraph(run_fn, warmup=10, iters=100, prep_fn=None, n_per_graph=20): + """Per-launch timer via hipGraph: capture n_per_graph launches, replay iters times, single event pair around the whole replay loop.""" capture_stream = torch.cuda.Stream() capture_stream.wait_stream(torch.cuda.current_stream()) - # Warmup on the capture stream so the allocator / JIT cache is settled. with torch.cuda.stream(capture_stream): for _ in range(warmup): if prep_fn is not None: @@ -943,15 +1082,46 @@ def _bench_kernel_us_cudagraph(run_fn, warmup=10, iters=100, prep_fn=None): torch.cuda.current_stream().wait_stream(capture_stream) torch.cuda.synchronize() - # Capture exactly one kernel launch into the graph. g = torch.cuda.CUDAGraph() if prep_fn is not None: prep_fn() with torch.cuda.graph(g, stream=capture_stream): + for _ in range(n_per_graph): + run_fn() + torch.cuda.synchronize() + + # Sanity guard against empty graph capture. + ref_start = torch.cuda.Event(enable_timing=True) + ref_end = torch.cuda.Event(enable_timing=True) + ref_start.record() + for _ in range(n_per_graph): run_fn() + ref_end.record() torch.cuda.synchronize() + ref_per_launch_us = ref_start.elapsed_time(ref_end) * 1e3 / n_per_graph + + rep_start = torch.cuda.Event(enable_timing=True) + rep_end = torch.cuda.Event(enable_timing=True) + rep_start.record() + g.replay() + rep_end.record() + torch.cuda.synchronize() + first_replay_per_launch_us = rep_start.elapsed_time(rep_end) * 1e3 / n_per_graph + + print( + f"SANITY_GRAPH,n_per_graph={n_per_graph}," + f"ref_per_launch_us={ref_per_launch_us:.3f}," + f"first_replay_per_launch_us={first_replay_per_launch_us:.3f}", + file=sys.stderr, + flush=True, + ) + if first_replay_per_launch_us < 1.0 and ref_per_launch_us > 2.0: + raise RuntimeError( + f"hipGraph replay per-launch={first_replay_per_launch_us:.3f}us " + f"<< ref direct-launch={ref_per_launch_us:.3f}us. " + f"Graph capture likely empty (stream mismatch?)." + ) - # Time iters replays as one batch. start_ev = torch.cuda.Event(enable_timing=True) end_ev = torch.cuda.Event(enable_timing=True) start_ev.record() @@ -959,18 +1129,11 @@ def _bench_kernel_us_cudagraph(run_fn, warmup=10, iters=100, prep_fn=None): g.replay() end_ev.record() torch.cuda.synchronize() - total_us = start_ev.elapsed_time(end_ev) * 1e3 - return total_us / iters + return start_ev.elapsed_time(end_ev) * 1e3 / (iters * n_per_graph) def _bench_kernel_us(run_fn, warmup=10, iters=50, flush_l2=True, prep_fn=None): - """Per-iteration CUDA events timer with L2 flush, IQR outlier removal, median. - - Follows the golden practices from benchmarks/gemm_gfx1250_benchmark.py: - - L2 flush between iterations via zero-fill of 2× L2 buffer - - Per-iteration event pairs (no inter-iteration sync on kernel) - - IQR outlier removal (if n >= 8), median latency - """ + """Per-iter CUDA events with L2 flush + IQR-trimmed median; fast path uses a single event pair when no flush/prep is requested (preserves back-to-back launch pipelining).""" flush_buf = None if flush_l2: l2_bytes = getattr( @@ -987,6 +1150,17 @@ def _bench_kernel_us(run_fn, warmup=10, iters=50, flush_l2=True, prep_fn=None): run_fn() torch.cuda.synchronize() + if flush_buf is None and prep_fn is None: + # Single event pair preserves back-to-back launch pipelining (returns mean latency). + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + run_fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) * 1e3 / iters + start_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -1038,8 +1212,11 @@ def _run_benchmark(args): is_fp4 = data_format == "fp4" is_a8w4 = data_format == "a8w4" _dtype_map = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16} - torch_out_dtype = _dtype_map[args.out_dtype] - elem_bytes_d = 2 if args.out_dtype in ("bf16", "f16") else 4 + # split_k>1 accumulates partial K-sums in fp32 for precision; bf16/f16 atomics are + # supported but compound rounding error, so we run f32 and convert back on the host. + kernel_out_dtype = "f32" if (args.split_k > 1 and args.out_dtype in ("bf16", "f16")) else args.out_dtype + torch_kernel_dtype = _dtype_map[kernel_out_dtype] + elem_bytes_d = 2 if kernel_out_dtype in ("bf16", "f16") else 4 fmt_name = "A8W4" if is_a8w4 else ("MXFP4" if is_fp4 else "MXFP8") print("=" * 72) @@ -1057,32 +1234,23 @@ def _run_benchmark(args): ) if args.split_k > 1: print(f" Split-K={args.split_k} (atomic accumulate, buffer-store epilogue)") - print(f" Warmup={args.warmup}, Iters={args.iters}, " f"L2 flush={'ON' if not args.no_flush_l2 else 'OFF'}") - print(" Zero fill: ON (outside timing)") + l2_flush_label = "OFF (graph)" if getattr(args, "use_graph", False) else ("OFF" if args.no_flush_l2 else "ON") + print(f" Warmup={args.warmup}, Iters={args.iters}, L2 flush={l2_flush_label}") + print(" Output init: zero before warmup") print("=" * 72) torch.manual_seed(0) - - if is_a8w4: - a = random_fp8_data(M, K) - b = fp4_utils.random_fp4_packed(N, K) - elif is_fp4: - a = fp4_utils.random_fp4_packed(M, K) - b = fp4_utils.random_fp4_packed(N, K) - else: - a = random_fp8_data(M, K) - b = random_fp8_data(N, K) - - a_scale = fp4_utils.random_e8m0(M, K // SCALE_BLOCK) - b_scale = fp4_utils.random_e8m0(N, K // SCALE_BLOCK) + a, b, a_scale, b_scale, fill_spec = _fill_mode_inputs(M, N, K, data_format, getattr(args, "fill_mode", "random")) + print(f" Fill mode: {_fill_mode_label(fill_spec, data_format)}") a, b, a_scale, b_scale = _pad_mxscale_inputs(a, b, a_scale, b_scale, padded_shape) skt = tile_k // SCALE_BLOCK warp_tile_m = tile_m // args.m_warp warp_tile_n = tile_n // args.n_warp - a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt) - b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt) + _coalesced_scale = args.scale_load_path in ("vgpr", "vgpr_ab_split") + a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt, coalesced=_coalesced_scale) + b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt, coalesced=_coalesced_scale) K_packed = padded_k // PACK_B b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) @@ -1091,7 +1259,7 @@ def _run_benchmark(args): b_gpu = b.cuda() as_gpu = a_scale.cuda() bs_gpu = b_scale.cuda() - c_gpu = torch.zeros(padded_m, padded_n, dtype=torch_out_dtype, device="cuda") + c_gpu = torch.zeros(padded_m, padded_n, dtype=torch_kernel_dtype, device="cuda") print("\n[1/3] Compiling kernel...") t0 = time.perf_counter() @@ -1115,7 +1283,7 @@ def _run_benchmark(args): cluster_m=args.cluster_m, cluster_n=args.cluster_n, use_tdm_store=use_tdm_store, - out_dtype=args.out_dtype, + out_dtype=kernel_out_dtype, inst_prefetch=args.inst_prefetch, wave_specialized_tdm=args.wave_spec_tdm, split_k=args.split_k, @@ -1126,23 +1294,13 @@ def _run_benchmark(args): scale_load_path=args.scale_load_path, ) - c_flat = c_gpu.view(-1) - a_flat = a_gpu.view(-1) - b_flat = b_gpu.view(-1) - as_flat = as_gpu.view(-1) - bs_flat = bs_gpu.view(-1) - - # Pre-bind via flyc.compile so the bench loop calls go through the - # CompiledFunction ctypes fast path. The slow JitFunction path adds - # ~17us of inspect.Signature.bind + _resolve_and_make_cache_key per call, which - # would dominate per-launch latency for short kernels. compiled_exe = flyc.compile( launch_fn, - c_flat, - a_flat, - b_flat, - as_flat, - bs_flat, + c_gpu, + a_gpu, + b_gpu, + as_gpu, + bs_gpu, padded_m, padded_n, torch.cuda.current_stream(), @@ -1151,16 +1309,13 @@ def _run_benchmark(args): def prep_kernel(): c_gpu.zero_() - # Resolve the stream lazily inside the closure so the graph-bench path - # captures on the active capture stream rather than the stream bound - # before capture. Same value on the eager path. def run_kernel(): compiled_exe( - c_flat, - a_flat, - b_flat, - as_flat, - bs_flat, + c_gpu, + a_gpu, + b_gpu, + as_gpu, + bs_gpu, padded_m, padded_n, torch.cuda.current_stream(), @@ -1174,12 +1329,7 @@ def run_kernel(): use_graph = getattr(args, "use_graph", False) if use_graph: - print(f"[2/3] Warming up ({args.warmup} iters) + bench via hipGraph ({args.iters} replays)...") - # Graph mode prep: don't zero c_gpu inside the captured kernel - # (zero would be baked into the graph and runs every replay, but - # that's also fine — it would add a trivial memset per replay). - # We omit prep_fn here because the c_gpu state across replays - # doesn't matter for timing. + print(f"[2/3] Warming up ({args.warmup} iters) + bench via hipGraph " f"({args.iters} replays)...") us = _bench_kernel_us_cudagraph(run_kernel, warmup=args.warmup, iters=args.iters) else: print(f"[2/3] Warming up ({args.warmup} iters) + benchmarking ({args.iters} iters)...") @@ -1249,17 +1399,162 @@ def run_kernel(): return us, reported_tflops, bw_gbs +def _run_graph_verify(args): + """Compare eager launch and hipGraph replay for the CLI-selected shape.""" + arch = str(get_rocm_arch()) + if arch != "gfx1250": + raise SystemExit(f"WMMA_SCALE requires gfx1250, got {arch}") + + data_format = args.data_format + M, N, K = args.M, args.N, args.K + tile_m, tile_n, tile_k = args.tile_m, args.tile_n, args.tile_k + if K % SCALE_BLOCK != 0: + raise SystemExit(f"K={K} must be divisible by SCALE_BLOCK={SCALE_BLOCK}") + + padded_shape = _get_padded_problem_shape(data_format, M, N, K, tile_m, tile_n, tile_k, args.split_k) + padded_m = padded_shape["M"] + padded_n = padded_shape["N"] + padded_k = padded_shape["K"] + + print("=" * 72) + print(f" Graph functional verification ({data_format}) on gfx1250") + print(f" Shape: M={M}, N={N}, K={K} (padded {padded_m}x{padded_n}x{padded_k})") + print( + f" Tile: ({tile_m},{tile_n},{tile_k}) warps=({args.m_warp}x{args.n_warp}) " + f"nb={args.num_buffers} sk={args.split_k} " + f"cluster=({args.cluster_m},{args.cluster_n})" + ) + print("=" * 72) + + torch.manual_seed(0) + a, b, a_scale, b_scale, fill_spec = _fill_mode_inputs(M, N, K, data_format, getattr(args, "fill_mode", "random")) + expect_nonzero_output = _expect_nonzero_graph_output(a, b, data_format, fill_spec) + print(f" Fill: {_fill_mode_label(fill_spec, data_format)}") + + a, b, a_scale, b_scale = _pad_mxscale_inputs(a, b, a_scale, b_scale, padded_shape) + + skt = tile_k // SCALE_BLOCK + warp_tile_m = tile_m // args.m_warp + warp_tile_n = tile_n // args.n_warp + _coalesced_scale = args.scale_load_path in ("vgpr", "vgpr_ab_split") + a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt, coalesced=_coalesced_scale) + b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt, coalesced=_coalesced_scale) + K_packed = padded_k // padded_shape["pack_b"] + b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) + + a_gpu = a.cuda() + b_gpu = b.cuda() + as_gpu = a_scale.cuda() + bs_gpu = b_scale.cuda() + _dtype_map = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16} + # split_k>1 accumulates partial K-sums in fp32 for precision; bf16/f16 atomics are + # supported but compound rounding error, so we run f32 and convert back on the host. + kernel_out_dtype = "f32" if (args.split_k > 1 and args.out_dtype in ("bf16", "f16")) else args.out_dtype + c_gpu = torch.zeros(padded_m, padded_n, dtype=_dtype_map[kernel_out_dtype], device="cuda") + + use_tdm_store = not args.no_tdm_store and args.split_k == 1 + launch_fn = compile_mxscale_gemm( + data_format=data_format, + M=padded_m, + N=padded_n, + K=padded_k, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=args.m_warp, + n_warp=args.n_warp, + num_buffers=args.num_buffers, + waves_per_eu=args.waves_per_eu, + l2_prefetch_distance=args.l2_prefetch_distance, + cluster_m=args.cluster_m, + cluster_n=args.cluster_n, + use_tdm_store=use_tdm_store, + out_dtype=kernel_out_dtype, + inst_prefetch=args.inst_prefetch, + wave_specialized_tdm=args.wave_spec_tdm, + split_k=args.split_k, + use_scale_opsel=args.use_scale_opsel, + expert_sched_mode=args.expert_sched_mode, + atomic_barrier_enable=args.atomic_barrier_enable, + b_streaming=args.b_streaming, + scale_load_path=args.scale_load_path, + ) + + c_flat = c_gpu.contiguous() + a_flat = a_gpu.contiguous() + b_flat = b_gpu.contiguous() + as_flat = as_gpu.contiguous() + bs_flat = bs_gpu.contiguous() + compiled_exe = flyc.compile( + launch_fn, + c_flat, + a_flat, + b_flat, + as_flat, + bs_flat, + padded_m, + padded_n, + torch.cuda.current_stream(), + ) + + def launch(): + compiled_exe(c_flat, a_flat, b_flat, as_flat, bs_flat, padded_m, padded_n, torch.cuda.current_stream()) + + c_gpu.zero_() + launch() + torch.cuda.synchronize() + eager_result = c_gpu.clone() + + g = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + launch() + torch.cuda.current_stream().wait_stream(s) + torch.cuda.synchronize() + + c_gpu.zero_() + with torch.cuda.graph(g, stream=s): + launch() + torch.cuda.synchronize() + + c_gpu.zero_() + g.replay() + torch.cuda.synchronize() + graph_result = c_gpu.clone() + + if expect_nonzero_output: + if eager_result.abs().max().item() == 0: + raise SystemExit( + "FAIL: eager run produced all zeros -- kernel did not execute (unexpected for non-zero fill)." + ) + if graph_result.abs().max().item() == 0: + raise SystemExit( + "FAIL: hipGraph replay produced all zeros -- kernel was NOT captured (stream mismatch suspected)." + ) + if not torch.equal(eager_result, graph_result): + diff = (eager_result.float() - graph_result.float()).abs().max().item() + raise SystemExit(f"FAIL: eager vs hipGraph result mismatch, max abs diff = {diff:.6f}") + + sample_max = eager_result.abs().max().item() + print( + f" Eager output |max| = {sample_max:.6g}" + + ("" if expect_nonzero_output else " (zero is expected for this fill)") + ) + print(" PASS: eager == hipGraph replay (bit-exact)") + + if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() - parser.add_argument("--data-format", type=str, default="fp4", choices=["fp4", "fp8", "a8w4"]) + parser.add_argument("--data-format", type=str, default="fp8", choices=["fp4", "fp8", "a8w4"]) parser.add_argument("-M", type=int, default=1024) parser.add_argument("-N", type=int, default=1024) parser.add_argument("-K", type=int, default=2048) - parser.add_argument("--tile-m", type=int, default=128) + parser.add_argument("--tile-m", type=int, default=256) parser.add_argument("--tile-n", type=int, default=256) - parser.add_argument("--tile-k", type=int, default=256) + parser.add_argument("--tile-k", type=int, default=128) parser.add_argument("--m-warp", type=int, default=2) parser.add_argument("--n-warp", type=int, default=2) parser.add_argument("--num-buffers", type=int, default=4, choices=[2, 3, 4]) @@ -1270,14 +1565,14 @@ def run_kernel(): parser.add_argument("--no-tdm-store", action="store_true", default=False) parser.add_argument("--out-dtype", type=str, default="bf16", choices=["f32", "bf16", "f16"]) parser.add_argument("--inst-prefetch", action="store_true", default=False) - parser.add_argument("--wave-spec-tdm", action="store_true", default=False) + parser.add_argument("--no-wave-spec-tdm", dest="wave_spec_tdm", action="store_false", default=True) parser.add_argument("--waves-per-eu", type=int, default=None) parser.add_argument("--use-scale-opsel", action="store_true", default=False) parser.add_argument( "--scale-load-path", type=str, default="tdm", - choices=["tdm", "buffer_lds_stage", "buffer_lds_stage_ab_split"], + choices=["tdm", "vgpr", "vgpr_ab_split"], ) parser.add_argument("--disable-expert-sched-mode", dest="expert_sched_mode", action="store_false", default=True) parser.add_argument("--b-streaming", action="store_true", default=False) @@ -1303,8 +1598,26 @@ def run_kernel(): "Implicitly disables L2 flush (graph replays " "are back-to-back, hot-cache).", ) + parser.add_argument( + "--verify-graph", + action="store_true", + default=False, + help="Functional verification: capture the kernel in a hipGraph, " + "replay once, assert bit-exact match against an eager launch. ", + ) + parser.add_argument( + "--fill-mode", + type=str, + default="random", + help="Input fill mode: 'random', 'zero', or a finite float. Constant " + "mode uses FP8/FP4 encodings for A/B and neutral E8M0 scales.", + ) args = parser.parse_args() + if args.verify_graph: + _run_graph_verify(args) + if not args.benchmark: + sys.exit(0) if args.benchmark: _run_benchmark(args) else: