diff --git a/kernels/blockscale_preshuffle_gemm.py b/kernels/blockscale_preshuffle_gemm.py index 2371d9e81..648e06c54 100644 --- a/kernels/blockscale_preshuffle_gemm.py +++ b/kernels/blockscale_preshuffle_gemm.py @@ -203,12 +203,12 @@ def kernel_gemm( # ---- Wave / lane decomposition ---- wave_size = 64 layout_wave_lane = fx.make_layout((4, wave_size), (64, 1)) - coord_wave_lane = fx.idx2crd(tx, layout_wave_lane) + coord_wave_lane = fx.idx2crd(fx.Int32(tx), layout_wave_lane) wave_id = fx.get(coord_wave_lane, 0) lane_id = fx.get(coord_wave_lane, 1) layout_lane16 = fx.make_layout((4, 16), (16, 1)) - coord_lane16 = fx.idx2crd(lane_id, layout_lane16) + coord_lane16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = fx.get(coord_lane16, 0) lane_mod_16 = fx.get(coord_lane16, 1) @@ -252,8 +252,8 @@ def load_b_packs_k64(base_k, ku: int, ni: int): k0_base = base_k_bytes // c64_b k0 = k0_base + ku k1 = lane_div_16 - coord_pack = (n_blk_list[ni], k0, k1, n_intra_list[ni], fx.Index(0)) - idx_pack = crd2idx(coord_pack, layout_b) + coord_pack = (n_blk_list[ni], k0, k1, n_intra_list[ni], fx.Int32(0)) + idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b) b16 = _buffer_load_vec( buffer_ops, vector, diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index ee09dc7a4..5d5ba55f8 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -405,7 +405,7 @@ def kernel_mxscale_gemm( b_mcast_mask = 0 layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) - thr_coord = idx2crd(tx, layout_thr) + thr_coord = idx2crd(fx.Int32(tx), layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index ffc3530ad..b0dcdb7fc 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -720,7 +720,7 @@ def _load_norm_input_value(index): mean = sum_val / n_float var = sumsq_val / n_float - mean * mean var = (var < c_zero_f).select(c_zero_f, var) - rstd = (var + eps_c).rsqrt(fastmath=fm_fast) + rstd = fmath.rsqrt(var + eps_c, fastmath=fm_fast) thread_row_max = c_zero_f for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): diff --git a/kernels/layout_utils.py b/kernels/layout_utils.py index 976996c06..0adc68eb9 100644 --- a/kernels/layout_utils.py +++ b/kernels/layout_utils.py @@ -86,12 +86,15 @@ def idx2crd(idx, layout): """ parsed = _parse_layout(layout) + if hasattr(idx, "ir_value"): + idx = idx.ir_value() + if parsed is None or _has_dynamic_strides(parsed[1]): - result = fx.idx2crd(idx, layout) + result = fx.idx2crd(fx.Int32(idx), layout) ndims = len(parsed[1]) if parsed else 1 return [_wrap(fx.get(result, i)) for i in range(ndims)] - if hasattr(idx, "type") and str(idx.type) != "index": + if isinstance(idx, ir.Value) and not isinstance(idx.type, ir.IndexType): idx = arith.index_cast(T.index, idx) shapes, strides = parsed ndims = len(strides) @@ -156,9 +159,8 @@ def crd2idx(crd, layout): cv = raw crd_i32.append(cv) coord_val = fx.make_coord(*crd_i32) - result = fx.crd2idx(coord_val, layout) - scalar = fx.get_scalar(result) - if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType): + scalar = fx.get_scalar(fx.crd2idx(coord_val, layout)).ir_value() + if not isinstance(scalar.type, ir.IndexType): scalar = arith.index_cast(T.index, scalar) return _wrap(scalar) diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index 118ba6703..c556ee970 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -20,12 +20,11 @@ def crd2idx(crd, layout): - """crd2idx returning an index-type scalar (unwraps fly.int_tuple).""" - result = fx.crd2idx(crd, layout) - scalar = fx.get_scalar(result) - if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType): - scalar = _arith.IndexCastOp(T.index, scalar).result - return scalar + """crd2idx returning an index-typed ir.Value (unwraps fly.int_tuple).""" + scalar = fx.get_scalar(fx.crd2idx(crd, layout)).ir_value() + if isinstance(scalar.type, ir.IndexType): + return scalar + return _arith.IndexCastOp(T.index, scalar).result def swizzle_xor16(row, col, k_blocks16): @@ -326,7 +325,7 @@ def load_b_raw_w4a16( k2_base = lane_odd * fx.Index(half_bytes) coord_pack = (n_blk, k0, k1_local, n_intra, fx.Index(0)) - idx_pack = crd2idx(coord_pack, layout_b) + idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b) idx_bytes = idx_pack + k2_base b4 = _buffer_load_vec( @@ -464,7 +463,7 @@ def load_b_pack_k32( k2_base = arith.constant((ki_step % 2) * half_bytes, index=True) coord_pack = (n_blk, k0, k1, n_intra, fx.Index(0)) - idx_pack = crd2idx(coord_pack, layout_b) + idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b) if unpack_int4: idx_bytes = idx_pack + k2_base @@ -527,7 +526,7 @@ def tile_chunk_coord_i32( raise ValueError(f"chunk_i32 must be one of (1,2,4), got {chunk_i32!r}") chunk_off_i32 = arith.constant(i * total_threads * chunk_i32, index=True) tile_idx_i32 = tx_i32_base + chunk_off_i32 - coord_local = fx.idx2crd(tile_idx_i32, layout_tile_div4) + coord_local = fx.idx2crd(fx.Int32(tile_idx_i32), layout_tile_div4) row_local = fx.get(coord_local, 0) col_local_i32 = fx.get(coord_local, 1) return row_local, col_local_i32 @@ -580,7 +579,7 @@ def lds_store_16b_xor16( col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16) col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2 coord_store = (row_local, col_swz) - idx0 = crd2idx(coord_store, layout_lds) + lds_base + idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base v16 = vector.bitcast(vec16_ty, vec_part_i32x4) vector.store(v16, lds_memref, [idx0]) @@ -607,7 +606,7 @@ def lds_store_8b_xor16( col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16) col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2 coord_store = (row_local, col_swz) - idx0 = crd2idx(coord_store, layout_lds) + lds_base + idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base v8 = vector.bitcast(vec8_ty, vec_part_i32x2) vector.store(v8, lds_memref, [idx0]) @@ -634,7 +633,7 @@ def lds_store_4b_xor16( col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16) col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2 coord_store = (row_local, col_swz) - idx0 = crd2idx(coord_store, layout_lds) + lds_base + idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base v4 = vector.bitcast(vec4_ty, vec_part_i32x1) vector.store(v4, lds_memref, [idx0]) @@ -660,14 +659,14 @@ def lds_load_pack_k32( col_base_swz = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16) if ck_lds128: coord_a16 = (curr_row_a_lds, col_base_swz) - idx_a16 = crd2idx(coord_a16, layout_lds) + lds_base + idx_a16 = crd2idx(tuple(fx.Int32(c) for c in coord_a16), layout_lds) + lds_base loaded_a16 = vector.load_op(vec16_ty, lds_memref, [idx_a16]) a_vec128 = vector.bitcast(vec2_i64_ty, loaded_a16) return vector.extract(a_vec128, static_position=[half], dynamic_position=[]) else: col_swizzled = col_base_swz + (half * 8) coord_a = (curr_row_a_lds, col_swizzled) - idx_a = crd2idx(coord_a, layout_lds) + lds_base + idx_a = crd2idx(tuple(fx.Int32(c) for c in coord_a), layout_lds) + lds_base loaded_a8 = vector.load_op(vec8_ty, lds_memref, [idx_a]) a_vec64 = vector.bitcast(vec1_i64_ty, loaded_a8) return vector.extract(a_vec64, static_position=[0], dynamic_position=[]) diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index 5a7f3c24a..712291931 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -729,10 +729,10 @@ def load_x_tile(base_k): return parts # Wave/lane decomposition (identical to stage2) - coord_wl = idx2crd(tx, layout_tx_wave_lane) + coord_wl = idx2crd(fx.Int32(tx), layout_tx_wave_lane) wave_id = layout_get(coord_wl, 0) lane_id = layout_get(coord_wl, 1) - coord_l16 = idx2crd(lane_id, layout_lane16) + coord_l16 = idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = layout_get(coord_l16, 0) lane_mod_16 = layout_get(coord_l16, 1) row_a_lds = lane_mod_16 @@ -763,12 +763,12 @@ def load_x_tile(base_k): global_n = by_n + n_tile_base + c_offset + lane_mod_16 # Gate/interleave: rows [expert_off, expert_off + 2*inter_dim) gate_row_w = expert_off_idx + global_n - gate_coord = idx2crd(gate_row_w, layout_n_blk_intra) + gate_coord = idx2crd(fx.Int32(gate_row_w), layout_n_blk_intra) gate_n_blk_list.append(layout_get(gate_coord, 0)) gate_n_intra_list.append(layout_get(gate_coord, 1)) if const_expr(not mock_gate_only and not gate_up_interleave): up_row_w = gate_row_w + inter_idx - up_coord = idx2crd(up_row_w, layout_n_blk_intra) + up_coord = idx2crd(fx.Int32(up_row_w), layout_n_blk_intra) up_n_blk_list.append(layout_get(up_coord, 0)) up_n_intra_list.append(layout_get(up_coord, 1)) @@ -799,7 +799,7 @@ def load_b_packs_k64(base_k, ku: int, n_blk, n_intra): k0 = base_k_bytes // c64 + arith.constant(ku, index=True) k1 = lane_div_16 coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True)) - idx_pack = crd2idx(coord_pack, layout_b) + idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b) vec_elems = kpack_bytes // int(b_elem_bytes) b16 = _buffer_load_vec( buffer_ops, @@ -1015,7 +1015,7 @@ def prefetch_x_to_lds(base_k, lds_buffer): def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16) col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2)) - idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds) + idx_a16 = crd2idx([fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)], layout_lds) loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16]) a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) @@ -3074,10 +3074,10 @@ def load_x_tile(base_k): return parts # tx -> wave/lane (GEMM-style decomposition). - coord_wl = idx2crd(tx, layout_tx_wave_lane) + coord_wl = idx2crd(fx.Int32(tx), layout_tx_wave_lane) wave_id = layout_get(coord_wl, 0) lane_id = layout_get(coord_wl, 1) - coord_l16 = idx2crd(lane_id, layout_lane16) + coord_l16 = idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = layout_get(coord_l16, 0) lane_mod_16 = layout_get(coord_l16, 1) @@ -3330,7 +3330,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16) col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2)) - idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds) + idx_a16 = crd2idx([fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)], layout_lds) loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16]) a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) diff --git a/kernels/moe_blockscale_2stage.py b/kernels/moe_blockscale_2stage.py index 2c5eb635b..c257bad18 100644 --- a/kernels/moe_blockscale_2stage.py +++ b/kernels/moe_blockscale_2stage.py @@ -466,10 +466,10 @@ def load_x_tile(base_k, x_load_bytes_v): return parts # tx -> wave/lane (GEMM-style decomposition). - coord_wl = fx.idx2crd(tx, layout_tx_wave_lane) + coord_wl = fx.idx2crd(fx.Int32(tx), layout_tx_wave_lane) wave_id = fx.get(coord_wl, 0) lane_id = fx.get(coord_wl, 1) - coord_l16 = fx.idx2crd(lane_id, layout_lane16) + coord_l16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = fx.get(coord_l16, 0) lane_mod_16 = fx.get(coord_l16, 1) @@ -511,11 +511,11 @@ def load_x_tile(base_k, x_load_bytes_v): row_gate = expert_off_idx + col_g row_up = row_gate + inter_idx - coord_gate = fx.idx2crd(row_gate, layout_n_blk_intra) + coord_gate = fx.idx2crd(fx.Int32(row_gate), layout_n_blk_intra) n_blk_gate.append(fx.get(coord_gate, 0)) n_intra_gate.append(fx.get(coord_gate, 1)) - coord_up = fx.idx2crd(row_up, layout_n_blk_intra) + coord_up = fx.idx2crd(fx.Int32(row_up), layout_n_blk_intra) n_blk_up.append(fx.get(coord_up, 0)) n_intra_up.append(fx.get(coord_up, 1)) @@ -620,7 +620,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): col_base_swz = ( col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes // arith.index(int(elem_bytes))) ) - idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds) + idx_a16 = crd2idx((fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)), layout_lds) idx_a16 = idx_a16 + lds_base loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) a_i64x2 = vector.bitcast(T.i64x2, loaded_a16) @@ -1604,10 +1604,10 @@ def load_x_tile(base_k, x_load_bytes_v): return parts # tx -> wave/lane (GEMM-style decomposition). - coord_wl = fx.idx2crd(tx, layout_tx_wave_lane) + coord_wl = fx.idx2crd(fx.Int32(tx), layout_tx_wave_lane) wave_id = fx.get(coord_wl, 0) lane_id = fx.get(coord_wl, 1) - coord_l16 = fx.idx2crd(lane_id, layout_lane16) + coord_l16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = fx.get(coord_l16, 0) lane_mod_16 = fx.get(coord_l16, 1) @@ -1640,7 +1640,7 @@ def load_x_tile(base_k, x_load_bytes_v): col_g_list.append(col_g) row_w = expert_off_idx + col_g - coord_w = fx.idx2crd(row_w, layout_n_blk_intra) + coord_w = fx.idx2crd(fx.Int32(row_w), layout_n_blk_intra) n_blk_list.append(fx.get(coord_w, 0)) n_intra_list.append(fx.get(coord_w, 1)) @@ -1742,7 +1742,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): col_base_swz = ( col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes // arith.index(int(elem_bytes))) ) - idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds) + idx_a16 = crd2idx((fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)), layout_lds) idx_a16 = idx_a16 + lds_base loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) a_i64x2 = vector.bitcast(T.i64x2, loaded_a16) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 1402ffada..57769c7d2 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -598,10 +598,10 @@ def load_x_tile(base_k): return parts # tx -> wave/lane (GEMM-style decomposition). - coord_wl = fx.idx2crd(tx, layout_tx_wave_lane) + coord_wl = fx.idx2crd(fx.Int32(tx), layout_tx_wave_lane) wave_id = fx.get(coord_wl, 0) lane_id = fx.get(coord_wl, 1) - coord_l16 = fx.idx2crd(lane_id, layout_lane16) + coord_l16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = fx.get(coord_l16, 0) lane_mod_16 = fx.get(coord_l16, 1) @@ -644,11 +644,11 @@ def load_x_tile(base_k): row_gate = expert_off_idx + col_g row_up = row_gate + inter_idx - coord_gate = fx.idx2crd(row_gate, layout_n_blk_intra) + coord_gate = fx.idx2crd(fx.Int32(row_gate), layout_n_blk_intra) n_blk_gate.append(fx.get(coord_gate, 0)) n_intra_gate.append(fx.get(coord_gate, 1)) - coord_up = fx.idx2crd(row_up, layout_n_blk_intra) + coord_up = fx.idx2crd(fx.Int32(row_up), layout_n_blk_intra) n_blk_up.append(fx.get(coord_up, 0)) n_intra_up.append(fx.get(coord_up, 1)) @@ -811,7 +811,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): col_base_swz = ( col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes // arith.index(int(elem_bytes))) ) - idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds) + idx_a16 = crd2idx((fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)), layout_lds) idx_a16 = idx_a16 + lds_base loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) a_i64x2 = vector.bitcast(T.i64x2, loaded_a16) @@ -2256,10 +2256,10 @@ def load_x_tile(base_k): return parts # tx -> wave/lane (GEMM-style decomposition). - coord_wl = fx.idx2crd(tx, layout_tx_wave_lane) + coord_wl = fx.idx2crd(fx.Int32(tx), layout_tx_wave_lane) wave_id = fx.get(coord_wl, 0) lane_id = fx.get(coord_wl, 1) - coord_l16 = fx.idx2crd(lane_id, layout_lane16) + coord_l16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = fx.get(coord_l16, 0) lane_mod_16 = fx.get(coord_l16, 1) @@ -2293,7 +2293,7 @@ def load_x_tile(base_k): col_g_list.append(col_g) row_w = expert_off_idx + col_g - coord_w = fx.idx2crd(row_w, layout_n_blk_intra) + coord_w = fx.idx2crd(fx.Int32(row_w), layout_n_blk_intra) n_blk_list.append(fx.get(coord_w, 0)) n_intra_list.append(fx.get(coord_w, 1)) @@ -2453,7 +2453,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): col_base_swz = ( col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes // arith.index(int(elem_bytes))) ) - idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds) + idx_a16 = crd2idx((fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)), layout_lds) idx_a16 = idx_a16 + lds_base loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) a_i64x2 = vector.bitcast(T.i64x2, loaded_a16) diff --git a/kernels/moe_gemm_2stage_mxscale_gfx1250.py b/kernels/moe_gemm_2stage_mxscale_gfx1250.py index 5cb14c60f..f2a013f93 100644 --- a/kernels/moe_gemm_2stage_mxscale_gfx1250.py +++ b/kernels/moe_gemm_2stage_mxscale_gfx1250.py @@ -394,7 +394,7 @@ def moe_mxscale_stage1_single( block_ok = arith.andi(block_in_valid, arith.andi(eid_ok0, eid_ok1)) layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx) - thr_coord = idx2crd(tx, layout_thr) + thr_coord = idx2crd(fx.Int32(tx), layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), fx.get(thr_coord, 2), fx.get(thr_coord, 3) ) @@ -2560,7 +2560,7 @@ def moe_mxscale_stage2_single( block_ok = arith.andi(block_in_valid, arith.andi(eid_ok0, eid_ok1)) layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx) - thr_coord = idx2crd(tx, layout_thr) + thr_coord = idx2crd(fx.Int32(tx), layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), fx.get(thr_coord, 2), fx.get(thr_coord, 3) ) diff --git a/kernels/moe_gemm_2stage_wmma_gfx1250.py b/kernels/moe_gemm_2stage_wmma_gfx1250.py index ebed9aa5d..06cc63dd4 100644 --- a/kernels/moe_gemm_2stage_wmma_gfx1250.py +++ b/kernels/moe_gemm_2stage_wmma_gfx1250.py @@ -149,7 +149,7 @@ def moe_fp16_stage1_single( eid_ok = arith.andi(eid_ok0, eid_ok1) layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx) - thr_coord = idx2crd(tx, layout_thr) + thr_coord = idx2crd(fx.Int32(tx), layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), @@ -556,7 +556,7 @@ def moe_fp16_stage2_single( block_ok = arith.andi(block_in_valid, arith.andi(eid_ok0, eid_ok1)) layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx) - thr_coord = idx2crd(tx, layout_thr) + thr_coord = idx2crd(fx.Int32(tx), layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index dedd3ac86..31947b723 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -441,12 +441,12 @@ def kernel_gemm( # ---- Wave / lane decomposition ---- wave_size = 64 layout_wave_lane = fx.make_layout((4, wave_size), (64, 1)) - coord_wave_lane = fx.idx2crd(tx, layout_wave_lane) + coord_wave_lane = fx.idx2crd(fx.Int32(tx), layout_wave_lane) wave_id = fx.get(coord_wave_lane, 0) lane_id = fx.get(coord_wave_lane, 1) layout_lane16 = fx.make_layout((4, 16), (16, 1)) - coord_lane16 = fx.idx2crd(lane_id, layout_lane16) + coord_lane16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16) lane_div_16 = fx.get(coord_lane16, 0) lane_mod_16 = fx.get(coord_lane16, 1) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index ce4bd0a98..66235110a 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -214,7 +214,7 @@ def block_reduce_add2(val0, val1): _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) mean_sq = sum_sq / n_float ms_eps = mean_sq + eps_c - rrms = ms_eps.rsqrt(fastmath=fm_fast) + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) # Pass 2: normalize + gamma + store (reuse cached input) for tile_i in range_constexpr(num_tiles): @@ -517,7 +517,7 @@ def block_reduce_add2(val0, val1): _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) mean_sq = sum_sq / n_float ms_eps = mean_sq + eps_c - rrms = ms_eps.rsqrt(fastmath=fm_fast) + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) # Pass 2: normalize + gamma + store (reuse cached added values) for tile_i in range_constexpr(num_tiles): @@ -790,7 +790,7 @@ def block_reduce_max(val): _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) mean_sq = sum_sq / n_float ms_eps = mean_sq + eps_c - rrms = ms_eps.rsqrt(fastmath=fm_fast) + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) thread_row_max = c_zero_f y_local = [] @@ -1176,7 +1176,7 @@ def block_reduce_max(val): _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) mean_sq = sum_sq / n_float ms_eps = mean_sq + eps_c - rrms = ms_eps.rsqrt(fastmath=fm_fast) + rrms = fmath.rsqrt(ms_eps, fastmath=fm_fast) thread_row_max = c_zero_f y_local = [] diff --git a/kernels/wmma_gemm_gfx1250.py b/kernels/wmma_gemm_gfx1250.py index 51115078e..a2e05154f 100644 --- a/kernels/wmma_gemm_gfx1250.py +++ b/kernels/wmma_gemm_gfx1250.py @@ -250,7 +250,7 @@ def kernel_wmma_gemm_tdm( # --- Thread/wave decomposition --- layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) - thr_coord = idx2crd(tx, layout_thr) + thr_coord = idx2crd(fx.Int32(tx), layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index 8282b0e4a..e41b40ad6 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -14,7 +14,7 @@ from .._mlir import ir from .._mlir.dialects import arith, scf from ..expr import const_expr -from ..expr.numeric import _unwrap_value, _wrap_like +from ..expr.typing import as_dsl_value, as_ir_value from ..utils import env, log @@ -496,7 +496,7 @@ def _is_dynamic(cond): @staticmethod def _to_i1(cond): - return _unwrap_value(cond) + return as_ir_value(cond) @staticmethod def _normalize_named_values(names, values, names_label="names", values_label="values"): @@ -542,7 +542,7 @@ def _normalize_branch_result(branch_result, state_names, state_map, branch_label def _unwrap_mlir_values(values, state_names, branch_label): raw_values = [] for name, value in zip(state_names, values): - raw = _unwrap_value(value) + raw = as_ir_value(value) if not isinstance(raw, ir.Value): raise TypeError( f"if/else variable '{name}' in {branch_label} is {type(raw).__name__}, " @@ -555,7 +555,7 @@ def _unwrap_mlir_values(values, state_names, branch_label): def _pack_dispatch_results(results, state_values): if not results: return None - wrapped = [_wrap_like(v, exemplar) for v, exemplar in zip(results, state_values)] + wrapped = [as_dsl_value(v, exemplar) for v, exemplar in zip(results, state_values)] if len(wrapped) == 1: return wrapped[0] return tuple(wrapped) @@ -622,7 +622,7 @@ def scf_if_dispatch( if not isinstance(cond_i1, ir.Value): raise TypeError(f"dynamic if condition must lower to ir.Value, got {type(cond_i1).__name__}") - none_vars = [name for name, value in zip(result_names, result_values) if _unwrap_value(value) is None] + none_vars = [name for name, value in zip(result_names, result_values) if as_ir_value(value) is None] if none_vars: raise TypeError( f"Variable(s) {none_vars} initialized as None before a dynamic " @@ -652,7 +652,7 @@ def scf_if_dispatch( state_raw = [] for name, value in zip(result_names, result_values): - raw = _unwrap_value(value) + raw = as_ir_value(value) if not isinstance(raw, ir.Value): raise TypeError( f"state variable '{name}' is {type(raw).__name__}, not an MLIR Value; " @@ -881,9 +881,9 @@ def scf_ifexp_dispatch(cond, then_fn, else_fn): sandbox.region.blocks.append() with ir.InsertionPoint(sandbox.region.blocks[0]): probe_then = then_fn() - probe_then_raw = _unwrap_value(probe_then) + probe_then_raw = as_ir_value(probe_then) probe_else = else_fn() - probe_else_raw = _unwrap_value(probe_else) + probe_else_raw = as_ir_value(probe_else) if not isinstance(probe_then_raw, ir.Value): raise TypeError( f"dynamic ifexp then-branch must produce an MLIR Value, " f"got {type(probe_then_raw).__name__}" @@ -902,14 +902,14 @@ def scf_ifexp_dispatch(cond, then_fn, else_fn): op = scf.IfOp(cond_i1, [yield_type], has_else=True, loc=ir.Location.unknown()) with ir.InsertionPoint(op.regions[0].blocks[0]): - scf.YieldOp([_unwrap_value(then_fn())]) + scf.YieldOp([as_ir_value(then_fn())]) if len(op.regions[1].blocks) == 0: op.regions[1].blocks.append() with ir.InsertionPoint(op.regions[1].blocks[0]): - scf.YieldOp([_unwrap_value(else_fn())]) + scf.YieldOp([as_ir_value(else_fn())]) sandbox.operation.erase() - return _wrap_like(op.results[0], probe_then) + return as_dsl_value(op.results[0], probe_then) @ASTRewriter.register @@ -942,7 +942,7 @@ def scf_range(start, stop=None, step=None, *, init=None): stop_val = InsertEmptyYieldForSCFFor._to_index(stop) step_val = InsertEmptyYieldForSCFFor._to_index(step) if init is not None: - init = [_unwrap_value(v) for v in init] + init = [as_ir_value(v) for v in init] for_op = scf.ForOp(start_val, stop_val, step_val, init) with ir.InsertionPoint(for_op.body): yield for_op.induction_variable, list(for_op.inner_iter_args) @@ -953,9 +953,9 @@ def scf_range(start, stop=None, step=None, *, init=None): @staticmethod def scf_for_dispatch(start, stop, step, body_fn, *, result_names=(), result_values=()): - start_val = _unwrap_value(start) - stop_val = _unwrap_value(stop) - step_val = _unwrap_value(step) + start_val = as_ir_value(start) + stop_val = as_ir_value(stop) + step_val = as_ir_value(step) i32_ty = ir.IntegerType.get_signless(32) idx_ty = ir.IndexType.get() @@ -974,7 +974,7 @@ def scf_for_dispatch(start, stop, step, body_fn, *, result_names=(), result_valu result_values = tuple(result_values) result_map = {name: value for name, value in zip(result_names, result_values)} - none_vars = [name for name, value in zip(result_names, result_values) if _unwrap_value(value) is None] + none_vars = [name for name, value in zip(result_names, result_values) if as_ir_value(value) is None] if none_vars: raise TypeError( f"Variable(s) {none_vars} initialized as None before a dynamic " @@ -994,7 +994,7 @@ def scf_for_dispatch(start, stop, step, body_fn, *, result_names=(), result_valu state_raw = [] for name, value in zip(result_names, result_values): - raw = _unwrap_value(value) + raw = as_ir_value(value) if not isinstance(raw, ir.Value): raise TypeError( f"for-loop variable '{name}' is {type(raw).__name__}, not an MLIR Value; " @@ -1006,7 +1006,7 @@ def scf_for_dispatch(start, stop, step, body_fn, *, result_names=(), result_valu with ir.InsertionPoint(for_op.body): iv = for_op.induction_variable - inner_args = [_wrap_like(a, ex) for a, ex in zip(for_op.inner_iter_args, result_values)] + inner_args = [as_dsl_value(a, ex) for a, ex in zip(for_op.inner_iter_args, result_values)] body_result = body_fn(iv, result_names, *inner_args) @@ -1305,7 +1305,7 @@ def scf_while_dispatch(before_fn, after_fn, *, result_names=(), result_values=() ) result_map = {name: value for name, value in zip(result_names, result_values)} - none_vars = [name for name, value in zip(result_names, result_values) if _unwrap_value(value) is None] + none_vars = [name for name, value in zip(result_names, result_values) if as_ir_value(value) is None] if none_vars: raise TypeError( f"Variable(s) {none_vars} initialized as None before a dynamic " @@ -1318,7 +1318,7 @@ def scf_while_dispatch(before_fn, after_fn, *, result_names=(), result_values=() state_raw = [] for name, value in zip(result_names, result_values): - raw = _unwrap_value(value) + raw = as_ir_value(value) if not isinstance(raw, ir.Value): raise TypeError( f"while-loop variable '{name}' is {type(raw).__name__}, not an MLIR Value; " @@ -1333,7 +1333,7 @@ def scf_while_dispatch(before_fn, after_fn, *, result_names=(), result_values=() with ir.InsertionPoint(while_op.regions[0].blocks[0]): before_args = list(while_op.regions[0].blocks[0].arguments) - wrapped_before = [_wrap_like(a, ex) for a, ex in zip(before_args, result_values)] if result_names else [] + wrapped_before = [as_dsl_value(a, ex) for a, ex in zip(before_args, result_values)] if result_names else [] before_cond = ReplaceIfWithDispatch._call_branch(before_fn, result_names, wrapped_before) cond_i1 = ReplaceIfWithDispatch._to_i1(before_cond) if not isinstance(cond_i1, ir.Value): @@ -1342,7 +1342,7 @@ def scf_while_dispatch(before_fn, after_fn, *, result_names=(), result_values=() with ir.InsertionPoint(while_op.regions[1].blocks[0]): after_args = list(while_op.regions[1].blocks[0].arguments) - wrapped_after = [_wrap_like(a, ex) for a, ex in zip(after_args, result_values)] if result_names else [] + wrapped_after = [as_dsl_value(a, ex) for a, ex in zip(after_args, result_values)] if result_names else [] body_result = ReplaceIfWithDispatch._call_branch(after_fn, result_names, wrapped_after) if result_names: body_values = ReplaceIfWithDispatch._normalize_branch_result( diff --git a/python/flydsl/expr/__init__.py b/python/flydsl/expr/__init__.py index b630550b3..f904c5e9f 100644 --- a/python/flydsl/expr/__init__.py +++ b/python/flydsl/expr/__init__.py @@ -7,6 +7,8 @@ from .gpu import * from .derived import * from .struct import * +from .arith import * +from .math import * from . import utils as utils from . import arith as arith diff --git a/python/flydsl/expr/arith.py b/python/flydsl/expr/arith.py index c04998c2e..fa37e61d6 100644 --- a/python/flydsl/expr/arith.py +++ b/python/flydsl/expr/arith.py @@ -15,6 +15,25 @@ from .._mlir.dialects.arith import * # noqa: F401,F403 +__all__ = [ + "ArithValue", # Deprecated: will be removed in a future release + "_to_raw", # Deprecated: will be removed in a future release + "andi", + "constant", + "constant_vector", + "index", # Deprecated: will be removed in a future release + "index_cast", # Deprecated: will be removed in a future release + "int_to_fp", + "select", + "shli", + "sitofp", + "trunc_f", + "unwrap", # Deprecated: will be removed in a future release + "xori", + "cmpi", + "cmpf", +] + # Override star-import cmpi/cmpf to accept Numeric types (Int32, etc.) from .._mlir.dialects import arith as _mlir_arith from .meta import traced_op diff --git a/python/flydsl/expr/derived.py b/python/flydsl/expr/derived.py index c3b739bca..8576249c6 100644 --- a/python/flydsl/expr/derived.py +++ b/python/flydsl/expr/derived.py @@ -93,8 +93,8 @@ def make_rmem_tensor(shape_or_layout, dtype, *, loc=None, ip=None): tensor = make_rmem_tensor(8, fx.Float32) tensor = make_rmem_tensor(make_layout(4, 1), fx.Float16) """ - if not issubclass(dtype, Numeric): - raise TypeError(f"dtype must be a Numeric type, but got {type(dtype)}") + if not (isinstance(dtype, type) and issubclass(dtype, Numeric)): + raise TypeError(f"dtype must be a Numeric subclass, but got {dtype!r}") elem_ty = dtype.ir_type if dtype is not Boolean else Int8.ir_type if not isinstance(shape_or_layout, Layout): diff --git a/python/flydsl/expr/extern.py b/python/flydsl/expr/extern.py index 1269e8425..acf86ce33 100644 --- a/python/flydsl/expr/extern.py +++ b/python/flydsl/expr/extern.py @@ -20,6 +20,7 @@ DenseI32ArrayAttr, FlatSymbolRefAttr, InsertionPoint, + IntegerAttr, IntegerType, TypeAttr, ) @@ -121,16 +122,18 @@ def __call__(self, *args: Any) -> Any: if len(args) != len(arg_types): raise TypeError(f"ffi {self.symbol!r} expects {len(arg_types)} argument(s), got {len(args)}") - from .._mlir.dialects import llvm as _llvm - from .._mlir.ir import IntegerAttr + from .numeric import Numeric raw_args: List[ir.Value] = [] for arg_pos, arg in enumerate(args): expected_type = arg_types[arg_pos] + if isinstance(arg, Numeric) and isinstance(arg.value, (bool, int)): + arg = int(arg.value) + if isinstance(arg, int): target_type = expected_type or IntegerType.get_signless(64) - raw_args.append(_llvm.ConstantOp(target_type, IntegerAttr.get(target_type, arg)).result) + raw_args.append(llvm.ConstantOp(target_type, IntegerAttr.get(target_type, arg)).result) continue if isinstance(arg, ir.Value): diff --git a/python/flydsl/expr/gpu.py b/python/flydsl/expr/gpu.py index 0686048a1..2aafbaa16 100644 --- a/python/flydsl/expr/gpu.py +++ b/python/flydsl/expr/gpu.py @@ -20,7 +20,7 @@ from .._mlir.dialects import gpu from .._mlir.dialects._fly_enum_gen import AddressSpace from ..compiler.protocol import dsl_align_of, dsl_size_of -from .numeric import Uint8 +from .numeric import Numeric, Uint8 from .primitive import get_dyn_shared, make_ptr from .struct import ( Arena, @@ -104,6 +104,8 @@ def base_ptr(self): return self._base def allocate(self, storable_or_int, alignment=None): + if isinstance(storable_or_int, Numeric) and not isinstance(storable_or_int.value, ir.Value): + storable_or_int = int(storable_or_int.value) if not self._static: return super().allocate(storable_or_int, alignment) return self._allocate_static(storable_or_int, alignment) diff --git a/python/flydsl/expr/math.py b/python/flydsl/expr/math.py index c4d128c2f..e62613fb6 100644 --- a/python/flydsl/expr/math.py +++ b/python/flydsl/expr/math.py @@ -1,53 +1,100 @@ # SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 FlyDSL Project Contributors +# Copyright (c) 2026 FlyDSL Project Contributors -"""Math dialect API — DSL-friendly wrappers with traced locations and auto-unwrap. +"""Math dialect API — thin DSL wrappers over the MLIR ``math`` dialect. Usage: - from flydsl.expr import math + import flydsl.expr as fx - y = math.exp(x) - y = math.sqrt(x, fastmath="fast") - y = math.fma(a, b, c) - pred = math.isnan(x) + y = fx.exp(x) + y = fx.sqrt(x, fastmath="fast") + y = fx.fma(a, b, c) + pred = fx.isnan(x) """ from functools import wraps from .._mlir import ir -from .._mlir.dialects import math as _mlir_math -from .._mlir.dialects.math import * # noqa: F401,F403 -from .meta import _caller_location, _flatten_args +from .._mlir.dialects import math +from .meta import dsl_loc_tracing from .numeric import Numeric -from .utils.arith import _to_raw - - -def _traced_math_op(fn): - """Like @traced_op, but re-wraps results to preserve Numeric class hierarchy. - - If the first positional arg is a Numeric (Float32, Int32, …), the MLIR - result is wrapped back into the appropriate Numeric subclass via - ``Numeric.from_ir_type``. Raw ir.Value inputs pass through unchanged. - """ - +from .typing import as_ir_value + +__all__ = [ + "absf", + "ceil", + "floor", + "trunc", + "round", + "roundeven", + "exp", + "exp2", + "expm1", + "log", + "log2", + "log10", + "log1p", + "sqrt", + "rsqrt", + "cbrt", + "sin", + "cos", + "tan", + "asin", + "acos", + "atan", + "sinh", + "cosh", + "tanh", + "asinh", + "acosh", + "atanh", + "erf", + "erfc", + "sincos", + "absi", + "ctlz", + "cttz", + "ctpop", + "powf", + "fpowi", + "ipowi", + "atan2", + "copysign", + "fma", + "clampf", + "isnan", + "isinf", + "isfinite", + "isnormal", +] + + +def dsl_math_wrap_result(fn): @wraps(fn) def wrapper(*args, **kwargs): + from .typing import Vector + first = args[0] if args else None - do_rewrap = isinstance(first, Numeric) + is_vector = isinstance(first, Vector) + is_numeric = isinstance(first, Numeric) + + result = fn(*args, **kwargs) + + if not (is_vector or is_numeric): + return tuple(result) if not isinstance(result, ir.Value) and hasattr(result, "__iter__") else result - loc = kwargs.pop("loc", None) - if loc is None: - loc = _caller_location(depth=1) - args, kwargs = _flatten_args(args, kwargs) - with loc: - result = fn(*args, **kwargs) + def dsl_wrap(value): + if not isinstance(value, ir.Value): + return value + if is_vector: + elem_dtype = Numeric.from_ir_type(ir.VectorType(value.type).element_type) + return Vector(value, first.shape, elem_dtype) + return Numeric.from_ir_type(value.type)(value) - if not do_rewrap: - return result if isinstance(result, ir.Value): - return Numeric.from_ir_type(result.type)(result) - # Multi-result (e.g. sincos) - return tuple(Numeric.from_ir_type(r.type)(r) for r in result) + return dsl_wrap(result) + return tuple(dsl_wrap(r) for r in result) return wrapper @@ -57,154 +104,184 @@ def wrapper(*args, **kwargs): # --------------------------------------------------------------------------- -@_traced_math_op -def absf(x, *, fastmath=None, **kw): - return _mlir_math.absf(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def absf(x, *, fastmath=None, **kwargs): + return math.absf(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def ceil(x, *, fastmath=None, **kw): - return _mlir_math.ceil(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def ceil(x, *, fastmath=None, **kwargs): + return math.ceil(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def floor(x, *, fastmath=None, **kw): - return _mlir_math.floor(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def floor(x, *, fastmath=None, **kwargs): + return math.floor(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def trunc(x, *, fastmath=None, **kw): - return _mlir_math.trunc(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def trunc(x, *, fastmath=None, **kwargs): + return math.trunc(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def round(x, *, fastmath=None, **kw): - return _mlir_math.round(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def round(x, *, fastmath=None, **kwargs): + return math.round(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def roundeven(x, *, fastmath=None, **kw): - return _mlir_math.roundeven(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def roundeven(x, *, fastmath=None, **kwargs): + return math.roundeven(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def exp(x, *, fastmath=None, **kw): - return _mlir_math.exp(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def exp(x, *, fastmath=None, **kwargs): + return math.exp(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def exp2(x, *, fastmath=None, **kw): - return _mlir_math.exp2(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def exp2(x, *, fastmath=None, **kwargs): + return math.exp2(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def expm1(x, *, fastmath=None, **kw): - return _mlir_math.expm1(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def expm1(x, *, fastmath=None, **kwargs): + return math.expm1(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def log(x, *, fastmath=None, **kw): - return _mlir_math.log(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def log(x, *, fastmath=None, **kwargs): + return math.log(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def log2(x, *, fastmath=None, **kw): - return _mlir_math.log2(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def log2(x, *, fastmath=None, **kwargs): + return math.log2(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def log10(x, *, fastmath=None, **kw): - return _mlir_math.log10(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def log10(x, *, fastmath=None, **kwargs): + return math.log10(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def log1p(x, *, fastmath=None, **kw): - return _mlir_math.log1p(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def log1p(x, *, fastmath=None, **kwargs): + return math.log1p(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def sqrt(x, *, fastmath=None, **kw): - return _mlir_math.sqrt(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def sqrt(x, *, fastmath=None, **kwargs): + return math.sqrt(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def rsqrt(x, *, fastmath=None, **kw): - return _mlir_math.rsqrt(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def rsqrt(x, *, fastmath=None, **kwargs): + return math.rsqrt(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def cbrt(x, *, fastmath=None, **kw): - return _mlir_math.cbrt(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def cbrt(x, *, fastmath=None, **kwargs): + return math.cbrt(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def sin(x, *, fastmath=None, **kw): - return _mlir_math.sin(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def sin(x, *, fastmath=None, **kwargs): + return math.sin(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def cos(x, *, fastmath=None, **kw): - return _mlir_math.cos(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def cos(x, *, fastmath=None, **kwargs): + return math.cos(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def tan(x, *, fastmath=None, **kw): - return _mlir_math.tan(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def tan(x, *, fastmath=None, **kwargs): + return math.tan(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def asin(x, *, fastmath=None, **kw): - return _mlir_math.asin(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def asin(x, *, fastmath=None, **kwargs): + return math.asin(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def acos(x, *, fastmath=None, **kw): - return _mlir_math.acos(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def acos(x, *, fastmath=None, **kwargs): + return math.acos(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def atan(x, *, fastmath=None, **kw): - return _mlir_math.atan(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def atan(x, *, fastmath=None, **kwargs): + return math.atan(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def sinh(x, *, fastmath=None, **kw): - return _mlir_math.sinh(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def sinh(x, *, fastmath=None, **kwargs): + return math.sinh(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def cosh(x, *, fastmath=None, **kw): - return _mlir_math.cosh(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def cosh(x, *, fastmath=None, **kwargs): + return math.cosh(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def tanh(x, *, fastmath=None, **kw): - return _mlir_math.tanh(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def tanh(x, *, fastmath=None, **kwargs): + return math.tanh(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def asinh(x, *, fastmath=None, **kw): - return _mlir_math.asinh(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def asinh(x, *, fastmath=None, **kwargs): + return math.asinh(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def acosh(x, *, fastmath=None, **kw): - return _mlir_math.acosh(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def acosh(x, *, fastmath=None, **kwargs): + return math.acosh(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def atanh(x, *, fastmath=None, **kw): - return _mlir_math.atanh(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def atanh(x, *, fastmath=None, **kwargs): + return math.atanh(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def erf(x, *, fastmath=None, **kw): - return _mlir_math.erf(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def erf(x, *, fastmath=None, **kwargs): + return math.erf(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def erfc(x, *, fastmath=None, **kw): - return _mlir_math.erfc(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def erfc(x, *, fastmath=None, **kwargs): + return math.erfc(as_ir_value(x), fastmath=fastmath, **kwargs) # --------------------------------------------------------------------------- @@ -212,10 +289,11 @@ def erfc(x, *, fastmath=None, **kw): # --------------------------------------------------------------------------- -@_traced_math_op -def sincos(x, *, fastmath=None, **kw): +@dsl_loc_tracing +@dsl_math_wrap_result +def sincos(x, *, fastmath=None, **kwargs): """Simultaneous sin and cos. Returns ``(sin(x), cos(x))``.""" - return _mlir_math.sincos(_to_raw(x), fastmath=fastmath, **kw) + return math.sincos(as_ir_value(x), fastmath=fastmath, **kwargs) # --------------------------------------------------------------------------- @@ -223,24 +301,28 @@ def sincos(x, *, fastmath=None, **kw): # --------------------------------------------------------------------------- -@_traced_math_op -def absi(x, **kw): - return _mlir_math.absi(_to_raw(x), **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def absi(x, **kwargs): + return math.absi(as_ir_value(x), **kwargs) -@_traced_math_op -def ctlz(x, **kw): - return _mlir_math.ctlz(_to_raw(x), **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def ctlz(x, **kwargs): + return math.ctlz(as_ir_value(x), **kwargs) -@_traced_math_op -def cttz(x, **kw): - return _mlir_math.cttz(_to_raw(x), **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def cttz(x, **kwargs): + return math.cttz(as_ir_value(x), **kwargs) -@_traced_math_op -def ctpop(x, **kw): - return _mlir_math.ctpop(_to_raw(x), **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def ctpop(x, **kwargs): + return math.ctpop(as_ir_value(x), **kwargs) # --------------------------------------------------------------------------- @@ -248,29 +330,34 @@ def ctpop(x, **kw): # --------------------------------------------------------------------------- -@_traced_math_op -def powf(base, exp, *, fastmath=None, **kw): - return _mlir_math.powf(_to_raw(base), _to_raw(exp), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def powf(base, exp, *, fastmath=None, **kwargs): + return math.powf(as_ir_value(base), as_ir_value(exp), fastmath=fastmath, **kwargs) -@_traced_math_op -def fpowi(base, exp, *, fastmath=None, **kw): - return _mlir_math.fpowi(_to_raw(base), _to_raw(exp), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def fpowi(base, exp, *, fastmath=None, **kwargs): + return math.fpowi(as_ir_value(base), as_ir_value(exp), fastmath=fastmath, **kwargs) -@_traced_math_op -def ipowi(base, exp, **kw): - return _mlir_math.ipowi(_to_raw(base), _to_raw(exp), **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def ipowi(base, exp, **kwargs): + return math.ipowi(as_ir_value(base), as_ir_value(exp), **kwargs) -@_traced_math_op -def atan2(y, x, *, fastmath=None, **kw): - return _mlir_math.atan2(_to_raw(y), _to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def atan2(y, x, *, fastmath=None, **kwargs): + return math.atan2(as_ir_value(y), as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def copysign(mag, sign, *, fastmath=None, **kw): - return _mlir_math.copysign(_to_raw(mag), _to_raw(sign), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def copysign(mag, sign, *, fastmath=None, **kwargs): + return math.copysign(as_ir_value(mag), as_ir_value(sign), fastmath=fastmath, **kwargs) # --------------------------------------------------------------------------- @@ -278,36 +365,42 @@ def copysign(mag, sign, *, fastmath=None, **kw): # --------------------------------------------------------------------------- -@_traced_math_op -def fma(a, b, c, *, fastmath=None, **kw): - return _mlir_math.fma(_to_raw(a), _to_raw(b), _to_raw(c), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def fma(a, b, c, *, fastmath=None, **kwargs): + return math.fma(as_ir_value(a), as_ir_value(b), as_ir_value(c), fastmath=fastmath, **kwargs) -@_traced_math_op -def clampf(x, lo, hi, *, fastmath=None, **kw): - return _mlir_math.clampf(_to_raw(x), _to_raw(lo), _to_raw(hi), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def clampf(x, lo, hi, *, fastmath=None, **kwargs): + return math.clampf(as_ir_value(x), as_ir_value(lo), as_ir_value(hi), fastmath=fastmath, **kwargs) # --------------------------------------------------------------------------- -# Predicates (return i1) +# Predicates :: Float -> Boolean # --------------------------------------------------------------------------- -@_traced_math_op -def isnan(x, *, fastmath=None, **kw): - return _mlir_math.isnan(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def isnan(x, *, fastmath=None, **kwargs): + return math.isnan(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def isinf(x, *, fastmath=None, **kw): - return _mlir_math.isinf(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def isinf(x, *, fastmath=None, **kwargs): + return math.isinf(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def isfinite(x, *, fastmath=None, **kw): - return _mlir_math.isfinite(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def isfinite(x, *, fastmath=None, **kwargs): + return math.isfinite(as_ir_value(x), fastmath=fastmath, **kwargs) -@_traced_math_op -def isnormal(x, *, fastmath=None, **kw): - return _mlir_math.isnormal(_to_raw(x), fastmath=fastmath, **kw) +@dsl_loc_tracing +@dsl_math_wrap_result +def isnormal(x, *, fastmath=None, **kwargs): + return math.isnormal(as_ir_value(x), fastmath=fastmath, **kwargs) diff --git a/python/flydsl/expr/meta.py b/python/flydsl/expr/meta.py index eb5e24f9f..68457ef1a 100644 --- a/python/flydsl/expr/meta.py +++ b/python/flydsl/expr/meta.py @@ -7,6 +7,7 @@ from .._mlir import ir +# TODO: remove this in the future. def _to_raw_value(obj): if isinstance(obj, ir.Value): return obj @@ -24,6 +25,7 @@ def _to_raw_value(obj): return obj +# TODO: remove this in the future. def _flatten_args(args, kwargs): new_args = tuple(_to_raw_value(a) for a in args) new_kwargs = {k: _to_raw_value(v) if k not in ("loc", "ip") else v for k, v in kwargs.items()} @@ -52,6 +54,7 @@ def _caller_location(depth=1): return ir.Location.name(label, childLoc=file_loc) +# TODO: remove this in the future. def traced_op(op): @wraps(op) def wrapper(*args, **kwargs): @@ -63,3 +66,44 @@ def wrapper(*args, **kwargs): return op(*args, **kwargs) return wrapper + + +def dsl_loc_tracing(op): + """Capture the caller's Python source position as an MLIR Location + + TODO: enhance this in the recent changes. loc is missed in the op arguments. + """ + + @wraps(op) + def wrapper(*args, **kwargs): + loc = kwargs.pop("loc", None) + if loc is None: + loc = _caller_location(depth=1) + with loc: + return op(*args, **kwargs) + + return wrapper + + +def dsl_wrap_result(target=None): + """Wrap the op result(s) back into DslType values. + + - ``target=None`` (default): dispatch by the result's ``ir.Type``. + - ``target=SomeClass``: force ``SomeClass(value)`` — useful when the result + type cannot be uniquely determined from the ``ir.Type`` (vectors, …). + + Multi-value returns (tuples / lists) are wrapped element-wise. + """ + + def decorator(op, target): + @wraps(op) + def wrapper(*args, **kwargs): + from .typing import as_dsl_value + + return as_dsl_value(op(*args, **kwargs), target) + + return wrapper + + if inspect.isfunction(target): + return decorator(target, None) + return lambda op: decorator(op, target) diff --git a/python/flydsl/expr/numeric.py b/python/flydsl/expr/numeric.py index 8fe4ce0aa..0245e190c 100644 --- a/python/flydsl/expr/numeric.py +++ b/python/flydsl/expr/numeric.py @@ -10,7 +10,6 @@ from .._mlir import ir from .._mlir.dialects import arith from .._mlir.extras import types as T -from ..utils import log from .utils.arith import ( ArithValue, _to_raw, @@ -171,14 +170,17 @@ def zero(cls): _CMP_OPS = frozenset({operator.lt, operator.le, operator.gt, operator.ge, operator.eq, operator.ne}) -def _widen_narrow_int(x, widen_bool=False): - """Promote sub-32-bit integers (and optionally bools) to i32.""" - ty = type(x) - if ty is Boolean and not widen_bool: - return x, ty - if ty.is_integer and ty.width < 32: +def _widen_bool_to_int32(x, widen_bool=False): + """Promote Boolean to Int32 for arithmetic when widen_bool=True. + + Per C++-style usual arithmetic conversions, we deliberately do NOT apply + integer promotion: i8/i16/u8/u16 stay at their narrow width. + Same-width same-signedness operands keep their type; cross-width or + cross-sign mixing is resolved by ``_coerce_operands``. + """ + if widen_bool and type(x) is Boolean: return x.to(Int32), Int32 - return x, ty + return x, type(x) def _resolve_float_type(ta, tb): @@ -205,8 +207,8 @@ def _resolve_float_type(ta, tb): def _coerce_operands(a, b, widen_bool=False): """Promote *a* and *b* to a common scalar type.""" ta, tb = type(a), type(b) - a, ta = _widen_narrow_int(a, widen_bool=widen_bool) - b, tb = _widen_narrow_int(b, widen_bool=widen_bool) + a, ta = _widen_bool_to_int32(a, widen_bool=widen_bool) + b, tb = _widen_bool_to_int32(b, widen_bool=widen_bool) if ta is tb: return a, b, ta @@ -247,48 +249,6 @@ def _extract_arith(val, signed): return v.with_signedness(signed) if isinstance(v, ArithValue) else v -def _unwrap_value(value): - """Convert FlyDSL wrappers to raw MLIR values when possible.""" - if isinstance(value, ir.Value): - return value - if isinstance(value, (bool, int, float)): - try: - return as_numeric(value).ir_value() - except Exception: - log().error(f"failed to construct {as_numeric(value)} from {value}") - return value - if hasattr(value, "__extract_to_ir_values__"): - values = value.__extract_to_ir_values__() - if len(values) == 1: - return values[0] - if hasattr(value, "ir_value"): - return value.ir_value() - return value - - -def _wrap_like(value, exemplar=None): - """Wrap an MLIR value back to a FlyDSL wrapper when possible.""" - if not isinstance(value, ir.Value): - return value - - if exemplar is not None: - if isinstance(exemplar, Numeric): - return type(exemplar)(value) - ctor = getattr(type(exemplar), "__construct_from_ir_values__", None) - if ctor is not None: - try: - return ctor([value]) - except Exception: - log().error(f"failed to construct {type(exemplar)} from {value}") - return value - - try: - return Numeric.from_ir_type(value.type)(value) - except Exception: - log().error(f"failed to construct {Numeric.from_ir_type(value.type)} from {value}") - return value - - def _make_binop(op, promote=True, widen_bool=False, swap=False): """Create a binary-operator closure for Numeric subclasses.""" @@ -331,7 +291,10 @@ def __hash__(self): def select(self, true_value, false_value, *, loc=None): """Ternary select (for Boolean conditions from Int32 comparisons).""" - return ArithValue(self).select(true_value, false_value, loc=loc) + from .typing import as_dsl_value + + result = ArithValue(self).select(true_value, false_value, loc=loc) + return as_dsl_value(result, true_value) @classmethod def __coerce__(cls, value): @@ -453,6 +416,9 @@ def from_ir_type(ir_type): T.ui32(): Uint32, T.ui16(): Uint16, T.ui8(): Uint8, + T.i(128): Int128, + T.si(128): Int128, + T.ui(128): Uint128, T.f8E5M2(): Float8E5M2, T.f8E4M3(): Float8E4M3, T.f8E4M3FN(): Float8E4M3FN, @@ -552,6 +518,13 @@ def __gt__(self, other, *, loc=None, ip=None): def __ge__(self, other, *, loc=None, ip=None): return _make_binop(operator.ge)(self, other, loc=loc, ip=ip) + def bitcast(self, dtype, *, loc=None, ip=None): + """Reinterpret this value's bits as *dtype* (a same-width Numeric type).""" + if not (isinstance(dtype, type) and issubclass(dtype, Numeric)): + raise TypeError(f"dtype must be a Numeric subclass, but got {dtype!r}") + res = arith.BitcastOp(dtype.ir_type, self.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + return dtype(res, loc=loc, ip=ip) + def as_numeric(obj): if isinstance(obj, Numeric): @@ -717,6 +690,11 @@ class Int64(Integer, metaclass=NumericMeta, width=64, signed=True, ir_type=T.i64 pass +class Int128(Integer, metaclass=NumericMeta, width=128, signed=True, ir_type=lambda: T.i(128)): + def __get_c_pointers__(self): + raise TypeError("Int128 is not a JitArgument for now. ctypes has no support for 128b integers.") + + class Uint8(Integer, metaclass=NumericMeta, width=8, signed=False, ir_type=T.i8): pass @@ -733,6 +711,11 @@ class Uint64(Integer, metaclass=NumericMeta, width=64, signed=False, ir_type=T.i pass +class Uint128(Integer, metaclass=NumericMeta, width=128, signed=False, ir_type=lambda: T.i(128)): + def __get_c_pointers__(self): + raise TypeError("Uint128 is not a JitArgument for now. ctypes has no support for 128b integers.") + + class Float16(Float, metaclass=NumericMeta, width=16, ir_type=T.f16): def __get_c_pointers__(self): if not isinstance(self.value, float): diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 10f0de307..bd00fd06f 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors +import inspect from enum import IntEnum +from functools import wraps from typing import overload from .._mlir import ir @@ -33,7 +35,7 @@ has_none, ) from .._mlir.extras import types as T -from .meta import traced_op +from .meta import dsl_loc_tracing, dsl_wrap_result __all__ = [ # Maybe remove it in the future @@ -217,14 +219,21 @@ def _is_int_tuple_value(value): def _expand_int_tuple_leaves(value, loc=None, ip=None): - from .numeric import Numeric + from .numeric import Int32, Numeric if _is_int_tuple_value(value): return _expand_int_tuple_leaves(value.to_py_value(loc=loc, ip=ip)) if isinstance(value, (list, tuple)): return tuple(_expand_int_tuple_leaves(v, loc=loc, ip=ip) for v in value) + # widen narrow dynamic ints to i32 if isinstance(value, Numeric): + if isinstance(value.value, ir.Value) and type(value).width < 32: + return Int32(value, loc=loc, ip=ip).value return value.value + if isinstance(value, ir.Value) and isinstance(value.type, ir.IntegerType) and value.type.width < 32: + return Int32(value, loc=loc, ip=ip).value + if isinstance(value, ir.Value) and isinstance(value.type, ir.IndexType): + return Int32(value, loc=loc, ip=ip).value return value @@ -247,6 +256,47 @@ def _check_profile(match_func, lhs, rhs): raise ValueError(f"profile mismatch: {match_func.__name__}({lhs.type}, {rhs.type}) is False") +# ---- IntTuple covariance ---- +# Covariance rules (Python value → fly.IntTuple): +# int <: fly.IntTuple (leaf) +# Numeric <: fly.IntTuple (leaf, e.g. Int32(5)) +# tuple(X1, ...) <: fly.IntTuple<(X1, ...)> (non-leaf; tuple is constructor) +# fly.IntTuple <: fly.IntTuple (trivial) + + +def _coerce_int_tuple(v): + if _is_int_tuple_value(v): + return v + return make_int_tuple(v) + + +def _coerce_int_tuple_permissive(v): + if isinstance(v, ir.Value): + return v + return make_int_tuple(v) + + +def coerce_int_tuple_args(*arg_names, permissive=False): + coerce = _coerce_int_tuple_permissive if permissive else _coerce_int_tuple + + def decorator(fn): + sig = inspect.signature(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + bound = sig.bind_partial(*args, **kwargs) + for name in arg_names: + v = bound.arguments.get(name) + if v is None: + continue + bound.arguments[name] = coerce(v) + return fn(*bound.args, **bound.kwargs) + + return wrapper + + return decorator + + # ===----------------------------------------------------------------------=== # # Compile-time utility # ===----------------------------------------------------------------------=== # @@ -300,7 +350,7 @@ def depth(int_or_tuple): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing def static(result_type, loc=None, ip=None): """Materialize a value whose entire content is encoded in *result_type*. @@ -314,7 +364,7 @@ def static(result_type, loc=None, ip=None): return fly.static(result_type, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_int_tuple(elems, loc=None, ip=None): """Build a (possibly nested) integer tuple from Python ints or runtime values. @@ -328,7 +378,7 @@ def make_int_tuple(elems, loc=None, ip=None): return fly.make_int_tuple(IntTupleTy, dyncElems, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_shape(*shape, loc=None, ip=None): """Build a shape tuple describing the extent of each mode. @@ -342,7 +392,7 @@ def make_shape(*shape, loc=None, ip=None): return fly.make_shape(IntTupleTy, dyncElems, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_stride(*stride, loc=None, ip=None): """Build a stride tuple: the step (in elements) when moving along each mode. @@ -356,7 +406,7 @@ def make_stride(*stride, loc=None, ip=None): return fly.make_stride(IntTupleTy, dyncElems, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_coord(*coord, loc=None, ip=None): """Build a coordinate used for indexing / slicing a layout. @@ -370,7 +420,7 @@ def make_coord(*coord, loc=None, ip=None): return fly.make_coord(IntTupleTy, dyncElems, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_layout(shape, stride, loc=None, ip=None): """Pair a *shape* with a *stride* to describe how logical coords map to memory. @@ -381,20 +431,20 @@ def make_layout(shape, stride, loc=None, ip=None): make_layout((4, 8), (1, 4)) -> ((4, 8), (1, 4)) make_layout((4, 8), (8, 1)) -> ((4, 8), (8, 1)) """ - if not isinstance(shape, ir.Value): + if not _is_int_tuple_value(shape): shape = make_int_tuple(shape, loc=loc, ip=ip) - if not isinstance(stride, ir.Value): + if not _is_int_tuple_value(stride): stride = make_int_tuple(stride, loc=loc, ip=ip) _check_profile(is_profile_congruent, shape, stride) return fly.make_layout(shape, stride=stride, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_layout_like(ref, loc=None, ip=None): return fly.make_layout_like(ref, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_ordered_layout(shape, order, loc=None, ip=None): """Build a compact layout whose stride order matches *order*. @@ -405,9 +455,9 @@ def make_ordered_layout(shape, order, loc=None, ip=None): make_ordered_layout((M, N), (0, 1)) # column-major: M iterates fastest make_ordered_layout((M, N), (1, 0)) # row-major: N iterates fastest """ - if not isinstance(shape, ir.Value): + if not _is_int_tuple_value(shape): shape = make_int_tuple(shape, loc=loc, ip=ip) - if not isinstance(order, ir.Value): + if not _is_int_tuple_value(order): order = make_int_tuple(order, loc=loc, ip=ip) _check_profile(is_profile_weakly_congruent, order, shape) return fly.make_ordered_layout(shape, order, loc=loc, ip=ip) @@ -417,7 +467,7 @@ def make_ordered_layout(shape, order, loc=None, ip=None): def make_composed_layout(inner, offset, outer, loc=None, ip=None): ... @overload def make_composed_layout(inner, outer, loc=None, ip=None): ... -@traced_op +@dsl_loc_tracing def make_composed_layout(inner, offset_or_outer, outer=None, loc=None, ip=None): """Stack two layouts: a coord is first mapped by *outer*, then by *inner*. @@ -433,12 +483,12 @@ def make_composed_layout(inner, offset_or_outer, outer=None, loc=None, ip=None): offset = coprofile(outer, loc=loc, ip=ip) else: offset = offset_or_outer - if not isinstance(offset, ir.Value): + if not _is_int_tuple_value(offset): offset = make_int_tuple(offset, loc=loc, ip=ip) return fly.make_composed_layout(inner, offset, outer, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_identity_layout(shape, loc=None, ip=None): """Build the identity layout in FlyDSL's layout-algebra sense. @@ -449,22 +499,22 @@ def make_identity_layout(shape, loc=None, ip=None): Examples: make_identity_layout((4, 8)) -> ((4, 8), (1E0, 1E1)) """ - if not isinstance(shape, ir.Value): + if not _is_int_tuple_value(shape): shape = make_int_tuple(shape, loc=loc, ip=ip) return fly.make_identity_layout(shape, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_view(iter, layout, loc=None, ip=None): return fly.make_view(iter, layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_fragment_layout_like(tensor, loc=None, ip=None): return fly.make_fragment_layout_like(tensor, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_fragment_like(tensor, dtype=None, loc=None, ip=None): if hasattr(dtype, "ir_type"): dtype = dtype.ir_type @@ -476,48 +526,91 @@ def make_fragment_like(tensor, dtype=None, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing +@dsl_wrap_result def get_scalar(int_tuple, loc=None, ip=None): + """Unwrap a rank-1, single-element tuple back to a plain scalar value. + + Fails if the input has more than one leaf - use this only when you know + the tuple is a trivial wrapper. + + Examples: + get_scalar(make_coord(tid)) -> Int32(tid) + get_scalar(make_int_tuple(5)) -> 5 + """ + if not _is_int_tuple_value(int_tuple): + return int_tuple + if int_tuple.is_leaf and int_tuple.is_static: + return int_tuple.get_static_leaf_int return fly.get_scalar(int_tuple, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@dsl_wrap_result def get_leaves(input, dynamic_only=False, loc=None, ip=None): - res_lists = fly.GetLeavesOp(input, dynamicOnly=dynamic_only, loc=loc, ip=ip) - return tuple(res_lists.results) + """Flatten an IntTuple into a flat sequence of leaf values. + + Set *dynamic_only=True* to keep only runtime values and drop static + constants - handy when you need the inputs that were passed at call time. + Examples: + get_leaves(make_coord(tid, 0)) -> (Int32(tid), 0) + get_leaves(make_coord(tid, 0), dynamic_only=True) -> (Int32(tid),) # 0 is static, dropped + """ + if dynamic_only: + res_lists = fly.GetLeavesOp(input, dynamicOnly=True, loc=loc, ip=ip) + return tuple(res_lists.results) + + def _walk_int_tuple_leaves(ty): + if ty.is_leaf: + yield ty + return + for i in range(ty.rank): + yield from _walk_int_tuple_leaves(ty.at(i)) + + ty = IntTupleType(input.type) + res_lists = fly.GetLeavesOp(input, dynamicOnly=True, loc=loc, ip=ip) + dyn_iter = iter(res_lists.results) + out = [] + for leaf_ty in _walk_int_tuple_leaves(ty): + if leaf_ty.is_static: + out.append(leaf_ty.get_static_leaf_int) + else: + out.append(next(dyn_iter)) + return tuple(out) -@traced_op + +@dsl_loc_tracing def get_shape(layout, loc=None, ip=None): return fly.get_shape(layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def get_stride(layout, loc=None, ip=None): return fly.get_stride(layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def get_layout(memref, loc=None, ip=None): return fly.get_layout(memref, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def get_iter(memref, loc=None, ip=None): return fly.get_iter(memref, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def composed_get_inner(input, loc=None, ip=None): return fly.composed_get_inner(input, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def composed_get_offset(input, loc=None, ip=None): return fly.composed_get_offset(input, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def composed_get_outer(input, loc=None, ip=None): return fly.composed_get_outer(input, loc=loc, ip=ip) @@ -527,62 +620,76 @@ def composed_get_outer(input, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def int_tuple_add(lhs, rhs, loc=None, ip=None): return fly.int_tuple_add(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def int_tuple_sub(lhs, rhs, loc=None, ip=None): return fly.int_tuple_sub(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def int_tuple_mul(lhs, rhs, loc=None, ip=None): return fly.int_tuple_mul(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def int_tuple_div(lhs, rhs, loc=None, ip=None): return fly.int_tuple_div(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def int_tuple_mod(lhs, rhs, loc=None, ip=None): return fly.int_tuple_mod(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple") def int_tuple_product(int_tuple, loc=None, ip=None): return fly.int_tuple_product(int_tuple, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple") def int_tuple_product_each(int_tuple, loc=None, ip=None): return fly.int_tuple_product_each(int_tuple, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def int_tuple_product_like(lhs, rhs, loc=None, ip=None): return fly.int_tuple_product_like(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def shape_div(lhs, rhs, loc=None, ip=None): return fly.shape_div(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("lhs", "rhs") def ceil_div(lhs, rhs, loc=None, ip=None): return fly.ceil_div(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@dsl_wrap_result +@coerce_int_tuple_args("lhs", "rhs") def elem_less(lhs, rhs, loc=None, ip=None): return fly.elem_less(lhs, rhs, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@dsl_wrap_result +@coerce_int_tuple_args("lhs", "rhs") def equal(lhs, rhs, loc=None, ip=None): return fly.equal(lhs, rhs, loc=loc, ip=ip) @@ -592,7 +699,7 @@ def equal(lhs, rhs, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing def get(int_tuple, mode, loc=None, ip=None): if isinstance(int_tuple, (list, tuple)): return int_tuple[mode] @@ -603,39 +710,45 @@ def get(int_tuple, mode, loc=None, ip=None): return result -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple") def get_(int_tuple, mode, loc=None, ip=None): if isinstance(mode, int): mode = [mode] return fly.get(int_tuple, mode, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple") def take(int_tuple, begin: int, end: int, loc=None, ip=None): return fly.take(int_tuple, begin=begin, end=end, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple") def select(int_tuple, indices, loc=None, ip=None): return fly.select(int_tuple, indices=indices, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple") def group(int_tuple, begin: int, end: int, loc=None, ip=None): return fly.group(int_tuple, begin=begin, end=end, loc=loc, ip=ip) -@traced_op -def append(base, elem, n: int | None = None, loc=None, ip=None): +@dsl_loc_tracing +@coerce_int_tuple_args("base", "elem", permissive=True) +def append(base, elem, *, n: int | None = None, loc=None, ip=None): return fly.append(base, elem, n=n, loc=loc, ip=ip) -@traced_op -def prepend(base, elem, n: int | None = None, loc=None, ip=None): +@dsl_loc_tracing +@coerce_int_tuple_args("base", "elem", permissive=True) +def prepend(base, elem, *, n: int | None = None, loc=None, ip=None): return fly.prepend(base, elem, n=n, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def slice(src, coord, loc=None, ip=None): """Keep the modes where *coord* has `None` (wildcard), drop the rest. @@ -646,13 +759,13 @@ def slice(src, coord, loc=None, ip=None): slice((4, 8, 16), (None, 3, None)) -> (4, 16) # mode 1 fixed, dropped slice(layout, make_coord(None, bid)) -> sub-layout for column `bid` """ - if not isinstance(coord, ir.Value): + if not _is_int_tuple_value(coord): coord = make_int_tuple(coord, loc=loc, ip=ip) _check_profile(is_profile_weakly_congruent, coord, src) return fly.slice(src, coord, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def dice(src, coord, loc=None, ip=None): """Complement of `slice`: keep the *fixed* modes, drop the `None` (wildcard) ones. @@ -662,7 +775,7 @@ def dice(src, coord, loc=None, ip=None): dice((4, 8, 16), (None, 3, None)) -> (8,) dice(coord_tensor, make_coord(tid, None)) -> the thread-only part """ - if not isinstance(coord, ir.Value): + if not _is_int_tuple_value(coord): coord = make_int_tuple(coord, loc=loc, ip=ip) _check_profile(is_profile_weakly_congruent, coord, src) return fly.dice(src, coord, loc=loc, ip=ip) @@ -673,34 +786,28 @@ def dice(src, coord, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("int_tuple", permissive=True) def size(int_tuple, loc=None, ip=None): return fly.size(int_tuple, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def coprofile(layout, loc=None, ip=None): return fly.coprofile(layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def coshape(layout, loc=None, ip=None): return fly.coshape(layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def cosize(layout, loc=None, ip=None): return fly.cosize(layout, loc=loc, ip=ip) -def _to_i32(v): - """Cast index-type ir.Value to i32 (required by fly.make_int_tuple).""" - if isinstance(v, ir.Value) and isinstance(v.type, ir.IndexType): - return _arith.IndexCastOp(T.i32(), v).result - return v - - -@traced_op +@dsl_loc_tracing def crd2idx(crd, layout, loc=None, ip=None): """Map a coordinate tuple to an index through *layout*. @@ -712,15 +819,13 @@ def crd2idx(crd, layout, loc=None, ip=None): crd2idx((1, 2), make_layout((4, 8), (1, 4))) -> 9 crd2idx(7, make_layout((4, 8), (1, 4))) -> 7 """ - if not isinstance(crd, ir.Value): - if isinstance(crd, (list, tuple)): - crd = tuple(_to_i32(c) for c in crd) + if not _is_int_tuple_value(crd): crd = make_int_tuple(crd, loc=loc, ip=ip) _check_profile(is_profile_weakly_congruent, crd, layout) return fly.crd2idx(crd, layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def idx2crd(index, layout, loc=None, ip=None): """Map an index back to a coordinate tuple for a plain `Layout`. @@ -732,15 +837,12 @@ def idx2crd(index, layout, loc=None, ip=None): idx2crd(9, make_layout((4, 8), (1, 4))) -> (1, 2) idx2crd(5, make_layout((4, 8), (8, 1))) -> (0, 5) """ - if isinstance(index, ir.Value) and not isinstance(index.type, IntTupleType): - index = _to_i32(index) - index = make_int_tuple(index, loc=loc, ip=ip) - if not isinstance(index, ir.Value): + if not _is_int_tuple_value(index): index = make_int_tuple(index, loc=loc, ip=ip) return fly.idx2crd(index, layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def get_flat_coord(index, layout, loc=None, ip=None): """Map an index to a *fully flattened* coordinate, ignoring nested grouping. @@ -751,12 +853,12 @@ def get_flat_coord(index, layout, loc=None, ip=None): get_flat_coord(9, make_layout((4, 8), (1, 4))) -> (1, 2) get_flat_coord(3, make_layout(((2, 2), 4), ((1, 2), 4))) -> (1, 1, 0) """ - if not isinstance(index, ir.Value): + if not _is_int_tuple_value(index): index = make_int_tuple(index, loc=loc, ip=ip) return fly.get_flat_coord(index, layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def get_1d_coord(index, layout, loc=None, ip=None): """Map an index to a single 1-D coordinate in the layout's shape space. @@ -764,97 +866,104 @@ def get_1d_coord(index, layout, loc=None, ip=None): get_1d_coord(9, make_layout((4, 8), (1, 4))) -> 9 get_1d_coord(5, make_layout((4, 8), (8, 1))) -> 20 """ - if not isinstance(index, ir.Value): + if not _is_int_tuple_value(index): index = make_int_tuple(index, loc=loc, ip=ip) return fly.get_1d_coord(index, layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("pattern") def coalesce(layout, pattern=None, loc=None, ip=None): return fly.coalesce(layout, pattern=pattern, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def composition(layout, tiler, loc=None, ip=None): return fly.composition(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("codomain_size") def complement(layout, codomain_size=None, loc=None, ip=None): - if codomain_size is not None and not isinstance(codomain_size, ir.Value): - codomain_size = make_int_tuple(codomain_size, loc=loc, ip=ip) return fly.complement(layout, codomain_size=codomain_size, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def right_inverse(layout, loc=None, ip=None): return fly.right_inverse(layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def left_inverse(layout, loc=None, ip=None): return fly.left_inverse(layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def logical_divide(layout, divisor, loc=None, ip=None): if not isinstance(divisor, ir.Value): divisor = make_tile(*divisor, loc=loc, ip=ip) return fly.logical_divide(layout, divisor, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def zipped_divide(layout, divisor, loc=None, ip=None): if not isinstance(divisor, ir.Value): divisor = make_tile(*divisor, loc=loc, ip=ip) return fly.zipped_divide(layout, divisor, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def tiled_divide(layout, divisor, loc=None, ip=None): if not isinstance(divisor, ir.Value): divisor = make_tile(*divisor, loc=loc, ip=ip) return fly.tiled_divide(layout, divisor, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def flat_divide(layout, divisor, loc=None, ip=None): if not isinstance(divisor, ir.Value): divisor = make_tile(*divisor, loc=loc, ip=ip) return fly.flat_divide(layout, divisor, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def logical_product(layout, tiler, loc=None, ip=None): return fly.logical_product(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def zipped_product(layout, tiler, loc=None, ip=None): return fly.zipped_product(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def tiled_product(layout, tiler, loc=None, ip=None): return fly.tiled_product(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def flat_product(layout, tiler, loc=None, ip=None): return fly.flat_product(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def blocked_product(layout, tiler, loc=None, ip=None): return fly.blocked_product(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("tiler", permissive=True) def raked_product(layout, tiler, loc=None, ip=None): return fly.raked_product(layout, tiler, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def recast_layout(layout, old_type_bits, new_type_bits, loc=None, ip=None): def _to_static_bits(v): if isinstance(v, int): @@ -870,7 +979,8 @@ def _to_static_bits(v): return fly.recast_layout(new_type_bits=new_type_bits, old_type_bits=old_type_bits, src=layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("trg_shape", "ord_shape") def tile_to_shape(block, trg_shape, ord_shape, loc=None, ip=None): return fly.tile_to_shape(block, trg_shape, ord_shape, loc=loc, ip=ip) @@ -880,13 +990,13 @@ def tile_to_shape(block, trg_shape, ord_shape, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing def make_mma_atom(mma_op_type, loc=None, ip=None): mma_atom_ty = MmaAtomType.get(mma_op=mma_op_type) return fly.make_mma_atom(mma_atom_ty, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_copy_atom(copy_op_type, elem_type, loc=None, ip=None): from .numeric import NumericMeta @@ -905,73 +1015,79 @@ def make_copy_atom(copy_op_type, elem_type, loc=None, ip=None): return fly.make_copy_atom(copy_atom_ty, val_bits=val_bits, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def atom_set_value(atom, field, value, loc=None, ip=None): + from .typing import as_ir_value + if isinstance(field, IntEnum): field = str(field) - return fly.atom_set_value(atom, field, value, loc=loc, ip=ip) + return fly.atom_set_value(atom, field, as_ir_value(value), loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def copy_atom_call(copy_atom, src, dst, *, pred=None, loc=None, ip=None): return fly.copy_atom_call(copy_atom, src, dst, pred=pred, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def mma_atom_call(mma_atom, d, a, b, c, loc=None, ip=None): return fly.mma_atom_call(mma_atom, d, a, b, c, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_tiled_copy(copy_atom, layout_thr_val, tile_mn, loc=None, ip=None): if not isinstance(tile_mn, ir.Value): tile_mn = make_tile(*tile_mn, loc=loc, ip=ip) return fly.make_tiled_copy(copy_atom, layout_thr_val, tile_mn, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def make_tiled_mma(mma_atom, atom_layout, permutation=None, loc=None, ip=None): if permutation is not None and not isinstance(permutation, ir.Value): permutation = make_tile(*permutation, loc=loc, ip=ip) return fly.make_tiled_mma(mma_atom, atom_layout, permutation=permutation, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("thr_int_tuple") def tiled_copy_partition_src(tiled_copy, src, thr_int_tuple, loc=None, ip=None): return fly.tiled_copy_partition_src(tiled_copy, src, thr_int_tuple, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("thr_int_tuple") def tiled_copy_partition_dst(tiled_copy, dst, thr_int_tuple, loc=None, ip=None): return fly.tiled_copy_partition_dst(tiled_copy, dst, thr_int_tuple, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def tiled_copy_retile(tiled_copy, t, loc=None, ip=None): return fly.tiled_copy_retile(tiled_copy, t, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("coord") def tiled_mma_partition(operand_id, tiled_mma, t, coord, loc=None, ip=None): return fly.tiled_mma_partition(operand_id, tiled_mma, t, coord, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@coerce_int_tuple_args("shape") def tiled_mma_partition_shape(operand_id, tiled_mma, shape, loc=None, ip=None): return fly.tiled_mma_partition_shape(operand_id, tiled_mma, shape, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def mma_make_fragment(operand_id, tiled_mma, input, *, stages=None, loc=None, ip=None): return fly.mma_make_fragment(operand_id, tiled_mma, input, stages=stages, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def copy(copy_atom, src, dst, *, pred=None, loc=None, ip=None, **kwargs): return fly.copy(copy_atom.set_value(kwargs), src, dst, pred=pred, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def gemm(mma_atom, d, a, b, c, *, traversal_order=None, traversal_layout=None, loc=None, ip=None, **kwargs): if traversal_order is not None and traversal_layout is not None: raise ValueError("Only one of 'traversal_order' or 'traversal_layout' can be specified, not both") @@ -993,7 +1109,7 @@ def gemm(mma_atom, d, a, b, c, *, traversal_order=None, traversal_layout=None, l # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing def make_ptr(result_type, args, *, dict_attrs=None, loc=None, ip=None): result = fly.make_ptr(result_type, args, loc=loc, ip=ip) if dict_attrs is not None: @@ -1001,7 +1117,7 @@ def make_ptr(result_type, args, *, dict_attrs=None, loc=None, ip=None): return result -@traced_op +@dsl_loc_tracing def get_dyn_shared(dtype=None, loc=None, ip=None): """Return a pointer to the start of the kernel's dynamic shared-memory buffer. @@ -1015,20 +1131,21 @@ def get_dyn_shared(dtype=None, loc=None, ip=None): return recast_iter(dtype, raw_ptr) -@traced_op +@dsl_loc_tracing def inttoptr(result_type, src, loc=None, ip=None): """Interpret an integer address *src* as a pointer of *result_type*. Requirement: ptr.address_space != Register """ - from .typing import is_generic_address_space + from .typing import as_ir_value, is_generic_address_space if is_generic_address_space(result_type.address_space, AddressSpace.Register): raise ValueError("inttoptr is not supported for register address space") - return fly.inttoptr(result_type, src, loc=loc, ip=ip) + return fly.inttoptr(result_type, as_ir_value(src), loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@dsl_wrap_result def ptrtoint(ptr, loc=None, ip=None): """Get the raw integer address underlying *ptr*. @@ -1044,31 +1161,56 @@ def ptrtoint(ptr, loc=None, ip=None): return fly.ptrtoint(ptr, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def add_offset(ptr, offset, loc=None, ip=None): + """Shift *ptr* by *offset* elements + + Examples: + ptr2 = add_offset(ptr, 16) # move forward 16 elements + ptr2 = add_offset(ptr, tile_id * BM) # runtime offset + """ if not _is_int_tuple_value(offset): offset = make_int_tuple(offset, loc=loc, ip=ip) return fly.add_offset(ptr, offset, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def apply_swizzle(ptr, swizzle, loc=None, ip=None): return fly.apply_swizzle(ptr, swizzle, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@dsl_wrap_result def ptr_load(ptr, result_type=None, loc=None, ip=None): + """Load one value (scalar or vector) from *ptr*; dtype defaults to ptr's element type. + + Examples: + v = ptr_load(ptr) + """ if result_type is None: result_type = ptr.element_type - return fly.ptr_load(result_type.ir_type, ptr, loc=loc, ip=ip) + if not isinstance(result_type, ir.Type): + result_type = result_type.ir_type + return fly.ptr_load(result_type, ptr, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def ptr_store(value, ptr, loc=None, ip=None): + """Store *value* into *ptr*. Types must match the pointer's element type. + + Examples: + ptr_store(val, ptr) + """ + from .numeric import Numeric + + if isinstance(value, Numeric): + value = value.ir_value() + elif not isinstance(value, ir.Value): + value = ptr.element_type(value).ir_value() return fly.ptr_store(value, ptr, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def recast_iter(result_type, src, loc=None, ip=None): """Reinterpret a pointer / iterator as another element type (like `reinterpret_cast`). @@ -1088,29 +1230,29 @@ def recast_iter(result_type, src, loc=None, ip=None): return fly.recast_iter(result_type, src, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def memref_alloca(memref_type, layout, loc=None, ip=None): return fly.memref_alloca(memref_type, layout, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def memref_load_vec(memref, loc=None, ip=None): - return fly.memref_load_vec(memref, loc=loc, ip=ip) + from .typing import Vector + return Vector(fly.memref_load_vec(memref, loc=loc, ip=ip), memref.shape.to_py_value(), memref.dtype) -@traced_op + +@dsl_loc_tracing def memref_store_vec(vector, memref, loc=None, ip=None): return fly.memref_store_vec(vector, memref, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing +@dsl_wrap_result def memref_load(memref, indices, loc=None, ip=None): if isinstance(indices, ir.Value): - if str(indices.type).startswith("!fly.int_tuple"): - return fly.memref_load(memref, indices, loc=loc, ip=ip) - if str(indices.type) == "index": - indices = _arith.IndexCastOp(T.i32(), indices) - indices = make_int_tuple(indices, loc=loc, ip=ip) + if not _is_int_tuple_value(indices): + indices = make_int_tuple(indices, loc=loc, ip=ip) return fly.memref_load(memref, indices, loc=loc, ip=ip) indices = make_int_tuple(indices, loc=loc, ip=ip) @@ -1118,14 +1260,14 @@ def memref_load(memref, indices, loc=None, ip=None): return fly.memref_load(memref, indices, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def memref_store(value, memref, indices, loc=None, ip=None): + from .typing import as_ir_value + + value = as_ir_value(value) if isinstance(indices, ir.Value): - if str(indices.type).startswith("!fly.int_tuple"): - return fly.memref_store(value, memref, indices, loc=loc, ip=ip) - if str(indices.type) == "index": - indices = _arith.IndexCastOp(T.i32(), indices) - indices = make_int_tuple(indices, loc=loc, ip=ip) + if not _is_int_tuple_value(indices): + indices = make_int_tuple(indices, loc=loc, ip=ip) return fly.memref_store(value, memref, indices, loc=loc, ip=ip) indices = make_int_tuple(indices, loc=loc, ip=ip) @@ -1138,7 +1280,7 @@ def memref_store(value, memref, indices, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing def printf(*args, format_str="", loc=None, ip=None): def _convert_printf_value(val): if isinstance(val, ir.Value): @@ -1197,7 +1339,7 @@ def _convert_printf_value(val): return fly.print_(final_format, ir_values, loc=loc, ip=ip) -@traced_op +@dsl_loc_tracing def assume(result_type, dst, src, loc=None, ip=None): """ WIP, unsupported for now @@ -1210,7 +1352,7 @@ def assume(result_type, dst, src, loc=None, ip=None): # ===----------------------------------------------------------------------=== # -@traced_op +@dsl_loc_tracing def make_tile(*args, loc=None, ip=None): from .typing import Layout diff --git a/python/flydsl/expr/rocdl.py b/python/flydsl/expr/rocdl.py index 99fa11c09..ac0fe3e6b 100644 --- a/python/flydsl/expr/rocdl.py +++ b/python/flydsl/expr/rocdl.py @@ -95,9 +95,9 @@ def make_buffer_tensor(memref, alignment=4, loc=None, ip=None): from .._mlir.dialects import arith as _arith from .._mlir.dialects import fly from . import primitive as _prim - from .meta import _to_raw_value + from .typing import as_ir_value - raw_memref = _to_raw_value(memref) + raw_memref = as_ir_value(memref) layout = _prim.get_layout(memref, loc=loc, ip=ip) elem_type = fly.MemRefType(raw_memref.type).element_type diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py index d8fdf5969..96746788e 100644 --- a/python/flydsl/expr/typing.py +++ b/python/flydsl/expr/typing.py @@ -37,12 +37,15 @@ Int16, Int32, Int64, + Int128, Integer, Numeric, Uint8, Uint16, Uint32, Uint64, + Uint128, + as_numeric, ) from .primitive import * from .utils.arith import ( @@ -56,6 +59,97 @@ ) +def as_ir_value(value, *, keep_static=False): + """Convert any DslType value into a raw ``ir.Value`` + + This is the *canonical* "DSL -> ir.Value" converter. Body code that + needs to feed an MLIR builder should call this explicitly per argument. + + Behavior summary: + - ``None`` -> ``None`` + - ``ir.Value`` -> returned unchanged + - ``Numeric`` holding a Python literal, when + ``keep_static=True`` -> returned unchanged + ``keep_static=False`` -> promoted via ``as_numeric(value).ir_value()`` + - ``tuple`` / ``list`` -> recursed, shape preserved + - object with ``__extract_to_ir_values__`` -> single value extracted; multi-value returns a list + - ``bool`` / ``int`` / ``float`` -> promoted via ``as_numeric(value).ir_value()`` + - object with ``ir_value()`` -> called as a fallback + - anything else -> returned unchanged + """ + if value is None: + return None + if isinstance(value, ir.Value): + return value + if keep_static and isinstance(value, Numeric) and not isinstance(value.value, ir.Value): + return value + if isinstance(value, tuple): + return tuple(as_ir_value(v, keep_static=keep_static) for v in value) + if isinstance(value, list): + return [as_ir_value(v, keep_static=keep_static) for v in value] + if hasattr(value, "__extract_to_ir_values__"): + values = value.__extract_to_ir_values__() + if len(values) == 1: + return values[0] + return values + if isinstance(value, (bool, int, float)): + return as_numeric(value).ir_value() + if hasattr(value, "ir_value"): + return value.ir_value() + return value + + +def as_dsl_value(value, exemplar=None): + """Wrap a raw ``ir.Value`` back into a DSL value. This is the inverse + of :func:`as_ir_value` (``ir.Value -> DslType``). + + ``exemplar`` is an optional *type template* describing how to wrap ``value``: + - a DslType class -> constructed directly via ``exemplar(value)`` + - a DslType instance -> ``type(exemplar)(value)`` + + Behavior summary (mirrors the branches of :func:`as_ir_value`): + - ``None`` -> ``None`` + - ``tuple`` / ``list`` -> recursed, shape preserved, + paired element-wise with ``exemplar`` (a non-sequence ``exemplar`` is + broadcast to every element) + - with no usable ``exemplar``: a ``value`` already satisfying the + ``DslType`` protocol is returned unchanged; a bare scalar ``ir.Value`` + is dispatched by ``value.type`` via ``Numeric.from_ir_type``; any other + non-``ir.Value`` is returned unchanged. + + Raises ``TypeError`` when a bare ``ir.Value`` cannot be wrapped into any DSL + value. + """ + if value is None: + return None + if isinstance(value, (tuple, list)): + exemplars = exemplar if isinstance(exemplar, (tuple, list)) else [exemplar] * len(value) + return type(value)(as_dsl_value(v, ex) for v, ex in zip(value, exemplars)) + + if exemplar is not None and isinstance(value, ir.Value): + if isclass(exemplar): + return exemplar(value) + if isinstance(exemplar, Numeric): + return type(exemplar)(value) + ctor = getattr(type(exemplar), "__construct_from_ir_values__", None) + if ctor is not None: + try: + return ctor([value]) + except Exception: + raise ValueError(f"failed to construct {type(exemplar)} from {value}") + + from ..compiler.protocol import DslType + + if isinstance(value, DslType): + return value + if not isinstance(value, ir.Value): + return value + try: + return Numeric.from_ir_type(value.type)(value) + except Exception as e: + raise TypeError(f"as_dsl_value cannot wrap ir.Value of type {value.type!s} into a DSL value") from e + + def _vec(n: int, elem: ir.Type) -> ir.Type: return ir.VectorType.get([int(n)], elem) @@ -165,6 +259,10 @@ def i64(self) -> ir.Type: def i64x2(self) -> ir.Type: return _vec(2, Int64.ir_type) + @property + def i128(self) -> ir.Type: + return Int128.ir_type + # ---- Float scalars & vectors ---- @property def f16(self) -> ir.Type: @@ -248,6 +346,9 @@ def vec(self, n: int, elem: ir.Type) -> ir.Type: "Types", "T", "default_f8_type", + # DSL utilities + "as_ir_value", + "as_dsl_value", "is_generic_address_space", "is_target_address_space", # DSL value types @@ -272,11 +373,13 @@ def vec(self, n: int, elem: ir.Type) -> ir.Type: "Int16", "Int32", "Int64", + "Int128", "Index", "Uint8", "Uint16", "Uint32", "Uint64", + "Uint128", "Constexpr", "IntTuple", "Layout", @@ -582,11 +685,8 @@ def _rebuild_py_value(self, leaf_iter): if self.is_leaf: if self.is_static: return self.get_static_leaf_int - val = next(leaf_iter) - width = ir.IntegerType(val.type).width - wrapper = Int64 if width == 64 else Int32 - return wrapper(val) - return tuple(IntTuple(get_(self, i))._rebuild_py_value(leaf_iter) for i in range(self.rank)) + return next(leaf_iter) + return tuple(get_(self, i)._rebuild_py_value(leaf_iter) for i in range(self.rank)) @traced_op def to_py_value(self, loc=None, ip=None): @@ -821,7 +921,7 @@ def load(self, loc=None, ip=None): @traced_op def store(self, value, loc=None, ip=None): - if isinstance(value, (bool, int, float)): + if isinstance(value, (bool, int, float, Numeric)): value = self.element_type(value) return ptr_store(value, self, loc=loc, ip=ip) @@ -917,7 +1017,7 @@ def __setitem__(self, coord, value, loc=None, ip=None): @traced_op def load(self, loc=None, ip=None): - return Vector(memref_load_vec(self, loc=loc, ip=ip), self.shape.to_py_value(), self.dtype) + return memref_load_vec(self, loc=loc, ip=ip) @traced_op def store(self, vector, loc=None, ip=None): diff --git a/tests/unit/test_layout_algebra.py b/tests/unit/test_layout_algebra.py index 4f7a55d6c..610b3d851 100644 --- a/tests/unit/test_layout_algebra.py +++ b/tests/unit/test_layout_algebra.py @@ -32,13 +32,7 @@ FLY_PIPELINE = ( - "builtin.module(" - "fly-canonicalize," - "fly-layout-lowering," - "fly-canonicalize," - "convert-fly-to-rocdl," - "canonicalize," - "cse)" + "builtin.module(fly-canonicalize,fly-layout-lowering,fly-canonicalize,convert-fly-to-rocdl,canonicalize,cse)" ) @@ -218,9 +212,8 @@ def build_static(): with Location.unknown(ctx): module = Module.create() i32 = IntegerType.get_signless(32) - idx = IndexType.get() with InsertionPoint(module.body): - f = func.FuncOp("comp_dyn", FunctionType.get([i32] * 8, [idx])) + f = func.FuncOp("comp_dyn", FunctionType.get([i32] * 8, [i32])) entry = f.add_entry_block() with InsertionPoint(entry): args = list(entry.arguments) @@ -228,8 +221,8 @@ def build_static(): B = fx.make_layout(fx.make_shape(args[4], args[5]), fx.make_stride(args[6], args[7])) R = fx.composition(A, B) sz = fx.size(R) - sc = fx.get_scalar(sz) - func.ReturnOp([arith.IndexCastOp(idx, sc).result]) + sc = fx.get_scalar(sz).ir_value() + func.ReturnOp([sc]) pm = PassManager.parse(FLY_PIPELINE, ctx) pm.run(module.operation) assert module.operation.verify() @@ -317,9 +310,8 @@ def test_complement_rank_2_dynamic_stride_error(): with Location.unknown(ctx): module = Module.create() i32 = IntegerType.get_signless(32) - idx = IndexType.get() with InsertionPoint(module.body): - f = func.FuncOp("compl_dyn", FunctionType.get([i32], [idx])) + f = func.FuncOp("compl_dyn", FunctionType.get([i32], [i32])) entry = f.add_entry_block() with InsertionPoint(entry): runtime_stride = entry.arguments[0] @@ -328,8 +320,8 @@ def test_complement_rank_2_dynamic_stride_error(): tiler = fx.make_layout(shape, stride) comp = fx.complement(tiler, 12) sz = fx.size(comp) - sc = fx.get_scalar(sz) - func.ReturnOp([arith.IndexCastOp(idx, sc).result]) + sc = fx.get_scalar(sz).ir_value() + func.ReturnOp([sc]) pm = PassManager.parse(FLY_PIPELINE, ctx) pm.run(module.operation) diff --git a/tests/unit/test_numeric_promotion.py b/tests/unit/test_numeric_promotion.py new file mode 100644 index 000000000..c4aa7a68a --- /dev/null +++ b/tests/unit/test_numeric_promotion.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 FlyDSL Project Contributors + +"""C++-style usual-arithmetic-conversion promotion for DSL Numeric types. + +We deliberately skip the C++ "integer promotion to int" step: ``int8 + int8`` +must stay ``int8``, ``uint16 + uint16`` stays ``uint16``. Cross-width and +cross-sign promotion follows usual arithmetic conversions (unsigned wins at +equal width; wider wins among same-sign; signed-can-represent rule for +mixed-sign mixed-width). +""" + +import pytest + +import flydsl.expr as fx +from flydsl._mlir.ir import Context, InsertionPoint, Location, Module + +pytestmark = [pytest.mark.l1b_target_dialect] + + +def _binop(lhs_ty, rhs_ty, op): + """Build two block-arg values of the requested DSL types and apply `op`. + + Returns the resulting Numeric. We use block args so the operands are + genuinely dynamic ir.Values (not Python literals), which is the path + most kernel code hits. + """ + with Context() as ctx: + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + module = Module.create() + from flydsl._mlir.dialects import func + from flydsl._mlir.ir import FunctionType + + with InsertionPoint(module.body): + f = func.FuncOp("k", FunctionType.get([lhs_ty.ir_type, rhs_ty.ir_type], [])) + entry = f.add_entry_block() + with InsertionPoint(entry): + a = lhs_ty(entry.arguments[0]) + b = rhs_ty(entry.arguments[1]) + result = op(a, b) + func.ReturnOp([]) + assert module.operation.verify() + return result + + +# Same-sign / same-width: must stay narrow (no auto-int32 promotion). +@pytest.mark.parametrize( + "ty", + [fx.Int8, fx.Int16, fx.Uint8, fx.Uint16, fx.Int32, fx.Int64, fx.Uint32, fx.Uint64, fx.Int128, fx.Uint128], +) +def test_same_type_stays_narrow(ty): + assert _binop(ty, ty, lambda a, b: a + b).dtype is ty + assert _binop(ty, ty, lambda a, b: a * b).dtype is ty + + +# Same-sign cross-width: wider wins. +@pytest.mark.parametrize( + "a,b,expected", + [ + (fx.Int8, fx.Int16, fx.Int16), + (fx.Int8, fx.Int32, fx.Int32), + (fx.Int16, fx.Int64, fx.Int64), + (fx.Uint8, fx.Uint16, fx.Uint16), + (fx.Uint16, fx.Uint64, fx.Uint64), + (fx.Int32, fx.Int128, fx.Int128), + (fx.Int64, fx.Int128, fx.Int128), + (fx.Uint32, fx.Uint128, fx.Uint128), + ], +) +def test_same_sign_wider_wins(a, b, expected): + assert _binop(a, b, lambda x, y: x + y).dtype is expected + assert _binop(b, a, lambda x, y: x + y).dtype is expected # commutative + + +# Mixed sign: unsigned wins iff u.width >= s.width, else signed. +@pytest.mark.parametrize( + "a,b,expected", + [ + (fx.Int32, fx.Uint32, fx.Uint32), # equal width → unsigned wins + (fx.Int32, fx.Uint64, fx.Uint64), # u wider → unsigned wins + (fx.Int64, fx.Uint32, fx.Int64), # s wider → signed (signed-can-represent) + (fx.Int8, fx.Uint16, fx.Uint16), # u wider → unsigned + (fx.Int16, fx.Uint8, fx.Int16), # s wider → signed + (fx.Int128, fx.Uint128, fx.Uint128), # equal width → unsigned + (fx.Int128, fx.Uint64, fx.Int128), # s wider → signed + (fx.Int128, fx.Uint32, fx.Int128), # s wider → signed + (fx.Uint128, fx.Int32, fx.Uint128), # u wider → unsigned + (fx.Uint128, fx.Int64, fx.Uint128), # u wider → unsigned + ], +) +def test_mixed_sign(a, b, expected): + assert _binop(a, b, lambda x, y: x + y).dtype is expected + assert _binop(b, a, lambda x, y: x + y).dtype is expected + + +# Python literal: as_numeric promotes int→Int32 (C++ `int` literal default), +# then C++ promotion runs. +def test_python_int_literal_promotes_via_int32(): + # Int8(arg) + 5 → Int8 + Int32 → Int32 (wider wins) + with Context() as ctx, Location.unknown(ctx): + ctx.allow_unregistered_dialects = True + module = Module.create() + from flydsl._mlir.dialects import func + from flydsl._mlir.ir import FunctionType + + with InsertionPoint(module.body): + f = func.FuncOp("k", FunctionType.get([fx.Int8.ir_type], [])) + entry = f.add_entry_block() + with InsertionPoint(entry): + a = fx.Int8(entry.arguments[0]) + r = a + 5 + func.ReturnOp([]) + assert module.operation.verify() + assert r.dtype is fx.Int32 + + +# Int + Float: promote to the float side. +@pytest.mark.parametrize( + "itype,ftype", + [ + (fx.Int8, fx.Float16), + (fx.Int32, fx.Float32), + (fx.Int64, fx.Float64), + (fx.Int128, fx.Float64), # no Float128; precision loss is expected and OK + ], +) +def test_int_plus_float(itype, ftype): + assert _binop(itype, ftype, lambda x, y: x + y).dtype is ftype + assert _binop(ftype, itype, lambda x, y: x + y).dtype is ftype + + +# Float + Float: wider wins. +@pytest.mark.parametrize( + "a,b,expected", + [ + (fx.Float16, fx.Float32, fx.Float32), + (fx.Float32, fx.Float64, fx.Float64), + (fx.Float16, fx.Float64, fx.Float64), + ], +) +def test_float_wider_wins(a, b, expected): + assert _binop(a, b, lambda x, y: x + y).dtype is expected + assert _binop(b, a, lambda x, y: x + y).dtype is expected + + +# Boolean arithmetic: bool + bool → Int32 (matches C++ "bool participates as int"). +def test_bool_plus_bool_widens_to_int32(): + with Context() as ctx, Location.unknown(ctx): + ctx.allow_unregistered_dialects = True + module = Module.create() + from flydsl._mlir.dialects import func + from flydsl._mlir.ir import FunctionType + + with InsertionPoint(module.body): + f = func.FuncOp("k", FunctionType.get([fx.Boolean.ir_type, fx.Boolean.ir_type], [])) + entry = f.add_entry_block() + with InsertionPoint(entry): + a = fx.Boolean(entry.arguments[0]) + b = fx.Boolean(entry.arguments[1]) + r = a + b + func.ReturnOp([]) + assert module.operation.verify() + assert r.dtype is fx.Int32 + + +# True division on integers: Python `/` lifts int/int to float. +@pytest.mark.parametrize( + "ty,expected", + [ + (fx.Int8, fx.Float32), + (fx.Int32, fx.Float32), + (fx.Int64, fx.Float64), + (fx.Int128, fx.Float64), + ], +) +def test_truediv_int_lifts_to_float(ty, expected): + assert _binop(ty, ty, lambda x, y: x / y).dtype is expected + + +# Floor division on integers: stays integer (Python `//` semantics). +@pytest.mark.parametrize("ty", [fx.Int8, fx.Int32, fx.Int64, fx.Uint32, fx.Int128]) +def test_floordiv_int_stays_int(ty): + assert _binop(ty, ty, lambda x, y: x // y).dtype is ty diff --git a/tests/unit/test_static_vs_dynamic.py b/tests/unit/test_static_vs_dynamic.py index a3c42c320..e027320eb 100644 --- a/tests/unit/test_static_vs_dynamic.py +++ b/tests/unit/test_static_vs_dynamic.py @@ -27,13 +27,7 @@ FLY_PIPELINE = ( - "builtin.module(" - "fly-canonicalize," - "fly-layout-lowering," - "fly-canonicalize," - "convert-fly-to-rocdl," - "canonicalize," - "cse)" + "builtin.module(fly-canonicalize,fly-layout-lowering,fly-canonicalize,convert-fly-to-rocdl,canonicalize,cse)" ) @@ -85,9 +79,8 @@ def test_layout_dynamic_types(): with Location.unknown(ctx): module = Module.create() i32 = IntegerType.get_signless(32) - idx = IndexType.get() with InsertionPoint(module.body): - f = func.FuncOp("dynamic_layout", FunctionType.get([i32] * 4, [idx])) + f = func.FuncOp("dynamic_layout", FunctionType.get([i32] * 4, [i32])) entry = f.add_entry_block() with InsertionPoint(entry): dim0, dim1, stride0, stride1 = entry.arguments @@ -98,7 +91,7 @@ def test_layout_dynamic_types(): layout = fx.make_layout(shape, stride) sz = fx.size(layout) sc = fx.get_scalar(sz) - func.ReturnOp([arith.IndexCastOp(idx, sc).result]) + func.ReturnOp([sc.ir_value()]) pm = PassManager.parse(FLY_PIPELINE, ctx) pm.run(module.operation) @@ -138,9 +131,8 @@ def test_mixed_static_dynamic(): with Location.unknown(ctx): module = Module.create() i32 = IntegerType.get_signless(32) - idx = IndexType.get() with InsertionPoint(module.body): - f = func.FuncOp("mixed_layout", FunctionType.get([i32, i32], [idx])) + f = func.FuncOp("mixed_layout", FunctionType.get([i32, i32], [i32])) entry = f.add_entry_block() with InsertionPoint(entry): runtime_extent, runtime_stride = entry.arguments @@ -152,8 +144,8 @@ def test_mixed_static_dynamic(): stride = fx.make_stride(c16, runtime_stride) layout = fx.make_layout(shape, stride) sz = fx.size(layout) - sc = fx.get_scalar(sz) - func.ReturnOp([arith.IndexCastOp(idx, sc).result]) + sc = fx.get_scalar(sz).ir_value() + func.ReturnOp([sc]) pm = PassManager.parse(FLY_PIPELINE, ctx) pm.run(module.operation)