Skip to content

Commit 2e598d6

Browse files
doraemonmjmajin0824
andauthored
Refactor: migrate legacy golden.py/kernel_config.py tests to @scene_test format (#548)
- Move 10 synthetic examples to tests/st/ as SceneTestCase tests: spmd_basic, spmd_multiblock_aiv, spmd_multiblock_mix, spmd_starvation, spmd_sync_start, spmd_sync_start_aiv, spmd_sync_start_edge, spmd_sync_start_stress, mixed_example, multi_round_paged_attention - Merge batch_paged_attention (examples + tests/st) into single example with unified runtime-dispatch kernels (typename T template parameter) - Merge paged_attention into examples/ with sim-compatible cases - Convert bgemm, benchmark_bgemm, paged_attention_unroll to @scene_test - Remove all golden.py and kernel_config.py legacy files - Delete original example dirs after migration to tests/st/ Fix: add runtime common/ to incore kernel include path The SPMD kernel sources #include "intrinsic.h" which lives in src/{arch}/runtime/{runtime}/common/. This directory was implicitly available via the legacy run_example.py pipeline but missing from the SceneTestCase kernel compiler include dirs. Co-authored-by: majin0824 <majin15@huawei.com>
1 parent c599350 commit 2e598d6

87 files changed

Lines changed: 1784 additions & 4381 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/golden.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_pv_matmul.cpp

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
// Processes batch_count batches in a single kernel invocation.
1515
// Per-batch addresses are computed from global tensor bases + block_table lookup.
1616
//
17-
// Template: M=q_tile, K=block_size, N=head_dim
17+
// Supports three tile configurations via runtime dispatch:
18+
// Small: (16, 16) @ ( 16, 16) -> (16, 16) [fp16]
19+
// Case1: (16, 128) @ (128, 128) -> (16, 128) [bf16]
20+
// Case2: (64, 64) @ ( 64, 128) -> (64, 128) [bf16]
21+
//
22+
// Template: T=data_type, M=q_tile, K=block_size, N=head_dim
1823

1924
#include <cstdint>
20-
// NOLINTBEGIN(clang-diagnostic-error,bugprone-reserved-identifier,bugprone-easily-swappable-parameters,modernize-use-auto)
2125
#include <pto/pto-inst.hpp>
2226

2327
#include "tensor.h"
@@ -33,25 +37,25 @@ using namespace pto;
3337
#define __aicore__ [aicore] // NOLINT(whitespace/braces)
3438
#endif
3539

36-
template <int M, int K, int N>
40+
template <typename T, int M, int K, int N>
3741
static __aicore__ void pv_matmul_batch_impl(
3842
__gm__ Tensor *pij_batch, __gm__ Tensor *value_cache, __gm__ Tensor *block_table_t, __gm__ Tensor *oi_new_batch,
3943
uint64_t batch_count, uint64_t block_idx, uint64_t block_num, uint64_t batch_start
4044
) {
41-
__gm__ half *pij_base = reinterpret_cast<__gm__ half *>(pij_batch->buffer.addr);
42-
__gm__ half *val_base = reinterpret_cast<__gm__ half *>(value_cache->buffer.addr);
45+
__gm__ T *pij_base = reinterpret_cast<__gm__ T *>(pij_batch->buffer.addr);
46+
__gm__ T *val_base = reinterpret_cast<__gm__ T *>(value_cache->buffer.addr);
4347
__gm__ float *oi_base = reinterpret_cast<__gm__ float *>(oi_new_batch->buffer.addr);
4448
__gm__ int32_t *bt = reinterpret_cast<__gm__ int32_t *>(block_table_t->buffer.addr);
4549

46-
using GlobalA = GlobalTensor<half, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
47-
using GlobalB = GlobalTensor<half, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, N, 1>>;
50+
using GlobalA = GlobalTensor<T, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
51+
using GlobalB = GlobalTensor<T, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, N, 1>>;
4852
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;
4953

50-
using TileMatA = Tile<TileType::Mat, half, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
51-
using TileMatB = Tile<TileType::Mat, half, K, N, BLayout::ColMajor, K, N, SLayout::RowMajor, 512>;
54+
using TileMatA = Tile<TileType::Mat, T, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
55+
using TileMatB = Tile<TileType::Mat, T, K, N, BLayout::ColMajor, K, N, SLayout::RowMajor, 512>;
5256

53-
using LeftTile = TileLeft<half, M, K, M, K>;
54-
using RightTile = TileRight<half, K, N, K, N>;
57+
using LeftTile = TileLeft<T, M, K, M, K>;
58+
using RightTile = TileRight<T, K, N, K, N>;
5559
using AccTile = TileAcc<float, M, N, M, N>;
5660

5761
TileMatA aMatTile;
@@ -67,9 +71,9 @@ static __aicore__ void pv_matmul_batch_impl(
6771
TASSIGN(cTile, 0x0);
6872

6973
for (uint64_t b = 0; b < batch_count; b++) {
70-
__gm__ half *pij_addr = pij_base + b * M * K;
74+
__gm__ T *pij_addr = pij_base + b * M * K;
7175
int32_t phys_block = bt[(batch_start + b) * block_num + block_idx];
72-
__gm__ half *vj_addr = val_base + static_cast<uint64_t>(phys_block) * K * N;
76+
__gm__ T *vj_addr = val_base + static_cast<uint64_t>(phys_block) * K * N;
7377
__gm__ float *oi_addr = oi_base + b * M * N;
7478

7579
GlobalA pijGlobal(pij_addr);
@@ -99,6 +103,9 @@ static __aicore__ void pv_matmul_batch_impl(
99103
pipe_barrier(PIPE_ALL);
100104
}
101105
}
106+
107+
set_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
108+
wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
102109
}
103110

104111
extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {
@@ -111,8 +118,20 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {
111118
uint64_t block_num = static_cast<uint64_t>(args[6]);
112119
uint64_t batch_start = static_cast<uint64_t>(args[7]);
113120

114-
pv_matmul_batch_impl<16, 16, 16>(
115-
pij_batch, value_cache, block_table_t, oi_new_batch, batch_count, block_idx, block_num, batch_start
116-
);
121+
uint64_t q_tile_size = static_cast<uint64_t>(pij_batch->shapes[0] / batch_count);
122+
uint64_t block_size = static_cast<uint64_t>(pij_batch->shapes[1]);
123+
124+
if (q_tile_size == 16 && block_size == 16) {
125+
pv_matmul_batch_impl<half, 16, 16, 16>(
126+
pij_batch, value_cache, block_table_t, oi_new_batch, batch_count, block_idx, block_num, batch_start
127+
);
128+
} else if (q_tile_size == 16) {
129+
pv_matmul_batch_impl<bfloat16_t, 16, 128, 128>(
130+
pij_batch, value_cache, block_table_t, oi_new_batch, batch_count, block_idx, block_num, batch_start
131+
);
132+
} else {
133+
pv_matmul_batch_impl<bfloat16_t, 64, 64, 128>(
134+
pij_batch, value_cache, block_table_t, oi_new_batch, batch_count, block_idx, block_num, batch_start
135+
);
136+
}
117137
}
118-
// NOLINTEND(clang-diagnostic-error,bugprone-reserved-identifier,bugprone-easily-swappable-parameters,modernize-use-auto)

examples/a2a3/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_qk_matmul.cpp

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
// Processes batch_count batches in a single kernel invocation.
1515
// Per-batch addresses are computed from global tensor bases + block_table lookup.
1616
//
17-
// Template: M=q_tile, K=head_dim, N=block_size
17+
// Supports three tile configurations via runtime dispatch:
18+
// Small: (16, 16) @ ( 16, 16).T -> (16, 16) [fp16]
19+
// Case1: (16, 128) @ (128, 128).T -> (16, 128) [bf16]
20+
// Case2: (64, 128) @ (128, 64).T -> (64, 64) [bf16]
21+
//
22+
// Template: T=data_type, M=q_tile, K=head_dim, N=block_size
1823

1924
#include <cstdint>
2025
#include <pto/pto-inst.hpp>
@@ -32,26 +37,26 @@ using namespace pto;
3237
#define __aicore__ [aicore] // NOLINT(whitespace/braces)
3338
#endif
3439

35-
template <int M, int K, int N>
40+
template <typename T, int M, int K, int N>
3641
static __aicore__ void qk_matmul_batch_impl(
3742
__gm__ Tensor *query, __gm__ Tensor *key_cache, __gm__ Tensor *block_table_t, __gm__ Tensor *sij_batch,
3843
uint64_t batch_count, uint64_t block_idx, uint64_t q_offset, uint64_t block_num, uint64_t num_heads,
3944
uint64_t batch_start
4045
) {
41-
__gm__ half *query_base = reinterpret_cast<__gm__ half *>(query->buffer.addr);
42-
__gm__ half *key_base = reinterpret_cast<__gm__ half *>(key_cache->buffer.addr);
46+
__gm__ T *query_base = reinterpret_cast<__gm__ T *>(query->buffer.addr);
47+
__gm__ T *key_base = reinterpret_cast<__gm__ T *>(key_cache->buffer.addr);
4348
__gm__ float *sij_base = reinterpret_cast<__gm__ float *>(sij_batch->buffer.addr);
4449
__gm__ int32_t *bt = reinterpret_cast<__gm__ int32_t *>(block_table_t->buffer.addr);
4550

46-
using GlobalA = GlobalTensor<half, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
47-
using GlobalB = GlobalTensor<half, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, 1, K>, Layout::DN>;
51+
using GlobalA = GlobalTensor<T, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
52+
using GlobalB = GlobalTensor<T, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, 1, K>, Layout::DN>;
4853
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;
4954

50-
using TileMatA = Tile<TileType::Mat, half, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
51-
using TileMatB = Tile<TileType::Mat, half, K, N, BLayout::RowMajor, K, N, SLayout::ColMajor, 512>;
55+
using TileMatA = Tile<TileType::Mat, T, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
56+
using TileMatB = Tile<TileType::Mat, T, K, N, BLayout::RowMajor, K, N, SLayout::ColMajor, 512>;
5257

53-
using LeftTile = TileLeft<half, M, K, M, K>;
54-
using RightTile = TileRight<half, K, N, K, N>;
58+
using LeftTile = TileLeft<T, M, K, M, K>;
59+
using RightTile = TileRight<T, K, N, K, N>;
5560
using AccTile = TileAcc<float, M, N, M, N>;
5661

5762
TileMatA aMatTile;
@@ -67,22 +72,23 @@ static __aicore__ void qk_matmul_batch_impl(
6772
TASSIGN(cTile, 0x0);
6873

6974
for (uint64_t b = 0; b < batch_count; b++) {
70-
__gm__ half *qi_addr = query_base + ((batch_start + b) * num_heads + q_offset) * K;
75+
__gm__ T *qi_addr = query_base + ((batch_start + b) * num_heads + q_offset) * K;
7176
int32_t phys_block = bt[(batch_start + b) * block_num + block_idx];
72-
__gm__ half *kj_addr = key_base + static_cast<uint64_t>(phys_block) * N * K;
77+
__gm__ T *kj_addr = key_base + static_cast<uint64_t>(phys_block) * N * K;
7378
__gm__ float *sij_addr = sij_base + b * M * N;
7479

7580
GlobalA qiGlobal(qi_addr);
7681
GlobalB kjGlobal(kj_addr);
7782
GlobalOut sijGlobal(sij_addr);
7883

7984
TLOAD(aMatTile, qiGlobal);
85+
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
8086
TLOAD(bMatTile, kjGlobal);
87+
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
8188

82-
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
8389
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
84-
8590
TMOV(aTile, aMatTile);
91+
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
8692
TMOV(bTile, bMatTile);
8793

8894
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
@@ -99,6 +105,9 @@ static __aicore__ void qk_matmul_batch_impl(
99105
pipe_barrier(PIPE_ALL);
100106
}
101107
}
108+
109+
set_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
110+
wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
102111
}
103112

104113
extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {
@@ -113,7 +122,23 @@ extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {
113122
uint64_t num_heads = static_cast<uint64_t>(args[8]);
114123
uint64_t batch_start = static_cast<uint64_t>(args[9]);
115124

116-
qk_matmul_batch_impl<16, 16, 16>(
117-
query, key_cache, block_table_t, sij_batch, batch_count, block_idx, q_offset, block_num, num_heads, batch_start
118-
);
125+
uint64_t q_tile_size = static_cast<uint64_t>(sij_batch->shapes[0] / batch_count);
126+
uint64_t block_size = static_cast<uint64_t>(sij_batch->shapes[1]);
127+
128+
if (q_tile_size == 16 && block_size == 16) {
129+
qk_matmul_batch_impl<half, 16, 16, 16>(
130+
query, key_cache, block_table_t, sij_batch, batch_count, block_idx, q_offset, block_num, num_heads,
131+
batch_start
132+
);
133+
} else if (q_tile_size == 16) {
134+
qk_matmul_batch_impl<bfloat16_t, 16, 128, 128>(
135+
query, key_cache, block_table_t, sij_batch, batch_count, block_idx, q_offset, block_num, num_heads,
136+
batch_start
137+
);
138+
} else {
139+
qk_matmul_batch_impl<bfloat16_t, 64, 128, 64>(
140+
query, key_cache, block_table_t, sij_batch, batch_count, block_idx, q_offset, block_num, num_heads,
141+
batch_start
142+
);
143+
}
119144
}

0 commit comments

Comments
 (0)