Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions kernels/blockscale_preshuffle_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ def kernel_gemm(
# ---- Wave / lane decomposition ----
wave_size = 64
layout_wave_lane = fx.make_layout((4, wave_size), (64, 1))
coord_wave_lane = fx.idx2crd(tx, layout_wave_lane)
coord_wave_lane = fx.idx2crd(fx.Int32(tx), layout_wave_lane)
wave_id = fx.get(coord_wave_lane, 0)
lane_id = fx.get(coord_wave_lane, 1)

layout_lane16 = fx.make_layout((4, 16), (16, 1))
coord_lane16 = fx.idx2crd(lane_id, layout_lane16)
coord_lane16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16)
lane_div_16 = fx.get(coord_lane16, 0)
lane_mod_16 = fx.get(coord_lane16, 1)

Expand Down Expand Up @@ -252,8 +252,8 @@ def load_b_packs_k64(base_k, ku: int, ni: int):
k0_base = base_k_bytes // c64_b
k0 = k0_base + ku
k1 = lane_div_16
coord_pack = (n_blk_list[ni], k0, k1, n_intra_list[ni], fx.Index(0))
idx_pack = crd2idx(coord_pack, layout_b)
coord_pack = (n_blk_list[ni], k0, k1, n_intra_list[ni], fx.Int32(0))
idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b)
b16 = _buffer_load_vec(
buffer_ops,
vector,
Expand Down
2 changes: 1 addition & 1 deletion kernels/gemm_fp8fp4_gfx1250.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def kernel_mxscale_gemm(
b_mcast_mask = 0

layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1))
thr_coord = idx2crd(tx, layout_thr)
thr_coord = idx2crd(fx.Int32(tx), layout_thr)
wave_m_idx, wave_n_idx, lane_kgrp, lane16 = (
fx.get(thr_coord, 0),
fx.get(thr_coord, 1),
Expand Down
2 changes: 1 addition & 1 deletion kernels/layernorm_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def _load_norm_input_value(index):
mean = sum_val / n_float
var = sumsq_val / n_float - mean * mean
var = (var < c_zero_f).select(c_zero_f, var)
rstd = (var + eps_c).rsqrt(fastmath=fm_fast)
rstd = fmath.rsqrt(var + eps_c, fastmath=fm_fast)

thread_row_max = c_zero_f
for base_idx_int in range_constexpr(0, N, BLOCK_THREADS):
Expand Down
12 changes: 7 additions & 5 deletions kernels/layout_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,15 @@ def idx2crd(idx, layout):
"""
parsed = _parse_layout(layout)

if hasattr(idx, "ir_value"):
idx = idx.ir_value()

if parsed is None or _has_dynamic_strides(parsed[1]):
result = fx.idx2crd(idx, layout)
result = fx.idx2crd(fx.Int32(idx), layout)
ndims = len(parsed[1]) if parsed else 1
return [_wrap(fx.get(result, i)) for i in range(ndims)]

if hasattr(idx, "type") and str(idx.type) != "index":
if isinstance(idx, ir.Value) and not isinstance(idx.type, ir.IndexType):
idx = arith.index_cast(T.index, idx)
shapes, strides = parsed
ndims = len(strides)
Expand Down Expand Up @@ -156,9 +159,8 @@ def crd2idx(crd, layout):
cv = raw
crd_i32.append(cv)
coord_val = fx.make_coord(*crd_i32)
result = fx.crd2idx(coord_val, layout)
scalar = fx.get_scalar(result)
if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType):
scalar = fx.get_scalar(fx.crd2idx(coord_val, layout)).ir_value()
if not isinstance(scalar.type, ir.IndexType):
scalar = arith.index_cast(T.index, scalar)
return _wrap(scalar)

Expand Down
27 changes: 13 additions & 14 deletions kernels/mfma_preshuffle_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@


def crd2idx(crd, layout):
"""crd2idx returning an index-type scalar (unwraps fly.int_tuple)."""
result = fx.crd2idx(crd, layout)
scalar = fx.get_scalar(result)
if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType):
scalar = _arith.IndexCastOp(T.index, scalar).result
return scalar
"""crd2idx returning an index-typed ir.Value (unwraps fly.int_tuple)."""
scalar = fx.get_scalar(fx.crd2idx(crd, layout)).ir_value()
Comment thread
sjfeng1999 marked this conversation as resolved.
if isinstance(scalar.type, ir.IndexType):
return scalar
return _arith.IndexCastOp(T.index, scalar).result


def swizzle_xor16(row, col, k_blocks16):
Expand Down Expand Up @@ -326,7 +325,7 @@ def load_b_raw_w4a16(
k2_base = lane_odd * fx.Index(half_bytes)

coord_pack = (n_blk, k0, k1_local, n_intra, fx.Index(0))
idx_pack = crd2idx(coord_pack, layout_b)
idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b)
idx_bytes = idx_pack + k2_base

b4 = _buffer_load_vec(
Expand Down Expand Up @@ -464,7 +463,7 @@ def load_b_pack_k32(
k2_base = arith.constant((ki_step % 2) * half_bytes, index=True)

coord_pack = (n_blk, k0, k1, n_intra, fx.Index(0))
idx_pack = crd2idx(coord_pack, layout_b)
idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b)

if unpack_int4:
idx_bytes = idx_pack + k2_base
Expand Down Expand Up @@ -527,7 +526,7 @@ def tile_chunk_coord_i32(
raise ValueError(f"chunk_i32 must be one of (1,2,4), got {chunk_i32!r}")
chunk_off_i32 = arith.constant(i * total_threads * chunk_i32, index=True)
tile_idx_i32 = tx_i32_base + chunk_off_i32
coord_local = fx.idx2crd(tile_idx_i32, layout_tile_div4)
coord_local = fx.idx2crd(fx.Int32(tile_idx_i32), layout_tile_div4)
row_local = fx.get(coord_local, 0)
col_local_i32 = fx.get(coord_local, 1)
return row_local, col_local_i32
Expand Down Expand Up @@ -580,7 +579,7 @@ def lds_store_16b_xor16(
col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16)
col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2
coord_store = (row_local, col_swz)
idx0 = crd2idx(coord_store, layout_lds) + lds_base
idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base
v16 = vector.bitcast(vec16_ty, vec_part_i32x4)
vector.store(v16, lds_memref, [idx0])

Expand All @@ -607,7 +606,7 @@ def lds_store_8b_xor16(
col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16)
col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2
coord_store = (row_local, col_swz)
idx0 = crd2idx(coord_store, layout_lds) + lds_base
idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base
v8 = vector.bitcast(vec8_ty, vec_part_i32x2)
vector.store(v8, lds_memref, [idx0])

Expand All @@ -634,7 +633,7 @@ def lds_store_4b_xor16(
col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16)
col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2
coord_store = (row_local, col_swz)
idx0 = crd2idx(coord_store, layout_lds) + lds_base
idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base
v4 = vector.bitcast(vec4_ty, vec_part_i32x1)
vector.store(v4, lds_memref, [idx0])

Expand All @@ -660,14 +659,14 @@ def lds_load_pack_k32(
col_base_swz = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16)
if ck_lds128:
coord_a16 = (curr_row_a_lds, col_base_swz)
idx_a16 = crd2idx(coord_a16, layout_lds) + lds_base
idx_a16 = crd2idx(tuple(fx.Int32(c) for c in coord_a16), layout_lds) + lds_base
loaded_a16 = vector.load_op(vec16_ty, lds_memref, [idx_a16])
a_vec128 = vector.bitcast(vec2_i64_ty, loaded_a16)
return vector.extract(a_vec128, static_position=[half], dynamic_position=[])
else:
col_swizzled = col_base_swz + (half * 8)
coord_a = (curr_row_a_lds, col_swizzled)
idx_a = crd2idx(coord_a, layout_lds) + lds_base
idx_a = crd2idx(tuple(fx.Int32(c) for c in coord_a), layout_lds) + lds_base
loaded_a8 = vector.load_op(vec8_ty, lds_memref, [idx_a])
a_vec64 = vector.bitcast(vec1_i64_ty, loaded_a8)
return vector.extract(a_vec64, static_position=[0], dynamic_position=[])
Expand Down
18 changes: 9 additions & 9 deletions kernels/mixed_moe_gemm_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,10 +729,10 @@ def load_x_tile(base_k):
return parts

# Wave/lane decomposition (identical to stage2)
coord_wl = idx2crd(tx, layout_tx_wave_lane)
coord_wl = idx2crd(fx.Int32(tx), layout_tx_wave_lane)
wave_id = layout_get(coord_wl, 0)
lane_id = layout_get(coord_wl, 1)
coord_l16 = idx2crd(lane_id, layout_lane16)
coord_l16 = idx2crd(fx.Int32(lane_id), layout_lane16)
lane_div_16 = layout_get(coord_l16, 0)
lane_mod_16 = layout_get(coord_l16, 1)
row_a_lds = lane_mod_16
Expand Down Expand Up @@ -763,12 +763,12 @@ def load_x_tile(base_k):
global_n = by_n + n_tile_base + c_offset + lane_mod_16
# Gate/interleave: rows [expert_off, expert_off + 2*inter_dim)
gate_row_w = expert_off_idx + global_n
gate_coord = idx2crd(gate_row_w, layout_n_blk_intra)
gate_coord = idx2crd(fx.Int32(gate_row_w), layout_n_blk_intra)
gate_n_blk_list.append(layout_get(gate_coord, 0))
gate_n_intra_list.append(layout_get(gate_coord, 1))
if const_expr(not mock_gate_only and not gate_up_interleave):
up_row_w = gate_row_w + inter_idx
up_coord = idx2crd(up_row_w, layout_n_blk_intra)
up_coord = idx2crd(fx.Int32(up_row_w), layout_n_blk_intra)
up_n_blk_list.append(layout_get(up_coord, 0))
up_n_intra_list.append(layout_get(up_coord, 1))

Expand Down Expand Up @@ -799,7 +799,7 @@ def load_b_packs_k64(base_k, ku: int, n_blk, n_intra):
k0 = base_k_bytes // c64 + arith.constant(ku, index=True)
k1 = lane_div_16
coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True))
idx_pack = crd2idx(coord_pack, layout_b)
idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b)
vec_elems = kpack_bytes // int(b_elem_bytes)
b16 = _buffer_load_vec(
buffer_ops,
Expand Down Expand Up @@ -1015,7 +1015,7 @@ def prefetch_x_to_lds(base_k, lds_buffer):
def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer):
col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16)
col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2))
idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds)
idx_a16 = crd2idx([fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)], layout_lds)
loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16])
a_i64x2 = vector.bitcast(vec2_i64, loaded_a16)
a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[])
Expand Down Expand Up @@ -3074,10 +3074,10 @@ def load_x_tile(base_k):
return parts

# tx -> wave/lane (GEMM-style decomposition).
coord_wl = idx2crd(tx, layout_tx_wave_lane)
coord_wl = idx2crd(fx.Int32(tx), layout_tx_wave_lane)
wave_id = layout_get(coord_wl, 0)
lane_id = layout_get(coord_wl, 1)
coord_l16 = idx2crd(lane_id, layout_lane16)
coord_l16 = idx2crd(fx.Int32(lane_id), layout_lane16)
lane_div_16 = layout_get(coord_l16, 0)
lane_mod_16 = layout_get(coord_l16, 1)

Expand Down Expand Up @@ -3330,7 +3330,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer):
def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer):
col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16)
col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2))
idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds)
idx_a16 = crd2idx([fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)], layout_lds)
loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16])
a_i64x2 = vector.bitcast(vec2_i64, loaded_a16)
a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[])
Expand Down
18 changes: 9 additions & 9 deletions kernels/moe_blockscale_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,10 @@ def load_x_tile(base_k, x_load_bytes_v):
return parts

# tx -> wave/lane (GEMM-style decomposition).
coord_wl = fx.idx2crd(tx, layout_tx_wave_lane)
coord_wl = fx.idx2crd(fx.Int32(tx), layout_tx_wave_lane)
wave_id = fx.get(coord_wl, 0)
lane_id = fx.get(coord_wl, 1)
coord_l16 = fx.idx2crd(lane_id, layout_lane16)
coord_l16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16)
lane_div_16 = fx.get(coord_l16, 0)
lane_mod_16 = fx.get(coord_l16, 1)

Expand Down Expand Up @@ -511,11 +511,11 @@ def load_x_tile(base_k, x_load_bytes_v):
row_gate = expert_off_idx + col_g
row_up = row_gate + inter_idx

coord_gate = fx.idx2crd(row_gate, layout_n_blk_intra)
coord_gate = fx.idx2crd(fx.Int32(row_gate), layout_n_blk_intra)
n_blk_gate.append(fx.get(coord_gate, 0))
n_intra_gate.append(fx.get(coord_gate, 1))

coord_up = fx.idx2crd(row_up, layout_n_blk_intra)
coord_up = fx.idx2crd(fx.Int32(row_up), layout_n_blk_intra)
n_blk_up.append(fx.get(coord_up, 0))
n_intra_up.append(fx.get(coord_up, 1))

Expand Down Expand Up @@ -620,7 +620,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base):
col_base_swz = (
col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes // arith.index(int(elem_bytes)))
)
idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds)
idx_a16 = crd2idx((fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)), layout_lds)
idx_a16 = idx_a16 + lds_base
loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16])
a_i64x2 = vector.bitcast(T.i64x2, loaded_a16)
Expand Down Expand Up @@ -1604,10 +1604,10 @@ def load_x_tile(base_k, x_load_bytes_v):
return parts

# tx -> wave/lane (GEMM-style decomposition).
coord_wl = fx.idx2crd(tx, layout_tx_wave_lane)
coord_wl = fx.idx2crd(fx.Int32(tx), layout_tx_wave_lane)
wave_id = fx.get(coord_wl, 0)
lane_id = fx.get(coord_wl, 1)
coord_l16 = fx.idx2crd(lane_id, layout_lane16)
coord_l16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16)
lane_div_16 = fx.get(coord_l16, 0)
lane_mod_16 = fx.get(coord_l16, 1)

Expand Down Expand Up @@ -1640,7 +1640,7 @@ def load_x_tile(base_k, x_load_bytes_v):
col_g_list.append(col_g)

row_w = expert_off_idx + col_g
coord_w = fx.idx2crd(row_w, layout_n_blk_intra)
coord_w = fx.idx2crd(fx.Int32(row_w), layout_n_blk_intra)
n_blk_list.append(fx.get(coord_w, 0))
n_intra_list.append(fx.get(coord_w, 1))

Expand Down Expand Up @@ -1742,7 +1742,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base):
col_base_swz = (
col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes // arith.index(int(elem_bytes)))
)
idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds)
idx_a16 = crd2idx((fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)), layout_lds)
idx_a16 = idx_a16 + lds_base
loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16])
a_i64x2 = vector.bitcast(T.i64x2, loaded_a16)
Expand Down
18 changes: 9 additions & 9 deletions kernels/moe_gemm_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,10 +598,10 @@ def load_x_tile(base_k):
return parts

# tx -> wave/lane (GEMM-style decomposition).
coord_wl = fx.idx2crd(tx, layout_tx_wave_lane)
coord_wl = fx.idx2crd(fx.Int32(tx), layout_tx_wave_lane)
wave_id = fx.get(coord_wl, 0)
lane_id = fx.get(coord_wl, 1)
coord_l16 = fx.idx2crd(lane_id, layout_lane16)
coord_l16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16)
lane_div_16 = fx.get(coord_l16, 0)
lane_mod_16 = fx.get(coord_l16, 1)

Expand Down Expand Up @@ -644,11 +644,11 @@ def load_x_tile(base_k):
row_gate = expert_off_idx + col_g
row_up = row_gate + inter_idx

coord_gate = fx.idx2crd(row_gate, layout_n_blk_intra)
coord_gate = fx.idx2crd(fx.Int32(row_gate), layout_n_blk_intra)
n_blk_gate.append(fx.get(coord_gate, 0))
n_intra_gate.append(fx.get(coord_gate, 1))

coord_up = fx.idx2crd(row_up, layout_n_blk_intra)
coord_up = fx.idx2crd(fx.Int32(row_up), layout_n_blk_intra)
n_blk_up.append(fx.get(coord_up, 0))
n_intra_up.append(fx.get(coord_up, 1))

Expand Down Expand Up @@ -811,7 +811,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base):
col_base_swz = (
col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes // arith.index(int(elem_bytes)))
)
idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds)
idx_a16 = crd2idx((fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)), layout_lds)
idx_a16 = idx_a16 + lds_base
loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16])
a_i64x2 = vector.bitcast(T.i64x2, loaded_a16)
Expand Down Expand Up @@ -2256,10 +2256,10 @@ def load_x_tile(base_k):
return parts

# tx -> wave/lane (GEMM-style decomposition).
coord_wl = fx.idx2crd(tx, layout_tx_wave_lane)
coord_wl = fx.idx2crd(fx.Int32(tx), layout_tx_wave_lane)
wave_id = fx.get(coord_wl, 0)
lane_id = fx.get(coord_wl, 1)
coord_l16 = fx.idx2crd(lane_id, layout_lane16)
coord_l16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16)
lane_div_16 = fx.get(coord_l16, 0)
lane_mod_16 = fx.get(coord_l16, 1)

Expand Down Expand Up @@ -2293,7 +2293,7 @@ def load_x_tile(base_k):
col_g_list.append(col_g)

row_w = expert_off_idx + col_g
coord_w = fx.idx2crd(row_w, layout_n_blk_intra)
coord_w = fx.idx2crd(fx.Int32(row_w), layout_n_blk_intra)
n_blk_list.append(fx.get(coord_w, 0))
n_intra_list.append(fx.get(coord_w, 1))

Expand Down Expand Up @@ -2453,7 +2453,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base):
col_base_swz = (
col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes // arith.index(int(elem_bytes)))
)
idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds)
idx_a16 = crd2idx((fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)), layout_lds)
idx_a16 = idx_a16 + lds_base
loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16])
a_i64x2 = vector.bitcast(T.i64x2, loaded_a16)
Expand Down
4 changes: 2 additions & 2 deletions kernels/moe_gemm_2stage_mxscale_gfx1250.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def moe_mxscale_stage1_single(
block_ok = arith.andi(block_in_valid, arith.andi(eid_ok0, eid_ok1))

layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx)
thr_coord = idx2crd(tx, layout_thr)
thr_coord = idx2crd(fx.Int32(tx), layout_thr)
wave_m_idx, wave_n_idx, lane_kgrp, lane16 = (
fx.get(thr_coord, 0), fx.get(thr_coord, 1), fx.get(thr_coord, 2), fx.get(thr_coord, 3)
)
Expand Down Expand Up @@ -2560,7 +2560,7 @@ def moe_mxscale_stage2_single(
block_ok = arith.andi(block_in_valid, arith.andi(eid_ok0, eid_ok1))

layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx)
thr_coord = idx2crd(tx, layout_thr)
thr_coord = idx2crd(fx.Int32(tx), layout_thr)
wave_m_idx, wave_n_idx, lane_kgrp, lane16 = (
fx.get(thr_coord, 0), fx.get(thr_coord, 1), fx.get(thr_coord, 2), fx.get(thr_coord, 3)
)
Expand Down
Loading
Loading