Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,9 @@ static __aicore__ void aic_qk_step(

// TPUSH sij (C2V): AccTile L0C -> GM. Ensure prior MTE3 is done,
// then push, then wait for MTE3 DMA to complete before signaling consumer.
pipe_barrier(PIPE_MTE3);
TPUSH<SijPipeT, AccTile_QK, TileSplitAxis::TILE_UP_DOWN>(sij_pipe, cTile_QK);
pipe_barrier(PIPE_MTE3);
// set_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
// wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
set_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
Comment thread
chenshengxin2026 marked this conversation as resolved.
sij_pipe.prod.record();
}

Expand All @@ -206,15 +204,15 @@ static __aicore__ void aic_pv_step(
TileMatB_PV &bMatTile_PV_A, TileMatB_PV &bMatTile_PV_B, LeftTile_PV &aTile_PV, RightTile_PV &bTile_PV,
AccTile_PV &cTile_PV, PijPipeT &pij_pipe, OiPipeT &oi_pipe
) {
TPOP<PijPipeT, PijMatTile, TileSplitAxis::TILE_NO_SPLIT>(pij_pipe, pijMatTile);

GlobalB_PV vjGlobal(val_base + static_cast<uint64_t>(bt[bt_offset + i]) * N * K);
if (i % 2 == 0) {
TLOAD(bMatTile_PV_A, vjGlobal);
} else {
TLOAD(bMatTile_PV_B, vjGlobal);
}

TPOP<PijPipeT, PijMatTile, TileSplitAxis::TILE_NO_SPLIT>(pij_pipe, pijMatTile);

// PV step uses EVENT_ID1 (QK step uses EVENT_ID0) to avoid flag aliasing
// when pipe_barrier(PIPE_ALL) is removed between steps.
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
Expand All @@ -236,9 +234,9 @@ static __aicore__ void aic_pv_step(
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1);

// TPUSH oi (C2V): AccTile L0C -> GM. Same manual record pattern as sij.
pipe_barrier(PIPE_MTE3);
TPUSH<OiPipeT, AccTile_PV, TileSplitAxis::TILE_UP_DOWN>(oi_pipe, cTile_PV);
pipe_barrier(PIPE_MTE3);
set_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
Comment thread
chenshengxin2026 marked this conversation as resolved.
oi_pipe.prod.record();
}

Expand Down Expand Up @@ -315,23 +313,16 @@ static __aicore__ void aic_process_blocks(
key_base, bt, bt_offset, 0, aMatTile_QK, bMatTile_QK_A, bMatTile_QK_B, aTile_QK, bTile_QK, cTile_QK,
sij_pipe
);
pipe_barrier(PIPE_ALL);
// Steady state: QK[i] then PV[i-1] (QK-first order).
// QK[i]'s Cube compute overlaps with AIV's SF[i-1] Vector compute.
// By the time AIC finishes QK[i] and TPOP(pij[i-1]), SF[i-1] is done.
// EVENT_ID separation (QK=ID0, PV=ID1) prevents flag aliasing between
// overlapping QK and PV operations within the same iteration.
for (uint64_t i = 1; i < n_blocks; i++) {
aic_qk_step<M, K, N, SijPipeT, GlobalB_QK>(
key_base, bt, bt_offset, i, aMatTile_QK, bMatTile_QK_A, bMatTile_QK_B, aTile_QK, bTile_QK, cTile_QK,
sij_pipe
);
pipe_barrier(PIPE_ALL);
aic_pv_step<M, K, N, PijPipeT, OiPipeT, GlobalB_PV>(
val_base, bt, bt_offset, i - 1, pijMatTile, bMatTile_PV_A, bMatTile_PV_B, aTile_PV, bTile_PV, cTile_PV,
pij_pipe, oi_pipe
);
pipe_barrier(PIPE_ALL);
}

// Epilogue: PV[n-1] — consume last pij
Expand All @@ -340,9 +331,6 @@ static __aicore__ void aic_process_blocks(
cTile_PV, pij_pipe, oi_pipe
);
}

set_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
}

// ============================================================================
Expand All @@ -364,7 +352,7 @@ template <
typename Cfg, int TM, int TN, typename SijVecTile, typename TileSijPad, typename TileVecMxN,
typename PijVecBf16Tile, typename TileScalarDN, typename TileScalarRow>
static __aicore__ void aiv_sf_step(
uint64_t i, uint64_t n_blocks, uint64_t valid_len_last, float scale_value, SijVecTile &sijTile,
uint64_t i, bool is_last_partial, uint64_t valid_len_last, float scale_value, SijVecTile &sijTile,
TileSijPad &sijPadTile, TileVecMxN &pijTile, TileVecMxN &tmpTile, PijVecBf16Tile &pijBf16Tile,
TileScalarDN &localMaxDN, TileScalarDN &globalMaxDN, TileScalarDN &llDN, TileScalarRow &localMaxRow,
TileScalarRow &globalMaxRow, typename Cfg::SijPipeT &sij_pipe, typename Cfg::PijPipeT &pij_pipe
Expand All @@ -378,7 +366,7 @@ static __aicore__ void aiv_sf_step(
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);

if (i == n_blocks - 1 && valid_len_last < static_cast<uint64_t>(TN)) {
if (is_last_partial) {
int sij_addr = Cfg::SIJ_UB_BASE + static_cast<int>((i % 2) * TM * TN * static_cast<int>(sizeof(float)));
TASSIGN(sijPadTile, sij_addr);
TileSijDyn sijDynTile(static_cast<size_t>(valid_len_last));
Expand All @@ -398,7 +386,6 @@ static __aicore__ void aiv_sf_step(
pipe_barrier(PIPE_V);
TMAX(globalMaxRow, globalMaxRow, localMaxRow);
}
pipe_barrier(PIPE_V);
TRESHAPE(globalMaxDN, globalMaxRow);

TMULS(sijTile, sijTile, scale_value);
Expand All @@ -414,7 +401,6 @@ static __aicore__ void aiv_sf_step(
pipe_barrier(PIPE_V);

TROWSUM(llDN, pijTile, tmpTile);
pipe_barrier(PIPE_V);

set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
Expand Down Expand Up @@ -444,28 +430,23 @@ static __aicore__ void aiv_up_step(

if (i == 0) {
TMULS(goTile, oiNewTile, 1.0f);
pipe_barrier(PIPE_V);
TRESHAPE(llND, llDN_i);
pipe_barrier(PIPE_V);
TMULS(glND, llND, 1.0f);
} else {
TRESHAPE(llND, llDN_i);
pipe_barrier(PIPE_V);

TRESHAPE(mijND, curMaxRow);
TRESHAPE(dmND, prevMaxRow);
pipe_barrier(PIPE_V);

TSUB(alphaND, dmND, mijND);
pipe_barrier(PIPE_V);
TEXP(alphaND, alphaND);
pipe_barrier(PIPE_V);

TRESHAPE(alphaDN_dn, alphaND);
pipe_barrier(PIPE_V);
TROWEXPANDMUL(goTile, goTile, alphaDN_dn);
pipe_barrier(PIPE_V);
TADD(goTile, goTile, oiNewTile);
pipe_barrier(PIPE_V);

TMUL(glND, glND, alphaND);
pipe_barrier(PIPE_V);
Expand Down Expand Up @@ -561,27 +542,25 @@ static __aicore__ void aiv_process_blocks(

GlobalDataMxHD dstGlobal(dst_ptr);

bool last_partial = (valid_len_last < static_cast<uint64_t>(TN));

if (n_blocks == 1) {
aiv_sf_step<Cfg, TM, TN>(
0, n_blocks, valid_len_last, scale_value, sijTile, sijPadTile, pijTile, tmpTile, pijBf16Tile, localMaxDN,
globalMaxDN, llDN, localMaxRow, globalMaxRow, sij_pipe, pij_pipe
0, last_partial, valid_len_last, scale_value, sijTile, sijPadTile, pijTile, tmpTile, pijBf16Tile,
localMaxDN, globalMaxDN, llDN, localMaxRow, globalMaxRow, sij_pipe, pij_pipe
);
aiv_up_step<Cfg, TM, TN>(
0, oiNewTile, goTile, alphaDN_dn, llDN, glND, alphaND, llND, dmND, mijND, globalMaxRow, globalMaxRow,
oi_pipe
);
} else {
// Prologue: SF[0]
// Prologue: SF[0] — not the last block
aiv_sf_step<Cfg, TM, TN>(
0, n_blocks, valid_len_last, scale_value, sijTile, sijPadTile, pijTile, tmpTile, pijBf16Tile, localMaxDN,
0, false, valid_len_last, scale_value, sijTile, sijPadTile, pijTile, tmpTile, pijBf16Tile, localMaxDN,
globalMaxDN, llDN, localMaxRow, globalMaxRow, sij_pipe, pij_pipe
);

// Steady state: SF[i] then UP[i-1] (SF-first order).
// SF[i] overlaps with AIC's PV[i-1] matmul when AIC uses QK-first order.
// Before SF[i], save globalMaxRow and llDN for UP[i-1]'s use since SF
// overwrites both. Two generations of max are maintained (savedMaxRow
// and prevMaxRow) for the alpha = exp(prevMax - curMax) correction.
for (uint64_t i = 1; i < n_blocks; i++) {
// Shift max history: prevMaxRow ← savedMaxRow (M[i-2])
// Save current: savedMaxRow ← globalMaxRow (M[i-1])
Expand All @@ -590,8 +569,9 @@ static __aicore__ void aiv_process_blocks(
TMULS(savedLlDN, llDN, 1.0f);
pipe_barrier(PIPE_V);

bool cur_last_partial = (i == n_blocks - 1) && last_partial;
aiv_sf_step<Cfg, TM, TN>(
i, n_blocks, valid_len_last, scale_value, sijTile, sijPadTile, pijTile, tmpTile, pijBf16Tile,
i, cur_last_partial, valid_len_last, scale_value, sijTile, sijPadTile, pijTile, tmpTile, pijBf16Tile,
localMaxDN, globalMaxDN, llDN, localMaxRow, globalMaxRow, sij_pipe, pij_pipe
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TestPagedAttentionUnrollTpushPop(SceneTestCase):

CALLABLE = {
"orchestration": {
"source": "kernels/orchestration/spmd_paged_attention_tpush_orch.cpp",
"source": "kernels/orchestration/spmd_paged_attention_orch.cpp",
"function_name": "aicpu_orchestration_entry",
"signature": [D.IN, D.IN, D.IN, D.IN, D.IN, D.OUT],
},
Expand Down
4 changes: 2 additions & 2 deletions tools/benchmark_rounds.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ declare -A TMR_EXAMPLE_CASES=(
[benchmark_bgemm]="Case0"
[paged_attention_unroll]="Case1,Case2"
[batch_paged_attention]="Case1"
[spmd_paged_attention_tpush]="Case1,Case2"
[spmd_paged_attention]="Case1,Case2"
)
TMR_EXAMPLE_ORDER=(
alternating_matmul_add
benchmark_bgemm
paged_attention_unroll
batch_paged_attention
spmd_paged_attention_tpush
spmd_paged_attention
)

# --- aicpu_build_graph ---
Expand Down
Loading