Skip to content

Commit 44d638e

Browse files
Update: convert all paged_attention examples from float16 to bfloat16
Convert all four paged_attention example variants to use bfloat16: - a2a3/tensormap_and_ringbuffer/paged_attention - a5/tensormap_and_ringbuffer/paged_attention - a2a3/host_build_graph/paged_attention - a5/host_build_graph/paged_attention Changes across all variants: - Change half/float16 to bfloat16_t/bfloat16 in kernel files - Update golden.py dtype from "float16" to "bfloat16" - Rename f16/fp16 variables and comments to bf16 - Align pto::Stride to Stride (matching reference style)
1 parent 87f961c commit 44d638e

16 files changed

Lines changed: 128 additions & 128 deletions

File tree

examples/a2a3/host_build_graph/paged_attention/golden.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
77
# See LICENSE in the root of the software repository for the full text of the License.
88
# -----------------------------------------------------------------------------------------------------------
9-
"""Paged Attention Golden - host_build_graph example (small scale, float16).
9+
"""Paged Attention Golden - host_build_graph example (small scale, bfloat16).
1010
1111
Args layout: [query, key_cache, value_cache, block_table, context_lens, out, scale]
1212
- Tensors retain original multi-dimensional shapes (ContinuousTensor metadata carries shape/dtype)
@@ -33,7 +33,7 @@
3333
"block_size": 16,
3434
"context_len": 16,
3535
"max_model_len": 256,
36-
"dtype": "float16",
36+
"dtype": "bfloat16",
3737
},
3838
"Case2": {
3939
"batch": 1,
@@ -43,7 +43,7 @@
4343
"block_size": 16,
4444
"context_len": 64,
4545
"max_model_len": 256,
46-
"dtype": "float16",
46+
"dtype": "bfloat16",
4747
},
4848
}
4949

examples/a2a3/host_build_graph/paged_attention/kernels/aic/aic_pv_matmul.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
//
1313
// Fixed tile size: (16, 16) @ (16, 16) -> (16, 16)
1414
//
15-
// pij is float16 (converted from fp32 in softmax_prepare via TCVT).
15+
// pij is bfloat16 (converted from fp32 in softmax_prepare via TCVT).
1616
// vj is stored as (K, N) = (block_size, head_dim) in row-major (ND) layout.
1717
// Standard non-transposed B pattern: ND GlobalB + ColMajor/RowMajor TileMatB.
1818

@@ -32,26 +32,26 @@ using namespace pto;
3232
static __aicore__ void pv_matmul_impl(__gm__ uint8_t *pij_raw, __gm__ uint8_t *vj_raw, __gm__ uint8_t *oi_raw) {
3333
constexpr int M = 16, K = 16, N = 16;
3434

35-
__gm__ half *pij = reinterpret_cast<__gm__ half *>(pij_raw);
36-
__gm__ half *vj = reinterpret_cast<__gm__ half *>(vj_raw);
35+
__gm__ bfloat16_t *pij = reinterpret_cast<__gm__ bfloat16_t *>(pij_raw);
36+
__gm__ bfloat16_t *vj = reinterpret_cast<__gm__ bfloat16_t *>(vj_raw);
3737
__gm__ float *oi = reinterpret_cast<__gm__ float *>(oi_raw);
3838

39-
// pij (M, K) fp16, vj (K, N) fp16 in ND (row-major), oi_new (M, N) fp32
40-
using GlobalA = GlobalTensor<half, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
41-
using GlobalB = GlobalTensor<half, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, N, 1>>;
39+
// pij (M, K) bf16, vj (K, N) bf16 in ND (row-major), oi_new (M, N) fp32
40+
using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
41+
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, N, 1>>;
4242
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;
4343

4444
GlobalA pijGlobal(pij);
4545
GlobalB vjGlobal(vj);
4646
GlobalOut oiGlobal(oi);
4747

4848
// L1 Mat tiles: standard ND pattern for both A and B
49-
using TileMatA = Tile<TileType::Mat, half, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
50-
using TileMatB = Tile<TileType::Mat, half, K, N, BLayout::ColMajor, K, N, SLayout::RowMajor, 512>;
49+
using TileMatA = Tile<TileType::Mat, bfloat16_t, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
50+
using TileMatB = Tile<TileType::Mat, bfloat16_t, K, N, BLayout::ColMajor, K, N, SLayout::RowMajor, 512>;
5151

5252
// L0 tiles
53-
using LeftTile = TileLeft<half, M, K, M, K>;
54-
using RightTile = TileRight<half, K, N, K, N>;
53+
using LeftTile = TileLeft<bfloat16_t, M, K, M, K>;
54+
using RightTile = TileRight<bfloat16_t, K, N, K, N>;
5555
using AccTile = TileAcc<float, M, N, M, N>;
5656

5757
TileMatA aMatTile;

examples/a2a3/host_build_graph/paged_attention/kernels/aic/aic_qk_matmul.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,27 @@ using namespace pto;
3232
static __aicore__ void qk_matmul_impl(__gm__ uint8_t *qi_raw, __gm__ uint8_t *kj_raw, __gm__ uint8_t *sij_raw) {
3333
constexpr int M = 16, K = 16, N = 16;
3434

35-
__gm__ half *qi = reinterpret_cast<__gm__ half *>(qi_raw);
36-
__gm__ half *kj = reinterpret_cast<__gm__ half *>(kj_raw);
35+
__gm__ bfloat16_t *qi = reinterpret_cast<__gm__ bfloat16_t *>(qi_raw);
36+
__gm__ bfloat16_t *kj = reinterpret_cast<__gm__ bfloat16_t *>(kj_raw);
3737
__gm__ float *sij = reinterpret_cast<__gm__ float *>(sij_raw);
3838

39-
// qi (M, K) fp16 in ND (row-major) layout
40-
using GlobalA = GlobalTensor<half, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
39+
// qi (M, K) bf16 in ND (row-major) layout
40+
using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
4141
// kj stored as (N, K) row-major = (K, N) column-major -> DN layout
42-
using GlobalB = GlobalTensor<half, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, 1, K>, Layout::DN>;
42+
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, 1, K>, Layout::DN>;
4343
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;
4444

4545
GlobalA qiGlobal(qi);
4646
GlobalB kjGlobal(kj);
4747
GlobalOut sijGlobal(sij);
4848

4949
// L1 Mat tiles: A is standard ND, B uses transposed-B pattern (RowMajor/ColMajor)
50-
using TileMatA = Tile<TileType::Mat, half, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
51-
using TileMatB = Tile<TileType::Mat, half, K, N, BLayout::RowMajor, K, N, SLayout::ColMajor, 512>;
50+
using TileMatA = Tile<TileType::Mat, bfloat16_t, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
51+
using TileMatB = Tile<TileType::Mat, bfloat16_t, K, N, BLayout::RowMajor, K, N, SLayout::ColMajor, 512>;
5252

5353
// L0 tiles
54-
using LeftTile = TileLeft<half, M, K, M, K>;
55-
using RightTile = TileRight<half, K, N, K, N>;
54+
using LeftTile = TileLeft<bfloat16_t, M, K, M, K>;
55+
using RightTile = TileRight<bfloat16_t, K, N, K, N>;
5656
using AccTile = TileAcc<float, M, N, M, N>;
5757

5858
TileMatA aMatTile;

examples/a2a3/host_build_graph/paged_attention/kernels/aiv/aiv_softmax_prepare.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,38 +38,38 @@ static __aicore__ void softmax_prepare_impl(
3838
constexpr int M = 16, N = 16;
3939

4040
__gm__ float *sij = reinterpret_cast<__gm__ float *>(sij_raw);
41-
__gm__ half *pij = reinterpret_cast<__gm__ half *>(pij_raw);
41+
__gm__ bfloat16_t *pij = reinterpret_cast<__gm__ bfloat16_t *>(pij_raw);
4242
__gm__ float *mij = reinterpret_cast<__gm__ float *>(mij_raw);
4343
__gm__ float *lij = reinterpret_cast<__gm__ float *>(lij_raw);
4444

4545
constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float));
4646

4747
using GlobalDataMxN = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<1, 1, 1, N, 1>>;
48-
using GlobalDataMxN_f16 = GlobalTensor<half, Shape<1, 1, 1, M, N>, Stride<1, 1, 1, N, 1>>;
48+
using GlobalDataMxN_bf16 = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, N>, Stride<1, 1, 1, N, 1>>;
4949
using GlobalScalarDN = GlobalTensor<float, Shape<1, 1, 1, kAlignedRows, 1>, Stride<1, 1, 1, 1, 1>, Layout::DN>;
5050

5151
GlobalDataMxN sijGlobal(sij);
52-
GlobalDataMxN_f16 pijGlobal(pij);
52+
GlobalDataMxN_bf16 pijGlobal(pij);
5353
GlobalScalarDN mijGlobal(mij);
5454
GlobalScalarDN lijGlobal(lij);
5555

5656
using TileVecMxN = Tile<TileType::Vec, float, M, N, BLayout::RowMajor, M, N>;
57-
using TileVecMxN_f16 = Tile<TileType::Vec, half, M, N, BLayout::RowMajor, M, N>;
57+
using TileVecMxN_bf16 = Tile<TileType::Vec, bfloat16_t, M, N, BLayout::RowMajor, M, N>;
5858
using TileScalarDN = Tile<TileType::Vec, float, kAlignedRows, 1, BLayout::ColMajor, M, 1>;
5959

6060
TileVecMxN sijTile;
6161
TileVecMxN pijTile;
6262
TileVecMxN tmpTile;
6363
TileScalarDN maxTile;
6464
TileScalarDN sumTile;
65-
TileVecMxN_f16 pijF16Tile;
65+
TileVecMxN_bf16 pijBf16Tile;
6666

6767
TASSIGN(sijTile, 0x0);
6868
TASSIGN(pijTile, M * N * sizeof(float));
6969
TASSIGN(tmpTile, 2 * M * N * sizeof(float));
7070
TASSIGN(maxTile, 3 * M * N * sizeof(float));
7171
TASSIGN(sumTile, 3 * M * N * sizeof(float) + kAlignedRows * sizeof(float));
72-
TASSIGN(pijF16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float));
72+
TASSIGN(pijBf16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float));
7373

7474
TLOAD(sijTile, sijGlobal);
7575
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
@@ -79,16 +79,16 @@ static __aicore__ void softmax_prepare_impl(
7979
TROWMAX(maxTile, sijTile, tmpTile);
8080
TROWEXPANDSUB(pijTile, sijTile, maxTile);
8181
TEXP(pijTile, pijTile);
82-
// Truncate pij to fp16 first, then compute lij from truncated values (matches golden)
83-
TCVT(pijF16Tile, pijTile, RoundMode::CAST_ROUND);
84-
TCVT(pijTile, pijF16Tile, RoundMode::CAST_ROUND);
82+
// Truncate pij to bf16 first, then compute lij from truncated values (matches golden)
83+
TCVT(pijBf16Tile, pijTile, RoundMode::CAST_ROUND);
84+
TCVT(pijTile, pijBf16Tile, RoundMode::CAST_ROUND);
8585
TROWSUM(sumTile, pijTile, tmpTile);
8686

8787
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
8888
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
8989
TSTORE(mijGlobal, maxTile);
9090
TSTORE(lijGlobal, sumTile);
91-
TSTORE(pijGlobal, pijF16Tile);
91+
TSTORE(pijGlobal, pijBf16Tile);
9292
}
9393

9494
extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {

examples/a2a3/host_build_graph/paged_attention/kernels/orchestration/paged_attention_orch.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
* Paged Attention Orchestration - Small Scale (16x16)
1313
*
1414
* Supports small-scale paged attention with:
15-
* Query: (batch, q_head_num, head_dim) fp16
16-
* Key: (total_blocks, block_size, kv_head_num, head_dim) fp16 (NOT transposed)
17-
* Value: (total_blocks, block_size, kv_head_num, head_dim) fp16
15+
* Query: (batch, q_head_num, head_dim) bf16
16+
* Key: (total_blocks, block_size, kv_head_num, head_dim) bf16 (NOT transposed)
17+
* Value: (total_blocks, block_size, kv_head_num, head_dim) bf16
1818
* Output: (batch, q_head_num, head_dim) float32
1919
*
2020
* Head tiling: q_tile_size = min(num_heads, 128)
@@ -148,7 +148,7 @@ int build_paged_attention_graph(OrchestrationRuntime *runtime, const ChipStorage
148148
for (uint32_t ht = 0; ht < num_head_tiles; ht++) {
149149
uint32_t cur_offset = ht * q_tile_size;
150150

151-
// Query: (batch, q_head_num, head_dim) fp16
151+
// Query: (batch, q_head_num, head_dim) bf16
152152
// qi points to heads [cur_offset .. cur_offset+q_tile_size) for batch b_idx
153153
uint8_t *qi_ptr = reinterpret_cast<uint8_t *>(dev_query) +
154154
static_cast<int64_t>(b_idx * num_heads + cur_offset) * head_dim * sizeof(uint16_t);
@@ -171,12 +171,12 @@ int build_paged_attention_graph(OrchestrationRuntime *runtime, const ChipStorage
171171
for (uint32_t bn = 0; bn < bn_this_batch; bn++) {
172172
int cur_block_idx = host_block_table[b_idx * max_num_blocks + bn];
173173

174-
// Key: (total_blocks, block_size, kv_head_num, head_dim) fp16
174+
// Key: (total_blocks, block_size, kv_head_num, head_dim) bf16
175175
uint8_t *kj_ptr = reinterpret_cast<uint8_t *>(dev_key_cache) +
176176
(static_cast<int64_t>(cur_block_idx) * block_size * kv_head_num + kv_head_idx) *
177177
head_dim * sizeof(uint16_t);
178178

179-
// Value: (total_blocks, block_size, kv_head_num, head_dim) fp16
179+
// Value: (total_blocks, block_size, kv_head_num, head_dim) bf16
180180
uint8_t *vj_ptr = reinterpret_cast<uint8_t *>(dev_value_cache) +
181181
(static_cast<int64_t>(cur_block_idx) * block_size * kv_head_num + kv_head_idx) *
182182
head_dim * sizeof(uint16_t);

examples/a5/host_build_graph/paged_attention/golden.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
77
# See LICENSE in the root of the software repository for the full text of the License.
88
# -----------------------------------------------------------------------------------------------------------
9-
"""Paged Attention Golden - host_build_graph example (small scale, float16).
9+
"""Paged Attention Golden - host_build_graph example (small scale, bfloat16).
1010
1111
Args layout: [query, key_cache, value_cache, block_table, context_lens, out, scale]
1212
- Tensors retain original multi-dimensional shapes (ContinuousTensor metadata carries shape/dtype)
@@ -33,7 +33,7 @@
3333
"block_size": 16,
3434
"context_len": 16,
3535
"max_model_len": 256,
36-
"dtype": "float16",
36+
"dtype": "bfloat16",
3737
},
3838
"Case2": {
3939
"batch": 1,
@@ -43,7 +43,7 @@
4343
"block_size": 16,
4444
"context_len": 64,
4545
"max_model_len": 256,
46-
"dtype": "float16",
46+
"dtype": "bfloat16",
4747
},
4848
}
4949

examples/a5/host_build_graph/paged_attention/kernels/aic/aic_pv_matmul.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
//
1313
// Fixed tile size: (16, 16) @ (16, 16) -> (16, 16)
1414
//
15-
// pij is float16 (converted from fp32 in softmax_prepare via TCVT).
15+
// pij is bfloat16 (converted from fp32 in softmax_prepare via TCVT).
1616
// vj is stored as (K, N) = (block_size, head_dim) in row-major (ND) layout.
1717
// Standard non-transposed B pattern: ND GlobalB + ColMajor/RowMajor TileMatB.
1818

@@ -32,26 +32,26 @@ using namespace pto;
3232
static __aicore__ void pv_matmul_impl(__gm__ uint8_t *pij_raw, __gm__ uint8_t *vj_raw, __gm__ uint8_t *oi_raw) {
3333
constexpr int M = 16, K = 16, N = 16;
3434

35-
__gm__ half *pij = reinterpret_cast<__gm__ half *>(pij_raw);
36-
__gm__ half *vj = reinterpret_cast<__gm__ half *>(vj_raw);
35+
__gm__ bfloat16_t *pij = reinterpret_cast<__gm__ bfloat16_t *>(pij_raw);
36+
__gm__ bfloat16_t *vj = reinterpret_cast<__gm__ bfloat16_t *>(vj_raw);
3737
__gm__ float *oi = reinterpret_cast<__gm__ float *>(oi_raw);
3838

39-
// pij (M, K) fp16, vj (K, N) fp16 in ND (row-major), oi_new (M, N) fp32
40-
using GlobalA = GlobalTensor<half, Shape<1, 1, 1, M, K>, pto::Stride<M * K, M * K, M * K, K, 1>>;
41-
using GlobalB = GlobalTensor<half, Shape<1, 1, 1, K, N>, pto::Stride<K * N, K * N, K * N, N, 1>>;
42-
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, pto::Stride<M * N, M * N, M * N, N, 1>>;
39+
// pij (M, K) bf16, vj (K, N) bf16 in ND (row-major), oi_new (M, N) fp32
40+
using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
41+
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, N, 1>>;
42+
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;
4343

4444
GlobalA pijGlobal(pij);
4545
GlobalB vjGlobal(vj);
4646
GlobalOut oiGlobal(oi);
4747

4848
// L1 Mat tiles: standard ND pattern for both A and B
49-
using TileMatA = Tile<TileType::Mat, half, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
50-
using TileMatB = Tile<TileType::Mat, half, K, N, BLayout::ColMajor, K, N, SLayout::RowMajor, 512>;
49+
using TileMatA = Tile<TileType::Mat, bfloat16_t, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
50+
using TileMatB = Tile<TileType::Mat, bfloat16_t, K, N, BLayout::ColMajor, K, N, SLayout::RowMajor, 512>;
5151

5252
// L0 tiles
53-
using LeftTile = TileLeft<half, M, K, M, K>;
54-
using RightTile = TileRight<half, K, N, K, N>;
53+
using LeftTile = TileLeft<bfloat16_t, M, K, M, K>;
54+
using RightTile = TileRight<bfloat16_t, K, N, K, N>;
5555
using AccTile = TileAcc<float, M, N, M, N>;
5656

5757
TileMatA aMatTile;

examples/a5/host_build_graph/paged_attention/kernels/aic/aic_qk_matmul.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,27 @@ using namespace pto;
3232
static __aicore__ void qk_matmul_impl(__gm__ uint8_t *qi_raw, __gm__ uint8_t *kj_raw, __gm__ uint8_t *sij_raw) {
3333
constexpr int M = 16, K = 16, N = 16;
3434

35-
__gm__ half *qi = reinterpret_cast<__gm__ half *>(qi_raw);
36-
__gm__ half *kj = reinterpret_cast<__gm__ half *>(kj_raw);
35+
__gm__ bfloat16_t *qi = reinterpret_cast<__gm__ bfloat16_t *>(qi_raw);
36+
__gm__ bfloat16_t *kj = reinterpret_cast<__gm__ bfloat16_t *>(kj_raw);
3737
__gm__ float *sij = reinterpret_cast<__gm__ float *>(sij_raw);
3838

39-
// qi (M, K) fp16 in ND (row-major) layout
40-
using GlobalA = GlobalTensor<half, Shape<1, 1, 1, M, K>, pto::Stride<M * K, M * K, M * K, K, 1>>;
39+
// qi (M, K) bf16 in ND (row-major) layout
40+
using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
4141
// kj stored as (N, K) row-major = (K, N) column-major -> DN layout
42-
using GlobalB = GlobalTensor<half, Shape<1, 1, 1, K, N>, pto::Stride<K * N, K * N, K * N, 1, K>, Layout::DN>;
43-
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, pto::Stride<M * N, M * N, M * N, N, 1>>;
42+
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, 1, K>, Layout::DN>;
43+
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;
4444

4545
GlobalA qiGlobal(qi);
4646
GlobalB kjGlobal(kj);
4747
GlobalOut sijGlobal(sij);
4848

4949
// L1 Mat tiles: A is standard ND, B uses transposed-B pattern (RowMajor/ColMajor)
50-
using TileMatA = Tile<TileType::Mat, half, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
51-
using TileMatB = Tile<TileType::Mat, half, K, N, BLayout::RowMajor, K, N, SLayout::ColMajor, 512>;
50+
using TileMatA = Tile<TileType::Mat, bfloat16_t, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
51+
using TileMatB = Tile<TileType::Mat, bfloat16_t, K, N, BLayout::RowMajor, K, N, SLayout::ColMajor, 512>;
5252

5353
// L0 tiles
54-
using LeftTile = TileLeft<half, M, K, M, K>;
55-
using RightTile = TileRight<half, K, N, K, N>;
54+
using LeftTile = TileLeft<bfloat16_t, M, K, M, K>;
55+
using RightTile = TileRight<bfloat16_t, K, N, K, N>;
5656
using AccTile = TileAcc<float, M, N, M, N>;
5757

5858
TileMatA aMatTile;

examples/a5/host_build_graph/paged_attention/kernels/aiv/aiv_softmax_prepare.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,38 +38,38 @@ static __aicore__ void softmax_prepare_impl(
3838
constexpr int M = 16, N = 16;
3939

4040
__gm__ float *sij = reinterpret_cast<__gm__ float *>(sij_raw);
41-
__gm__ half *pij = reinterpret_cast<__gm__ half *>(pij_raw);
41+
__gm__ bfloat16_t *pij = reinterpret_cast<__gm__ bfloat16_t *>(pij_raw);
4242
__gm__ float *mij = reinterpret_cast<__gm__ float *>(mij_raw);
4343
__gm__ float *lij = reinterpret_cast<__gm__ float *>(lij_raw);
4444

4545
constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float));
4646

4747
using GlobalDataMxN = GlobalTensor<float, Shape<1, 1, 1, M, N>, pto::Stride<1, 1, 1, N, 1>>;
48-
using GlobalDataMxN_f16 = GlobalTensor<half, Shape<1, 1, 1, M, N>, pto::Stride<1, 1, 1, N, 1>>;
48+
using GlobalDataMxN_bf16 = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, N>, Stride<1, 1, 1, N, 1>>;
4949
using GlobalScalarDN = GlobalTensor<float, Shape<1, 1, 1, kAlignedRows, 1>, pto::Stride<1, 1, 1, 1, 1>, Layout::DN>;
5050

5151
GlobalDataMxN sijGlobal(sij);
52-
GlobalDataMxN_f16 pijGlobal(pij);
52+
GlobalDataMxN_bf16 pijGlobal(pij);
5353
GlobalScalarDN mijGlobal(mij);
5454
GlobalScalarDN lijGlobal(lij);
5555

5656
using TileVecMxN = Tile<TileType::Vec, float, M, N, BLayout::RowMajor, M, N>;
57-
using TileVecMxN_f16 = Tile<TileType::Vec, half, M, N, BLayout::RowMajor, M, N>;
57+
using TileVecMxN_bf16 = Tile<TileType::Vec, bfloat16_t, M, N, BLayout::RowMajor, M, N>;
5858
using TileScalarDN = Tile<TileType::Vec, float, kAlignedRows, 1, BLayout::ColMajor, M, 1>;
5959

6060
TileVecMxN sijTile;
6161
TileVecMxN pijTile;
6262
TileVecMxN tmpTile;
6363
TileScalarDN maxTile;
6464
TileScalarDN sumTile;
65-
TileVecMxN_f16 pijF16Tile;
65+
TileVecMxN_bf16 pijBf16Tile;
6666

6767
TASSIGN(sijTile, 0x0);
6868
TASSIGN(pijTile, M * N * sizeof(float));
6969
TASSIGN(tmpTile, 2 * M * N * sizeof(float));
7070
TASSIGN(maxTile, 3 * M * N * sizeof(float));
7171
TASSIGN(sumTile, 3 * M * N * sizeof(float) + kAlignedRows * sizeof(float));
72-
TASSIGN(pijF16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float));
72+
TASSIGN(pijBf16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float));
7373

7474
TLOAD(sijTile, sijGlobal);
7575
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
@@ -79,16 +79,16 @@ static __aicore__ void softmax_prepare_impl(
7979
TROWMAX(maxTile, sijTile, tmpTile);
8080
TROWEXPANDSUB(pijTile, sijTile, maxTile);
8181
TEXP(pijTile, pijTile);
82-
// Truncate pij to fp16 first, then compute lij from truncated values (matches golden)
83-
TCVT(pijF16Tile, pijTile, RoundMode::CAST_ROUND);
84-
TCVT(pijTile, pijF16Tile, RoundMode::CAST_ROUND);
82+
// Truncate pij to bf16 first, then compute lij from truncated values (matches golden)
83+
TCVT(pijBf16Tile, pijTile, RoundMode::CAST_ROUND);
84+
TCVT(pijTile, pijBf16Tile, RoundMode::CAST_ROUND);
8585
TROWSUM(sumTile, pijTile, tmpTile);
8686

8787
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
8888
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
8989
TSTORE(mijGlobal, maxTile);
9090
TSTORE(lijGlobal, sumTile);
91-
TSTORE(pijGlobal, pijF16Tile);
91+
TSTORE(pijGlobal, pijBf16Tile);
9292
}
9393

9494
extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {

0 commit comments

Comments
 (0)