From 3e7b9fad4e93b45b74a173f18a22d23de27f556c Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Thu, 21 May 2026 10:41:14 +0000 Subject: [PATCH 1/7] [Enh] Ensure type closure for primitive func --- kernels/layout_utils.py | 5 +- kernels/mfma_preshuffle_pipeline.py | 11 ++-- python/flydsl/expr/math.py | 27 +++++--- python/flydsl/expr/primitive.py | 99 ++++++++++++++++++++++++---- python/flydsl/expr/typing.py | 9 +-- tests/unit/test_layout_algebra.py | 22 ++----- tests/unit/test_static_vs_dynamic.py | 20 ++---- 7 files changed, 127 insertions(+), 66 deletions(-) diff --git a/kernels/layout_utils.py b/kernels/layout_utils.py index 976996c06..1439af186 100644 --- a/kernels/layout_utils.py +++ b/kernels/layout_utils.py @@ -156,9 +156,8 @@ def crd2idx(crd, layout): cv = raw crd_i32.append(cv) coord_val = fx.make_coord(*crd_i32) - result = fx.crd2idx(coord_val, layout) - scalar = fx.get_scalar(result) - if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType): + scalar = fx.get_scalar(fx.crd2idx(coord_val, layout)).ir_value() + if not isinstance(scalar.type, ir.IndexType): scalar = arith.index_cast(T.index, scalar) return _wrap(scalar) diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index 118ba6703..0a1309956 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -20,12 +20,11 @@ def crd2idx(crd, layout): - """crd2idx returning an index-type scalar (unwraps fly.int_tuple).""" - result = fx.crd2idx(crd, layout) - scalar = fx.get_scalar(result) - if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType): - scalar = _arith.IndexCastOp(T.index, scalar).result - return scalar + """crd2idx returning an index-typed ir.Value (unwraps fly.int_tuple).""" + scalar = fx.get_scalar(fx.crd2idx(crd, layout)).ir_value() + if isinstance(scalar.type, ir.IndexType): + return scalar + return _arith.IndexCastOp(T.index, scalar).result def swizzle_xor16(row, col, k_blocks16): diff --git a/python/flydsl/expr/math.py b/python/flydsl/expr/math.py index c4d128c2f..4cc061fef 100644 --- a/python/flydsl/expr/math.py +++ b/python/flydsl/expr/math.py @@ -23,17 +23,21 @@ def _traced_math_op(fn): - """Like @traced_op, but re-wraps results to preserve Numeric class hierarchy. + """Like @traced_op, but re-wraps results to preserve DslType closure. - If the first positional arg is a Numeric (Float32, Int32, …), the MLIR - result is wrapped back into the appropriate Numeric subclass via - ``Numeric.from_ir_type``. Raw ir.Value inputs pass through unchanged. + If the first positional arg is a ``Numeric`` (Float32, Int32, …) or a + ``Vector``, the MLIR result is wrapped back into the matching DSL type so + callers stay at the DSL level instead of dropping to raw ir.Value / ArithValue. + Raw ir.Value inputs pass through unchanged. """ @wraps(fn) def wrapper(*args, **kwargs): + from .typing import Vector + first = args[0] if args else None - do_rewrap = isinstance(first, Numeric) + is_vector = isinstance(first, Vector) + is_numeric = isinstance(first, Numeric) loc = kwargs.pop("loc", None) if loc is None: @@ -42,12 +46,19 @@ def wrapper(*args, **kwargs): with loc: result = fn(*args, **kwargs) - if not do_rewrap: + if not (is_vector or is_numeric): return result + + def _wrap_arith_type(value): + if is_vector: + elem_dtype = Numeric.from_ir_type(ir.VectorType(value.type).element_type) + return Vector(value, first.shape, elem_dtype) + return Numeric.from_ir_type(value.type)(value) + if isinstance(result, ir.Value): - return Numeric.from_ir_type(result.type)(result) + return _wrap_arith_type(result) # Multi-result (e.g. sincos) - return tuple(Numeric.from_ir_type(r.type)(r) for r in result) + return tuple(_wrap_arith_type(r) for r in result) return wrapper diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 10f0de307..831330e32 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -247,6 +247,16 @@ def _check_profile(match_func, lhs, rhs): raise ValueError(f"profile mismatch: {match_func.__name__}({lhs.type}, {rhs.type}) is False") +def _wrap_numeric_type(value): + from .numeric import Numeric + + if not isinstance(value, ir.Value): + return value + if isinstance(value, Numeric): + return value + return Numeric.from_ir_type(value.type)(value) + + # ===----------------------------------------------------------------------=== # # Compile-time utility # ===----------------------------------------------------------------------=== # @@ -478,13 +488,54 @@ def make_fragment_like(tensor, dtype=None, loc=None, ip=None): @traced_op def get_scalar(int_tuple, loc=None, ip=None): - return fly.get_scalar(int_tuple, loc=loc, ip=ip) + """Unwrap a rank-1, single-element tuple back to a plain scalar value. + + Fails if the input has more than one leaf - use this only when you know + the tuple is a trivial wrapper. + + Examples: + get_scalar(make_coord(tid)) -> Int32(tid) + get_scalar(make_int_tuple(5)) -> 5 + """ + if not _is_int_tuple_value(int_tuple): + return int_tuple + if int_tuple.is_leaf and int_tuple.is_static: + return int_tuple.get_static_leaf_int + return _wrap_numeric_type(fly.get_scalar(int_tuple, loc=loc, ip=ip)) @traced_op def get_leaves(input, dynamic_only=False, loc=None, ip=None): - res_lists = fly.GetLeavesOp(input, dynamicOnly=dynamic_only, loc=loc, ip=ip) - return tuple(res_lists.results) + """Flatten an IntTuple into a flat sequence of leaf values. + + Set *dynamic_only=True* to keep only runtime values and drop static + constants - handy when you need the inputs that were passed at call time. + + Examples: + get_leaves(make_coord(tid, 0)) -> (Int32(tid), 0) + get_leaves(make_coord(tid, 0), dynamic_only=True) -> (Int32(tid),) # 0 is static, dropped + """ + if dynamic_only: + res_lists = fly.GetLeavesOp(input, dynamicOnly=True, loc=loc, ip=ip) + return tuple(_wrap_numeric_type(r) for r in res_lists.results) + + def _walk_int_tuple_leaves(ty): + if ty.is_leaf: + yield ty + return + for i in range(ty.rank): + yield from _walk_int_tuple_leaves(ty.at(i)) + + ty = IntTupleType(input.type) + res_lists = fly.GetLeavesOp(input, dynamicOnly=True, loc=loc, ip=ip) + dyn_iter = iter(res_lists.results) + out = [] + for leaf_ty in _walk_int_tuple_leaves(ty): + if leaf_ty.is_static: + out.append(leaf_ty.get_static_leaf_int) + else: + out.append(_wrap_numeric_type(next(dyn_iter))) + return tuple(out) @traced_op @@ -1041,11 +1092,17 @@ def ptrtoint(ptr, loc=None, ip=None): if is_generic_address_space(ptr.address_space, AddressSpace.Register): raise ValueError("ptrtoint is not supported for register address space") - return fly.ptrtoint(ptr, loc=loc, ip=ip) + return _wrap_numeric_type(fly.ptrtoint(ptr, loc=loc, ip=ip)) @traced_op def add_offset(ptr, offset, loc=None, ip=None): + """Shift *ptr* by *offset* elements + + Examples: + ptr2 = add_offset(ptr, 16) # move forward 16 elements + ptr2 = add_offset(ptr, tile_id * BM) # runtime offset + """ if not _is_int_tuple_value(offset): offset = make_int_tuple(offset, loc=loc, ip=ip) return fly.add_offset(ptr, offset, loc=loc, ip=ip) @@ -1058,13 +1115,27 @@ def apply_swizzle(ptr, swizzle, loc=None, ip=None): @traced_op def ptr_load(ptr, result_type=None, loc=None, ip=None): + """Load one value (scalar or vector) from *ptr*; dtype defaults to ptr's element type. + + Examples: + v = ptr_load(ptr) + """ if result_type is None: result_type = ptr.element_type - return fly.ptr_load(result_type.ir_type, ptr, loc=loc, ip=ip) + if not isinstance(result_type, ir.Type): + result_type = result_type.ir_type + return _wrap_numeric_type(fly.ptr_load(result_type, ptr, loc=loc, ip=ip)) @traced_op def ptr_store(value, ptr, loc=None, ip=None): + """Store *value* into *ptr*. Types must match the pointer's element type. + + Examples: + ptr_store(val, ptr) + """ + if not isinstance(value, ir.Value): + value = ptr.element_type(value).ir_value() return fly.ptr_store(value, ptr, loc=loc, ip=ip) @@ -1095,7 +1166,9 @@ def memref_alloca(memref_type, layout, loc=None, ip=None): @traced_op def memref_load_vec(memref, loc=None, ip=None): - return fly.memref_load_vec(memref, loc=loc, ip=ip) + from .typing import Vector + + return Vector(fly.memref_load_vec(memref, loc=loc, ip=ip), memref.shape.to_py_value(), memref.dtype) @traced_op @@ -1106,26 +1179,24 @@ def memref_store_vec(vector, memref, loc=None, ip=None): @traced_op def memref_load(memref, indices, loc=None, ip=None): if isinstance(indices, ir.Value): - if str(indices.type).startswith("!fly.int_tuple"): - return fly.memref_load(memref, indices, loc=loc, ip=ip) if str(indices.type) == "index": indices = _arith.IndexCastOp(T.i32(), indices) - indices = make_int_tuple(indices, loc=loc, ip=ip) - return fly.memref_load(memref, indices, loc=loc, ip=ip) + if not _is_int_tuple_value(indices): + indices = make_int_tuple(indices, loc=loc, ip=ip) + return _wrap_numeric_type(fly.memref_load(memref, indices, loc=loc, ip=ip)) indices = make_int_tuple(indices, loc=loc, ip=ip) _check_profile(is_profile_weakly_congruent, indices, memref) - return fly.memref_load(memref, indices, loc=loc, ip=ip) + return _wrap_numeric_type(fly.memref_load(memref, indices, loc=loc, ip=ip)) @traced_op def memref_store(value, memref, indices, loc=None, ip=None): if isinstance(indices, ir.Value): - if str(indices.type).startswith("!fly.int_tuple"): - return fly.memref_store(value, memref, indices, loc=loc, ip=ip) if str(indices.type) == "index": indices = _arith.IndexCastOp(T.i32(), indices) - indices = make_int_tuple(indices, loc=loc, ip=ip) + if not _is_int_tuple_value(indices): + indices = make_int_tuple(indices, loc=loc, ip=ip) return fly.memref_store(value, memref, indices, loc=loc, ip=ip) indices = make_int_tuple(indices, loc=loc, ip=ip) diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py index d8fdf5969..215081d83 100644 --- a/python/flydsl/expr/typing.py +++ b/python/flydsl/expr/typing.py @@ -582,11 +582,8 @@ def _rebuild_py_value(self, leaf_iter): if self.is_leaf: if self.is_static: return self.get_static_leaf_int - val = next(leaf_iter) - width = ir.IntegerType(val.type).width - wrapper = Int64 if width == 64 else Int32 - return wrapper(val) - return tuple(IntTuple(get_(self, i))._rebuild_py_value(leaf_iter) for i in range(self.rank)) + return next(leaf_iter) + return tuple(get_(self, i)._rebuild_py_value(leaf_iter) for i in range(self.rank)) @traced_op def to_py_value(self, loc=None, ip=None): @@ -917,7 +914,7 @@ def __setitem__(self, coord, value, loc=None, ip=None): @traced_op def load(self, loc=None, ip=None): - return Vector(memref_load_vec(self, loc=loc, ip=ip), self.shape.to_py_value(), self.dtype) + return memref_load_vec(self, loc=loc, ip=ip) @traced_op def store(self, vector, loc=None, ip=None): diff --git a/tests/unit/test_layout_algebra.py b/tests/unit/test_layout_algebra.py index 4f7a55d6c..610b3d851 100644 --- a/tests/unit/test_layout_algebra.py +++ b/tests/unit/test_layout_algebra.py @@ -32,13 +32,7 @@ FLY_PIPELINE = ( - "builtin.module(" - "fly-canonicalize," - "fly-layout-lowering," - "fly-canonicalize," - "convert-fly-to-rocdl," - "canonicalize," - "cse)" + "builtin.module(fly-canonicalize,fly-layout-lowering,fly-canonicalize,convert-fly-to-rocdl,canonicalize,cse)" ) @@ -218,9 +212,8 @@ def build_static(): with Location.unknown(ctx): module = Module.create() i32 = IntegerType.get_signless(32) - idx = IndexType.get() with InsertionPoint(module.body): - f = func.FuncOp("comp_dyn", FunctionType.get([i32] * 8, [idx])) + f = func.FuncOp("comp_dyn", FunctionType.get([i32] * 8, [i32])) entry = f.add_entry_block() with InsertionPoint(entry): args = list(entry.arguments) @@ -228,8 +221,8 @@ def build_static(): B = fx.make_layout(fx.make_shape(args[4], args[5]), fx.make_stride(args[6], args[7])) R = fx.composition(A, B) sz = fx.size(R) - sc = fx.get_scalar(sz) - func.ReturnOp([arith.IndexCastOp(idx, sc).result]) + sc = fx.get_scalar(sz).ir_value() + func.ReturnOp([sc]) pm = PassManager.parse(FLY_PIPELINE, ctx) pm.run(module.operation) assert module.operation.verify() @@ -317,9 +310,8 @@ def test_complement_rank_2_dynamic_stride_error(): with Location.unknown(ctx): module = Module.create() i32 = IntegerType.get_signless(32) - idx = IndexType.get() with InsertionPoint(module.body): - f = func.FuncOp("compl_dyn", FunctionType.get([i32], [idx])) + f = func.FuncOp("compl_dyn", FunctionType.get([i32], [i32])) entry = f.add_entry_block() with InsertionPoint(entry): runtime_stride = entry.arguments[0] @@ -328,8 +320,8 @@ def test_complement_rank_2_dynamic_stride_error(): tiler = fx.make_layout(shape, stride) comp = fx.complement(tiler, 12) sz = fx.size(comp) - sc = fx.get_scalar(sz) - func.ReturnOp([arith.IndexCastOp(idx, sc).result]) + sc = fx.get_scalar(sz).ir_value() + func.ReturnOp([sc]) pm = PassManager.parse(FLY_PIPELINE, ctx) pm.run(module.operation) diff --git a/tests/unit/test_static_vs_dynamic.py b/tests/unit/test_static_vs_dynamic.py index a3c42c320..e027320eb 100644 --- a/tests/unit/test_static_vs_dynamic.py +++ b/tests/unit/test_static_vs_dynamic.py @@ -27,13 +27,7 @@ FLY_PIPELINE = ( - "builtin.module(" - "fly-canonicalize," - "fly-layout-lowering," - "fly-canonicalize," - "convert-fly-to-rocdl," - "canonicalize," - "cse)" + "builtin.module(fly-canonicalize,fly-layout-lowering,fly-canonicalize,convert-fly-to-rocdl,canonicalize,cse)" ) @@ -85,9 +79,8 @@ def test_layout_dynamic_types(): with Location.unknown(ctx): module = Module.create() i32 = IntegerType.get_signless(32) - idx = IndexType.get() with InsertionPoint(module.body): - f = func.FuncOp("dynamic_layout", FunctionType.get([i32] * 4, [idx])) + f = func.FuncOp("dynamic_layout", FunctionType.get([i32] * 4, [i32])) entry = f.add_entry_block() with InsertionPoint(entry): dim0, dim1, stride0, stride1 = entry.arguments @@ -98,7 +91,7 @@ def test_layout_dynamic_types(): layout = fx.make_layout(shape, stride) sz = fx.size(layout) sc = fx.get_scalar(sz) - func.ReturnOp([arith.IndexCastOp(idx, sc).result]) + func.ReturnOp([sc.ir_value()]) pm = PassManager.parse(FLY_PIPELINE, ctx) pm.run(module.operation) @@ -138,9 +131,8 @@ def test_mixed_static_dynamic(): with Location.unknown(ctx): module = Module.create() i32 = IntegerType.get_signless(32) - idx = IndexType.get() with InsertionPoint(module.body): - f = func.FuncOp("mixed_layout", FunctionType.get([i32, i32], [idx])) + f = func.FuncOp("mixed_layout", FunctionType.get([i32, i32], [i32])) entry = f.add_entry_block() with InsertionPoint(entry): runtime_extent, runtime_stride = entry.arguments @@ -152,8 +144,8 @@ def test_mixed_static_dynamic(): stride = fx.make_stride(c16, runtime_stride) layout = fx.make_layout(shape, stride) sz = fx.size(layout) - sc = fx.get_scalar(sz) - func.ReturnOp([arith.IndexCastOp(idx, sc).result]) + sc = fx.get_scalar(sz).ir_value() + func.ReturnOp([sc]) pm = PassManager.parse(FLY_PIPELINE, ctx) pm.run(module.operation) From abb308a89995b5aaedcbc5959e046fdfd631ef17 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Tue, 2 Jun 2026 08:53:57 +0000 Subject: [PATCH 2/7] ensure int_tuple covariance relationship --- kernels/blockscale_preshuffle_gemm.py | 6 +- kernels/gemm_fp8fp4_gfx1250.py | 2 +- kernels/layout_utils.py | 2 +- kernels/mfma_preshuffle_pipeline.py | 14 +- kernels/mixed_moe_gemm_2stage.py | 18 +- kernels/moe_blockscale_2stage.py | 18 +- kernels/moe_gemm_2stage.py | 18 +- kernels/moe_gemm_2stage_mxscale_gfx1250.py | 4 +- kernels/moe_gemm_2stage_wmma_gfx1250.py | 4 +- kernels/preshuffle_gemm.py | 4 +- kernels/wmma_gemm_gfx1250.py | 2 +- python/flydsl/compiler/ast_rewriter.py | 44 +-- python/flydsl/expr/__init__.py | 2 + python/flydsl/expr/derived.py | 9 +- python/flydsl/expr/extern.py | 9 +- python/flydsl/expr/gpu.py | 4 +- python/flydsl/expr/math.py | 426 ++++++++++++--------- python/flydsl/expr/meta.py | 44 +++ python/flydsl/expr/numeric.py | 89 ++--- python/flydsl/expr/primitive.py | 373 ++++++++++-------- python/flydsl/expr/rocdl.py | 4 +- python/flydsl/expr/typing.py | 105 ++++- tests/unit/test_numeric_promotion.py | 186 +++++++++ 23 files changed, 931 insertions(+), 456 deletions(-) create mode 100644 tests/unit/test_numeric_promotion.py diff --git a/kernels/blockscale_preshuffle_gemm.py b/kernels/blockscale_preshuffle_gemm.py index 2371d9e81..c2517e37c 100644 --- a/kernels/blockscale_preshuffle_gemm.py +++ b/kernels/blockscale_preshuffle_gemm.py @@ -203,12 +203,12 @@ def kernel_gemm( # ---- Wave / lane decomposition ---- wave_size = 64 layout_wave_lane = fx.make_layout((4, wave_size), (64, 1)) - coord_wave_lane = fx.idx2crd(tx, layout_wave_lane) + coord_wave_lane = fx.idx2crd(fx.Int32(tx), layout_wave_lane) wave_id = fx.get(coord_wave_lane, 0) lane_id = fx.get(coord_wave_lane, 1) layout_lane16 = fx.make_layout((4, 16), (16, 1)) - coord_lane16 = fx.idx2crd(lane_id, layout_lane16) + coord_lane16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = fx.get(coord_lane16, 0) lane_mod_16 = fx.get(coord_lane16, 1) @@ -253,7 +253,7 @@ def load_b_packs_k64(base_k, ku: int, ni: int): k0 = k0_base + ku k1 = lane_div_16 coord_pack = (n_blk_list[ni], k0, k1, n_intra_list[ni], fx.Index(0)) - idx_pack = crd2idx(coord_pack, layout_b) + idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b) b16 = _buffer_load_vec( buffer_ops, vector, diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index ee09dc7a4..5d5ba55f8 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -405,7 +405,7 @@ def kernel_mxscale_gemm( b_mcast_mask = 0 layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) - thr_coord = idx2crd(tx, layout_thr) + thr_coord = idx2crd(fx.Int32(tx), layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), diff --git a/kernels/layout_utils.py b/kernels/layout_utils.py index 1439af186..ccb1d992e 100644 --- a/kernels/layout_utils.py +++ b/kernels/layout_utils.py @@ -87,7 +87,7 @@ def idx2crd(idx, layout): parsed = _parse_layout(layout) if parsed is None or _has_dynamic_strides(parsed[1]): - result = fx.idx2crd(idx, layout) + result = fx.idx2crd(fx.Int32(idx), layout) ndims = len(parsed[1]) if parsed else 1 return [_wrap(fx.get(result, i)) for i in range(ndims)] diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index 0a1309956..4cd0bc9b3 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -325,7 +325,7 @@ def load_b_raw_w4a16( k2_base = lane_odd * fx.Index(half_bytes) coord_pack = (n_blk, k0, k1_local, n_intra, fx.Index(0)) - idx_pack = crd2idx(coord_pack, layout_b) + idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b) idx_bytes = idx_pack + k2_base b4 = _buffer_load_vec( @@ -463,7 +463,7 @@ def load_b_pack_k32( k2_base = arith.constant((ki_step % 2) * half_bytes, index=True) coord_pack = (n_blk, k0, k1, n_intra, fx.Index(0)) - idx_pack = crd2idx(coord_pack, layout_b) + idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b) if unpack_int4: idx_bytes = idx_pack + k2_base @@ -579,7 +579,7 @@ def lds_store_16b_xor16( col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16) col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2 coord_store = (row_local, col_swz) - idx0 = crd2idx(coord_store, layout_lds) + lds_base + idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base v16 = vector.bitcast(vec16_ty, vec_part_i32x4) vector.store(v16, lds_memref, [idx0]) @@ -606,7 +606,7 @@ def lds_store_8b_xor16( col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16) col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2 coord_store = (row_local, col_swz) - idx0 = crd2idx(coord_store, layout_lds) + lds_base + idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base v8 = vector.bitcast(vec8_ty, vec_part_i32x2) vector.store(v8, lds_memref, [idx0]) @@ -633,7 +633,7 @@ def lds_store_4b_xor16( col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16) col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2 coord_store = (row_local, col_swz) - idx0 = crd2idx(coord_store, layout_lds) + lds_base + idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base v4 = vector.bitcast(vec4_ty, vec_part_i32x1) vector.store(v4, lds_memref, [idx0]) @@ -659,14 +659,14 @@ def lds_load_pack_k32( col_base_swz = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16) if ck_lds128: coord_a16 = (curr_row_a_lds, col_base_swz) - idx_a16 = crd2idx(coord_a16, layout_lds) + lds_base + idx_a16 = crd2idx(tuple(fx.Int32(c) for c in coord_a16), layout_lds) + lds_base loaded_a16 = vector.load_op(vec16_ty, lds_memref, [idx_a16]) a_vec128 = vector.bitcast(vec2_i64_ty, loaded_a16) return vector.extract(a_vec128, static_position=[half], dynamic_position=[]) else: col_swizzled = col_base_swz + (half * 8) coord_a = (curr_row_a_lds, col_swizzled) - idx_a = crd2idx(coord_a, layout_lds) + lds_base + idx_a = crd2idx(tuple(fx.Int32(c) for c in coord_a), layout_lds) + lds_base loaded_a8 = vector.load_op(vec8_ty, lds_memref, [idx_a]) a_vec64 = vector.bitcast(vec1_i64_ty, loaded_a8) return vector.extract(a_vec64, static_position=[0], dynamic_position=[]) diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index 5a7f3c24a..712291931 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -729,10 +729,10 @@ def load_x_tile(base_k): return parts # Wave/lane decomposition (identical to stage2) - coord_wl = idx2crd(tx, layout_tx_wave_lane) + coord_wl = idx2crd(fx.Int32(tx), layout_tx_wave_lane) wave_id = layout_get(coord_wl, 0) lane_id = layout_get(coord_wl, 1) - coord_l16 = idx2crd(lane_id, layout_lane16) + coord_l16 = idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = layout_get(coord_l16, 0) lane_mod_16 = layout_get(coord_l16, 1) row_a_lds = lane_mod_16 @@ -763,12 +763,12 @@ def load_x_tile(base_k): global_n = by_n + n_tile_base + c_offset + lane_mod_16 # Gate/interleave: rows [expert_off, expert_off + 2*inter_dim) gate_row_w = expert_off_idx + global_n - gate_coord = idx2crd(gate_row_w, layout_n_blk_intra) + gate_coord = idx2crd(fx.Int32(gate_row_w), layout_n_blk_intra) gate_n_blk_list.append(layout_get(gate_coord, 0)) gate_n_intra_list.append(layout_get(gate_coord, 1)) if const_expr(not mock_gate_only and not gate_up_interleave): up_row_w = gate_row_w + inter_idx - up_coord = idx2crd(up_row_w, layout_n_blk_intra) + up_coord = idx2crd(fx.Int32(up_row_w), layout_n_blk_intra) up_n_blk_list.append(layout_get(up_coord, 0)) up_n_intra_list.append(layout_get(up_coord, 1)) @@ -799,7 +799,7 @@ def load_b_packs_k64(base_k, ku: int, n_blk, n_intra): k0 = base_k_bytes // c64 + arith.constant(ku, index=True) k1 = lane_div_16 coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True)) - idx_pack = crd2idx(coord_pack, layout_b) + idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b) vec_elems = kpack_bytes // int(b_elem_bytes) b16 = _buffer_load_vec( buffer_ops, @@ -1015,7 +1015,7 @@ def prefetch_x_to_lds(base_k, lds_buffer): def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16) col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2)) - idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds) + idx_a16 = crd2idx([fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)], layout_lds) loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16]) a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) @@ -3074,10 +3074,10 @@ def load_x_tile(base_k): return parts # tx -> wave/lane (GEMM-style decomposition). - coord_wl = idx2crd(tx, layout_tx_wave_lane) + coord_wl = idx2crd(fx.Int32(tx), layout_tx_wave_lane) wave_id = layout_get(coord_wl, 0) lane_id = layout_get(coord_wl, 1) - coord_l16 = idx2crd(lane_id, layout_lane16) + coord_l16 = idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = layout_get(coord_l16, 0) lane_mod_16 = layout_get(coord_l16, 1) @@ -3330,7 +3330,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16) col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2)) - idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds) + idx_a16 = crd2idx([fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)], layout_lds) loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16]) a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) diff --git a/kernels/moe_blockscale_2stage.py b/kernels/moe_blockscale_2stage.py index 2c5eb635b..c257bad18 100644 --- a/kernels/moe_blockscale_2stage.py +++ b/kernels/moe_blockscale_2stage.py @@ -466,10 +466,10 @@ def load_x_tile(base_k, x_load_bytes_v): return parts # tx -> wave/lane (GEMM-style decomposition). - coord_wl = fx.idx2crd(tx, layout_tx_wave_lane) + coord_wl = fx.idx2crd(fx.Int32(tx), layout_tx_wave_lane) wave_id = fx.get(coord_wl, 0) lane_id = fx.get(coord_wl, 1) - coord_l16 = fx.idx2crd(lane_id, layout_lane16) + coord_l16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = fx.get(coord_l16, 0) lane_mod_16 = fx.get(coord_l16, 1) @@ -511,11 +511,11 @@ def load_x_tile(base_k, x_load_bytes_v): row_gate = expert_off_idx + col_g row_up = row_gate + inter_idx - coord_gate = fx.idx2crd(row_gate, layout_n_blk_intra) + coord_gate = fx.idx2crd(fx.Int32(row_gate), layout_n_blk_intra) n_blk_gate.append(fx.get(coord_gate, 0)) n_intra_gate.append(fx.get(coord_gate, 1)) - coord_up = fx.idx2crd(row_up, layout_n_blk_intra) + coord_up = fx.idx2crd(fx.Int32(row_up), layout_n_blk_intra) n_blk_up.append(fx.get(coord_up, 0)) n_intra_up.append(fx.get(coord_up, 1)) @@ -620,7 +620,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): col_base_swz = ( col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes // arith.index(int(elem_bytes))) ) - idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds) + idx_a16 = crd2idx((fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)), layout_lds) idx_a16 = idx_a16 + lds_base loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) a_i64x2 = vector.bitcast(T.i64x2, loaded_a16) @@ -1604,10 +1604,10 @@ def load_x_tile(base_k, x_load_bytes_v): return parts # tx -> wave/lane (GEMM-style decomposition). - coord_wl = fx.idx2crd(tx, layout_tx_wave_lane) + coord_wl = fx.idx2crd(fx.Int32(tx), layout_tx_wave_lane) wave_id = fx.get(coord_wl, 0) lane_id = fx.get(coord_wl, 1) - coord_l16 = fx.idx2crd(lane_id, layout_lane16) + coord_l16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = fx.get(coord_l16, 0) lane_mod_16 = fx.get(coord_l16, 1) @@ -1640,7 +1640,7 @@ def load_x_tile(base_k, x_load_bytes_v): col_g_list.append(col_g) row_w = expert_off_idx + col_g - coord_w = fx.idx2crd(row_w, layout_n_blk_intra) + coord_w = fx.idx2crd(fx.Int32(row_w), layout_n_blk_intra) n_blk_list.append(fx.get(coord_w, 0)) n_intra_list.append(fx.get(coord_w, 1)) @@ -1742,7 +1742,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): col_base_swz = ( col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes // arith.index(int(elem_bytes))) ) - idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds) + idx_a16 = crd2idx((fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)), layout_lds) idx_a16 = idx_a16 + lds_base loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) a_i64x2 = vector.bitcast(T.i64x2, loaded_a16) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 1402ffada..57769c7d2 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -598,10 +598,10 @@ def load_x_tile(base_k): return parts # tx -> wave/lane (GEMM-style decomposition). - coord_wl = fx.idx2crd(tx, layout_tx_wave_lane) + coord_wl = fx.idx2crd(fx.Int32(tx), layout_tx_wave_lane) wave_id = fx.get(coord_wl, 0) lane_id = fx.get(coord_wl, 1) - coord_l16 = fx.idx2crd(lane_id, layout_lane16) + coord_l16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = fx.get(coord_l16, 0) lane_mod_16 = fx.get(coord_l16, 1) @@ -644,11 +644,11 @@ def load_x_tile(base_k): row_gate = expert_off_idx + col_g row_up = row_gate + inter_idx - coord_gate = fx.idx2crd(row_gate, layout_n_blk_intra) + coord_gate = fx.idx2crd(fx.Int32(row_gate), layout_n_blk_intra) n_blk_gate.append(fx.get(coord_gate, 0)) n_intra_gate.append(fx.get(coord_gate, 1)) - coord_up = fx.idx2crd(row_up, layout_n_blk_intra) + coord_up = fx.idx2crd(fx.Int32(row_up), layout_n_blk_intra) n_blk_up.append(fx.get(coord_up, 0)) n_intra_up.append(fx.get(coord_up, 1)) @@ -811,7 +811,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): col_base_swz = ( col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes // arith.index(int(elem_bytes))) ) - idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds) + idx_a16 = crd2idx((fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)), layout_lds) idx_a16 = idx_a16 + lds_base loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) a_i64x2 = vector.bitcast(T.i64x2, loaded_a16) @@ -2256,10 +2256,10 @@ def load_x_tile(base_k): return parts # tx -> wave/lane (GEMM-style decomposition). - coord_wl = fx.idx2crd(tx, layout_tx_wave_lane) + coord_wl = fx.idx2crd(fx.Int32(tx), layout_tx_wave_lane) wave_id = fx.get(coord_wl, 0) lane_id = fx.get(coord_wl, 1) - coord_l16 = fx.idx2crd(lane_id, layout_lane16) + coord_l16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = fx.get(coord_l16, 0) lane_mod_16 = fx.get(coord_l16, 1) @@ -2293,7 +2293,7 @@ def load_x_tile(base_k): col_g_list.append(col_g) row_w = expert_off_idx + col_g - coord_w = fx.idx2crd(row_w, layout_n_blk_intra) + coord_w = fx.idx2crd(fx.Int32(row_w), layout_n_blk_intra) n_blk_list.append(fx.get(coord_w, 0)) n_intra_list.append(fx.get(coord_w, 1)) @@ -2453,7 +2453,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): col_base_swz = ( col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes // arith.index(int(elem_bytes))) ) - idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds) + idx_a16 = crd2idx((fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)), layout_lds) idx_a16 = idx_a16 + lds_base loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) a_i64x2 = vector.bitcast(T.i64x2, loaded_a16) diff --git a/kernels/moe_gemm_2stage_mxscale_gfx1250.py b/kernels/moe_gemm_2stage_mxscale_gfx1250.py index 5cb14c60f..f2a013f93 100644 --- a/kernels/moe_gemm_2stage_mxscale_gfx1250.py +++ b/kernels/moe_gemm_2stage_mxscale_gfx1250.py @@ -394,7 +394,7 @@ def moe_mxscale_stage1_single( block_ok = arith.andi(block_in_valid, arith.andi(eid_ok0, eid_ok1)) layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx) - thr_coord = idx2crd(tx, layout_thr) + thr_coord = idx2crd(fx.Int32(tx), layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), fx.get(thr_coord, 2), fx.get(thr_coord, 3) ) @@ -2560,7 +2560,7 @@ def moe_mxscale_stage2_single( block_ok = arith.andi(block_in_valid, arith.andi(eid_ok0, eid_ok1)) layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx) - thr_coord = idx2crd(tx, layout_thr) + thr_coord = idx2crd(fx.Int32(tx), layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), fx.get(thr_coord, 2), fx.get(thr_coord, 3) ) diff --git a/kernels/moe_gemm_2stage_wmma_gfx1250.py b/kernels/moe_gemm_2stage_wmma_gfx1250.py index ebed9aa5d..06cc63dd4 100644 --- a/kernels/moe_gemm_2stage_wmma_gfx1250.py +++ b/kernels/moe_gemm_2stage_wmma_gfx1250.py @@ -149,7 +149,7 @@ def moe_fp16_stage1_single( eid_ok = arith.andi(eid_ok0, eid_ok1) layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx) - thr_coord = idx2crd(tx, layout_thr) + thr_coord = idx2crd(fx.Int32(tx), layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), @@ -556,7 +556,7 @@ def moe_fp16_stage2_single( block_ok = arith.andi(block_in_valid, arith.andi(eid_ok0, eid_ok1)) layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx) - thr_coord = idx2crd(tx, layout_thr) + thr_coord = idx2crd(fx.Int32(tx), layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index dedd3ac86..31947b723 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -441,12 +441,12 @@ def kernel_gemm( # ---- Wave / lane decomposition ---- wave_size = 64 layout_wave_lane = fx.make_layout((4, wave_size), (64, 1)) - coord_wave_lane = fx.idx2crd(tx, layout_wave_lane) + coord_wave_lane = fx.idx2crd(fx.Int32(tx), layout_wave_lane) wave_id = fx.get(coord_wave_lane, 0) lane_id = fx.get(coord_wave_lane, 1) layout_lane16 = fx.make_layout((4, 16), (16, 1)) - coord_lane16 = fx.idx2crd(lane_id, layout_lane16) + coord_lane16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = fx.get(coord_lane16, 0) lane_mod_16 = fx.get(coord_lane16, 1) diff --git a/kernels/wmma_gemm_gfx1250.py b/kernels/wmma_gemm_gfx1250.py index 51115078e..a2e05154f 100644 --- a/kernels/wmma_gemm_gfx1250.py +++ b/kernels/wmma_gemm_gfx1250.py @@ -250,7 +250,7 @@ def kernel_wmma_gemm_tdm( # --- Thread/wave decomposition --- layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) - thr_coord = idx2crd(tx, layout_thr) + thr_coord = idx2crd(fx.Int32(tx), layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index 8282b0e4a..e41b40ad6 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -14,7 +14,7 @@ from .._mlir import ir from .._mlir.dialects import arith, scf from ..expr import const_expr -from ..expr.numeric import _unwrap_value, _wrap_like +from ..expr.typing import as_dsl_value, as_ir_value from ..utils import env, log @@ -496,7 +496,7 @@ def _is_dynamic(cond): @staticmethod def _to_i1(cond): - return _unwrap_value(cond) + return as_ir_value(cond) @staticmethod def _normalize_named_values(names, values, names_label="names", values_label="values"): @@ -542,7 +542,7 @@ def _normalize_branch_result(branch_result, state_names, state_map, branch_label def _unwrap_mlir_values(values, state_names, branch_label): raw_values = [] for name, value in zip(state_names, values): - raw = _unwrap_value(value) + raw = as_ir_value(value) if not isinstance(raw, ir.Value): raise TypeError( f"if/else variable '{name}' in {branch_label} is {type(raw).__name__}, " @@ -555,7 +555,7 @@ def _unwrap_mlir_values(values, state_names, branch_label): def _pack_dispatch_results(results, state_values): if not results: return None - wrapped = [_wrap_like(v, exemplar) for v, exemplar in zip(results, state_values)] + wrapped = [as_dsl_value(v, exemplar) for v, exemplar in zip(results, state_values)] if len(wrapped) == 1: return wrapped[0] return tuple(wrapped) @@ -622,7 +622,7 @@ def scf_if_dispatch( if not isinstance(cond_i1, ir.Value): raise TypeError(f"dynamic if condition must lower to ir.Value, got {type(cond_i1).__name__}") - none_vars = [name for name, value in zip(result_names, result_values) if _unwrap_value(value) is None] + none_vars = [name for name, value in zip(result_names, result_values) if as_ir_value(value) is None] if none_vars: raise TypeError( f"Variable(s) {none_vars} initialized as None before a dynamic " @@ -652,7 +652,7 @@ def scf_if_dispatch( state_raw = [] for name, value in zip(result_names, result_values): - raw = _unwrap_value(value) + raw = as_ir_value(value) if not isinstance(raw, ir.Value): raise TypeError( f"state variable '{name}' is {type(raw).__name__}, not an MLIR Value; " @@ -881,9 +881,9 @@ def scf_ifexp_dispatch(cond, then_fn, else_fn): sandbox.region.blocks.append() with ir.InsertionPoint(sandbox.region.blocks[0]): probe_then = then_fn() - probe_then_raw = _unwrap_value(probe_then) + probe_then_raw = as_ir_value(probe_then) probe_else = else_fn() - probe_else_raw = _unwrap_value(probe_else) + probe_else_raw = as_ir_value(probe_else) if not isinstance(probe_then_raw, ir.Value): raise TypeError( f"dynamic ifexp then-branch must produce an MLIR Value, " f"got {type(probe_then_raw).__name__}" @@ -902,14 +902,14 @@ def scf_ifexp_dispatch(cond, then_fn, else_fn): op = scf.IfOp(cond_i1, [yield_type], has_else=True, loc=ir.Location.unknown()) with ir.InsertionPoint(op.regions[0].blocks[0]): - scf.YieldOp([_unwrap_value(then_fn())]) + scf.YieldOp([as_ir_value(then_fn())]) if len(op.regions[1].blocks) == 0: op.regions[1].blocks.append() with ir.InsertionPoint(op.regions[1].blocks[0]): - scf.YieldOp([_unwrap_value(else_fn())]) + scf.YieldOp([as_ir_value(else_fn())]) sandbox.operation.erase() - return _wrap_like(op.results[0], probe_then) + return as_dsl_value(op.results[0], probe_then) @ASTRewriter.register @@ -942,7 +942,7 @@ def scf_range(start, stop=None, step=None, *, init=None): stop_val = InsertEmptyYieldForSCFFor._to_index(stop) step_val = InsertEmptyYieldForSCFFor._to_index(step) if init is not None: - init = [_unwrap_value(v) for v in init] + init = [as_ir_value(v) for v in init] for_op = scf.ForOp(start_val, stop_val, step_val, init) with ir.InsertionPoint(for_op.body): yield for_op.induction_variable, list(for_op.inner_iter_args) @@ -953,9 +953,9 @@ def scf_range(start, stop=None, step=None, *, init=None): @staticmethod def scf_for_dispatch(start, stop, step, body_fn, *, result_names=(), result_values=()): - start_val = _unwrap_value(start) - stop_val = _unwrap_value(stop) - step_val = _unwrap_value(step) + start_val = as_ir_value(start) + stop_val = as_ir_value(stop) + step_val = as_ir_value(step) i32_ty = ir.IntegerType.get_signless(32) idx_ty = ir.IndexType.get() @@ -974,7 +974,7 @@ def scf_for_dispatch(start, stop, step, body_fn, *, result_names=(), result_valu result_values = tuple(result_values) result_map = {name: value for name, value in zip(result_names, result_values)} - none_vars = [name for name, value in zip(result_names, result_values) if _unwrap_value(value) is None] + none_vars = [name for name, value in zip(result_names, result_values) if as_ir_value(value) is None] if none_vars: raise TypeError( f"Variable(s) {none_vars} initialized as None before a dynamic " @@ -994,7 +994,7 @@ def scf_for_dispatch(start, stop, step, body_fn, *, result_names=(), result_valu state_raw = [] for name, value in zip(result_names, result_values): - raw = _unwrap_value(value) + raw = as_ir_value(value) if not isinstance(raw, ir.Value): raise TypeError( f"for-loop variable '{name}' is {type(raw).__name__}, not an MLIR Value; " @@ -1006,7 +1006,7 @@ def scf_for_dispatch(start, stop, step, body_fn, *, result_names=(), result_valu with ir.InsertionPoint(for_op.body): iv = for_op.induction_variable - inner_args = [_wrap_like(a, ex) for a, ex in zip(for_op.inner_iter_args, result_values)] + inner_args = [as_dsl_value(a, ex) for a, ex in zip(for_op.inner_iter_args, result_values)] body_result = body_fn(iv, result_names, *inner_args) @@ -1305,7 +1305,7 @@ def scf_while_dispatch(before_fn, after_fn, *, result_names=(), result_values=() ) result_map = {name: value for name, value in zip(result_names, result_values)} - none_vars = [name for name, value in zip(result_names, result_values) if _unwrap_value(value) is None] + none_vars = [name for name, value in zip(result_names, result_values) if as_ir_value(value) is None] if none_vars: raise TypeError( f"Variable(s) {none_vars} initialized as None before a dynamic " @@ -1318,7 +1318,7 @@ def scf_while_dispatch(before_fn, after_fn, *, result_names=(), result_values=() state_raw = [] for name, value in zip(result_names, result_values): - raw = _unwrap_value(value) + raw = as_ir_value(value) if not isinstance(raw, ir.Value): raise TypeError( f"while-loop variable '{name}' is {type(raw).__name__}, not an MLIR Value; " @@ -1333,7 +1333,7 @@ def scf_while_dispatch(before_fn, after_fn, *, result_names=(), result_values=() with ir.InsertionPoint(while_op.regions[0].blocks[0]): before_args = list(while_op.regions[0].blocks[0].arguments) - wrapped_before = [_wrap_like(a, ex) for a, ex in zip(before_args, result_values)] if result_names else [] + wrapped_before = [as_dsl_value(a, ex) for a, ex in zip(before_args, result_values)] if result_names else [] before_cond = ReplaceIfWithDispatch._call_branch(before_fn, result_names, wrapped_before) cond_i1 = ReplaceIfWithDispatch._to_i1(before_cond) if not isinstance(cond_i1, ir.Value): @@ -1342,7 +1342,7 @@ def scf_while_dispatch(before_fn, after_fn, *, result_names=(), result_values=() with ir.InsertionPoint(while_op.regions[1].blocks[0]): after_args = list(while_op.regions[1].blocks[0].arguments) - wrapped_after = [_wrap_like(a, ex) for a, ex in zip(after_args, result_values)] if result_names else [] + wrapped_after = [as_dsl_value(a, ex) for a, ex in zip(after_args, result_values)] if result_names else [] body_result = ReplaceIfWithDispatch._call_branch(after_fn, result_names, wrapped_after) if result_names: body_values = ReplaceIfWithDispatch._normalize_branch_result( diff --git a/python/flydsl/expr/__init__.py b/python/flydsl/expr/__init__.py index b630550b3..f904c5e9f 100644 --- a/python/flydsl/expr/__init__.py +++ b/python/flydsl/expr/__init__.py @@ -7,6 +7,8 @@ from .gpu import * from .derived import * from .struct import * +from .arith import * +from .math import * from . import utils as utils from . import arith as arith diff --git a/python/flydsl/expr/derived.py b/python/flydsl/expr/derived.py index c3b739bca..8d3204b51 100644 --- a/python/flydsl/expr/derived.py +++ b/python/flydsl/expr/derived.py @@ -7,6 +7,7 @@ from .meta import traced_op from .numeric import Boolean, Numeric from .primitive import * +from .primitive import _coerce_int_tuple from .typing import Int8, Layout, Tensor, TiledCopy, TiledMma __all__ = [ @@ -34,7 +35,7 @@ def __init__(self, tiled_copy: TiledCopy, thr_idx): super().__init__(tiled_copy) self.tiled_copy = tiled_copy self._thr_idx = thr_idx - self._thr_idx_int = make_int_tuple(self.thr_idx) + self._thr_idx_int = _coerce_int_tuple(self.thr_idx) @property def thr_idx(self): @@ -64,7 +65,7 @@ def __init__(self, tiled_mma: TiledMma, thr_idx): super().__init__(tiled_mma) self.tiled_mma = tiled_mma self._thr_idx = thr_idx - self._thr_idx_int = make_int_tuple(self.thr_idx) + self._thr_idx_int = _coerce_int_tuple(self.thr_idx) @property def thr_idx(self): @@ -93,8 +94,8 @@ def make_rmem_tensor(shape_or_layout, dtype, *, loc=None, ip=None): tensor = make_rmem_tensor(8, fx.Float32) tensor = make_rmem_tensor(make_layout(4, 1), fx.Float16) """ - if not issubclass(dtype, Numeric): - raise TypeError(f"dtype must be a Numeric type, but got {type(dtype)}") + if not (isinstance(dtype, type) and issubclass(dtype, Numeric)): + raise TypeError(f"dtype must be a Numeric subclass, but got {dtype!r}") elem_ty = dtype.ir_type if dtype is not Boolean else Int8.ir_type if not isinstance(shape_or_layout, Layout): diff --git a/python/flydsl/expr/extern.py b/python/flydsl/expr/extern.py index 1269e8425..acf86ce33 100644 --- a/python/flydsl/expr/extern.py +++ b/python/flydsl/expr/extern.py @@ -20,6 +20,7 @@ DenseI32ArrayAttr, FlatSymbolRefAttr, InsertionPoint, + IntegerAttr, IntegerType, TypeAttr, ) @@ -121,16 +122,18 @@ def __call__(self, *args: Any) -> Any: if len(args) != len(arg_types): raise TypeError(f"ffi {self.symbol!r} expects {len(arg_types)} argument(s), got {len(args)}") - from .._mlir.dialects import llvm as _llvm - from .._mlir.ir import IntegerAttr + from .numeric import Numeric raw_args: List[ir.Value] = [] for arg_pos, arg in enumerate(args): expected_type = arg_types[arg_pos] + if isinstance(arg, Numeric) and isinstance(arg.value, (bool, int)): + arg = int(arg.value) + if isinstance(arg, int): target_type = expected_type or IntegerType.get_signless(64) - raw_args.append(_llvm.ConstantOp(target_type, IntegerAttr.get(target_type, arg)).result) + raw_args.append(llvm.ConstantOp(target_type, IntegerAttr.get(target_type, arg)).result) continue if isinstance(arg, ir.Value): diff --git a/python/flydsl/expr/gpu.py b/python/flydsl/expr/gpu.py index 0686048a1..2aafbaa16 100644 --- a/python/flydsl/expr/gpu.py +++ b/python/flydsl/expr/gpu.py @@ -20,7 +20,7 @@ from .._mlir.dialects import gpu from .._mlir.dialects._fly_enum_gen import AddressSpace from ..compiler.protocol import dsl_align_of, dsl_size_of -from .numeric import Uint8 +from .numeric import Numeric, Uint8 from .primitive import get_dyn_shared, make_ptr from .struct import ( Arena, @@ -104,6 +104,8 @@ def base_ptr(self): return self._base def allocate(self, storable_or_int, alignment=None): + if isinstance(storable_or_int, Numeric) and not isinstance(storable_or_int.value, ir.Value): + storable_or_int = int(storable_or_int.value) if not self._static: return super().allocate(storable_or_int, alignment) return self._allocate_static(storable_or_int, alignment) diff --git a/python/flydsl/expr/math.py b/python/flydsl/expr/math.py index 4cc061fef..e62613fb6 100644 --- a/python/flydsl/expr/math.py +++ b/python/flydsl/expr/math.py @@ -1,36 +1,76 @@ # SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 FlyDSL Project Contributors +# Copyright (c) 2026 FlyDSL Project Contributors -"""Math dialect API — DSL-friendly wrappers with traced locations and auto-unwrap. +"""Math dialect API — thin DSL wrappers over the MLIR ``math`` dialect. Usage: - from flydsl.expr import math + import flydsl.expr as fx - y = math.exp(x) - y = math.sqrt(x, fastmath="fast") - y = math.fma(a, b, c) - pred = math.isnan(x) + y = fx.exp(x) + y = fx.sqrt(x, fastmath="fast") + y = fx.fma(a, b, c) + pred = fx.isnan(x) """ from functools import wraps from .._mlir import ir -from .._mlir.dialects import math as _mlir_math -from .._mlir.dialects.math import * # noqa: F401,F403 -from .meta import _caller_location, _flatten_args +from .._mlir.dialects import math +from .meta import dsl_loc_tracing from .numeric import Numeric -from .utils.arith import _to_raw - - -def _traced_math_op(fn): - """Like @traced_op, but re-wraps results to preserve DslType closure. - - If the first positional arg is a ``Numeric`` (Float32, Int32, …) or a - ``Vector``, the MLIR result is wrapped back into the matching DSL type so - callers stay at the DSL level instead of dropping to raw ir.Value / ArithValue. - Raw ir.Value inputs pass through unchanged. - """ - +from .typing import as_ir_value + +__all__ = [ + "absf", + "ceil", + "floor", + "trunc", + "round", + "roundeven", + "exp", + "exp2", + "expm1", + "log", + "log2", + "log10", + "log1p", + "sqrt", + "rsqrt", + "cbrt", + "sin", + "cos", + "tan", + "asin", + "acos", + "atan", + "sinh", + "cosh", + "tanh", + "asinh", + "acosh", + "atanh", + "erf", + "erfc", + "sincos", + "absi", + "ctlz", + "cttz", + "ctpop", + "powf", + "fpowi", + "ipowi", + "atan2", + "copysign", + "fma", + "clampf", + "isnan", + "isinf", + "isfinite", + "isnormal", +] + + +def dsl_math_wrap_result(fn): @wraps(fn) def wrapper(*args, **kwargs): from .typing import Vector @@ -39,26 +79,22 @@ def wrapper(*args, **kwargs): is_vector = isinstance(first, Vector) is_numeric = isinstance(first, Numeric) - loc = kwargs.pop("loc", None) - if loc is None: - loc = _caller_location(depth=1) - args, kwargs = _flatten_args(args, kwargs) - with loc: - result = fn(*args, **kwargs) + result = fn(*args, **kwargs) if not (is_vector or is_numeric): - return result + return tuple(result) if not isinstance(result, ir.Value) and hasattr(result, "__iter__") else result - def _wrap_arith_type(value): + def dsl_wrap(value): + if not isinstance(value, ir.Value): + return value if is_vector: elem_dtype = Numeric.from_ir_type(ir.VectorType(value.type).element_type) return Vector(value, first.shape, elem_dtype) return Numeric.from_ir_type(value.type)(value) if isinstance(result, ir.Value): - return _wrap_arith_type(result) - # Multi-result (e.g. sincos) - return tuple(_wrap_arith_type(r) for r in result) + return dsl_wrap(result) + return tuple(dsl_wrap(r) for r in result) return wrapper @@ -68,154 +104,184 @@ def _wrap_arith_type(value): # --------------------------------------------------------------------------- -@_traced_math_op -def absf(x, *, fastmath=None, **kw): - return _mlir_math.absf(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def absf(x, *, fastmath=None, **kwargs): + return math.absf(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def ceil(x, *, fastmath=None, **kw): - return _mlir_math.ceil(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def ceil(x, *, fastmath=None, **kwargs): + return math.ceil(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def floor(x, *, fastmath=None, **kw): - return _mlir_math.floor(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def floor(x, *, fastmath=None, **kwargs): + return math.floor(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def trunc(x, *, fastmath=None, **kw): - return _mlir_math.trunc(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def trunc(x, *, fastmath=None, **kwargs): + return math.trunc(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def round(x, *, fastmath=None, **kw): - return _mlir_math.round(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def round(x, *, fastmath=None, **kwargs): + return math.round(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def roundeven(x, *, fastmath=None, **kw): - return _mlir_math.roundeven(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def roundeven(x, *, fastmath=None, **kwargs): + return math.roundeven(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def exp(x, *, fastmath=None, **kw): - return _mlir_math.exp(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def exp(x, *, fastmath=None, **kwargs): + return math.exp(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def exp2(x, *, fastmath=None, **kw): - return _mlir_math.exp2(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def exp2(x, *, fastmath=None, **kwargs): + return math.exp2(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def expm1(x, *, fastmath=None, **kw): - return _mlir_math.expm1(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def expm1(x, *, fastmath=None, **kwargs): + return math.expm1(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def log(x, *, fastmath=None, **kw): - return _mlir_math.log(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def log(x, *, fastmath=None, **kwargs): + return math.log(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def log2(x, *, fastmath=None, **kw): - return _mlir_math.log2(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def log2(x, *, fastmath=None, **kwargs): + return math.log2(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def log10(x, *, fastmath=None, **kw): - return _mlir_math.log10(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def log10(x, *, fastmath=None, **kwargs): + return math.log10(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def log1p(x, *, fastmath=None, **kw): - return _mlir_math.log1p(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def log1p(x, *, fastmath=None, **kwargs): + return math.log1p(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def sqrt(x, *, fastmath=None, **kw): - return _mlir_math.sqrt(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def sqrt(x, *, fastmath=None, **kwargs): + return math.sqrt(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def rsqrt(x, *, fastmath=None, **kw): - return _mlir_math.rsqrt(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def rsqrt(x, *, fastmath=None, **kwargs): + return math.rsqrt(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def cbrt(x, *, fastmath=None, **kw): - return _mlir_math.cbrt(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def cbrt(x, *, fastmath=None, **kwargs): + return math.cbrt(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def sin(x, *, fastmath=None, **kw): - return _mlir_math.sin(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def sin(x, *, fastmath=None, **kwargs): + return math.sin(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def cos(x, *, fastmath=None, **kw): - return _mlir_math.cos(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def cos(x, *, fastmath=None, **kwargs): + return math.cos(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def tan(x, *, fastmath=None, **kw): - return _mlir_math.tan(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def tan(x, *, fastmath=None, **kwargs): + return math.tan(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def asin(x, *, fastmath=None, **kw): - return _mlir_math.asin(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def asin(x, *, fastmath=None, **kwargs): + return math.asin(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def acos(x, *, fastmath=None, **kw): - return _mlir_math.acos(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def acos(x, *, fastmath=None, **kwargs): + return math.acos(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def atan(x, *, fastmath=None, **kw): - return _mlir_math.atan(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def atan(x, *, fastmath=None, **kwargs): + return math.atan(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def sinh(x, *, fastmath=None, **kw): - return _mlir_math.sinh(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def sinh(x, *, fastmath=None, **kwargs): + return math.sinh(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def cosh(x, *, fastmath=None, **kw): - return _mlir_math.cosh(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def cosh(x, *, fastmath=None, **kwargs): + return math.cosh(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def tanh(x, *, fastmath=None, **kw): - return _mlir_math.tanh(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def tanh(x, *, fastmath=None, **kwargs): + return math.tanh(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def asinh(x, *, fastmath=None, **kw): - return _mlir_math.asinh(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def asinh(x, *, fastmath=None, **kwargs): + return math.asinh(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def acosh(x, *, fastmath=None, **kw): - return _mlir_math.acosh(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def acosh(x, *, fastmath=None, **kwargs): + return math.acosh(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def atanh(x, *, fastmath=None, **kw): - return _mlir_math.atanh(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def atanh(x, *, fastmath=None, **kwargs): + return math.atanh(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def erf(x, *, fastmath=None, **kw): - return _mlir_math.erf(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def erf(x, *, fastmath=None, **kwargs): + return math.erf(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def erfc(x, *, fastmath=None, **kw): - return _mlir_math.erfc(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def erfc(x, *, fastmath=None, **kwargs): + return math.erfc(as_ir_value(x), fastmath=fastmath, **kwargs) # --------------------------------------------------------------------------- @@ -223,10 +289,11 @@ def erfc(x, *, fastmath=None, **kw): # --------------------------------------------------------------------------- -@_traced_math_op -def sincos(x, *, fastmath=None, **kw): +@dsl_loc_tracing +@dsl_math_wrap_result +def sincos(x, *, fastmath=None, **kwargs): """Simultaneous sin and cos. Returns ``(sin(x), cos(x))``.""" - return _mlir_math.sincos(_to_raw(x), fastmath=fastmath, **kw) + return math.sincos(as_ir_value(x), fastmath=fastmath, **kwargs) # --------------------------------------------------------------------------- @@ -234,24 +301,28 @@ def sincos(x, *, fastmath=None, **kw): # --------------------------------------------------------------------------- -@_traced_math_op -def absi(x, **kw): - return _mlir_math.absi(_to_raw(x), **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def absi(x, **kwargs): + return math.absi(as_ir_value(x), **kwargs) -@_traced_math_op -def ctlz(x, **kw): - return _mlir_math.ctlz(_to_raw(x), **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def ctlz(x, **kwargs): + return math.ctlz(as_ir_value(x), **kwargs) -@_traced_math_op -def cttz(x, **kw): - return _mlir_math.cttz(_to_raw(x), **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def cttz(x, **kwargs): + return math.cttz(as_ir_value(x), **kwargs) -@_traced_math_op -def ctpop(x, **kw): - return _mlir_math.ctpop(_to_raw(x), **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def ctpop(x, **kwargs): + return math.ctpop(as_ir_value(x), **kwargs) # --------------------------------------------------------------------------- @@ -259,29 +330,34 @@ def ctpop(x, **kw): # --------------------------------------------------------------------------- -@_traced_math_op -def powf(base, exp, *, fastmath=None, **kw): - return _mlir_math.powf(_to_raw(base), _to_raw(exp), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def powf(base, exp, *, fastmath=None, **kwargs): + return math.powf(as_ir_value(base), as_ir_value(exp), fastmath=fastmath, **kwargs) -@_traced_math_op -def fpowi(base, exp, *, fastmath=None, **kw): - return _mlir_math.fpowi(_to_raw(base), _to_raw(exp), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def fpowi(base, exp, *, fastmath=None, **kwargs): + return math.fpowi(as_ir_value(base), as_ir_value(exp), fastmath=fastmath, **kwargs) -@_traced_math_op -def ipowi(base, exp, **kw): - return _mlir_math.ipowi(_to_raw(base), _to_raw(exp), **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def ipowi(base, exp, **kwargs): + return math.ipowi(as_ir_value(base), as_ir_value(exp), **kwargs) -@_traced_math_op -def atan2(y, x, *, fastmath=None, **kw): - return _mlir_math.atan2(_to_raw(y), _to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def atan2(y, x, *, fastmath=None, **kwargs): + return math.atan2(as_ir_value(y), as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def copysign(mag, sign, *, fastmath=None, **kw): - return _mlir_math.copysign(_to_raw(mag), _to_raw(sign), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def copysign(mag, sign, *, fastmath=None, **kwargs): + return math.copysign(as_ir_value(mag), as_ir_value(sign), fastmath=fastmath, **kwargs) # --------------------------------------------------------------------------- @@ -289,36 +365,42 @@ def copysign(mag, sign, *, fastmath=None, **kw): # --------------------------------------------------------------------------- -@_traced_math_op -def fma(a, b, c, *, fastmath=None, **kw): - return _mlir_math.fma(_to_raw(a), _to_raw(b), _to_raw(c), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def fma(a, b, c, *, fastmath=None, **kwargs): + return math.fma(as_ir_value(a), as_ir_value(b), as_ir_value(c), fastmath=fastmath, **kwargs) -@_traced_math_op -def clampf(x, lo, hi, *, fastmath=None, **kw): - return _mlir_math.clampf(_to_raw(x), _to_raw(lo), _to_raw(hi), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def clampf(x, lo, hi, *, fastmath=None, **kwargs): + return math.clampf(as_ir_value(x), as_ir_value(lo), as_ir_value(hi), fastmath=fastmath, **kwargs) # --------------------------------------------------------------------------- -# Predicates (return i1) +# Predicates :: Float -> Boolean # --------------------------------------------------------------------------- -@_traced_math_op -def isnan(x, *, fastmath=None, **kw): - return _mlir_math.isnan(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def isnan(x, *, fastmath=None, **kwargs): + return math.isnan(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def isinf(x, *, fastmath=None, **kw): - return _mlir_math.isinf(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def isinf(x, *, fastmath=None, **kwargs): + return math.isinf(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def isfinite(x, *, fastmath=None, **kw): - return _mlir_math.isfinite(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def isfinite(x, *, fastmath=None, **kwargs): + return math.isfinite(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def isnormal(x, *, fastmath=None, **kw): - return _mlir_math.isnormal(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def isnormal(x, *, fastmath=None, **kwargs): + return math.isnormal(as_ir_value(x), fastmath=fastmath, **kwargs) diff --git a/python/flydsl/expr/meta.py b/python/flydsl/expr/meta.py index eb5e24f9f..68457ef1a 100644 --- a/python/flydsl/expr/meta.py +++ b/python/flydsl/expr/meta.py @@ -7,6 +7,7 @@ from .._mlir import ir +# TODO: remove this in the future. def _to_raw_value(obj): if isinstance(obj, ir.Value): return obj @@ -24,6 +25,7 @@ def _to_raw_value(obj): return obj +# TODO: remove this in the future. def _flatten_args(args, kwargs): new_args = tuple(_to_raw_value(a) for a in args) new_kwargs = {k: _to_raw_value(v) if k not in ("loc", "ip") else v for k, v in kwargs.items()} @@ -52,6 +54,7 @@ def _caller_location(depth=1): return ir.Location.name(label, childLoc=file_loc) +# TODO: remove this in the future. def traced_op(op): @wraps(op) def wrapper(*args, **kwargs): @@ -63,3 +66,44 @@ def wrapper(*args, **kwargs): return op(*args, **kwargs) return wrapper + + +def dsl_loc_tracing(op): + """Capture the caller's Python source position as an MLIR Location + + TODO: enhance this in the recent changes. loc is missed in the op arguments. + """ + + @wraps(op) + def wrapper(*args, **kwargs): + loc = kwargs.pop("loc", None) + if loc is None: + loc = _caller_location(depth=1) + with loc: + return op(*args, **kwargs) + + return wrapper + + +def dsl_wrap_result(target=None): + """Wrap the op result(s) back into DslType values. + + - ``target=None`` (default): dispatch by the result's ``ir.Type``. + - ``target=SomeClass``: force ``SomeClass(value)`` — useful when the result + type cannot be uniquely determined from the ``ir.Type`` (vectors, …). + + Multi-value returns (tuples / lists) are wrapped element-wise. + """ + + def decorator(op, target): + @wraps(op) + def wrapper(*args, **kwargs): + from .typing import as_dsl_value + + return as_dsl_value(op(*args, **kwargs), target) + + return wrapper + + if inspect.isfunction(target): + return decorator(target, None) + return lambda op: decorator(op, target) diff --git a/python/flydsl/expr/numeric.py b/python/flydsl/expr/numeric.py index 8fe4ce0aa..0245e190c 100644 --- a/python/flydsl/expr/numeric.py +++ b/python/flydsl/expr/numeric.py @@ -10,7 +10,6 @@ from .._mlir import ir from .._mlir.dialects import arith from .._mlir.extras import types as T -from ..utils import log from .utils.arith import ( ArithValue, _to_raw, @@ -171,14 +170,17 @@ def zero(cls): _CMP_OPS = frozenset({operator.lt, operator.le, operator.gt, operator.ge, operator.eq, operator.ne}) -def _widen_narrow_int(x, widen_bool=False): - """Promote sub-32-bit integers (and optionally bools) to i32.""" - ty = type(x) - if ty is Boolean and not widen_bool: - return x, ty - if ty.is_integer and ty.width < 32: +def _widen_bool_to_int32(x, widen_bool=False): + """Promote Boolean to Int32 for arithmetic when widen_bool=True. + + Per C++-style usual arithmetic conversions, we deliberately do NOT apply + integer promotion: i8/i16/u8/u16 stay at their narrow width. + Same-width same-signedness operands keep their type; cross-width or + cross-sign mixing is resolved by ``_coerce_operands``. + """ + if widen_bool and type(x) is Boolean: return x.to(Int32), Int32 - return x, ty + return x, type(x) def _resolve_float_type(ta, tb): @@ -205,8 +207,8 @@ def _resolve_float_type(ta, tb): def _coerce_operands(a, b, widen_bool=False): """Promote *a* and *b* to a common scalar type.""" ta, tb = type(a), type(b) - a, ta = _widen_narrow_int(a, widen_bool=widen_bool) - b, tb = _widen_narrow_int(b, widen_bool=widen_bool) + a, ta = _widen_bool_to_int32(a, widen_bool=widen_bool) + b, tb = _widen_bool_to_int32(b, widen_bool=widen_bool) if ta is tb: return a, b, ta @@ -247,48 +249,6 @@ def _extract_arith(val, signed): return v.with_signedness(signed) if isinstance(v, ArithValue) else v -def _unwrap_value(value): - """Convert FlyDSL wrappers to raw MLIR values when possible.""" - if isinstance(value, ir.Value): - return value - if isinstance(value, (bool, int, float)): - try: - return as_numeric(value).ir_value() - except Exception: - log().error(f"failed to construct {as_numeric(value)} from {value}") - return value - if hasattr(value, "__extract_to_ir_values__"): - values = value.__extract_to_ir_values__() - if len(values) == 1: - return values[0] - if hasattr(value, "ir_value"): - return value.ir_value() - return value - - -def _wrap_like(value, exemplar=None): - """Wrap an MLIR value back to a FlyDSL wrapper when possible.""" - if not isinstance(value, ir.Value): - return value - - if exemplar is not None: - if isinstance(exemplar, Numeric): - return type(exemplar)(value) - ctor = getattr(type(exemplar), "__construct_from_ir_values__", None) - if ctor is not None: - try: - return ctor([value]) - except Exception: - log().error(f"failed to construct {type(exemplar)} from {value}") - return value - - try: - return Numeric.from_ir_type(value.type)(value) - except Exception: - log().error(f"failed to construct {Numeric.from_ir_type(value.type)} from {value}") - return value - - def _make_binop(op, promote=True, widen_bool=False, swap=False): """Create a binary-operator closure for Numeric subclasses.""" @@ -331,7 +291,10 @@ def __hash__(self): def select(self, true_value, false_value, *, loc=None): """Ternary select (for Boolean conditions from Int32 comparisons).""" - return ArithValue(self).select(true_value, false_value, loc=loc) + from .typing import as_dsl_value + + result = ArithValue(self).select(true_value, false_value, loc=loc) + return as_dsl_value(result, true_value) @classmethod def __coerce__(cls, value): @@ -453,6 +416,9 @@ def from_ir_type(ir_type): T.ui32(): Uint32, T.ui16(): Uint16, T.ui8(): Uint8, + T.i(128): Int128, + T.si(128): Int128, + T.ui(128): Uint128, T.f8E5M2(): Float8E5M2, T.f8E4M3(): Float8E4M3, T.f8E4M3FN(): Float8E4M3FN, @@ -552,6 +518,13 @@ def __gt__(self, other, *, loc=None, ip=None): def __ge__(self, other, *, loc=None, ip=None): return _make_binop(operator.ge)(self, other, loc=loc, ip=ip) + def bitcast(self, dtype, *, loc=None, ip=None): + """Reinterpret this value's bits as *dtype* (a same-width Numeric type).""" + if not (isinstance(dtype, type) and issubclass(dtype, Numeric)): + raise TypeError(f"dtype must be a Numeric subclass, but got {dtype!r}") + res = arith.BitcastOp(dtype.ir_type, self.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + return dtype(res, loc=loc, ip=ip) + def as_numeric(obj): if isinstance(obj, Numeric): @@ -717,6 +690,11 @@ class Int64(Integer, metaclass=NumericMeta, width=64, signed=True, ir_type=T.i64 pass +class Int128(Integer, metaclass=NumericMeta, width=128, signed=True, ir_type=lambda: T.i(128)): + def __get_c_pointers__(self): + raise TypeError("Int128 is not a JitArgument for now. ctypes has no support for 128b integers.") + + class Uint8(Integer, metaclass=NumericMeta, width=8, signed=False, ir_type=T.i8): pass @@ -733,6 +711,11 @@ class Uint64(Integer, metaclass=NumericMeta, width=64, signed=False, ir_type=T.i pass +class Uint128(Integer, metaclass=NumericMeta, width=128, signed=False, ir_type=lambda: T.i(128)): + def __get_c_pointers__(self): + raise TypeError("Uint128 is not a JitArgument for now. ctypes has no support for 128b integers.") + + class Float16(Float, metaclass=NumericMeta, width=16, ir_type=T.f16): def __get_c_pointers__(self): if not isinstance(self.value, float): diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 831330e32..735bb56c8 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors +import inspect from enum import IntEnum +from functools import wraps from typing import overload from .._mlir import ir @@ -33,7 +35,7 @@ has_none, ) from .._mlir.extras import types as T -from .meta import traced_op +from .meta import dsl_loc_tracing, dsl_wrap_result __all__ = [ # Maybe remove it in the future @@ -217,14 +219,19 @@ def _is_int_tuple_value(value): def _expand_int_tuple_leaves(value, loc=None, ip=None): - from .numeric import Numeric + from .numeric import Int32, Numeric if _is_int_tuple_value(value): return _expand_int_tuple_leaves(value.to_py_value(loc=loc, ip=ip)) if isinstance(value, (list, tuple)): return tuple(_expand_int_tuple_leaves(v, loc=loc, ip=ip) for v in value) + # widen narrow dynamic ints to i32 if isinstance(value, Numeric): + if isinstance(value.value, ir.Value) and type(value).width < 32: + return Int32(value, loc=loc, ip=ip).value return value.value + if isinstance(value, ir.Value) and isinstance(value.type, ir.IntegerType) and value.type.width < 32: + return Int32(value, loc=loc, ip=ip).value return value @@ -247,14 +254,45 @@ def _check_profile(match_func, lhs, rhs): raise ValueError(f"profile mismatch: {match_func.__name__}({lhs.type}, {rhs.type}) is False") -def _wrap_numeric_type(value): - from .numeric import Numeric +# ---- IntTuple covariance ---- +# Covariance rules (Python value → fly.IntTuple): +# int <: fly.IntTuple (leaf) +# Numeric <: fly.IntTuple (leaf, e.g. Int32(5)) +# tuple(X1, ...) <: fly.IntTuple<(X1, ...)> (non-leaf; tuple is constructor) +# fly.IntTuple <: fly.IntTuple (trivial) - if not isinstance(value, ir.Value): - return value - if isinstance(value, Numeric): - return value - return Numeric.from_ir_type(value.type)(value) + +def _coerce_int_tuple(v): + if _is_int_tuple_value(v): + return v + return make_int_tuple(v) + + +def _coerce_int_tuple_permissive(v): + if isinstance(v, ir.Value): + return v + return make_int_tuple(v) + + +def coerce_int_tuple_args(*arg_names, permissive=False): + coerce = _coerce_int_tuple_permissive if permissive else _coerce_int_tuple + + def decorator(fn): + sig = inspect.signature(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + bound = sig.bind_partial(*args, **kwargs) + for name in arg_names: + v = bound.arguments.get(name) + if v is None: + continue + bound.arguments[name] = coerce(v) + return fn(*bound.args, **bound.kwargs) + + return wrapper + + return decorator # ===----------------------------------------------------------------------=== # @@ -310,7 +348,7 @@ def depth(int_or_tuple): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing def static(result_type, loc=None, ip=None): """Materialize a value whose entire content is encoded in *result_type*. @@ -324,7 +362,7 @@ def static(result_type, loc=None, ip=None): return fly.static(result_type, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_int_tuple(elems, loc=None, ip=None): """Build a (possibly nested) integer tuple from Python ints or runtime values. @@ -338,7 +376,7 @@ def make_int_tuple(elems, loc=None, ip=None): return fly.make_int_tuple(IntTupleTy, dyncElems, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_shape(*shape, loc=None, ip=None): """Build a shape tuple describing the extent of each mode. @@ -352,7 +390,7 @@ def make_shape(*shape, loc=None, ip=None): return fly.make_shape(IntTupleTy, dyncElems, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_stride(*stride, loc=None, ip=None): """Build a stride tuple: the step (in elements) when moving along each mode. @@ -366,7 +404,7 @@ def make_stride(*stride, loc=None, ip=None): return fly.make_stride(IntTupleTy, dyncElems, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_coord(*coord, loc=None, ip=None): """Build a coordinate used for indexing / slicing a layout. @@ -380,7 +418,7 @@ def make_coord(*coord, loc=None, ip=None): return fly.make_coord(IntTupleTy, dyncElems, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_layout(shape, stride, loc=None, ip=None): """Pair a *shape* with a *stride* to describe how logical coords map to memory. @@ -391,20 +429,20 @@ def make_layout(shape, stride, loc=None, ip=None): make_layout((4, 8), (1, 4)) -> ((4, 8), (1, 4)) make_layout((4, 8), (8, 1)) -> ((4, 8), (8, 1)) """ - if not isinstance(shape, ir.Value): + if not _is_int_tuple_value(shape): shape = make_int_tuple(shape, loc=loc, ip=ip) - if not isinstance(stride, ir.Value): + if not _is_int_tuple_value(stride): stride = make_int_tuple(stride, loc=loc, ip=ip) _check_profile(is_profile_congruent, shape, stride) return fly.make_layout(shape, stride=stride, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_layout_like(ref, loc=None, ip=None): return fly.make_layout_like(ref, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_ordered_layout(shape, order, loc=None, ip=None): """Build a compact layout whose stride order matches *order*. @@ -415,9 +453,9 @@ def make_ordered_layout(shape, order, loc=None, ip=None): make_ordered_layout((M, N), (0, 1)) # column-major: M iterates fastest make_ordered_layout((M, N), (1, 0)) # row-major: N iterates fastest """ - if not isinstance(shape, ir.Value): + if not _is_int_tuple_value(shape): shape = make_int_tuple(shape, loc=loc, ip=ip) - if not isinstance(order, ir.Value): + if not _is_int_tuple_value(order): order = make_int_tuple(order, loc=loc, ip=ip) _check_profile(is_profile_weakly_congruent, order, shape) return fly.make_ordered_layout(shape, order, loc=loc, ip=ip) @@ -427,7 +465,7 @@ def make_ordered_layout(shape, order, loc=None, ip=None): def make_composed_layout(inner, offset, outer, loc=None, ip=None): ... @overload def make_composed_layout(inner, outer, loc=None, ip=None): ... -@traced_op +@dsl_loc_tracing def make_composed_layout(inner, offset_or_outer, outer=None, loc=None, ip=None): """Stack two layouts: a coord is first mapped by *outer*, then by *inner*. @@ -443,12 +481,12 @@ def make_composed_layout(inner, offset_or_outer, outer=None, loc=None, ip=None): offset = coprofile(outer, loc=loc, ip=ip) else: offset = offset_or_outer - if not isinstance(offset, ir.Value): + if not _is_int_tuple_value(offset): offset = make_int_tuple(offset, loc=loc, ip=ip) return fly.make_composed_layout(inner, offset, outer, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_identity_layout(shape, loc=None, ip=None): """Build the identity layout in FlyDSL's layout-algebra sense. @@ -459,22 +497,22 @@ def make_identity_layout(shape, loc=None, ip=None): Examples: make_identity_layout((4, 8)) -> ((4, 8), (1E0, 1E1)) """ - if not isinstance(shape, ir.Value): + if not _is_int_tuple_value(shape): shape = make_int_tuple(shape, loc=loc, ip=ip) return fly.make_identity_layout(shape, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_view(iter, layout, loc=None, ip=None): return fly.make_view(iter, layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_fragment_layout_like(tensor, loc=None, ip=None): return fly.make_fragment_layout_like(tensor, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_fragment_like(tensor, dtype=None, loc=None, ip=None): if hasattr(dtype, "ir_type"): dtype = dtype.ir_type @@ -486,7 +524,8 @@ def make_fragment_like(tensor, dtype=None, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing +@dsl_wrap_result def get_scalar(int_tuple, loc=None, ip=None): """Unwrap a rank-1, single-element tuple back to a plain scalar value. @@ -501,10 +540,11 @@ def get_scalar(int_tuple, loc=None, ip=None): return int_tuple if int_tuple.is_leaf and int_tuple.is_static: return int_tuple.get_static_leaf_int - return _wrap_numeric_type(fly.get_scalar(int_tuple, loc=loc, ip=ip)) + return fly.get_scalar(int_tuple, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@dsl_wrap_result def get_leaves(input, dynamic_only=False, loc=None, ip=None): """Flatten an IntTuple into a flat sequence of leaf values. @@ -517,7 +557,7 @@ def get_leaves(input, dynamic_only=False, loc=None, ip=None): """ if dynamic_only: res_lists = fly.GetLeavesOp(input, dynamicOnly=True, loc=loc, ip=ip) - return tuple(_wrap_numeric_type(r) for r in res_lists.results) + return tuple(res_lists.results) def _walk_int_tuple_leaves(ty): if ty.is_leaf: @@ -534,41 +574,41 @@ def _walk_int_tuple_leaves(ty): if leaf_ty.is_static: out.append(leaf_ty.get_static_leaf_int) else: - out.append(_wrap_numeric_type(next(dyn_iter))) + out.append(next(dyn_iter)) return tuple(out) -@traced_op +@dsl_loc_tracing def get_shape(layout, loc=None, ip=None): return fly.get_shape(layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def get_stride(layout, loc=None, ip=None): return fly.get_stride(layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def get_layout(memref, loc=None, ip=None): return fly.get_layout(memref, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def get_iter(memref, loc=None, ip=None): return fly.get_iter(memref, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def composed_get_inner(input, loc=None, ip=None): return fly.composed_get_inner(input, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def composed_get_offset(input, loc=None, ip=None): return fly.composed_get_offset(input, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def composed_get_outer(input, loc=None, ip=None): return fly.composed_get_outer(input, loc=loc, ip=ip) @@ -578,62 +618,76 @@ def composed_get_outer(input, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def int_tuple_add(lhs, rhs, loc=None, ip=None): return fly.int_tuple_add(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def int_tuple_sub(lhs, rhs, loc=None, ip=None): return fly.int_tuple_sub(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def int_tuple_mul(lhs, rhs, loc=None, ip=None): return fly.int_tuple_mul(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def int_tuple_div(lhs, rhs, loc=None, ip=None): return fly.int_tuple_div(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def int_tuple_mod(lhs, rhs, loc=None, ip=None): return fly.int_tuple_mod(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple") def int_tuple_product(int_tuple, loc=None, ip=None): return fly.int_tuple_product(int_tuple, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple") def int_tuple_product_each(int_tuple, loc=None, ip=None): return fly.int_tuple_product_each(int_tuple, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def int_tuple_product_like(lhs, rhs, loc=None, ip=None): return fly.int_tuple_product_like(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def shape_div(lhs, rhs, loc=None, ip=None): return fly.shape_div(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def ceil_div(lhs, rhs, loc=None, ip=None): return fly.ceil_div(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@dsl_wrap_result +@coerce_int_tuple_args("lhs", "rhs") def elem_less(lhs, rhs, loc=None, ip=None): return fly.elem_less(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@dsl_wrap_result +@coerce_int_tuple_args("lhs", "rhs") def equal(lhs, rhs, loc=None, ip=None): return fly.equal(lhs, rhs, loc=loc, ip=ip) @@ -643,7 +697,7 @@ def equal(lhs, rhs, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing def get(int_tuple, mode, loc=None, ip=None): if isinstance(int_tuple, (list, tuple)): return int_tuple[mode] @@ -654,39 +708,45 @@ def get(int_tuple, mode, loc=None, ip=None): return result -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple") def get_(int_tuple, mode, loc=None, ip=None): if isinstance(mode, int): mode = [mode] return fly.get(int_tuple, mode, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple") def take(int_tuple, begin: int, end: int, loc=None, ip=None): return fly.take(int_tuple, begin=begin, end=end, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple") def select(int_tuple, indices, loc=None, ip=None): return fly.select(int_tuple, indices=indices, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple") def group(int_tuple, begin: int, end: int, loc=None, ip=None): return fly.group(int_tuple, begin=begin, end=end, loc=loc, ip=ip) -@traced_op -def append(base, elem, n: int | None = None, loc=None, ip=None): +@dsl_loc_tracing +@coerce_int_tuple_args("base", "elem", permissive=True) +def append(base, elem, *, n: int | None = None, loc=None, ip=None): return fly.append(base, elem, n=n, loc=loc, ip=ip) -@traced_op -def prepend(base, elem, n: int | None = None, loc=None, ip=None): +@dsl_loc_tracing +@coerce_int_tuple_args("base", "elem", permissive=True) +def prepend(base, elem, *, n: int | None = None, loc=None, ip=None): return fly.prepend(base, elem, n=n, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def slice(src, coord, loc=None, ip=None): """Keep the modes where *coord* has `None` (wildcard), drop the rest. @@ -697,13 +757,13 @@ def slice(src, coord, loc=None, ip=None): slice((4, 8, 16), (None, 3, None)) -> (4, 16) # mode 1 fixed, dropped slice(layout, make_coord(None, bid)) -> sub-layout for column `bid` """ - if not isinstance(coord, ir.Value): + if not _is_int_tuple_value(coord): coord = make_int_tuple(coord, loc=loc, ip=ip) _check_profile(is_profile_weakly_congruent, coord, src) return fly.slice(src, coord, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def dice(src, coord, loc=None, ip=None): """Complement of `slice`: keep the *fixed* modes, drop the `None` (wildcard) ones. @@ -713,7 +773,7 @@ def dice(src, coord, loc=None, ip=None): dice((4, 8, 16), (None, 3, None)) -> (8,) dice(coord_tensor, make_coord(tid, None)) -> the thread-only part """ - if not isinstance(coord, ir.Value): + if not _is_int_tuple_value(coord): coord = make_int_tuple(coord, loc=loc, ip=ip) _check_profile(is_profile_weakly_congruent, coord, src) return fly.dice(src, coord, loc=loc, ip=ip) @@ -724,34 +784,28 @@ def dice(src, coord, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple", permissive=True) def size(int_tuple, loc=None, ip=None): return fly.size(int_tuple, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def coprofile(layout, loc=None, ip=None): return fly.coprofile(layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def coshape(layout, loc=None, ip=None): return fly.coshape(layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def cosize(layout, loc=None, ip=None): return fly.cosize(layout, loc=loc, ip=ip) -def _to_i32(v): - """Cast index-type ir.Value to i32 (required by fly.make_int_tuple).""" - if isinstance(v, ir.Value) and isinstance(v.type, ir.IndexType): - return _arith.IndexCastOp(T.i32(), v).result - return v - - -@traced_op +@dsl_loc_tracing def crd2idx(crd, layout, loc=None, ip=None): """Map a coordinate tuple to an index through *layout*. @@ -763,15 +817,13 @@ def crd2idx(crd, layout, loc=None, ip=None): crd2idx((1, 2), make_layout((4, 8), (1, 4))) -> 9 crd2idx(7, make_layout((4, 8), (1, 4))) -> 7 """ - if not isinstance(crd, ir.Value): - if isinstance(crd, (list, tuple)): - crd = tuple(_to_i32(c) for c in crd) + if not _is_int_tuple_value(crd): crd = make_int_tuple(crd, loc=loc, ip=ip) _check_profile(is_profile_weakly_congruent, crd, layout) return fly.crd2idx(crd, layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def idx2crd(index, layout, loc=None, ip=None): """Map an index back to a coordinate tuple for a plain `Layout`. @@ -783,15 +835,12 @@ def idx2crd(index, layout, loc=None, ip=None): idx2crd(9, make_layout((4, 8), (1, 4))) -> (1, 2) idx2crd(5, make_layout((4, 8), (8, 1))) -> (0, 5) """ - if isinstance(index, ir.Value) and not isinstance(index.type, IntTupleType): - index = _to_i32(index) - index = make_int_tuple(index, loc=loc, ip=ip) - if not isinstance(index, ir.Value): + if not _is_int_tuple_value(index): index = make_int_tuple(index, loc=loc, ip=ip) return fly.idx2crd(index, layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def get_flat_coord(index, layout, loc=None, ip=None): """Map an index to a *fully flattened* coordinate, ignoring nested grouping. @@ -802,12 +851,12 @@ def get_flat_coord(index, layout, loc=None, ip=None): get_flat_coord(9, make_layout((4, 8), (1, 4))) -> (1, 2) get_flat_coord(3, make_layout(((2, 2), 4), ((1, 2), 4))) -> (1, 1, 0) """ - if not isinstance(index, ir.Value): + if not _is_int_tuple_value(index): index = make_int_tuple(index, loc=loc, ip=ip) return fly.get_flat_coord(index, layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def get_1d_coord(index, layout, loc=None, ip=None): """Map an index to a single 1-D coordinate in the layout's shape space. @@ -815,97 +864,104 @@ def get_1d_coord(index, layout, loc=None, ip=None): get_1d_coord(9, make_layout((4, 8), (1, 4))) -> 9 get_1d_coord(5, make_layout((4, 8), (8, 1))) -> 20 """ - if not isinstance(index, ir.Value): + if not _is_int_tuple_value(index): index = make_int_tuple(index, loc=loc, ip=ip) return fly.get_1d_coord(index, layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("pattern") def coalesce(layout, pattern=None, loc=None, ip=None): return fly.coalesce(layout, pattern=pattern, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def composition(layout, tiler, loc=None, ip=None): return fly.composition(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("codomain_size") def complement(layout, codomain_size=None, loc=None, ip=None): - if codomain_size is not None and not isinstance(codomain_size, ir.Value): - codomain_size = make_int_tuple(codomain_size, loc=loc, ip=ip) return fly.complement(layout, codomain_size=codomain_size, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def right_inverse(layout, loc=None, ip=None): return fly.right_inverse(layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def left_inverse(layout, loc=None, ip=None): return fly.left_inverse(layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def logical_divide(layout, divisor, loc=None, ip=None): if not isinstance(divisor, ir.Value): divisor = make_tile(*divisor, loc=loc, ip=ip) return fly.logical_divide(layout, divisor, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def zipped_divide(layout, divisor, loc=None, ip=None): if not isinstance(divisor, ir.Value): divisor = make_tile(*divisor, loc=loc, ip=ip) return fly.zipped_divide(layout, divisor, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def tiled_divide(layout, divisor, loc=None, ip=None): if not isinstance(divisor, ir.Value): divisor = make_tile(*divisor, loc=loc, ip=ip) return fly.tiled_divide(layout, divisor, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def flat_divide(layout, divisor, loc=None, ip=None): if not isinstance(divisor, ir.Value): divisor = make_tile(*divisor, loc=loc, ip=ip) return fly.flat_divide(layout, divisor, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def logical_product(layout, tiler, loc=None, ip=None): return fly.logical_product(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def zipped_product(layout, tiler, loc=None, ip=None): return fly.zipped_product(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def tiled_product(layout, tiler, loc=None, ip=None): return fly.tiled_product(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def flat_product(layout, tiler, loc=None, ip=None): return fly.flat_product(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def blocked_product(layout, tiler, loc=None, ip=None): return fly.blocked_product(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def raked_product(layout, tiler, loc=None, ip=None): return fly.raked_product(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def recast_layout(layout, old_type_bits, new_type_bits, loc=None, ip=None): def _to_static_bits(v): if isinstance(v, int): @@ -921,7 +977,8 @@ def _to_static_bits(v): return fly.recast_layout(new_type_bits=new_type_bits, old_type_bits=old_type_bits, src=layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("trg_shape", "ord_shape") def tile_to_shape(block, trg_shape, ord_shape, loc=None, ip=None): return fly.tile_to_shape(block, trg_shape, ord_shape, loc=loc, ip=ip) @@ -931,13 +988,13 @@ def tile_to_shape(block, trg_shape, ord_shape, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing def make_mma_atom(mma_op_type, loc=None, ip=None): mma_atom_ty = MmaAtomType.get(mma_op=mma_op_type) return fly.make_mma_atom(mma_atom_ty, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_copy_atom(copy_op_type, elem_type, loc=None, ip=None): from .numeric import NumericMeta @@ -956,73 +1013,79 @@ def make_copy_atom(copy_op_type, elem_type, loc=None, ip=None): return fly.make_copy_atom(copy_atom_ty, val_bits=val_bits, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def atom_set_value(atom, field, value, loc=None, ip=None): + from .typing import as_ir_value + if isinstance(field, IntEnum): field = str(field) - return fly.atom_set_value(atom, field, value, loc=loc, ip=ip) + return fly.atom_set_value(atom, field, as_ir_value(value), loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def copy_atom_call(copy_atom, src, dst, *, pred=None, loc=None, ip=None): return fly.copy_atom_call(copy_atom, src, dst, pred=pred, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def mma_atom_call(mma_atom, d, a, b, c, loc=None, ip=None): return fly.mma_atom_call(mma_atom, d, a, b, c, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_tiled_copy(copy_atom, layout_thr_val, tile_mn, loc=None, ip=None): if not isinstance(tile_mn, ir.Value): tile_mn = make_tile(*tile_mn, loc=loc, ip=ip) return fly.make_tiled_copy(copy_atom, layout_thr_val, tile_mn, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_tiled_mma(mma_atom, atom_layout, permutation=None, loc=None, ip=None): if permutation is not None and not isinstance(permutation, ir.Value): permutation = make_tile(*permutation, loc=loc, ip=ip) return fly.make_tiled_mma(mma_atom, atom_layout, permutation=permutation, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("thr_int_tuple") def tiled_copy_partition_src(tiled_copy, src, thr_int_tuple, loc=None, ip=None): return fly.tiled_copy_partition_src(tiled_copy, src, thr_int_tuple, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("thr_int_tuple") def tiled_copy_partition_dst(tiled_copy, dst, thr_int_tuple, loc=None, ip=None): return fly.tiled_copy_partition_dst(tiled_copy, dst, thr_int_tuple, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def tiled_copy_retile(tiled_copy, t, loc=None, ip=None): return fly.tiled_copy_retile(tiled_copy, t, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("coord") def tiled_mma_partition(operand_id, tiled_mma, t, coord, loc=None, ip=None): return fly.tiled_mma_partition(operand_id, tiled_mma, t, coord, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("shape") def tiled_mma_partition_shape(operand_id, tiled_mma, shape, loc=None, ip=None): return fly.tiled_mma_partition_shape(operand_id, tiled_mma, shape, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def mma_make_fragment(operand_id, tiled_mma, input, *, stages=None, loc=None, ip=None): return fly.mma_make_fragment(operand_id, tiled_mma, input, stages=stages, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def copy(copy_atom, src, dst, *, pred=None, loc=None, ip=None, **kwargs): return fly.copy(copy_atom.set_value(kwargs), src, dst, pred=pred, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def gemm(mma_atom, d, a, b, c, *, traversal_order=None, traversal_layout=None, loc=None, ip=None, **kwargs): if traversal_order is not None and traversal_layout is not None: raise ValueError("Only one of 'traversal_order' or 'traversal_layout' can be specified, not both") @@ -1044,7 +1107,7 @@ def gemm(mma_atom, d, a, b, c, *, traversal_order=None, traversal_layout=None, l # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing def make_ptr(result_type, args, *, dict_attrs=None, loc=None, ip=None): result = fly.make_ptr(result_type, args, loc=loc, ip=ip) if dict_attrs is not None: @@ -1052,7 +1115,7 @@ def make_ptr(result_type, args, *, dict_attrs=None, loc=None, ip=None): return result -@traced_op +@dsl_loc_tracing def get_dyn_shared(dtype=None, loc=None, ip=None): """Return a pointer to the start of the kernel's dynamic shared-memory buffer. @@ -1066,20 +1129,21 @@ def get_dyn_shared(dtype=None, loc=None, ip=None): return recast_iter(dtype, raw_ptr) -@traced_op +@dsl_loc_tracing def inttoptr(result_type, src, loc=None, ip=None): """Interpret an integer address *src* as a pointer of *result_type*. Requirement: ptr.address_space != Register """ - from .typing import is_generic_address_space + from .typing import as_ir_value, is_generic_address_space if is_generic_address_space(result_type.address_space, AddressSpace.Register): raise ValueError("inttoptr is not supported for register address space") - return fly.inttoptr(result_type, src, loc=loc, ip=ip) + return fly.inttoptr(result_type, as_ir_value(src), loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@dsl_wrap_result def ptrtoint(ptr, loc=None, ip=None): """Get the raw integer address underlying *ptr*. @@ -1092,10 +1156,10 @@ def ptrtoint(ptr, loc=None, ip=None): if is_generic_address_space(ptr.address_space, AddressSpace.Register): raise ValueError("ptrtoint is not supported for register address space") - return _wrap_numeric_type(fly.ptrtoint(ptr, loc=loc, ip=ip)) + return fly.ptrtoint(ptr, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def add_offset(ptr, offset, loc=None, ip=None): """Shift *ptr* by *offset* elements @@ -1108,12 +1172,13 @@ def add_offset(ptr, offset, loc=None, ip=None): return fly.add_offset(ptr, offset, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def apply_swizzle(ptr, swizzle, loc=None, ip=None): return fly.apply_swizzle(ptr, swizzle, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@dsl_wrap_result def ptr_load(ptr, result_type=None, loc=None, ip=None): """Load one value (scalar or vector) from *ptr*; dtype defaults to ptr's element type. @@ -1124,22 +1189,26 @@ def ptr_load(ptr, result_type=None, loc=None, ip=None): result_type = ptr.element_type if not isinstance(result_type, ir.Type): result_type = result_type.ir_type - return _wrap_numeric_type(fly.ptr_load(result_type, ptr, loc=loc, ip=ip)) + return fly.ptr_load(result_type, ptr, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def ptr_store(value, ptr, loc=None, ip=None): """Store *value* into *ptr*. Types must match the pointer's element type. Examples: ptr_store(val, ptr) """ - if not isinstance(value, ir.Value): + from .numeric import Numeric + + if isinstance(value, Numeric): + value = value.ir_value() + elif not isinstance(value, ir.Value): value = ptr.element_type(value).ir_value() return fly.ptr_store(value, ptr, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def recast_iter(result_type, src, loc=None, ip=None): """Reinterpret a pointer / iterator as another element type (like `reinterpret_cast`). @@ -1159,42 +1228,42 @@ def recast_iter(result_type, src, loc=None, ip=None): return fly.recast_iter(result_type, src, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def memref_alloca(memref_type, layout, loc=None, ip=None): return fly.memref_alloca(memref_type, layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def memref_load_vec(memref, loc=None, ip=None): from .typing import Vector return Vector(fly.memref_load_vec(memref, loc=loc, ip=ip), memref.shape.to_py_value(), memref.dtype) -@traced_op +@dsl_loc_tracing def memref_store_vec(vector, memref, loc=None, ip=None): return fly.memref_store_vec(vector, memref, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@dsl_wrap_result def memref_load(memref, indices, loc=None, ip=None): if isinstance(indices, ir.Value): - if str(indices.type) == "index": - indices = _arith.IndexCastOp(T.i32(), indices) if not _is_int_tuple_value(indices): indices = make_int_tuple(indices, loc=loc, ip=ip) - return _wrap_numeric_type(fly.memref_load(memref, indices, loc=loc, ip=ip)) + return fly.memref_load(memref, indices, loc=loc, ip=ip) indices = make_int_tuple(indices, loc=loc, ip=ip) _check_profile(is_profile_weakly_congruent, indices, memref) - return _wrap_numeric_type(fly.memref_load(memref, indices, loc=loc, ip=ip)) + return fly.memref_load(memref, indices, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def memref_store(value, memref, indices, loc=None, ip=None): + from .typing import as_ir_value + + value = as_ir_value(value) if isinstance(indices, ir.Value): - if str(indices.type) == "index": - indices = _arith.IndexCastOp(T.i32(), indices) if not _is_int_tuple_value(indices): indices = make_int_tuple(indices, loc=loc, ip=ip) return fly.memref_store(value, memref, indices, loc=loc, ip=ip) @@ -1209,7 +1278,7 @@ def memref_store(value, memref, indices, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing def printf(*args, format_str="", loc=None, ip=None): def _convert_printf_value(val): if isinstance(val, ir.Value): @@ -1268,7 +1337,7 @@ def _convert_printf_value(val): return fly.print_(final_format, ir_values, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def assume(result_type, dst, src, loc=None, ip=None): """ WIP, unsupported for now @@ -1281,7 +1350,7 @@ def assume(result_type, dst, src, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing def make_tile(*args, loc=None, ip=None): from .typing import Layout diff --git a/python/flydsl/expr/rocdl.py b/python/flydsl/expr/rocdl.py index eb52be5fd..505bb116b 100644 --- a/python/flydsl/expr/rocdl.py +++ b/python/flydsl/expr/rocdl.py @@ -95,9 +95,9 @@ def make_buffer_tensor(memref, alignment=4, loc=None, ip=None): from .._mlir.dialects import arith as _arith from .._mlir.dialects import fly from . import primitive as _prim - from .meta import _to_raw_value + from .typing import as_ir_value - raw_memref = _to_raw_value(memref) + raw_memref = as_ir_value(memref) layout = _prim.get_layout(memref, loc=loc, ip=ip) elem_type = fly.MemRefType(raw_memref.type).element_type diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py index 215081d83..96746788e 100644 --- a/python/flydsl/expr/typing.py +++ b/python/flydsl/expr/typing.py @@ -37,12 +37,15 @@ Int16, Int32, Int64, + Int128, Integer, Numeric, Uint8, Uint16, Uint32, Uint64, + Uint128, + as_numeric, ) from .primitive import * from .utils.arith import ( @@ -56,6 +59,97 @@ ) +def as_ir_value(value, *, keep_static=False): + """Convert any DslType value into a raw ``ir.Value`` + + This is the *canonical* "DSL -> ir.Value" converter. Body code that + needs to feed an MLIR builder should call this explicitly per argument. + + Behavior summary: + - ``None`` -> ``None`` + - ``ir.Value`` -> returned unchanged + - ``Numeric`` holding a Python literal, when + ``keep_static=True`` -> returned unchanged + ``keep_static=False`` -> promoted via ``as_numeric(value).ir_value()`` + - ``tuple`` / ``list`` -> recursed, shape preserved + - object with ``__extract_to_ir_values__`` -> single value extracted; multi-value returns a list + - ``bool`` / ``int`` / ``float`` -> promoted via ``as_numeric(value).ir_value()`` + - object with ``ir_value()`` -> called as a fallback + - anything else -> returned unchanged + """ + if value is None: + return None + if isinstance(value, ir.Value): + return value + if keep_static and isinstance(value, Numeric) and not isinstance(value.value, ir.Value): + return value + if isinstance(value, tuple): + return tuple(as_ir_value(v, keep_static=keep_static) for v in value) + if isinstance(value, list): + return [as_ir_value(v, keep_static=keep_static) for v in value] + if hasattr(value, "__extract_to_ir_values__"): + values = value.__extract_to_ir_values__() + if len(values) == 1: + return values[0] + return values + if isinstance(value, (bool, int, float)): + return as_numeric(value).ir_value() + if hasattr(value, "ir_value"): + return value.ir_value() + return value + + +def as_dsl_value(value, exemplar=None): + """Wrap a raw ``ir.Value`` back into a DSL value. This is the inverse + of :func:`as_ir_value` (``ir.Value -> DslType``). + + ``exemplar`` is an optional *type template* describing how to wrap ``value``: + - a DslType class -> constructed directly via ``exemplar(value)`` + - a DslType instance -> ``type(exemplar)(value)`` + + Behavior summary (mirrors the branches of :func:`as_ir_value`): + - ``None`` -> ``None`` + - ``tuple`` / ``list`` -> recursed, shape preserved, + paired element-wise with ``exemplar`` (a non-sequence ``exemplar`` is + broadcast to every element) + - with no usable ``exemplar``: a ``value`` already satisfying the + ``DslType`` protocol is returned unchanged; a bare scalar ``ir.Value`` + is dispatched by ``value.type`` via ``Numeric.from_ir_type``; any other + non-``ir.Value`` is returned unchanged. + + Raises ``TypeError`` when a bare ``ir.Value`` cannot be wrapped into any DSL + value. + """ + if value is None: + return None + if isinstance(value, (tuple, list)): + exemplars = exemplar if isinstance(exemplar, (tuple, list)) else [exemplar] * len(value) + return type(value)(as_dsl_value(v, ex) for v, ex in zip(value, exemplars)) + + if exemplar is not None and isinstance(value, ir.Value): + if isclass(exemplar): + return exemplar(value) + if isinstance(exemplar, Numeric): + return type(exemplar)(value) + ctor = getattr(type(exemplar), "__construct_from_ir_values__", None) + if ctor is not None: + try: + return ctor([value]) + except Exception: + raise ValueError(f"failed to construct {type(exemplar)} from {value}") + + from ..compiler.protocol import DslType + + if isinstance(value, DslType): + return value + if not isinstance(value, ir.Value): + return value + try: + return Numeric.from_ir_type(value.type)(value) + except Exception as e: + raise TypeError(f"as_dsl_value cannot wrap ir.Value of type {value.type!s} into a DSL value") from e + + def _vec(n: int, elem: ir.Type) -> ir.Type: return ir.VectorType.get([int(n)], elem) @@ -165,6 +259,10 @@ def i64(self) -> ir.Type: def i64x2(self) -> ir.Type: return _vec(2, Int64.ir_type) + @property + def i128(self) -> ir.Type: + return Int128.ir_type + # ---- Float scalars & vectors ---- @property def f16(self) -> ir.Type: @@ -248,6 +346,9 @@ def vec(self, n: int, elem: ir.Type) -> ir.Type: "Types", "T", "default_f8_type", + # DSL utilities + "as_ir_value", + "as_dsl_value", "is_generic_address_space", "is_target_address_space", # DSL value types @@ -272,11 +373,13 @@ def vec(self, n: int, elem: ir.Type) -> ir.Type: "Int16", "Int32", "Int64", + "Int128", "Index", "Uint8", "Uint16", "Uint32", "Uint64", + "Uint128", "Constexpr", "IntTuple", "Layout", @@ -818,7 +921,7 @@ def load(self, loc=None, ip=None): @traced_op def store(self, value, loc=None, ip=None): - if isinstance(value, (bool, int, float)): + if isinstance(value, (bool, int, float, Numeric)): value = self.element_type(value) return ptr_store(value, self, loc=loc, ip=ip) diff --git a/tests/unit/test_numeric_promotion.py b/tests/unit/test_numeric_promotion.py new file mode 100644 index 000000000..c4aa7a68a --- /dev/null +++ b/tests/unit/test_numeric_promotion.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 FlyDSL Project Contributors + +"""C++-style usual-arithmetic-conversion promotion for DSL Numeric types. + +We deliberately skip the C++ "integer promotion to int" step: ``int8 + int8`` +must stay ``int8``, ``uint16 + uint16`` stays ``uint16``. Cross-width and +cross-sign promotion follows usual arithmetic conversions (unsigned wins at +equal width; wider wins among same-sign; signed-can-represent rule for +mixed-sign mixed-width). +""" + +import pytest + +import flydsl.expr as fx +from flydsl._mlir.ir import Context, InsertionPoint, Location, Module + +pytestmark = [pytest.mark.l1b_target_dialect] + + +def _binop(lhs_ty, rhs_ty, op): + """Build two block-arg values of the requested DSL types and apply `op`. + + Returns the resulting Numeric. We use block args so the operands are + genuinely dynamic ir.Values (not Python literals), which is the path + most kernel code hits. + """ + with Context() as ctx: + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + module = Module.create() + from flydsl._mlir.dialects import func + from flydsl._mlir.ir import FunctionType + + with InsertionPoint(module.body): + f = func.FuncOp("k", FunctionType.get([lhs_ty.ir_type, rhs_ty.ir_type], [])) + entry = f.add_entry_block() + with InsertionPoint(entry): + a = lhs_ty(entry.arguments[0]) + b = rhs_ty(entry.arguments[1]) + result = op(a, b) + func.ReturnOp([]) + assert module.operation.verify() + return result + + +# Same-sign / same-width: must stay narrow (no auto-int32 promotion). +@pytest.mark.parametrize( + "ty", + [fx.Int8, fx.Int16, fx.Uint8, fx.Uint16, fx.Int32, fx.Int64, fx.Uint32, fx.Uint64, fx.Int128, fx.Uint128], +) +def test_same_type_stays_narrow(ty): + assert _binop(ty, ty, lambda a, b: a + b).dtype is ty + assert _binop(ty, ty, lambda a, b: a * b).dtype is ty + + +# Same-sign cross-width: wider wins. +@pytest.mark.parametrize( + "a,b,expected", + [ + (fx.Int8, fx.Int16, fx.Int16), + (fx.Int8, fx.Int32, fx.Int32), + (fx.Int16, fx.Int64, fx.Int64), + (fx.Uint8, fx.Uint16, fx.Uint16), + (fx.Uint16, fx.Uint64, fx.Uint64), + (fx.Int32, fx.Int128, fx.Int128), + (fx.Int64, fx.Int128, fx.Int128), + (fx.Uint32, fx.Uint128, fx.Uint128), + ], +) +def test_same_sign_wider_wins(a, b, expected): + assert _binop(a, b, lambda x, y: x + y).dtype is expected + assert _binop(b, a, lambda x, y: x + y).dtype is expected # commutative + + +# Mixed sign: unsigned wins iff u.width >= s.width, else signed. +@pytest.mark.parametrize( + "a,b,expected", + [ + (fx.Int32, fx.Uint32, fx.Uint32), # equal width → unsigned wins + (fx.Int32, fx.Uint64, fx.Uint64), # u wider → unsigned wins + (fx.Int64, fx.Uint32, fx.Int64), # s wider → signed (signed-can-represent) + (fx.Int8, fx.Uint16, fx.Uint16), # u wider → unsigned + (fx.Int16, fx.Uint8, fx.Int16), # s wider → signed + (fx.Int128, fx.Uint128, fx.Uint128), # equal width → unsigned + (fx.Int128, fx.Uint64, fx.Int128), # s wider → signed + (fx.Int128, fx.Uint32, fx.Int128), # s wider → signed + (fx.Uint128, fx.Int32, fx.Uint128), # u wider → unsigned + (fx.Uint128, fx.Int64, fx.Uint128), # u wider → unsigned + ], +) +def test_mixed_sign(a, b, expected): + assert _binop(a, b, lambda x, y: x + y).dtype is expected + assert _binop(b, a, lambda x, y: x + y).dtype is expected + + +# Python literal: as_numeric promotes int→Int32 (C++ `int` literal default), +# then C++ promotion runs. +def test_python_int_literal_promotes_via_int32(): + # Int8(arg) + 5 → Int8 + Int32 → Int32 (wider wins) + with Context() as ctx, Location.unknown(ctx): + ctx.allow_unregistered_dialects = True + module = Module.create() + from flydsl._mlir.dialects import func + from flydsl._mlir.ir import FunctionType + + with InsertionPoint(module.body): + f = func.FuncOp("k", FunctionType.get([fx.Int8.ir_type], [])) + entry = f.add_entry_block() + with InsertionPoint(entry): + a = fx.Int8(entry.arguments[0]) + r = a + 5 + func.ReturnOp([]) + assert module.operation.verify() + assert r.dtype is fx.Int32 + + +# Int + Float: promote to the float side. +@pytest.mark.parametrize( + "itype,ftype", + [ + (fx.Int8, fx.Float16), + (fx.Int32, fx.Float32), + (fx.Int64, fx.Float64), + (fx.Int128, fx.Float64), # no Float128; precision loss is expected and OK + ], +) +def test_int_plus_float(itype, ftype): + assert _binop(itype, ftype, lambda x, y: x + y).dtype is ftype + assert _binop(ftype, itype, lambda x, y: x + y).dtype is ftype + + +# Float + Float: wider wins. +@pytest.mark.parametrize( + "a,b,expected", + [ + (fx.Float16, fx.Float32, fx.Float32), + (fx.Float32, fx.Float64, fx.Float64), + (fx.Float16, fx.Float64, fx.Float64), + ], +) +def test_float_wider_wins(a, b, expected): + assert _binop(a, b, lambda x, y: x + y).dtype is expected + assert _binop(b, a, lambda x, y: x + y).dtype is expected + + +# Boolean arithmetic: bool + bool → Int32 (matches C++ "bool participates as int"). +def test_bool_plus_bool_widens_to_int32(): + with Context() as ctx, Location.unknown(ctx): + ctx.allow_unregistered_dialects = True + module = Module.create() + from flydsl._mlir.dialects import func + from flydsl._mlir.ir import FunctionType + + with InsertionPoint(module.body): + f = func.FuncOp("k", FunctionType.get([fx.Boolean.ir_type, fx.Boolean.ir_type], [])) + entry = f.add_entry_block() + with InsertionPoint(entry): + a = fx.Boolean(entry.arguments[0]) + b = fx.Boolean(entry.arguments[1]) + r = a + b + func.ReturnOp([]) + assert module.operation.verify() + assert r.dtype is fx.Int32 + + +# True division on integers: Python `/` lifts int/int to float. +@pytest.mark.parametrize( + "ty,expected", + [ + (fx.Int8, fx.Float32), + (fx.Int32, fx.Float32), + (fx.Int64, fx.Float64), + (fx.Int128, fx.Float64), + ], +) +def test_truediv_int_lifts_to_float(ty, expected): + assert _binop(ty, ty, lambda x, y: x / y).dtype is expected + + +# Floor division on integers: stays integer (Python `//` semantics). +@pytest.mark.parametrize("ty", [fx.Int8, fx.Int32, fx.Int64, fx.Uint32, fx.Int128]) +def test_floordiv_int_stays_int(ty): + assert _binop(ty, ty, lambda x, y: x // y).dtype is ty From 8287aa3728b3446ce4dbd4c3840824ab680b3c2b Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Tue, 2 Jun 2026 09:01:26 +0000 Subject: [PATCH 3/7] drop coerce in the derived --- python/flydsl/expr/derived.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/flydsl/expr/derived.py b/python/flydsl/expr/derived.py index 8d3204b51..8576249c6 100644 --- a/python/flydsl/expr/derived.py +++ b/python/flydsl/expr/derived.py @@ -7,7 +7,6 @@ from .meta import traced_op from .numeric import Boolean, Numeric from .primitive import * -from .primitive import _coerce_int_tuple from .typing import Int8, Layout, Tensor, TiledCopy, TiledMma __all__ = [ @@ -35,7 +34,7 @@ def __init__(self, tiled_copy: TiledCopy, thr_idx): super().__init__(tiled_copy) self.tiled_copy = tiled_copy self._thr_idx = thr_idx - self._thr_idx_int = _coerce_int_tuple(self.thr_idx) + self._thr_idx_int = make_int_tuple(self.thr_idx) @property def thr_idx(self): @@ -65,7 +64,7 @@ def __init__(self, tiled_mma: TiledMma, thr_idx): super().__init__(tiled_mma) self.tiled_mma = tiled_mma self._thr_idx = thr_idx - self._thr_idx_int = _coerce_int_tuple(self.thr_idx) + self._thr_idx_int = make_int_tuple(self.thr_idx) @property def thr_idx(self): From 8813c9adc394591ab2a4fba610656146ba1fd89e Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 5 Jun 2026 03:39:30 +0000 Subject: [PATCH 4/7] [Refactor]: rm IR value --- kernels/blockscale_preshuffle_gemm.py | 2 +- kernels/layernorm_kernel.py | 2 +- kernels/mfma_preshuffle_pipeline.py | 2 +- kernels/rmsnorm_kernel.py | 8 ++++---- python/flydsl/expr/primitive.py | 2 ++ 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/kernels/blockscale_preshuffle_gemm.py b/kernels/blockscale_preshuffle_gemm.py index c2517e37c..648e06c54 100644 --- a/kernels/blockscale_preshuffle_gemm.py +++ b/kernels/blockscale_preshuffle_gemm.py @@ -252,7 +252,7 @@ def load_b_packs_k64(base_k, ku: int, ni: int): k0_base = base_k_bytes // c64_b k0 = k0_base + ku k1 = lane_div_16 - coord_pack = (n_blk_list[ni], k0, k1, n_intra_list[ni], fx.Index(0)) + coord_pack = (n_blk_list[ni], k0, k1, n_intra_list[ni], fx.Int32(0)) idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b) b16 = _buffer_load_vec( buffer_ops, diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index ffc3530ad..b0dcdb7fc 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -720,7 +720,7 @@ def _load_norm_input_value(index): mean = sum_val / n_float var = sumsq_val / n_float - mean * mean var = (var < c_zero_f).select(c_zero_f, var) - rstd = (var + eps_c).rsqrt(fastmath=fm_fast) + rstd = fmath.rsqrt(var + eps_c, fastmath=fm_fast) thread_row_max = c_zero_f for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index 4cd0bc9b3..c556ee970 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -526,7 +526,7 @@ def tile_chunk_coord_i32( raise ValueError(f"chunk_i32 must be one of (1,2,4), got {chunk_i32!r}") chunk_off_i32 = arith.constant(i * total_threads * chunk_i32, index=True) tile_idx_i32 = tx_i32_base + chunk_off_i32 - coord_local = fx.idx2crd(tile_idx_i32, layout_tile_div4) + coord_local = fx.idx2crd(fx.Int32(tile_idx_i32), layout_tile_div4) row_local = fx.get(coord_local, 0) col_local_i32 = fx.get(coord_local, 1) return row_local, col_local_i32 diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index ce4bd0a98..66235110a 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -214,7 +214,7 @@ def block_reduce_add2(val0, val1): _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) mean_sq = sum_sq / n_float ms_eps = mean_sq + eps_c - rrms = ms_eps.rsqrt(fastmath=fm_fast) + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) # Pass 2: normalize + gamma + store (reuse cached input) for tile_i in range_constexpr(num_tiles): @@ -517,7 +517,7 @@ def block_reduce_add2(val0, val1): _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) mean_sq = sum_sq / n_float ms_eps = mean_sq + eps_c - rrms = ms_eps.rsqrt(fastmath=fm_fast) + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) # Pass 2: normalize + gamma + store (reuse cached added values) for tile_i in range_constexpr(num_tiles): @@ -790,7 +790,7 @@ def block_reduce_max(val): _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) mean_sq = sum_sq / n_float ms_eps = mean_sq + eps_c - rrms = ms_eps.rsqrt(fastmath=fm_fast) + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) thread_row_max = c_zero_f y_local = [] @@ -1176,7 +1176,7 @@ def block_reduce_max(val): _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) mean_sq = sum_sq / n_float ms_eps = mean_sq + eps_c - rrms = ms_eps.rsqrt(fastmath=fm_fast) + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) thread_row_max = c_zero_f y_local = [] diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 735bb56c8..bd00fd06f 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -232,6 +232,8 @@ def _expand_int_tuple_leaves(value, loc=None, ip=None): return value.value if isinstance(value, ir.Value) and isinstance(value.type, ir.IntegerType) and value.type.width < 32: return Int32(value, loc=loc, ip=ip).value + if isinstance(value, ir.Value) and isinstance(value.type, ir.IndexType): + return Int32(value, loc=loc, ip=ip).value return value From a22c8ecc1e711aceed9ed27344b119628638e573 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 5 Jun 2026 07:33:36 +0000 Subject: [PATCH 5/7] [Bugfix]: import __all__ --- python/flydsl/expr/arith.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/flydsl/expr/arith.py b/python/flydsl/expr/arith.py index c04998c2e..fa37e61d6 100644 --- a/python/flydsl/expr/arith.py +++ b/python/flydsl/expr/arith.py @@ -15,6 +15,25 @@ from .._mlir.dialects.arith import * # noqa: F401,F403 +__all__ = [ + "ArithValue", # Deprecated: will be removed in a future release + "_to_raw", # Deprecated: will be removed in a future release + "andi", + "constant", + "constant_vector", + "index", # Deprecated: will be removed in a future release + "index_cast", # Deprecated: will be removed in a future release + "int_to_fp", + "select", + "shli", + "sitofp", + "trunc_f", + "unwrap", # Deprecated: will be removed in a future release + "xori", + "cmpi", + "cmpf", +] + # Override star-import cmpi/cmpf to accept Numeric types (Int32, etc.) from .._mlir.dialects import arith as _mlir_arith from .meta import traced_op From 9587ccd6b478e1cec9f7a852dfeed5c7ce9068fb Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 5 Jun 2026 10:22:01 +0000 Subject: [PATCH 6/7] Fix DSL type compatibility in layout_utils.idx2crd Unwrap DSL types (Int32, etc.) to raw ir.Value at the entry of idx2crd, so that downstream arith ops (shrui, andi) receive proper ir.Value operands instead of Numeric wrappers. Also replace fragile hasattr/str-based type check with isinstance. Co-Authored-By: Claude Opus 4 --- kernels/layout_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/kernels/layout_utils.py b/kernels/layout_utils.py index ccb1d992e..0adc68eb9 100644 --- a/kernels/layout_utils.py +++ b/kernels/layout_utils.py @@ -86,12 +86,15 @@ def idx2crd(idx, layout): """ parsed = _parse_layout(layout) + if hasattr(idx, "ir_value"): + idx = idx.ir_value() + if parsed is None or _has_dynamic_strides(parsed[1]): result = fx.idx2crd(fx.Int32(idx), layout) ndims = len(parsed[1]) if parsed else 1 return [_wrap(fx.get(result, i)) for i in range(ndims)] - if hasattr(idx, "type") and str(idx.type) != "index": + if isinstance(idx, ir.Value) and not isinstance(idx.type, ir.IndexType): idx = arith.index_cast(T.index, idx) shapes, strides = parsed ndims = len(strides) From 0750258f9e0b28694dfd7eec6d8b4033fc469789 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Sat, 6 Jun 2026 16:57:43 +0000 Subject: [PATCH 7/7] Replace arith.constant(index=True) with fx.Int32() in mixed_moe_gemm_2stage Eliminate all index-typed arith.constant and arith.index calls in mixed_moe_gemm_2stage.py, replacing them with fx.Int32() DSL types. This prevents type mismatch errors (i32 vs index) when DSL-typed values (from fx.get, fx.idx2crd) participate in arithmetic with index-typed constants. 174 occurrences replaced. Co-Authored-By: Claude Opus 4 --- kernels/mixed_moe_gemm_2stage.py | 348 +++++++++++++++---------------- 1 file changed, 174 insertions(+), 174 deletions(-) diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index 712291931..67b1fccd9 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -457,7 +457,7 @@ def moe_gemm1( # B preshuffle layout: [E*2*inter_dim, model_dim] # Gate rows for expert e: [e*2*inter_dim, e*2*inter_dim + inter_dim) - c_n_total = arith.constant(experts * (2 * inter_dim), index=True) + c_n_total = fx.Int32(experts * (2 * inter_dim)) b_layout = make_preshuffle_b_layout( arith, c_n=c_n_total, @@ -470,13 +470,13 @@ def moe_gemm1( # A-scale: [sorted_size, K/32] -- pre-scattered by caller into sorted layout # Same as stage2: indexed by sorted_row position, not by token_id. - sorted_m = size_expert_ids_in * arith.constant(sort_block_m, index=True) + sorted_m = size_expert_ids_in * fx.Int32(sort_block_m) layout_a_scale = make_preshuffle_scale_layout( - arith, c_mn=sorted_m, c_k=arith.constant(model_dim, index=True) + arith, c_mn=sorted_m, c_k=fx.Int32(model_dim) ) # B-scale: [E*2*inter_dim, K/32] layout_b_scale = make_preshuffle_scale_layout( - arith, c_mn=c_n_total, c_k=arith.constant(model_dim, index=True) + arith, c_mn=c_n_total, c_k=fx.Int32(model_dim) ) _eff_lds_stride = lds_stride @@ -495,26 +495,26 @@ def moe_gemm1( if const_expr(xcd_swizzle > 0): _NUM_XCDS_S1 = 8 - _c1_sw = arith.constant(1, index=True) - _c_tn_sw = arith.constant(tile_n, index=True) - _c_idp_sw = arith.constant(2 * inter_dim_pad, index=True) + _c1_sw = fx.Int32(1) + _c_tn_sw = fx.Int32(tile_n) + _c_idp_sw = fx.Int32(2 * inter_dim_pad) if const_expr(mock_gate_only or gate_up_interleave): _gx = (n_in - _c_idp_sw + _c_tn_sw - _c1_sw) / _c_tn_sw else: - _c2_sw = arith.constant(2, index=True) + _c2_sw = fx.Int32(2) _gx = (n_in - _c_idp_sw + _c2_sw * _c_tn_sw - _c1_sw) / _c_tn_sw / _c2_sw - _c_pm_sw = arith.constant(persist_m, index=True) + _c_pm_sw = fx.Int32(persist_m) _gy = (size_expert_ids_in + _c_pm_sw - _c1_sw) / _c_pm_sw _linear_id = bx_persist * _gx + by _num_wgs = _gx * _gy - _c_xcds = arith.constant(_NUM_XCDS_S1, index=True) + _c_xcds = fx.Int32(_NUM_XCDS_S1) _wgs_per_xcd = _num_wgs / _c_xcds _wgid = (_linear_id % _c_xcds) * _wgs_per_xcd + (_linear_id / _c_xcds) _WGM_S1 = xcd_swizzle - _c_wgm = arith.constant(_WGM_S1, index=True) + _c_wgm = fx.Int32(_WGM_S1) _num_wgid_in_group = _c_wgm * _gx _group_id = _wgid / _num_wgid_in_group _first_pid_m = _group_id * _c_wgm @@ -526,14 +526,14 @@ def moe_gemm1( bx_persist = _first_pid_m + (_wgid_in_group % _group_size_m) by = _wgid_in_group / _group_size_m - by_n = by * arith.constant(tile_n, index=True) + by_n = by * fx.Int32(tile_n) - k_base_idx = arith.index(0) + k_base_idx = fx.Int32(0) if const_expr(_is_splitk): bz = gpu.block_id("z") # K-batch id - k_base_idx = bz * arith.constant(_k_dim, index=True) + k_base_idx = bz * fx.Int32(_k_dim) - k_blocks16 = arith.constant(_eff_tile_k_bytes // 16, index=True) + k_blocks16 = fx.Int32(_eff_tile_k_bytes // 16) layout_tx_wave_lane = fx.make_layout((num_waves, 64), stride=(64, 1)) layout_lane16 = fx.make_layout((4, 16), stride=(16, 1)) @@ -571,8 +571,8 @@ def moe_gemm1( lds_tid = SmemPtr(base_ptr_pong, _lds_tid_offset_pong, T.i32, shape=(tile_m,)).get() # Buffer resources - c_a_pack = arith.constant(int(a_elem_vec_pack), index=True) - c_elem_bytes = arith.constant(int(a_elem_bytes), index=True) + c_a_pack = fx.Int32(int(a_elem_vec_pack)) + c_elem_bytes = fx.Int32(int(a_elem_bytes)) # X: [tokens, model_dim] x_nbytes_idx = (tokens_in * k_in * c_elem_bytes) / c_a_pack @@ -587,13 +587,13 @@ def moe_gemm1( max_size=False, num_records_bytes=arith.constant(4, type=T.i32), ) - num_valid_i32 = buffer_ops.buffer_load(numids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=T.i32) + num_valid_i32 = buffer_ops.buffer_load(numids_rsrc, fx.Int32(0), vec_width=1, dtype=T.i32) sx_rsrc = 1 sw_rsrc = 1 if const_expr(not (is_f16_a or a_scale_one)): # A scale: [sorted_size, model_dim/32] pre-scattered by caller - c32 = arith.constant(32, index=True) + c32 = fx.Int32(32) kblk = k_in / c32 sx_nbytes_idx = sorted_m * kblk sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes_idx) @@ -602,16 +602,16 @@ def moe_gemm1( ) if const_expr(not is_f16_b): - c32 = arith.constant(32, index=True) + c32 = fx.Int32(32) kblk_w = k_in / c32 - mn_w = arith.constant(experts * (2 * inter_dim), index=True) + mn_w = fx.Int32(experts * (2 * inter_dim)) sw_nbytes_idx = mn_w * kblk_w sw_nbytes_i32 = arith.index_cast(T.i32, sw_nbytes_idx) sw_rsrc = buffer_ops.create_buffer_resource( arg_scale_w, max_size=False, num_records_bytes=sw_nbytes_i32 ) - sorted_nbytes_idx = size_expert_ids_in * arith.constant(sort_block_m * 4, index=True) + sorted_nbytes_idx = size_expert_ids_in * fx.Int32(sort_block_m * 4) sorted_nbytes_i32 = arith.index_cast(T.i32, sorted_nbytes_idx) sorted_rsrc = buffer_ops.create_buffer_resource( arg_sorted_token_ids, @@ -622,7 +622,7 @@ def moe_gemm1( arg_sorted_weights, max_size=False, num_records_bytes=sorted_nbytes_i32 ) - eid_nbytes_idx = size_expert_ids_in * arith.constant(4, index=True) + eid_nbytes_idx = size_expert_ids_in * fx.Int32(4) eid_nbytes_i32 = arith.index_cast(T.i32, eid_nbytes_idx) expert_rsrc = buffer_ops.create_buffer_resource( arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32 @@ -638,15 +638,15 @@ def moe_gemm1( # ---- persist_m loop (same pattern as stage2) ---- _PERSIST_M = persist_m - _c0_p = arith.constant(0, index=True) - _c1_p = arith.constant(1, index=True) - _c_pm = arith.constant(_PERSIST_M, index=True) + _c0_p = fx.Int32(0) + _c1_p = fx.Int32(1) + _c_pm = fx.Int32(_PERSIST_M) _for_persist = scf.ForOp(_c0_p, _c_pm, _c1_p) _for_ip = ir.InsertionPoint(_for_persist.body) _for_ip.__enter__() _mi_p = _for_persist.induction_variable bx = bx_persist * _c_pm + _mi_p - bx_m = bx * arith.constant(sort_block_m, index=True) + bx_m = bx * fx.Int32(sort_block_m) # Block validity bx_m_i32 = arith.index_cast(T.i32, bx_m) @@ -657,17 +657,17 @@ def moe_gemm1( def _moe_gemm1_body(): # Gate expert offset: first inter_dim rows of each expert's 2*inter_dim block - expert_off_idx = expert_idx * arith.constant(2 * inter_dim, index=True) + expert_off_idx = expert_idx * fx.Int32(2 * inter_dim) # X loading -- KEY DIFFERENCE from stage2: X row = token_id only x_load_bytes = 16 num_x_loads = bytes_per_thread_x // x_load_bytes chunk_i32 = x_load_bytes // 4 - c_k_div4 = ((k_in / c_a_pack) * arith.constant(int(a_elem_bytes), index=True)) / arith.index(4) + c_k_div4 = ((k_in / c_a_pack) * fx.Int32(int(a_elem_bytes))) / fx.Int32(4) tile_k_dwords = (int(tile_k) * int(a_elem_bytes)) // (4 * int(a_elem_vec_pack)) layout_x_tile_div4 = fx.make_layout((tile_m, tile_k_dwords), stride=(tile_k_dwords, 1)) - c_chunk_i32 = arith.constant(chunk_i32, index=True) + c_chunk_i32 = fx.Int32(chunk_i32) tx_i32_base = tx * c_chunk_i32 topk_i32 = arith.constant(topk) @@ -685,7 +685,7 @@ def x_tile_chunk_coord_i32(i: int): ) def load_x(idx_i32): - idx_elem = idx_i32 if a_elem_bytes == 1 else (idx_i32 * arith.index(2)) + idx_elem = idx_i32 if a_elem_bytes == 1 else (idx_i32 * fx.Int32(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, vector, @@ -720,7 +720,7 @@ def load_x(idx_i32): x_row_base_div4.append(t_idx * c_k_div4) def load_x_tile(base_k): - base_k_div4 = ((base_k / c_a_pack) * arith.constant(int(a_elem_bytes), index=True)) / arith.index(4) + base_k_div4 = ((base_k / c_a_pack) * fx.Int32(int(a_elem_bytes))) / fx.Int32(4) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] @@ -736,11 +736,11 @@ def load_x_tile(base_k): lane_div_16 = layout_get(coord_l16, 0) lane_mod_16 = layout_get(coord_l16, 1) row_a_lds = lane_mod_16 - col_offset_base = lane_div_16 * arith.constant(16, index=True) + col_offset_base = lane_div_16 * fx.Int32(16) num_acc_n = n_per_wave // 16 - c_n_per_wave = arith.constant(n_per_wave, index=True) - wave_n_id = wave_id % arith.constant(num_waves, index=True) + c_n_per_wave = fx.Int32(n_per_wave) + wave_n_id = wave_id % fx.Int32(num_waves) n_tile_base = wave_n_id * c_n_per_wave # N-tile precompute for gate AND up weights @@ -751,11 +751,11 @@ def load_x_tile(base_k): col_g_list = [] c_n0_static = experts * (2 * inter_dim) // 16 layout_n_blk_intra = fx.make_layout((c_n0_static, 16), stride=(16, 1)) - inter_idx = arith.constant(inter_dim, index=True) + inter_idx = fx.Int32(inter_dim) for i in range_constexpr(num_acc_n): offset = i * 16 - c_offset = arith.constant(offset, index=True) + c_offset = fx.Int32(offset) if const_expr(not gate_up_interleave): col_g = by_n + n_tile_base + c_offset + lane_mod_16 col_g_list.append(col_g) @@ -776,8 +776,8 @@ def load_x_tile(base_k): _gui_num_acc_n_out = num_acc_n // pack_N for _gui_i in range_constexpr(_gui_num_acc_n_out): _gui_offset = _gui_i * 16 - _gui_c_offset = arith.constant(_gui_offset, index=True) - _gui_col_g = (by_n + n_tile_base) // arith.constant(2, index=True) + _gui_c_offset + lane_mod_16 + _gui_c_offset = fx.Int32(_gui_offset) + _gui_col_g = (by_n + n_tile_base) // fx.Int32(2) + _gui_c_offset + lane_mod_16 col_g_list.append(_gui_col_g) m_repeat = tile_m // 16 @@ -794,11 +794,11 @@ def load_x_tile(base_k): # B load for gate and up separately def load_b_packs_k64(base_k, ku: int, n_blk, n_intra): - c64 = arith.constant(64, index=True) - base_k_bytes = base_k * arith.constant(int(b_elem_bytes), index=True) - k0 = base_k_bytes // c64 + arith.constant(ku, index=True) + c64 = fx.Int32(64) + base_k_bytes = base_k * fx.Int32(int(b_elem_bytes)) + k0 = base_k_bytes // c64 + fx.Int32(ku) k1 = lane_div_16 - coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True)) + coord_pack = (n_blk, k0, k1, n_intra, fx.Int32(0)) idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b) vec_elems = kpack_bytes // int(b_elem_bytes) b16 = _buffer_load_vec( @@ -847,11 +847,11 @@ def load_b_tile(base_k, ku_limit=k_unroll): _gate_scale_bases = [] _up_scale_bases = [] for _ni in range_constexpr(num_acc_n_packed): - _col_base = by_n + n_tile_base + arith.constant(_ni * 16 * pack_N, index=True) - _gate_mni = (expert_off_idx + _col_base) // arith.constant(32, index=True) + _col_base = by_n + n_tile_base + fx.Int32(_ni * 16 * pack_N) + _gate_mni = (expert_off_idx + _col_base) // fx.Int32(32) _gate_scale_bases.append(_gate_mni * layout_b_scale.stride_n0 + _scale_lane_elem) if const_expr(not mock_gate_only and not gate_up_interleave): - _up_mni = (expert_off_idx + inter_idx + _col_base) // arith.constant(32, index=True) + _up_mni = (expert_off_idx + inter_idx + _col_base) // fx.Int32(32) _up_scale_bases.append(_up_mni * layout_b_scale.stride_n0 + _scale_lane_elem) if const_expr(not a_scale_one): @@ -860,8 +860,8 @@ def load_b_tile(base_k, ku_limit=k_unroll): _a_mni = _mi + bx_m // scale_mn_pack // 16 _a_scale_bases.append(_a_mni * layout_a_scale.stride_n0 + _scale_lane_elem) - _c16_idx = arith.constant(16, index=True) - _c2_idx = arith.constant(2, index=True) + _c16_idx = fx.Int32(16) + _c2_idx = fx.Int32(2) _scale_mask_lo = arith.constant(0xFF, type=T.i32) _m_half_idx = arith.constant(0, type=T.i32) @@ -945,7 +945,7 @@ def prefetch_ab_scale_tile(base_k, ku_packed_limit=k_unroll_packed): up_b_scale.append(vector.from_elements(T.vec(1, T.i32), [us])) return [a_scale_tile, gate_b_scale, up_b_scale] - _lds_base_zero = arith.index(0) + _lds_base_zero = fx.Int32(0) def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): for i in range_constexpr(num_x_loads): @@ -960,7 +960,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): layout_lds=layout_lds, row_local=row_local, col_local_i32=col_local_i32, - tx_c4=arith.index(4), + tx_c4=fx.Int32(4), k_blocks16=k_blocks16, lds_base=_lds_base_zero, vec_part_i32x4=vec_x_in_parts[i], @@ -974,8 +974,8 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): _num_dma_loads = max(1, _eff_bytes_per_buffer // (total_threads * _dma_bytes)) def dma_x_tile_to_lds(base_k, lds_buffer): - c4_idx = arith.index(4) - base_k_div4 = ((base_k / c_a_pack) * arith.constant(int(elem_bytes), index=True)) / arith.index( + c4_idx = fx.Int32(4) + base_k_div4 = ((base_k / c_a_pack) * fx.Int32(int(elem_bytes))) / fx.Int32( 4 ) @@ -991,7 +991,7 @@ def dma_x_tile_to_lds(base_k, lds_buffer): if const_expr(i == 0): lds_addr = memref.extract_aligned_pointer_as_index( lds_buffer - ) + wave_id * arith.constant(_wave_size * _dma_bytes, index=True) + ) + wave_id * fx.Int32(_wave_size * _dma_bytes) lds_ptr_i64 = rocdl.readfirstlane(T.i64, arith.index_cast(T.i64, lds_addr)) else: lds_ptr_i64 = lds_ptr_i64 + arith.constant(total_threads * _dma_bytes, type=T.i64) @@ -1014,7 +1014,7 @@ def prefetch_x_to_lds(base_k, lds_buffer): def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16) - col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2)) + col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / fx.Int32(2)) idx_a16 = crd2idx([fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)], layout_lds) loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16]) a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) @@ -1028,7 +1028,7 @@ def prefetch_full_a_from_lds(lds_buffer, ku_limit=k_unroll): for k_idx in range_constexpr(ku_limit): col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack for mi_idx in range_constexpr(m_repeat): - mi_val = arith.constant(mi_idx * 16, index=True) + mi_val = fx.Int32(mi_idx * 16) curr_row = row_a_lds + mi_val a0, a1 = lds_load_packs_k64(curr_row, col_base, lds_buffer) if const_expr(is_f8_a): @@ -1065,18 +1065,18 @@ def compute_tile( bias_pf = [] for ni in range_constexpr(num_acc_n): _logical_col = ( - (by_n + n_tile_base) // arith.constant(2, index=True) - + arith.constant((ni // 2) * 16, index=True) + (by_n + n_tile_base) // fx.Int32(2) + + fx.Int32((ni // 2) * 16) + lane_mod_16 ) - _up_off = inter_idx if (ni % 2 == 1) else arith.constant(0, index=True) + _up_off = inter_idx if (ni % 2 == 1) else fx.Int32(0) bias_offset = expert_off_idx + _up_off + _logical_col bias_pf.append(_load_bias_scalar(bias_rsrc, bias_offset)) else: gate_bias_pf = [] up_bias_pf = [] if const_expr(not mock_gate_only) else None for ni in range_constexpr(num_acc_n): - global_n = by_n + n_tile_base + arith.constant(ni * 16, index=True) + lane_mod_16 + global_n = by_n + n_tile_base + fx.Int32(ni * 16) + lane_mod_16 gate_bias_pf.append(_load_bias_scalar(bias_rsrc, expert_off_idx + global_n)) if const_expr(not mock_gate_only): up_bias_pf.append( @@ -1089,10 +1089,10 @@ def compute_tile( tw_pf = None if const_expr(doweight_stage1): tw_pf = [] - lane_div_16_mul4_pf = lane_div_16 * arith.index(4) - ii_idx_list_pf = [arith.constant(ii, index=True) for ii in range(4)] + lane_div_16_mul4_pf = lane_div_16 * fx.Int32(4) + ii_idx_list_pf = [fx.Int32(ii) for ii in range(4)] for mi in range_constexpr(m_repeat): - mi_base_pf = arith.constant(mi * 16, index=True) + mi_base_pf = fx.Int32(mi * 16) for ii in range_constexpr(4): row_off_pf = lane_div_16_mul4_pf + ii_idx_list_pf[ii] sorted_row_pf = bx_m + mi_base_pf + row_off_pf @@ -1194,7 +1194,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): def load_a_subtile(k_idx, mi_idx, lds_buffer): """Load a single A sub-tile from LDS (one ds_read).""" col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack - mi_val = arith.constant(mi_idx * 16, index=True) + mi_val = fx.Int32(mi_idx * 16) curr_row = row_a_lds + mi_val a0, a1 = lds_load_packs_k64(curr_row, col_base, lds_buffer) if const_expr(is_f8_a): @@ -1309,9 +1309,9 @@ def _interleaved_half( Phase 1..N: B VMEM(distributed) + 2 ds_read(A, if avail) -> 4 MFMA(prev) Phase N+1..: remaining B VMEM -> 4 MFMA(prev) """ - _abs_k = k_base_idx + arith.constant(next_k_load, index=True) - _bk = _abs_k // arith.constant(2, index=True) - _sk = _abs_k // arith.constant(pack_K * 128, index=True) + _abs_k = k_base_idx + fx.Int32(next_k_load) + _bk = _abs_k // fx.Int32(2) + _sk = _abs_k // fx.Int32(pack_K * 128) _k_off = _sk * layout_b_scale.stride_k0 rocdl.sched_barrier(0) @@ -1320,7 +1320,7 @@ def _interleaved_half( rocdl.sched_barrier(0) # DMA A to OTHER buffer (for next half), non-blocking - _abs_k_dma = k_base_idx + arith.constant(next_k_dma_py, index=True) + _abs_k_dma = k_base_idx + fx.Int32(next_k_dma_py) if const_expr(use_async_copy and next_k_dma_py < int(_k_dim)): prefetch_x_to_lds(_abs_k_dma, lds_write) if const_expr(not use_async_copy): @@ -1526,9 +1526,9 @@ def _interleaved_half( x_regs0 = load_x_tile(k0) store_x_tile_to_lds(x_regs0, lds_x_pong) rocdl.sched_barrier(0) - _k0_scale = k_base_idx // arith.constant(pack_K * 128, index=True) + _k0_scale = k_base_idx // fx.Int32(pack_K * 128) a_scale_pong, gate_bs_pong, up_bs_pong = prefetch_ab_scale_tile(_k0_scale) - _c_tile_m_idx = arith.constant(tile_m, index=True) + _c_tile_m_idx = fx.Int32(tile_m) _tid_in_range = arith.cmpi(CmpIPredicate.ult, tx, _c_tile_m_idx) _if_tid = scf.IfOp(_tid_in_range) with ir.InsertionPoint(_if_tid.then_block): @@ -1541,7 +1541,7 @@ def _interleaved_half( acc_gate = [acc_init] * num_acc_n * m_repeat acc_up = [acc_init] * num_acc_n * m_repeat if not _single_b_pipe else None - _k1 = k_base_idx + arith.constant(tile_k, index=True) + _k1 = k_base_idx + fx.Int32(tile_k) rocdl.sched_barrier(0) if const_expr(use_async_copy): prefetch_x_to_lds(_k1, lds_x_ping) @@ -1549,7 +1549,7 @@ def _interleaved_half( _x_regs_prime = load_x_tile(_k1) store_x_tile_to_lds(_x_regs_prime, lds_x_ping) - _k0_b = k_base_idx // arith.constant(2, index=True) + _k0_b = k_base_idx // fx.Int32(2) gate_w0, up_w0 = load_b_tile(_k0_b) # Prime the deep pipeline: DMA K=tile_k -> ping (1 tile ahead) if const_expr(use_async_copy): @@ -1630,9 +1630,9 @@ def _interleaved_half( acc_up, ) - # _wave_mod2_b = wave_id % arith.constant(2, index=True) + # _wave_mod2_b = wave_id % fx.Int32(2) # _wave_odd = arith.cmpi( - # CmpIPredicate.eq, _wave_mod2_b, arith.constant(1, index=True) + # CmpIPredicate.eq, _wave_mod2_b, fx.Int32(1) # ) # _if_wave_odd = scf.IfOp(_wave_odd) # with ir.InsertionPoint(_if_wave_odd.then_block): @@ -1654,7 +1654,7 @@ def _interleaved_half( ku_count=_tail_ku if _pad_ku_skip > 0 else k_unroll, ) else: - _k_tail_rel = arith.constant(_k_dim - tile_k, index=True) + _k_tail_rel = fx.Int32(_k_dim - tile_k) k_tail1 = k_base_idx + _k_tail_rel x_regs_ping = [] if const_expr(use_async_copy): @@ -1663,17 +1663,17 @@ def _interleaved_half( x_regs_ping = load_x_tile(k_tail1) if const_expr(_pad_ku_skip > 0): gate_w_ping, up_w_ping = load_b_tile( - k_tail1 // arith.constant(2, index=True), + k_tail1 // fx.Int32(2), ku_limit=_tail_ku, ) a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( - k_tail1 // arith.constant(pack_K * 128, index=True), + k_tail1 // fx.Int32(pack_K * 128), ku_packed_limit=_tail_ku_packed, ) else: - gate_w_ping, up_w_ping = load_b_tile(k_tail1 // arith.constant(2, index=True)) + gate_w_ping, up_w_ping = load_b_tile(k_tail1 // fx.Int32(2)) a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( - k_tail1 // arith.constant(pack_K * 128, index=True) + k_tail1 // fx.Int32(pack_K * 128) ) acc_gate, acc_up, _ = compute_tile( acc_gate, @@ -1792,20 +1792,20 @@ def _act_vec4(gate_v4, up_v4): for _ni in range_constexpr(num_acc_n): if const_expr(gate_up_interleave): _logical_col = ( - (by_n + n_tile_base) // arith.constant(2, index=True) - + arith.constant((_ni // 2) * 16, index=True) + (by_n + n_tile_base) // fx.Int32(2) + + fx.Int32((_ni // 2) * 16) + lane_mod_16 ) - _up_off = inter_idx if (_ni % 2 == 1) else arith.constant(0, index=True) + _up_off = inter_idx if (_ni % 2 == 1) else fx.Int32(0) _bias_off = expert_off_idx + _up_off + _logical_col else: - _bn = by_n + n_tile_base + arith.constant(_ni * 16, index=True) + lane_mod_16 + _bn = by_n + n_tile_base + fx.Int32(_ni * 16) + lane_mod_16 _bias_off = expert_off_idx + _bn _bias_gate_vals.append(_load_bias_scalar(bias_rsrc, _bias_off)) if const_expr(not (mock_gate_only or gate_up_interleave)): _bias_up_vals = [] for _ni in range_constexpr(num_acc_n): - _bn = by_n + n_tile_base + arith.constant(_ni * 16, index=True) + lane_mod_16 + _bn = by_n + n_tile_base + fx.Int32(_ni * 16) + lane_mod_16 _bias_up_vals.append(_load_bias_scalar(bias_rsrc, expert_off_idx + inter_idx + _bn)) for _mi in range_constexpr(m_repeat): for _ni in range_constexpr(num_acc_n): @@ -1912,8 +1912,8 @@ def precompute_row(*, row_local, row): row_valid = arith.andi(row_valid0, arith.andi(t_ok, s_ok)) t_idx = arith.index_cast(ir.IndexType.get(), t) s_idx = arith.index_cast(ir.IndexType.get(), s) - ts_idx = t_idx * arith.constant(topk, index=True) + s_idx - row_byte_base = out_base_idx + ts_idx * arith.constant(_out_row_stride, index=True) + ts_idx = t_idx * fx.Int32(topk) + s_idx + row_byte_base = out_base_idx + ts_idx * fx.Int32(_out_row_stride) return ((fused2, row_byte_base), row_valid) def _idx_to_llvm_ptr(idx_val, addr_space=1): @@ -2027,7 +2027,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): byte_k = fp4_vals[2 * k] | (fp4_vals[2 * k + 1] << _c4_i32) packed_i32 = packed_i32 | (byte_k << arith.constant(k * 8, type=T.i32)) - ptr_addr_idx = row_byte_base + col_g0 / arith.constant(2, index=True) + ptr_addr_idx = row_byte_base + col_g0 / fx.Int32(2) out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) _pack_bytes = _e_vec // 2 if const_expr(_pack_bytes == 1): @@ -2094,7 +2094,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): packed_w, 1, ) - word_ptr = ptr_addr_idx + arith.constant(_wg * 4, index=True) + word_ptr = ptr_addr_idx + fx.Int32(_wg * 4) out_ptr_v = _idx_to_llvm_ptr(word_ptr) packed_raw = packed_w._value if hasattr(packed_w, "_value") else packed_w llvm.StoreOp( @@ -2129,8 +2129,8 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): ) scf.YieldOp([]) elif const_expr(_is_splitk): - col_idx = col_g0 + arith.constant(_sk_n_offset[0], index=True) - byte_off_col = col_idx * arith.constant(out_elem_bytes, index=True) + col_idx = col_g0 + fx.Int32(_sk_n_offset[0]) + byte_off_col = col_idx * fx.Int32(out_elem_bytes) ptr_addr_idx = row_byte_base + byte_off_col out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) frag_v = frag._value if hasattr(frag, "_value") else frag @@ -2144,7 +2144,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): ) else: col_idx = col_g0 - byte_off_col = col_idx * arith.constant(out_elem_bytes, index=True) + byte_off_col = col_idx * fx.Int32(out_elem_bytes) ptr_addr_idx = row_byte_base + byte_off_col out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) frag_v = frag._value if hasattr(frag, "_value") else frag @@ -2164,8 +2164,8 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): _gui_eff_n = _gui_out_n _gui_tile_n = tile_n // 2 _gui_cshuffle_nlane = min(32, _gui_tile_n // _e_vec) - _gui_by_n = by_n / arith.constant(2, index=True) - _gui_n_tile_base = n_tile_base / arith.constant(2, index=True) + _gui_by_n = by_n / fx.Int32(2) + _gui_n_tile_base = n_tile_base / fx.Int32(2) c_shuffle_epilog( arith=arith, vector=vector, @@ -2378,17 +2378,17 @@ def launch_mixed_moe_gemm1( allocator_ping.finalize() inter_in = arith.index_cast(ir.IndexType.get(), i32_inter_in.ir_value()) - tile_n_index = arith.constant(tile_n, index=True) - inter_dim_pad_total = arith.constant(2 * inter_dim_pad, index=True) + tile_n_index = fx.Int32(tile_n) + inter_dim_pad_total = fx.Int32(2 * inter_dim_pad) if const_expr(mock_gate_only or gate_up_interleave): gx = (inter_in - inter_dim_pad_total + tile_n_index - 1) / tile_n_index else: - gx = (inter_in - inter_dim_pad_total + 2 * tile_n_index - 1) / tile_n_index / arith.constant(2, index=True) - _c_pm_l = arith.constant(persist_m, index=True) + gx = (inter_in - inter_dim_pad_total + 2 * tile_n_index - 1) / tile_n_index / fx.Int32(2) + _c_pm_l = fx.Int32(persist_m) gy = ( arith.index_cast(ir.IndexType.get(), i32_size_expert_ids_in.ir_value()) + _c_pm_l - - arith.constant(1, index=True) + - fx.Int32(1) ) / _c_pm_l moe_gemm1( @@ -2697,11 +2697,11 @@ def moe_gemm2( acc_init = arith.constant_vector(0, vec4_i32) if is_int8 else arith.constant_vector(0.0, vec4_f32) # A2 layout (flatten token-slot -> M; use i32 for fly.make_shape). - topk_idx = arith.constant(topk, index=True) + topk_idx = fx.Int32(topk) m_in = tokens_in * topk_idx # B preshuffle layout: [experts*model_dim, inter_dim] - c_n_total = arith.constant(experts * model_dim, index=True) + c_n_total = fx.Int32(experts * model_dim) kpack_bytes = 8 if is_int4 else 16 from .layout_utils import _div_pow2, _mod_pow2 @@ -2713,7 +2713,7 @@ def check_c_k_valid_gate(base_k): # A&B's scale preshuffle layout # For fp4, k_in is already packed (inter_dim // a_elem_vec_pack), so we need original inter_dim - c_k_orig = arith.constant(inter_dim, index=True) + c_k_orig = fx.Int32(inter_dim) layout_a_scale = make_preshuffle_scale_layout(arith, c_mn=m_in, c_k=c_k_orig) layout_b_scale = make_preshuffle_scale_layout(arith, c_mn=c_n_total, c_k=c_k_orig) @@ -2727,25 +2727,25 @@ def check_c_k_valid_gate(base_k): if const_expr(xcd_swizzle > 0): _NUM_XCDS_S = 8 - _c1_sw = arith.constant(1, index=True) - _c_tn_sw = arith.constant(tile_n, index=True) - _c_mdp_sw = arith.constant(model_dim_pad, index=True) + _c1_sw = fx.Int32(1) + _c_tn_sw = fx.Int32(tile_n) + _c_mdp_sw = fx.Int32(model_dim_pad) _gx = (n_in - _c_mdp_sw + _c_tn_sw - _c1_sw) / _c_tn_sw if const_expr(_persistent): - _gy = arith.constant(_cu_num, index=True) + _gy = fx.Int32(_cu_num) else: - _c_pm_sw = arith.constant(persist_m, index=True) + _c_pm_sw = fx.Int32(persist_m) _gy = (size_expert_ids_in + _c_pm_sw - _c1_sw) / _c_pm_sw _linear_id = bx_persist * _gx + by _num_wgs = _gx * _gy - _c_xcds = arith.constant(_NUM_XCDS_S, index=True) + _c_xcds = fx.Int32(_NUM_XCDS_S) _wgs_per_xcd = _num_wgs / _c_xcds _wgid = (_linear_id % _c_xcds) * _wgs_per_xcd + (_linear_id / _c_xcds) _WGM_S = xcd_swizzle - _c_wgm = arith.constant(_WGM_S, index=True) + _c_wgm = fx.Int32(_WGM_S) _num_wgid_in_group = _c_wgm * _gx _group_id = _wgid / _num_wgid_in_group _first_pid_m = _group_id * _c_wgm @@ -2758,7 +2758,7 @@ def check_c_k_valid_gate(base_k): by = _wgid_in_group / _group_size_m # XOR16 swizzle parameter (in bytes; constant, power-of-two in our configs). - k_blocks16 = arith.constant(_eff_tile_k_bytes // 16, index=True) + k_blocks16 = fx.Int32(_eff_tile_k_bytes // 16) layout_tx_wave_lane = fx.make_layout((4, 64), stride=(64, 1)) layout_lane16 = fx.make_layout((4, 16), stride=(16, 1)) @@ -2781,12 +2781,12 @@ def check_c_k_valid_gate(base_k): # Buffer resources. # For dynamic memrefs, `max_size=False` cannot infer the logical size from the memref *type*, # so we should pass `num_records_bytes` explicitly for stable hardware OOB behavior. - c_topk = arith.constant(topk, index=True) + c_topk = fx.Int32(topk) # X(A2): buffer size in bytes, accounting for FP4 packing (2 elements per byte). # fp8/int8: 1 byte per element -> bytes = tokens*topk * K # fp4: 2 elements per byte -> bytes = tokens*topk * K / 2 - c_elem_bytes = arith.constant(int(a_elem_bytes), index=True) + c_elem_bytes = fx.Int32(int(a_elem_bytes)) x_nbytes_idx = _div_pow2((tokens_in * c_topk) * k_in * c_elem_bytes, int(a_elem_vec_pack)) x_nbytes_i32 = arith.index_cast(T.i32, x_nbytes_idx) x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=x_nbytes_i32) @@ -2795,9 +2795,9 @@ def check_c_k_valid_gate(base_k): # OUT: [tokens, model_dim] -> clamp to descriptor max (i32 bytes) to avoid overflow on huge tokens. out_elem_bytes = 4 if out_is_f32 else 2 - out_nbytes_idx = tokens_in * n_in * arith.constant(out_elem_bytes, index=True) + out_nbytes_idx = tokens_in * n_in * fx.Int32(out_elem_bytes) if const_expr(not bool(accumulate)): - out_nbytes_idx = tokens_in * arith.index(topk) * n_in * arith.constant(out_elem_bytes, index=True) + out_nbytes_idx = tokens_in * fx.Int32(topk) * n_in * fx.Int32(out_elem_bytes) out_nbytes_i32 = arith.index_cast(T.i32, out_nbytes_idx) out_rsrc = buffer_ops.create_buffer_resource(arg_out, max_size=False, num_records_bytes=out_nbytes_i32) @@ -2807,7 +2807,7 @@ def check_c_k_valid_gate(base_k): max_size=False, num_records_bytes=arith.constant(4, type=T.i32), ) - num_valid_i32 = buffer_ops.buffer_load(numids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=T.i32) + num_valid_i32 = buffer_ops.buffer_load(numids_rsrc, fx.Int32(0), vec_width=1, dtype=T.i32) # num_valid_ids is a scalar (same value for all lanes) loaded into # VGPR. Promote to SGPR so downstream buffer resource descriptors # that use it for num_records stay in SGPRs, eliminating the @@ -2830,7 +2830,7 @@ def check_c_k_valid_gate(base_k): ) else: # scale_x (A2 scale): [tokens*topk] f32 -> bytes = tokens*topk*4 - sx_nbytes_idx = (tokens_in * c_topk) * arith.constant(4, index=True) + sx_nbytes_idx = (tokens_in * c_topk) * fx.Int32(4) sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes_idx) sx_rsrc = buffer_ops.create_buffer_resource( arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 @@ -2840,7 +2840,7 @@ def check_c_k_valid_gate(base_k): # Weight microscale buffer (packed i32 holding e8m0 bytes). # Use an exact descriptor size so hardware OOB checking works. kblk_w = _div_pow2(k_in, 32) # K/32 - mn_w = arith.constant(experts * model_dim, index=True) + mn_w = fx.Int32(experts * model_dim) sw_nbytes_idx = mn_w * kblk_w # bytes (e8m0) sw_nbytes_i32 = arith.index_cast(T.i32, sw_nbytes_idx) sw_rsrc = buffer_ops.create_buffer_resource( @@ -2848,7 +2848,7 @@ def check_c_k_valid_gate(base_k): ) # sorted_token_ids / sorted_weights: [blocks*tile_m] (padded length) - sorted_nbytes_idx = size_expert_ids_in * arith.constant(tile_m, index=True) * arith.constant(4, index=True) + sorted_nbytes_idx = size_expert_ids_in * fx.Int32(tile_m) * fx.Int32(4) sorted_nbytes_i32 = arith.index_cast(T.i32, sorted_nbytes_idx) sorted_rsrc = buffer_ops.create_buffer_resource( arg_sorted_token_ids, @@ -2860,11 +2860,11 @@ def check_c_k_valid_gate(base_k): ) # expert ids: [sort_blocks] i32. - _c_sbm = arith.constant(_sort_block_m, index=True) - _c_tm = arith.constant(tile_m, index=True) - _c1 = arith.constant(1, index=True) + _c_sbm = fx.Int32(_sort_block_m) + _c_tm = fx.Int32(tile_m) + _c1 = fx.Int32(1) _sort_blocks_ub = _div_pow2(size_expert_ids_in * _c_tm + _c_sbm - _c1, _sort_block_m) - eid_nbytes_idx = _sort_blocks_ub * arith.constant(4, index=True) + eid_nbytes_idx = _sort_blocks_ub * fx.Int32(4) eid_nbytes_i32 = arith.index_cast(T.i32, eid_nbytes_idx) expert_rsrc = buffer_ops.create_buffer_resource( arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32 @@ -2872,16 +2872,16 @@ def check_c_k_valid_gate(base_k): bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False) if enable_bias else None # ---- persist loop ---- - _c0_p = arith.constant(0, index=True) - _c1_p = arith.constant(1, index=True) + _c0_p = fx.Int32(0) + _c1_p = fx.Int32(1) if const_expr(_persistent): # Expert-phase scheduling: contiguous M-tile dispatch. # grid_y = cu_num, each CTA handles a contiguous chunk of M-tiles: # [bx_persist * tiles_per_block, ..., (bx_persist+1) * tiles_per_block - 1] # Adjacent blocks process adjacent M-tiles -> same expert -> B weight L2 reuse. - _c_cu = arith.constant(_cu_num, index=True) - _c_tm_p = arith.constant(tile_m, index=True) + _c_cu = fx.Int32(_cu_num) + _c_tm_p = fx.Int32(tile_m) _num_valid_idx = arith.index_cast(ir.IndexType.get(), num_valid_i32) _total_m_tiles = (_num_valid_idx + _c_tm_p - _c1_p) / _c_tm_p _tiles_per_block = (_total_m_tiles + _c_cu - _c1_p) / _c_cu @@ -2890,9 +2890,9 @@ def check_c_k_valid_gate(base_k): _for_persist = scf.ForOp(_c0_p, _tiles_per_block, _c1_p, [_init_active]) else: # Legacy mode: fixed persist_m consecutive tiles. - _c_pm = arith.constant(persist_m, index=True) + _c_pm = fx.Int32(persist_m) _init_prev_expert = arith.constant(0, type=T.i32) - _init_prev_b_base = arith.constant(0, index=True) + _init_prev_b_base = fx.Int32(0) _for_persist = scf.ForOp( _c0_p, _c_pm, @@ -2910,9 +2910,9 @@ def check_c_k_valid_gate(base_k): else: _prev_expert_i32 = _for_persist.inner_iter_args[0] _prev_expert_b_base = _for_persist.inner_iter_args[1] - bx = bx_persist * arith.constant(persist_m, index=True) + _mi_p + bx = bx_persist * fx.Int32(persist_m) + _mi_p - bx_m = bx * arith.constant(tile_m, index=True) + bx_m = bx * fx.Int32(tile_m) # Early-exit guard: skip garbage expert blocks beyond `num_valid_ids`. bx_m_i32 = arith.index_cast(T.i32, bx_m) @@ -2925,12 +2925,12 @@ def check_c_k_valid_gate(base_k): if const_expr(_persistent): # Absolute B-base: no cross-iteration state needed. - _expert_b_base = expert_idx * arith.constant(_expert_b_stride, index=True) + _expert_b_base = expert_idx * fx.Int32(_expert_b_stride) else: # Legacy incremental B-base: delta = (cur - prev) * stride _delta_expert = arith.subi(expert_i32, _prev_expert_i32) _delta_expert_idx = arith.index_cast(ir.IndexType.get(), _delta_expert) - _delta_b = _delta_expert_idx * arith.constant(_expert_b_stride, index=True) + _delta_b = _delta_expert_idx * fx.Int32(_expert_b_stride) _expert_b_base = _prev_expert_b_base + _delta_b # Early-exit: if the first row of this tile is a sentinel (all-padding tile), @@ -2944,13 +2944,13 @@ def check_c_k_valid_gate(base_k): # correct bytes land at the op_sel positions we use. if const_expr(pack_M < _scale_pack_m): _m_off = _mod_pow2(_div_pow2(bx_m, 16), _scale_pack_m) - _m_scale_shift_i32 = arith.index_cast(T.i32, _m_off * arith.constant(8, index=True)) + _m_scale_shift_i32 = arith.index_cast(T.i32, _m_off * fx.Int32(8)) else: _m_scale_shift_i32 = None def _moe_gemm2_then_body(): # Expert id for this M tile. - n_idx = arith.constant(model_dim, index=True) + n_idx = fx.Int32(model_dim) expert_off_idx = expert_idx * n_idx # index # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- @@ -2976,12 +2976,12 @@ def _moe_gemm2_then_body(): vec4_i32 = T.vec(4, i32) c_k_div4 = _div_pow2( - _div_pow2(k_in, int(a_elem_vec_pack)) * arith.constant(int(a_elem_bytes), index=True), + _div_pow2(k_in, int(a_elem_vec_pack)) * fx.Int32(int(a_elem_bytes)), 4, ) tile_k_dwords = (int(tile_k) * int(a_elem_bytes)) // (4 * int(a_elem_vec_pack)) layout_x_tile_div4 = fx.make_layout((tile_m, tile_k_dwords), stride=(tile_k_dwords, 1)) - c_chunk_i32 = arith.constant(chunk_i32, index=True) + c_chunk_i32 = fx.Int32(chunk_i32) tx_i32_base = tx * c_chunk_i32 topk_i32 = arith.constant(topk) @@ -3009,7 +3009,7 @@ def load_x(idx_i32): For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. """ if const_expr(x_load_bytes == 16): - idx_elem = idx_i32 if a_elem_bytes == 1 else (idx_i32 * arith.index(2)) + idx_elem = idx_i32 if a_elem_bytes == 1 else (idx_i32 * fx.Int32(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, vector, @@ -3019,7 +3019,7 @@ def load_x(idx_i32): vec_elems=vec16_elems, ) # 8B/4B: convert dword index to byte offset and use offset_in_bytes path. - idx_bytes = idx_i32 * arith.index(4) + idx_bytes = idx_i32 * fx.Int32(4) return _buffer_load_vec( buffer_ops, vector, @@ -3057,7 +3057,7 @@ def load_x(idx_i32): def load_x_tile(base_k): base_k_div4 = _div_pow2( - _div_pow2(base_k, int(a_elem_vec_pack)) * arith.constant(int(a_elem_bytes), index=True), + _div_pow2(base_k, int(a_elem_vec_pack)) * fx.Int32(int(a_elem_bytes)), 4, ) parts = [] @@ -3083,22 +3083,22 @@ def load_x_tile(base_k): row_a_lds = lane_mod_16 - col_offset_base = lane_div_16 * arith.constant(16, index=True) + col_offset_base = lane_div_16 * fx.Int32(16) # Dynamic N tiling within block. num_waves = 4 n_per_wave = tile_n // num_waves num_acc_n = n_per_wave // 16 - c_n_per_wave = arith.constant(n_per_wave, index=True) + c_n_per_wave = fx.Int32(n_per_wave) wave_mod_4 = _mod_pow2(wave_id, 4) n_tile_base = wave_mod_4 * c_n_per_wave - by_n = by * arith.constant(tile_n, index=True) + by_n = by * fx.Int32(tile_n) if const_expr(pack_N < _scale_pack_n): _global_n_base = expert_off_idx + by_n + n_tile_base _n_off = _mod_pow2(_div_pow2(_global_n_base, 16), _scale_pack_n) - _n_scale_shift_i32 = arith.index_cast(T.i32, _n_off * arith.constant(8, index=True)) + _n_scale_shift_i32 = arith.index_cast(T.i32, _n_off * fx.Int32(8)) else: _n_scale_shift_i32 = None n_intra_list = [None] * num_acc_n @@ -3110,7 +3110,7 @@ def load_x_tile(base_k): col_g = _div_pow2(col_g, 2) + offset col_g = col_g + lane_mod_16 col_g_list[i] = col_g - c_offset = arith.constant(offset, index=True) + c_offset = fx.Int32(offset) global_n = by_n + n_tile_base + c_offset + lane_mod_16 n_blk_list[i] = _div_pow2(global_n, 16) n_intra_list[i] = _mod_pow2(global_n, 16) @@ -3132,9 +3132,9 @@ def load_x_tile(base_k): # --- B Load Logic (K64) - shared layout with preshuffle GEMM --- def load_b_packs_k64(base_k, ku: int, ni: int): """Load one K64-byte B micro-step: single 16B load, split into 2x i64.""" - base_k_bytes = base_k * arith.constant(int(b_elem_bytes), index=True) + base_k_bytes = base_k * fx.Int32(int(b_elem_bytes)) k0_base = _div_pow2(base_k_bytes, 64) - k0 = k0_base + arith.constant(ku, index=True) + k0 = k0_base + fx.Int32(ku) k1 = lane_div_16 # Incremental B addressing: _expert_b_base carries the # expert's preshuffle offset (updated via delta each @@ -3143,10 +3143,10 @@ def load_b_packs_k64(base_k, ku: int, ni: int): # compile-time constants -> shift/mul, no Barrett. idx_pack = ( _expert_b_base - + n_blk_list[ni] * arith.constant(_b_stride_n0, index=True) - + k0 * arith.constant(_b_stride_k0, index=True) - + k1 * arith.constant(_b_stride_klane, index=True) - + n_intra_list[ni] * arith.constant(_b_stride_nlane, index=True) + + n_blk_list[ni] * fx.Int32(_b_stride_n0) + + k0 * fx.Int32(_b_stride_k0) + + k1 * fx.Int32(_b_stride_klane) + + n_intra_list[ni] * fx.Int32(_b_stride_nlane) ) vec_elems = kpack_bytes // int(b_elem_bytes) @@ -3274,7 +3274,7 @@ def prefetch_ab_scale_tile(base_k, k_shift_bits=0, ku_packed_limit=k_unroll_pack vec4_x_lds = T.vec(vec4_elems, x_elem) # ---- Pipeline helpers: store X tile to LDS (unused in DMA path) ---- - _lds_base_zero = arith.index(0) + _lds_base_zero = fx.Int32(0) def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): for i in range_constexpr(num_x_loads): @@ -3289,7 +3289,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): layout_lds=layout_lds, row_local=row_local, col_local_i32=col_local_i32, - tx_c4=arith.index(4), + tx_c4=fx.Int32(4), k_blocks16=k_blocks16, lds_base=_lds_base_zero, vec_part_i32x4=vec_x_in_parts[i], @@ -3304,7 +3304,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): layout_lds=layout_lds, row_local=row_local, col_local_i32=col_local_i32, - tx_c4=arith.index(4), + tx_c4=fx.Int32(4), k_blocks16=k_blocks16, lds_base=_lds_base_zero, vec_part_i32x2=vec_x_in_parts[i], @@ -3319,7 +3319,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): layout_lds=layout_lds, row_local=row_local, col_local_i32=col_local_i32, - tx_c4=arith.index(4), + tx_c4=fx.Int32(4), k_blocks16=k_blocks16, lds_base=_lds_base_zero, vec_part_i32x1=vec_x_in_parts[i], @@ -3329,7 +3329,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): # --- A LDS load helper for K64 (load 16B once, extract 2x i64 halves) --- def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16) - col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2)) + col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / fx.Int32(2)) idx_a16 = crd2idx([fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)], layout_lds) loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16]) a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) @@ -3371,10 +3371,10 @@ def compute_tile( tw_pf = None if const_expr(doweight_stage2): tw_pf = [] - lane_div_16_mul4_pf = lane_div_16 * arith.index(4) - ii_idx_list_pf = [arith.constant(ii, index=True) for ii in range(4)] + lane_div_16_mul4_pf = lane_div_16 * fx.Int32(4) + ii_idx_list_pf = [fx.Int32(ii) for ii in range(4)] for mi in range_constexpr(m_repeat): - mi_base_pf = arith.constant(mi * 16, index=True) + mi_base_pf = fx.Int32(mi * 16) for ii in range_constexpr(4): row_off_pf = lane_div_16_mul4_pf + ii_idx_list_pf[ii] row_in_tile_pf = mi_base_pf + row_off_pf @@ -3434,7 +3434,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for imxdl in range_constexpr(pack_M): col_base0 = col_base mi_idx = mi * pack_M + imxdl - mi_val = arith.constant(mi_idx * 16, index=True) + mi_val = fx.Int32(mi_idx * 16) curr_row_a_lds = row_a_lds + mi_val if const_expr((a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0)): @@ -3484,9 +3484,9 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): _num_dma_loads = max(1, _eff_bytes_per_buffer // (total_threads * _dma_bytes)) def dma_x_tile_to_lds(base_k, lds_buffer): - c4_idx = arith.index(4) + c4_idx = fx.Int32(4) base_k_div4 = _div_pow2( - _div_pow2(base_k, int(a_elem_vec_pack)) * arith.constant(int(a_elem_bytes), index=True), + _div_pow2(base_k, int(a_elem_vec_pack)) * fx.Int32(int(a_elem_bytes)), 4, ) @@ -3538,7 +3538,7 @@ def _k_base(k_py): # Preload sorted_idx into lds_tid for epilogue precompute_row # (N-independent; placed before N-tile loop so it's done once per M-tile.) - _c_tile_m_idx = arith.constant(tile_m, index=True) + _c_tile_m_idx = fx.Int32(tile_m) _tid_in_range = arith.cmpi(CmpIPredicate.ult, tx, _c_tile_m_idx) _if_tid = scf.IfOp(_tid_in_range) with ir.InsertionPoint(_if_tid.then_block): @@ -3551,7 +3551,7 @@ def _k_base(k_py): gpu.barrier() # Prologue -- B-first + async DMA X(0) -> pong. - k0 = arith.index(0) + k0 = fx.Int32(0) if const_expr(_b_split_enabled): b_cur = load_b_tile_lo(k0) else: @@ -3581,7 +3581,7 @@ def _k_base(k_py): if const_expr(k_main2_py < 0): k_main2_py = 0 - c2_tile_k = arith.constant(tile_k * 2, index=True) + c2_tile_k = fx.Int32(tile_k * 2) b_pong = b_cur k0_pong_bk = k0 @@ -3596,7 +3596,7 @@ def _make_b_hi_loader(base_k): if const_expr(k_main2_py > 0): for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): rocdl.sched_barrier(0) - k_iv = arith.index(k_iv_py) + k_iv = fx.Int32(k_iv_py) next_k1 = k_iv + tile_k next_k1_bk = next_k1 // 2 # DMA X(next_k1) -> ping (non-blocking, overlaps with compute) @@ -3828,11 +3828,11 @@ def precompute_row(*, row_local, row): row_valid = arith.andi(row_valid0, arith.andi(t_ok, s_ok)) t_idx = arith.index_cast(ir.IndexType.get(), t) s_idx = arith.index_cast(ir.IndexType.get(), s) - ts_idx = t_idx * arith.constant(topk, index=True) + s_idx + ts_idx = t_idx * fx.Int32(topk) + s_idx if const_expr(accumulate): - row_byte_base = out_base_idx + t_idx * arith.constant(model_dim * out_elem_bytes, index=True) + row_byte_base = out_base_idx + t_idx * fx.Int32(model_dim * out_elem_bytes) else: - row_byte_base = out_base_idx + ts_idx * arith.constant(model_dim * out_elem_bytes, index=True) + row_byte_base = out_base_idx + ts_idx * fx.Int32(model_dim * out_elem_bytes) return ((fused2, row_byte_base), row_valid) def _idx_to_llvm_ptr(idx_val, addr_space=1): @@ -3848,7 +3848,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): if const_expr(not bool(accumulate)): # ---- 64-bit global store path (avoids i32 offset overflow) ---- col_idx = col_g0 - byte_off_col = col_idx * arith.constant(out_elem_bytes, index=True) + byte_off_col = col_idx * fx.Int32(out_elem_bytes) ptr_addr_idx = row_byte_base + byte_off_col out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) frag_v = frag._value if hasattr(frag, "_value") else frag @@ -3861,7 +3861,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): else: # ---- accumulate=True: 64-bit global atomic path ---- col_idx = col_g0 - byte_off_col = col_idx * arith.constant(out_elem_bytes, index=True) + byte_off_col = col_idx * fx.Int32(out_elem_bytes) ptr_addr_idx = row_byte_base + byte_off_col out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) frag_v = frag._value if hasattr(frag, "_value") else frag @@ -3971,17 +3971,17 @@ def launch_mixed_moe_gemm2( allocator_ping.finalize() n_in = arith.index_cast(ir.IndexType.get(), i32_n_in.ir_value()) - _tile_n_idx = arith.constant(tile_n, index=True) - _model_dim_pad_idx = arith.constant(model_dim_pad, index=True) - gx = (n_in - _model_dim_pad_idx + _tile_n_idx - arith.constant(1, index=True)) / _tile_n_idx + _tile_n_idx = fx.Int32(tile_n) + _model_dim_pad_idx = fx.Int32(model_dim_pad) + gx = (n_in - _model_dim_pad_idx + _tile_n_idx - fx.Int32(1)) / _tile_n_idx if const_expr(_persistent): - gy = arith.constant(_cu_num, index=True) + gy = fx.Int32(_cu_num) else: - _c_pm_l = arith.constant(persist_m, index=True) + _c_pm_l = fx.Int32(persist_m) gy = ( arith.index_cast(ir.IndexType.get(), i32_size_expert_ids_in.ir_value()) + _c_pm_l - - arith.constant(1, index=True) + - fx.Int32(1) ) / _c_pm_l moe_gemm2(