From 22f76d94753b74d54c7150a7936ae8afdfd8df2c Mon Sep 17 00:00:00 2001 From: wcwxy <26245345+ChaoWao@users.noreply.github.com> Date: Mon, 2 Mar 2026 12:19:35 +0800 Subject: [PATCH] Add: batch_paged_attention device test for production-scale bfloat16 Port batch_paged_attention from examples to device tests with: - bfloat16 data type (replacing float16 from example) - Production tile sizes (128x128/64x128) with runtime dispatch - Production scale: batch=64, head_dim=128, context_len=8193 - Variable sequence length test case (CaseVarSeq) - Tighter tolerance (RTOL/ATOL=1e-3 vs 1e-2 in example) - Chunked batch orchestration with IN_CORE_BATCH=16 --- .../batch_paged_attention/golden.py | 300 ++++++++++++++++++ .../kernels/aic/aic_hub.cpp | 14 + .../kernels/aic/aic_pv_matmul.cpp | 123 +++++++ .../kernels/aic/aic_qk_matmul.cpp | 129 ++++++++ .../kernels/aiv/aiv_hub.cpp | 14 + .../kernels/aiv/aiv_online_update.cpp | 237 ++++++++++++++ .../kernels/aiv/aiv_softmax_prepare.cpp | 177 +++++++++++ .../kernels/kernel_config.py | 43 +++ .../orchestration/paged_attention_orch.cpp | 219 +++++++++++++ 9 files changed, 1256 insertions(+) create mode 100644 tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/golden.py create mode 100644 tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_hub.cpp create mode 100644 tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_pv_matmul.cpp create mode 100644 tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_qk_matmul.cpp create mode 100644 tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_hub.cpp create mode 100644 tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_online_update.cpp create mode 100644 tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp create mode 100644 tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/kernel_config.py create mode 100644 tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp diff --git a/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/golden.py b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/golden.py new file mode 100644 index 000000000..55049cafb --- /dev/null +++ b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/golden.py @@ -0,0 +1,300 @@ +""" +Batch Paged Attention Golden Implementation - Production Scale + +Implements the online softmax algorithm for batched paged attention with: +- bfloat16 Q/K/V inputs +- Non-transposed K storage: (total_blocks, block_size, kv_head_num, head_dim) +- GQA support (kv_head_num=1) +- Head tiling: q_tile = min(q_head_num, 128) +- Variable sequence lengths per batch (CaseVarSeq) + +Args layout: [ptr_query, ..., ptr_config, size_query, size_key_cache, size_value_cache] +""" + +import ctypes +import struct +import torch + +__outputs__ = ["out"] + +RTOL = 1e-3 +ATOL = 1e-3 + + +# All test cases - production scale +ALL_CASES = { + "Case1": { + "batch": 64, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 128, + "block_size": 128, + "context_len": 8193, + "max_model_len": 32768, + }, + "Case2": { + "batch": 64, + "num_heads": 64, + "kv_head_num": 1, + "head_dim": 128, + "block_size": 64, + "context_len": 8192, + "max_model_len": 32768, + }, + # Variable sequence length cases + "CaseVarSeq": { + "batch": 64, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 128, + "block_size": 128, + "context_len": 8193, + "context_lens_list": [8193, 4096, 1024, 256, 8000, 512, 2048, 7777], + "max_model_len": 32768, + }, +} + +DEFAULT_CASE = "Case1" + + +def generate_inputs(params: dict) -> list: + """Generate input tensors and zeroed output tensor.""" + batch = params["batch"] + num_heads = params["num_heads"] + kv_head_num = params["kv_head_num"] + head_dim = params["head_dim"] + block_size = params["block_size"] + context_len = params["context_len"] + max_model_len = params["max_model_len"] + context_lens_list = params.get("context_lens_list") + + assert context_len >= 1, "context_len must be >= 1 to avoid division by zero in attention" + + max_num_blocks_per_req = max_model_len // block_size + scale_value = 1.0 + scale_bits = struct.unpack('I', struct.pack('f', scale_value))[0] + + # Build per-batch context_lens tensor + if context_lens_list is not None: + seq_vals = context_lens_list + if len(seq_vals) < batch: + seq_vals = (seq_vals * ((batch + len(seq_vals) - 1) // len(seq_vals)))[:batch] + elif len(seq_vals) > batch: + seq_vals = seq_vals[:batch] + context_lens = torch.tensor(seq_vals, dtype=torch.int32) + else: + context_lens = torch.full((batch,), context_len, dtype=torch.int32) + + max_ctx = int(context_lens.max().item()) + cur_valid_blocks = (max_ctx + block_size - 1) // block_size + total_blocks = batch * cur_valid_blocks + + # Random block table: (batch, max_num_blocks_per_req) int32 + block_table = torch.randint( + 0, + max(total_blocks, 1), + size=(batch, max_num_blocks_per_req), + dtype=torch.int32, + ) + + config = torch.tensor( + [batch, num_heads, kv_head_num, head_dim, block_size, + max_num_blocks_per_req, scale_bits], + dtype=torch.int64, + ) + + # Query: (batch, 1, num_heads * head_dim) -> (batch, num_heads, head_dim) bfloat16 + query_bf16 = torch.empty(batch, 1, num_heads * head_dim).uniform_(-0.5, 0.5).to(torch.bfloat16) + query_bf16 = query_bf16.reshape(batch, num_heads, head_dim) + + # Key cache: (total_blocks, block_size, kv_head_num, head_dim) bfloat16 + key_bf16 = torch.empty(total_blocks, block_size, kv_head_num, head_dim).uniform_(-0.5, 0.5).to(torch.bfloat16) + + # Value cache: (total_blocks, block_size, kv_head_num, head_dim) bfloat16 + value_bf16 = torch.empty(total_blocks, block_size, kv_head_num, head_dim).uniform_(-1, 1).to(torch.bfloat16) + + query = query_bf16.flatten() + key_cache = key_bf16.flatten() + value_cache = value_bf16.flatten() + block_table_flat = block_table.flatten() + out = torch.zeros(batch * num_heads * head_dim, dtype=torch.float32) + + return [ + ("query", query), + ("key_cache", key_cache), + ("value_cache", value_cache), + ("block_table", block_table_flat), + ("context_lens", context_lens), + ("out", out), + ("config", config), + ("size_query", ctypes.c_int64(query.nbytes)), + ("size_key_cache", ctypes.c_int64(key_cache.nbytes)), + ("size_value_cache", ctypes.c_int64(value_cache.nbytes)), + ] + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + num_heads: int, + scale_value: float, + block_table: torch.Tensor, + context_lens: torch.Tensor, +) -> torch.Tensor: + """ + Compute paged attention using online softmax with head tiling and GQA. + + Vectorized across the batch dimension for performance. + Supports different context_lens per batch via masking. + + Args: + query: (batch, num_heads, head_dim) bfloat16 + key_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16 + value_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16 + num_kv_heads: int + num_heads: int + scale_value: float + block_table: (batch, block_num) int32 (non-negative) + context_lens: (batch,) int32 + + Returns: + out: (batch * num_heads, head_dim) float32 + """ + assert num_kv_heads == 1 + batch, num_heads_dim, head_dim = query.shape + _, block_size, _, _ = key_cache.shape + + # Reshape for batched computation + key_cache_flat = key_cache.reshape(-1, block_size, head_dim) + value_cache_flat = value_cache.reshape(-1, block_size, head_dim) + + out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32) + + q_tile = min(num_heads_dim, 128) + + # Max blocks across all batches (each batch may have different context_len) + max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size) + + for q_offset in range(0, num_heads_dim, q_tile): + q_tile_size = min(q_tile, num_heads_dim - q_offset) + # qi: (batch, q_tile_size, head_dim) + qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32) + + oi = None # (batch, q_tile_size, head_dim) + li = None # (batch, q_tile_size, 1) + mi = None # (batch, q_tile_size, 1) + + for bn in range(max_bn): + # valid_len per batch for this block position + valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size) + active_mask = valid_lens > 0 # (batch,) + + if not active_mask.any(): + break + + # Gather block indices for all batches + block_indices = block_table[:, bn] # (batch,) + + # Gather K and V: (batch, block_size, head_dim) + kj_all = key_cache_flat[block_indices].to(torch.float32) + vj_all = value_cache_flat[block_indices].to(torch.float32) + + # QK matmul: (batch, q_tile_size, block_size) + sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value + + # Mask out invalid positions (beyond valid_len per batch) + pos = torch.arange(block_size, device=sij.device).unsqueeze(0) # (1, block_size) + valid_mask = pos < valid_lens.unsqueeze(1) # (batch, block_size) + valid_mask = valid_mask.unsqueeze(1) # (batch, 1, block_size) + sij = sij.masked_fill(~valid_mask, float('-inf')) + + # Also mask inactive batches (no blocks at this position) + batch_mask = active_mask.view(-1, 1, 1) # (batch, 1, 1) + sij = sij.masked_fill(~batch_mask, float('-inf')) + + mij = sij.max(dim=-1, keepdim=True)[0] # (batch, q_tile_size, 1) + mij = mij.clamp(min=-1e30) + pij = torch.exp(sij - mij) + pij = pij.masked_fill(~valid_mask, 0.0) + pij = pij.masked_fill(~batch_mask, 0.0) + pij = pij.to(torch.bfloat16).to(torch.float32) + lij = pij.sum(dim=-1, keepdim=True) # (batch, q_tile_size, 1) + + # PV matmul: (batch, q_tile_size, head_dim) + oi_new = torch.bmm(pij, vj_all) + + if bn == 0: + oi = oi_new + li = lij + mi = mij + else: + mi_new = torch.maximum(mi, mij) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(mij - mi_new) + li = alpha * li + beta * lij + oi = alpha * oi + beta * oi_new + mi = mi_new + + # Final normalization + out[:, q_offset:q_offset + q_tile_size, :] = oi / li + + return out.reshape(-1, head_dim) + + +def compute_golden(tensors: dict, params: dict) -> None: + """Compute expected output in-place using online softmax paged attention.""" + batch = params["batch"] + num_heads = params["num_heads"] + kv_head_num = params["kv_head_num"] + head_dim = params["head_dim"] + block_size = params["block_size"] + max_model_len = params["max_model_len"] + + max_num_blocks_per_req = max_model_len // block_size + + # Reconstruct shaped tensors from flat tensors + query = tensors["query"].reshape(batch, num_heads, head_dim) + key_cache = tensors["key_cache"].reshape(-1, block_size, kv_head_num, head_dim) + value_cache = tensors["value_cache"].reshape(-1, block_size, kv_head_num, head_dim) + block_table = tensors["block_table"].reshape(batch, max_num_blocks_per_req) + context_lens = tensors["context_lens"] + + out = paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=kv_head_num, + num_heads=num_heads, + scale_value=1.0, + block_table=block_table, + context_lens=context_lens, + ) + + tensors["out"][:] = out.flatten() + + +if __name__ == "__main__": + params = {"name": DEFAULT_CASE, **ALL_CASES[DEFAULT_CASE]} + result = generate_inputs(params) + tensors = {name: tensor for name, tensor in result if isinstance(tensor, torch.Tensor)} + compute_golden(tensors, params) + + print(f"=== Batch Paged Attention Golden Test ({params['name']}) ===") + print(f"batch={params['batch']}, num_heads={params['num_heads']}, head_dim={params['head_dim']}") + print(f"kv_head_num={params['kv_head_num']}, block_size={params['block_size']}") + if params.get('context_lens_list'): + print(f"context_lens (variable): {params['context_lens_list'][:8]}{'...' if len(params['context_lens_list']) > 8 else ''}") + else: + print(f"context_len={params['context_len']}") + + max_num_blocks = params['max_model_len'] // params['block_size'] + q_tile = min(params['num_heads'], 128) + print(f"max_num_blocks_per_req={max_num_blocks}, q_tile_size={q_tile}") + + out = tensors["out"].reshape(params["batch"] * params["num_heads"], params["head_dim"]) + print(f"Output shape: {out.shape}") + print(f"Output range: [{out.min():.4f}, {out.max():.4f}]") + print(f"Output mean: {out.mean():.4f}") + print("Golden test passed!") diff --git a/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_hub.cpp b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_hub.cpp new file mode 100644 index 000000000..98be505ed --- /dev/null +++ b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_hub.cpp @@ -0,0 +1,14 @@ +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {} diff --git a/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_pv_matmul.cpp b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_pv_matmul.cpp new file mode 100644 index 000000000..015a785a4 --- /dev/null +++ b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_pv_matmul.cpp @@ -0,0 +1,123 @@ +// Batched PV Matmul Kernel: for each batch b, pij(M, K) @ vj(K, N) -> oi_new(M, N) +// +// Processes batch_count batches in a single kernel invocation. +// Per-batch addresses are computed from global tensor bases + block_table lookup. +// +// Supports two tile configurations via runtime dispatch: +// Case1: (16, 128) @ (128, 128) -> (16, 128) +// Case2: (64, 64) @ ( 64, 128) -> (64, 128) +// +// Template: M=q_tile, K=block_size, N=head_dim + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +static __aicore__ void pv_matmul_batch_impl( + __gm__ TensorData* pij_batch, + __gm__ TensorData* value_cache, + __gm__ TensorData* oi_new_batch, + uint64_t block_table_ptr, + uint64_t batch_count, + uint64_t block_idx, + uint64_t block_num, + uint64_t batch_start) { + + __gm__ bfloat16_t* pij_base = reinterpret_cast<__gm__ bfloat16_t*>(pij_batch->buffer.addr); + __gm__ bfloat16_t* val_base = reinterpret_cast<__gm__ bfloat16_t*>(value_cache->buffer.addr); + __gm__ float* oi_base = reinterpret_cast<__gm__ float*>(oi_new_batch->buffer.addr); + // Block table values are always non-negative (physical block indices) + __gm__ int32_t* bt = reinterpret_cast<__gm__ int32_t*>(block_table_ptr); + + using GlobalA = GlobalTensor, Stride>; + using GlobalB = GlobalTensor, Stride>; + using GlobalOut = GlobalTensor, Stride>; + + using TileMatA = Tile; + using TileMatB = Tile; + + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + for (uint64_t b = 0; b < batch_count; b++) { + __gm__ bfloat16_t* pij_addr = pij_base + b * M * K; + int32_t phys_block = bt[(batch_start + b) * block_num + block_idx]; + __gm__ bfloat16_t* vj_addr = val_base + (uint64_t)phys_block * K * N; + __gm__ float* oi_addr = oi_base + b * M * N; + + GlobalA pijGlobal(pij_addr); + GlobalB vjGlobal(vj_addr); + GlobalOut oiGlobal(oi_addr); + + TLOAD(aMatTile, pijGlobal); + TLOAD(bMatTile, vjGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(oiGlobal, cTile); + + if (b + 1 < batch_count) { + pipe_barrier(PIPE_ALL); + } + } +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* pij_batch = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* value_cache = reinterpret_cast<__gm__ TensorData*>(args[1]); + __gm__ TensorData* oi_new_batch = reinterpret_cast<__gm__ TensorData*>(args[2]); + uint64_t block_table_ptr = static_cast(args[3]); + uint64_t batch_count = static_cast(args[4]); + uint64_t block_idx = static_cast(args[5]); + uint64_t block_num = static_cast(args[6]); + uint64_t batch_start = static_cast(args[7]); + + uint64_t q_tile_size = static_cast(pij_batch->shapes[0] / batch_count); + + if (q_tile_size == 16) { + pv_matmul_batch_impl<16, 128, 128>( + pij_batch, value_cache, oi_new_batch, + block_table_ptr, batch_count, block_idx, block_num, batch_start); + } else { + pv_matmul_batch_impl<64, 64, 128>( + pij_batch, value_cache, oi_new_batch, + block_table_ptr, batch_count, block_idx, block_num, batch_start); + } +} diff --git a/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_qk_matmul.cpp b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_qk_matmul.cpp new file mode 100644 index 000000000..b65512f9d --- /dev/null +++ b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aic/aic_qk_matmul.cpp @@ -0,0 +1,129 @@ +// Batched QK Matmul Kernel: for each batch b, qi(M, K) @ kj.T(K, N) -> sij(M, N) +// +// Processes batch_count batches in a single kernel invocation. +// Per-batch addresses are computed from global tensor bases + block_table lookup. +// +// Supports two tile configurations via runtime dispatch: +// Case1: (16, 128) @ (128, 128).T -> (16, 128) +// Case2: (64, 128) @ (128, 64).T -> (64, 64) +// +// Template: M=q_tile, K=head_dim, N=block_size + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +static __aicore__ void qk_matmul_batch_impl( + __gm__ TensorData* query, + __gm__ TensorData* key_cache, + __gm__ TensorData* sij_batch, + uint64_t block_table_ptr, + uint64_t batch_count, + uint64_t block_idx, + uint64_t q_offset, + uint64_t block_num, + uint64_t num_heads, + uint64_t batch_start) { + + __gm__ bfloat16_t* query_base = reinterpret_cast<__gm__ bfloat16_t*>(query->buffer.addr); + __gm__ bfloat16_t* key_base = reinterpret_cast<__gm__ bfloat16_t*>(key_cache->buffer.addr); + __gm__ float* sij_base = reinterpret_cast<__gm__ float*>(sij_batch->buffer.addr); + // Block table values are always non-negative (physical block indices) + __gm__ int32_t* bt = reinterpret_cast<__gm__ int32_t*>(block_table_ptr); + + using GlobalA = GlobalTensor, Stride>; + using GlobalB = GlobalTensor, Stride, Layout::DN>; + using GlobalOut = GlobalTensor, Stride>; + + using TileMatA = Tile; + using TileMatB = Tile; + + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + for (uint64_t b = 0; b < batch_count; b++) { + __gm__ bfloat16_t* qi_addr = query_base + ((batch_start + b) * num_heads + q_offset) * K; + int32_t phys_block = bt[(batch_start + b) * block_num + block_idx]; + __gm__ bfloat16_t* kj_addr = key_base + (uint64_t)phys_block * N * K; + __gm__ float* sij_addr = sij_base + b * M * N; + + GlobalA qiGlobal(qi_addr); + GlobalB kjGlobal(kj_addr); + GlobalOut sijGlobal(sij_addr); + + TLOAD(aMatTile, qiGlobal); + TLOAD(bMatTile, kjGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(sijGlobal, cTile); + + if (b + 1 < batch_count) { + pipe_barrier(PIPE_ALL); + } + } +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* query = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* key_cache = reinterpret_cast<__gm__ TensorData*>(args[1]); + __gm__ TensorData* sij_batch = reinterpret_cast<__gm__ TensorData*>(args[2]); + uint64_t block_table_ptr = static_cast(args[3]); + uint64_t batch_count = static_cast(args[4]); + uint64_t block_idx = static_cast(args[5]); + uint64_t q_offset = static_cast(args[6]); + uint64_t block_num = static_cast(args[7]); + uint64_t num_heads = static_cast(args[8]); + uint64_t batch_start = static_cast(args[9]); + + uint64_t q_tile_size = static_cast(sij_batch->shapes[0] / batch_count); + + if (q_tile_size == 16) { + qk_matmul_batch_impl<16, 128, 128>( + query, key_cache, sij_batch, + block_table_ptr, batch_count, block_idx, q_offset, block_num, num_heads, + batch_start); + } else { + qk_matmul_batch_impl<64, 128, 64>( + query, key_cache, sij_batch, + block_table_ptr, batch_count, block_idx, q_offset, block_num, num_heads, + batch_start); + } +} diff --git a/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_hub.cpp b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_hub.cpp new file mode 100644 index 000000000..98be505ed --- /dev/null +++ b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_hub.cpp @@ -0,0 +1,14 @@ +#include +#include + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) {} diff --git a/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_online_update.cpp b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_online_update.cpp new file mode 100644 index 000000000..ee55c00ab --- /dev/null +++ b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_online_update.cpp @@ -0,0 +1,237 @@ +// Batched Online Softmax Update + Normalize Kernel (AIV) +// +// Processes batch_count batches in a single kernel invocation. +// For each batch b, updates accumulators mi/li/oi with new block's mij/lij/oi_new. +// On is_last, normalizes and writes to the output tensor at the correct batch offset. +// +// Supports two tile configurations via runtime dispatch: +// Case1: (16, 128) -- q_tile=16, head_dim=128 +// Case2: (64, 128) -- q_tile=64, head_dim=128 +// +// Scalar layout strategy (unchanged from unbatched version): +// M scalar floats stored contiguously in GM can be loaded as either: +// - ND (kScalarRows, kScalarCols) RowMajor for element-wise ops +// - DN (kAlignedRows, 1) ColMajor for row-broadcast ops +// Conversion between layouts uses GM round-trip: ND TSTORE -> DN TLOAD. + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +static __aicore__ void online_update_batch_impl( + __gm__ TensorData* mij_batch, + __gm__ TensorData* lij_batch, + __gm__ TensorData* oi_new_batch, + __gm__ TensorData* mi_batch, + __gm__ TensorData* li_batch, + __gm__ TensorData* oi_batch, + __gm__ TensorData* out, + uint64_t is_first, + uint64_t is_last, + uint64_t batch_count, + uint64_t q_offset, + uint64_t num_heads, + uint64_t batch_start) { + + __gm__ float* mij_base = reinterpret_cast<__gm__ float*>(mij_batch->buffer.addr); + __gm__ float* lij_base = reinterpret_cast<__gm__ float*>(lij_batch->buffer.addr); + __gm__ float* oi_new_base = reinterpret_cast<__gm__ float*>(oi_new_batch->buffer.addr); + __gm__ float* mi_base = reinterpret_cast<__gm__ float*>(mi_batch->buffer.addr); + __gm__ float* li_base = reinterpret_cast<__gm__ float*>(li_batch->buffer.addr); + __gm__ float* oi_base = reinterpret_cast<__gm__ float*>(oi_batch->buffer.addr); + __gm__ float* out_base = reinterpret_cast<__gm__ float*>(out->buffer.addr); + + constexpr int kScalarCols = 32 / sizeof(float); + constexpr int kScalarRows = M / kScalarCols; + constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + + using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, N, 1>>; + using GlobalScalarND = + GlobalTensor, Stride<1, 1, 1, kScalarCols, 1>>; + using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; + + using TileDataMxN = Tile; + using TileScalarND = + Tile; + using TileScalarDN = Tile; + + constexpr int kDataBytes = M * N * sizeof(float); + constexpr int kScalarNDBytes = kScalarRows * kScalarCols * sizeof(float); + constexpr int kScalarDNBytes = kAlignedRows * sizeof(float); + + TileDataMxN oiNewTile; + TileDataMxN oiTile; + + TileScalarND mijND, lijND, miND, liND; + TileScalarND miNewND, alphaND, betaND, tmpND; + + TileScalarDN alphaDN, betaDN, liDN; + + TASSIGN(oiNewTile, 0); + TASSIGN(oiTile, kDataBytes); + TASSIGN(mijND, 2 * kDataBytes); + TASSIGN(lijND, 2 * kDataBytes + kScalarNDBytes); + TASSIGN(miND, 2 * kDataBytes + 2 * kScalarNDBytes); + TASSIGN(liND, 2 * kDataBytes + 3 * kScalarNDBytes); + TASSIGN(miNewND, 2 * kDataBytes + 4 * kScalarNDBytes); + TASSIGN(alphaND, 2 * kDataBytes + 5 * kScalarNDBytes); + TASSIGN(betaND, 2 * kDataBytes + 6 * kScalarNDBytes); + TASSIGN(tmpND, 2 * kDataBytes + 7 * kScalarNDBytes); + TASSIGN(alphaDN, 2 * kDataBytes + 8 * kScalarNDBytes); + TASSIGN(betaDN, 2 * kDataBytes + 8 * kScalarNDBytes + kScalarDNBytes); + TASSIGN(liDN, 2 * kDataBytes + 8 * kScalarNDBytes + 2 * kScalarDNBytes); + + for (uint64_t b = 0; b < batch_count; b++) { + __gm__ float* mij_ptr = mij_base + b * M; + __gm__ float* lij_ptr = lij_base + b * M; + __gm__ float* oi_new_ptr = oi_new_base + b * M * N; + __gm__ float* mi_ptr = mi_base + b * M; + __gm__ float* li_ptr = li_base + b * M; + __gm__ float* oi_ptr = oi_base + b * M * N; + __gm__ float* dst_ptr = out_base + ((batch_start + b) * num_heads + q_offset) * N; + + GlobalDataMxN oiNewGlobal(oi_new_ptr); + GlobalDataMxN oiGlobal(oi_ptr); + GlobalDataMxN dstGlobal(dst_ptr); + + GlobalScalarND mijGlobalND(mij_ptr); + GlobalScalarND lijGlobalND(lij_ptr); + GlobalScalarND miGlobalND(mi_ptr); + GlobalScalarND liGlobalND(li_ptr); + + GlobalScalarDN mijGlobalDN(mij_ptr); + GlobalScalarDN lijGlobalDN(lij_ptr); + GlobalScalarDN liGlobalDN(li_ptr); + + if (is_first) { + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(mijND, mijGlobalND); + TLOAD(lijND, lijGlobalND); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, mijND); + TSTORE(liGlobalND, lijND); + TSTORE(oiGlobal, oiNewTile); + + if (is_last) { + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + TLOAD(liDN, liGlobalDN); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TROWEXPANDDIV(oiNewTile, oiNewTile, liDN); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(dstGlobal, oiNewTile); + } + } else { + TLOAD(oiNewTile, oiNewGlobal); + TLOAD(oiTile, oiGlobal); + TLOAD(mijND, mijGlobalND); + TLOAD(lijND, lijGlobalND); + TLOAD(miND, miGlobalND); + TLOAD(liND, liGlobalND); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TMAX(miNewND, miND, mijND); + pipe_barrier(PIPE_V); + TSUB(alphaND, miND, miNewND); + pipe_barrier(PIPE_V); + TEXP(alphaND, alphaND); + pipe_barrier(PIPE_V); + TSUB(betaND, mijND, miNewND); + pipe_barrier(PIPE_V); + TEXP(betaND, betaND); + pipe_barrier(PIPE_V); + TMUL(liND, alphaND, liND); + pipe_barrier(PIPE_V); + TMUL(tmpND, betaND, lijND); + pipe_barrier(PIPE_V); + TADD(liND, liND, tmpND); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(miGlobalND, miNewND); + TSTORE(liGlobalND, liND); + TSTORE(mijGlobalND, alphaND); + TSTORE(lijGlobalND, betaND); + + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + TLOAD(alphaDN, mijGlobalDN); + TLOAD(betaDN, lijGlobalDN); + if (is_last) { + TLOAD(liDN, liGlobalDN); + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + + TROWEXPANDMUL(oiTile, oiTile, alphaDN); + TROWEXPANDMUL(oiNewTile, oiNewTile, betaDN); + pipe_barrier(PIPE_V); + TADD(oiTile, oiTile, oiNewTile); + + if (is_last) { + pipe_barrier(PIPE_V); + TROWEXPANDDIV(oiTile, oiTile, liDN); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(dstGlobal, oiTile); + } else { + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(oiGlobal, oiTile); + } + } + + if (b + 1 < batch_count) { + pipe_barrier(PIPE_ALL); + } + } +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* mij_batch = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* lij_batch = reinterpret_cast<__gm__ TensorData*>(args[1]); + __gm__ TensorData* oi_new_batch = reinterpret_cast<__gm__ TensorData*>(args[2]); + __gm__ TensorData* mi_batch = reinterpret_cast<__gm__ TensorData*>(args[3]); + __gm__ TensorData* li_batch = reinterpret_cast<__gm__ TensorData*>(args[4]); + __gm__ TensorData* oi_batch = reinterpret_cast<__gm__ TensorData*>(args[5]); + __gm__ TensorData* out = reinterpret_cast<__gm__ TensorData*>(args[6]); + uint64_t is_first = static_cast(args[7]); + uint64_t is_last = static_cast(args[8]); + uint64_t batch_count = static_cast(args[9]); + uint64_t q_offset = static_cast(args[10]); + uint64_t num_heads = static_cast(args[11]); + uint64_t batch_start = static_cast(args[12]); + + uint64_t q_tile_size = static_cast(mij_batch->shapes[0] / batch_count); + + if (q_tile_size == 16) { + online_update_batch_impl<16, 128>( + mij_batch, lij_batch, oi_new_batch, + mi_batch, li_batch, oi_batch, out, + is_first, is_last, batch_count, q_offset, num_heads, batch_start); + } else { + online_update_batch_impl<64, 128>( + mij_batch, lij_batch, oi_new_batch, + mi_batch, li_batch, oi_batch, out, + is_first, is_last, batch_count, q_offset, num_heads, batch_start); + } +} diff --git a/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp new file mode 100644 index 000000000..29818985a --- /dev/null +++ b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/aiv/aiv_softmax_prepare.cpp @@ -0,0 +1,177 @@ +// Batched Softmax Preparation Kernel (AIV) +// +// Processes batch_count batches in a single kernel invocation. +// For each batch b at block_idx bn: +// valid_len = min(N, context_lens[b] - bn * N) +// sij_masked = pad(sij[b], valid_len, -inf) +// sij_scale = sij_masked * scale +// mij[b] = row_max(sij_scale) +// pij[b] = exp(sij_scale - mij[b]) (truncated to bf16 then back) +// lij[b] = row_sum(pij[b]) +// +// Supports two tile configurations via runtime dispatch: +// Case1: (16, 128) -- q_tile=16, block_size=128 +// Case2: (64, 64) -- q_tile=64, block_size=64 + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +static __aicore__ void softmax_prepare_batch_impl( + __gm__ TensorData* sij_batch, + __gm__ TensorData* pij_batch, + __gm__ TensorData* mij_batch, + __gm__ TensorData* lij_batch, + float scale_value, + uint64_t context_lens_ptr, + uint64_t batch_count, + uint64_t block_idx, + uint64_t batch_start) { + + __gm__ float* sij_base = reinterpret_cast<__gm__ float*>(sij_batch->buffer.addr); + __gm__ bfloat16_t* pij_base = reinterpret_cast<__gm__ bfloat16_t*>(pij_batch->buffer.addr); + __gm__ float* mij_base = reinterpret_cast<__gm__ float*>(mij_batch->buffer.addr); + __gm__ float* lij_base = reinterpret_cast<__gm__ float*>(lij_batch->buffer.addr); + __gm__ int32_t* ctx_lens = reinterpret_cast<__gm__ int32_t*>(context_lens_ptr); + + constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float)); + + using GlobalDataMxN = GlobalTensor, Stride<1, 1, 1, N, 1>>; + using GlobalDataMxN_bf16 = GlobalTensor, Stride<1, 1, 1, N, 1>>; + using GlobalScalarDN = GlobalTensor, Stride<1, 1, 1, 1, 1>, Layout::DN>; + + using TileSijDyn = Tile; + using TileSijPad = Tile; + + using TileVecMxN = Tile; + using TileVecMxN_bf16 = Tile; + using TileScalarDN = Tile; + + TileVecMxN sijTile; + TileSijPad sijPadTile; + TileVecMxN pijTile; + TileVecMxN tmpTile; + TileScalarDN maxTile; + TileScalarDN sumTile; + TileVecMxN_bf16 pijBf16Tile; + + TASSIGN(sijTile, 0x0); + TASSIGN(sijPadTile, 0x0); + TASSIGN(pijTile, M * N * sizeof(float)); + TASSIGN(tmpTile, 2 * M * N * sizeof(float)); + TASSIGN(maxTile, 3 * M * N * sizeof(float)); + TASSIGN(sumTile, 3 * M * N * sizeof(float) + kAlignedRows * sizeof(float)); + TASSIGN(pijBf16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float)); + + for (uint64_t b = 0; b < batch_count; b++) { + int32_t cur_seq = ctx_lens[batch_start + b]; + uint64_t start = block_idx * N; + uint64_t valid_len = 0; + if (start < (uint64_t)cur_seq) { + uint64_t remaining = (uint64_t)cur_seq - start; + valid_len = (remaining < N) ? remaining : N; + } + + __gm__ float* sij_addr = sij_base + b * M * N; + __gm__ bfloat16_t* pij_addr = pij_base + b * M * N; + __gm__ float* mij_addr = mij_base + b * M; + __gm__ float* lij_addr = lij_base + b * M; + + GlobalDataMxN sijGlobal(sij_addr); + GlobalDataMxN_bf16 pijGlobal(pij_addr); + GlobalScalarDN mijGlobal(mij_addr); + GlobalScalarDN lijGlobal(lij_addr); + + if (valid_len == 0) { + // Block entirely beyond sequence: write mij=-1e30, lij=0, pij=0 + // Use -1e30 instead of -inf to avoid NaN in online_update (exp(-inf - (-inf)) = NaN) + constexpr float NEG_LARGE = -1e30f; + for (int i = 0; i < kAlignedRows; i++) { + maxTile.SetValue(i, NEG_LARGE); + sumTile.SetValue(i, 0.0f); + } + for (int i = 0; i < M * N; i++) { + pijBf16Tile.SetValue(i, static_cast(0.0f)); + } + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(mijGlobal, maxTile); + TSTORE(lijGlobal, sumTile); + TSTORE(pijGlobal, pijBf16Tile); + + if (b + 1 < batch_count) { + pipe_barrier(PIPE_ALL); + } + continue; + } + + TLOAD(sijTile, sijGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TileSijDyn sijDynTile(static_cast(valid_len)); + TASSIGN(sijDynTile, 0x0); + TFILLPAD_INPLACE(sijPadTile, sijDynTile); + + TMULS(sijTile, sijTile, scale_value); + pipe_barrier(PIPE_V); + TROWMAX(maxTile, sijTile, tmpTile); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(pijTile, sijTile, maxTile); + pipe_barrier(PIPE_V); + TEXP(pijTile, pijTile); + // Truncate pij to bf16 first, then compute lij from truncated values (matches golden) + TCVT(pijBf16Tile, pijTile, RoundMode::CAST_ROUND); + TCVT(pijTile, pijBf16Tile, RoundMode::CAST_ROUND); + TROWSUM(sumTile, pijTile, tmpTile); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(mijGlobal, maxTile); + TSTORE(lijGlobal, sumTile); + TSTORE(pijGlobal, pijBf16Tile); + + if (b + 1 < batch_count) { + pipe_barrier(PIPE_ALL); + } + } +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ TensorData* sij_batch = reinterpret_cast<__gm__ TensorData*>(args[0]); + __gm__ TensorData* pij_batch = reinterpret_cast<__gm__ TensorData*>(args[1]); + __gm__ TensorData* mij_batch = reinterpret_cast<__gm__ TensorData*>(args[2]); + __gm__ TensorData* lij_batch = reinterpret_cast<__gm__ TensorData*>(args[3]); + union { uint64_t u; float f; } scale_conv; + scale_conv.u = static_cast(args[4]); + float scale_value = scale_conv.f; + uint64_t context_lens_ptr = static_cast(args[5]); + uint64_t batch_count = static_cast(args[6]); + uint64_t block_idx = static_cast(args[7]); + uint64_t batch_start = static_cast(args[8]); + + uint64_t q_tile_size = static_cast(sij_batch->shapes[0] / batch_count); + + if (q_tile_size == 16) { + softmax_prepare_batch_impl<16, 128>( + sij_batch, pij_batch, mij_batch, lij_batch, + scale_value, context_lens_ptr, batch_count, block_idx, batch_start); + } else { + softmax_prepare_batch_impl<64, 64>( + sij_batch, pij_batch, mij_batch, lij_batch, + scale_value, context_lens_ptr, batch_count, block_idx, batch_start); + } +} diff --git a/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/kernel_config.py b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/kernel_config.py new file mode 100644 index 000000000..8b04925bc --- /dev/null +++ b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/kernel_config.py @@ -0,0 +1,43 @@ +""" +Batch Paged Attention Kernel and Orchestration Configuration + +Defines the kernels and orchestration function for batched paged attention +with AIC/AIV subgraph splitting: + +AIC Kernels (Matrix Multiplication): + - aic_qk_matmul: Q @ K^T computation (batched) + - aic_pv_matmul: P @ V computation (batched) + +AIV Kernels (Vector Operations): + - aiv_softmax_prepare: scale, rowmax, exp, rowsum (batched) + - aiv_online_update: online softmax accumulation + fused normalization (batched) +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +# Orchestration config +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "paged_attention_orch.cpp"), + "function_name": "aicpu_orchestration_entry", +} + +# Kernel configs +KERNELS = [ + # AIC kernels (matrix multiplication using Cube unit) + {"func_id": 0, "name": "QK", "source": str(_KERNELS_ROOT / "aic" / "aic_qk_matmul.cpp"), "core_type": "aic"}, + {"func_id": 2, "name": "PV", "source": str(_KERNELS_ROOT / "aic" / "aic_pv_matmul.cpp"), "core_type": "aic"}, + {"func_id": 4, "name": "AIC_HUB", "source": str(_KERNELS_ROOT / "aic" / "aic_hub.cpp"), "core_type": "aic"}, + # AIV kernels (vector operations) + {"func_id": 1, "name": "SF", "source": str(_KERNELS_ROOT / "aiv" / "aiv_softmax_prepare.cpp"), "core_type": "aiv"}, + {"func_id": 3, "name": "UP", "source": str(_KERNELS_ROOT / "aiv" / "aiv_online_update.cpp"), "core_type": "aiv"}, + {"func_id": 5, "name": "AIV_HUB", "source": str(_KERNELS_ROOT / "aiv" / "aiv_hub.cpp"), "core_type": "aiv"}, +] + +# Runtime configuration +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "block_dim": 24, +} diff --git a/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp new file mode 100644 index 000000000..b13687c45 --- /dev/null +++ b/tests/device_tests/tensormap_and_ringbuffer/batch_paged_attention/kernels/orchestration/paged_attention_orch.cpp @@ -0,0 +1,219 @@ +/** + * Batch Paged Attention Orchestration Function - Production Scale + * + * Chunked batched architecture: the full batch is split into chunks of + * IN_CORE_BATCH size. Each chunk's QK/SF/PV/UP tasks are independent + * and can be scheduled to different cores in parallel. + * + * Task count = num_chunks * (1 + max_bn * 4), where + * num_chunks = ceil(batch / IN_CORE_BATCH) + * + * For batch <= IN_CORE_BATCH, behavior is identical to the non-chunked version. + * + * Memory Layout: + * Query: (batch * num_heads, head_dim) bf16 + * Key: (total_blocks, block_size, head_dim) bf16 (stored as K^T for QK) + * Value: (total_blocks, block_size, head_dim) bf16 + * + * Per-chunk intermediate tensors (contiguous across chunk_bc dimension): + * sij: (chunk_bc * q_tile, block_size) fp32 + * pij: (chunk_bc * q_tile, block_size) bf16 + * mij/lij: (chunk_bc * q_tile) fp32 + * oi_new: (chunk_bc * q_tile, head_dim) fp32 + * oi: (chunk_bc * q_tile, head_dim) fp32 accumulator + * mi/li: (chunk_bc * q_tile) fp32 accumulator + * + * Kernels receive global tensors + scalar metadata (including batch_start) + * and compute per-batch addresses internally. + */ + +#include +#include + +#include "pto_orchestration_api.h" + +#define FUNC_QK_MATMUL 0 +#define FUNC_SOFTMAX_PREPARE 1 +#define FUNC_PV_MATMUL 2 +#define FUNC_ONLINE_UPDATE 3 +#define FUNC_AIC_HUB 4 +#define FUNC_AIV_HUB 5 + +static uint64_t float_to_u64(float f) { + union { + float f32; + uint64_t u64; + } conv; + conv.u64 = 0; + conv.f32 = f; + return conv.u64; +} + +extern "C" { + +__attribute__((visibility("default"))) +PTO2OrchestrationConfig aicpu_orchestration_config(uint64_t* args, int arg_count) { + (void)args; + (void)arg_count; + return PTO2OrchestrationConfig{ + .expected_arg_count = 10, + }; +} + +__attribute__((visibility("default"))) +void aicpu_orchestration_entry(PTO2Runtime* rt, uint64_t* args, int arg_count) { + (void)arg_count; + pto2_rt_init_tensor_pool(rt); + + void* host_query = (void*)(uintptr_t)args[0]; + void* host_key_cache = (void*)(uintptr_t)args[1]; + void* host_value_cache = (void*)(uintptr_t)args[2]; + int* host_block_table = (int*)(uintptr_t)args[3]; + int* host_context_lens = (int*)(uintptr_t)args[4]; + void* host_out = (void*)(uintptr_t)args[5]; + int64_t* host_config = (int64_t*)(uintptr_t)args[6]; + + size_t key_cache_size = (size_t)args[8]; + + uint64_t batch = static_cast(host_config[0]); + uint64_t num_heads = static_cast(host_config[1]); + uint64_t head_dim = static_cast(host_config[3]); + uint64_t block_size = static_cast(host_config[4]); + uint64_t block_num = static_cast(host_config[5]); + union { uint32_t u; float f; } scale_conv; + scale_conv.u = (uint32_t)host_config[6]; + float scale_value = scale_conv.f; + + uint64_t q_tile = std::min(num_heads, 128UL); + uint64_t q_loop = (num_heads + q_tile - 1) / q_tile; + DataType data_type = DataType::BFLOAT16; + uint64_t elem_size = get_element_size(data_type); + + LOG_INFO(rt, "batch_paged_attention: batch=%lu, num_heads=%lu", + (unsigned long)batch, (unsigned long)num_heads); + + uint64_t max_bn = 0; + for (uint64_t b = 0; b < batch; b++) { + uint64_t cur_seq = host_context_lens[b]; + uint64_t bn_b = (cur_seq + block_size - 1) / block_size; + if (bn_b > max_bn) max_bn = bn_b; + } + + uint64_t query_shapes[2] = {batch * num_heads, head_dim}; + uint64_t kv_total_rows = key_cache_size / (head_dim * elem_size); + uint64_t key_cache_shapes[2] = {kv_total_rows, head_dim}; + uint64_t value_cache_shapes[2] = {kv_total_rows, head_dim}; + uint64_t out_shapes[2] = {batch * num_heads, head_dim}; + + Tensor query = make_tensor_external(host_query, query_shapes, 2, data_type); + Tensor key_cache = make_tensor_external(host_key_cache, key_cache_shapes, 2, data_type); + Tensor value_cache = make_tensor_external(host_value_cache, value_cache_shapes, 2, data_type); + Tensor out = make_tensor_external(host_out, out_shapes, 2, DataType::FLOAT32); + + uint64_t bt_addr = (uint64_t)(uintptr_t)host_block_table; + uint64_t cl_addr = (uint64_t)(uintptr_t)host_context_lens; + + constexpr uint64_t IN_CORE_BATCH = 16; + uint64_t num_chunks = (batch + IN_CORE_BATCH - 1) / IN_CORE_BATCH; + + for (uint64_t q_idx = 0; q_idx < q_loop; q_idx++) { + uint64_t q_offset = q_idx * q_tile; + + for (uint64_t batch_start = 0; batch_start < batch; batch_start += IN_CORE_BATCH) { + uint64_t chunk_bc = batch - batch_start; + if (chunk_bc > IN_CORE_BATCH) chunk_bc = IN_CORE_BATCH; + + PTO2_SCOPE(rt) { + uint64_t oi_acc_shapes[2] = {chunk_bc * q_tile, head_dim}; + uint64_t scalar_acc_shapes[1] = {chunk_bc * q_tile}; + Tensor oi_batch = make_tensor(oi_acc_shapes, 2, DataType::FLOAT32); + Tensor li_batch = make_tensor(scalar_acc_shapes, 1, DataType::FLOAT32); + Tensor mi_batch = make_tensor(scalar_acc_shapes, 1, DataType::FLOAT32); + + PTOParam params_hub[] = { + make_output_param(oi_batch), + make_output_param(li_batch), + make_output_param(mi_batch), + }; + pto2_rt_submit_task(rt, FUNC_AIV_HUB, PTO2_WORKER_VECTOR, params_hub, 3); + + for (uint64_t bn = 0; bn < max_bn; bn++) { + uint64_t sij_shapes[2] = {chunk_bc * q_tile, block_size}; + uint64_t vec_shapes[1] = {chunk_bc * q_tile}; + uint64_t oi_new_shapes[2] = {chunk_bc * q_tile, head_dim}; + + Tensor sij_b = make_tensor(sij_shapes, 2, DataType::FLOAT32); + Tensor pij_b = make_tensor(sij_shapes, 2, data_type); + Tensor mij_b = make_tensor(vec_shapes, 1, DataType::FLOAT32); + Tensor lij_b = make_tensor(vec_shapes, 1, DataType::FLOAT32); + Tensor oi_new_b = make_tensor(oi_new_shapes, 2, DataType::FLOAT32); + + PTOParam params_qk[] = { + make_input_param(query), + make_input_param(key_cache), + make_output_param(sij_b), + make_scalar_param(bt_addr), + make_scalar_param(chunk_bc), + make_scalar_param(bn), + make_scalar_param(q_offset), + make_scalar_param(block_num), + make_scalar_param(num_heads), + make_scalar_param(batch_start), + }; + pto2_rt_submit_task(rt, FUNC_QK_MATMUL, PTO2_WORKER_CUBE, params_qk, 10); + + PTOParam params_sf[] = { + make_input_param(sij_b), + make_output_param(pij_b), + make_output_param(mij_b), + make_output_param(lij_b), + make_scalar_param(float_to_u64(scale_value)), + make_scalar_param(cl_addr), + make_scalar_param(chunk_bc), + make_scalar_param(bn), + make_scalar_param(batch_start), + }; + pto2_rt_submit_task(rt, FUNC_SOFTMAX_PREPARE, PTO2_WORKER_VECTOR, params_sf, 9); + + PTOParam params_pv[] = { + make_input_param(pij_b), + make_input_param(value_cache), + make_output_param(oi_new_b), + make_scalar_param(bt_addr), + make_scalar_param(chunk_bc), + make_scalar_param(bn), + make_scalar_param(block_num), + make_scalar_param(batch_start), + }; + pto2_rt_submit_task(rt, FUNC_PV_MATMUL, PTO2_WORKER_CUBE, params_pv, 8); + + uint64_t is_first = (bn == 0) ? 1 : 0; + uint64_t is_last = (bn == max_bn - 1) ? 1 : 0; + PTOParam params_up[] = { + make_input_param(mij_b), + make_input_param(lij_b), + make_input_param(oi_new_b), + make_inout_param(mi_batch), + make_inout_param(li_batch), + make_output_param(oi_batch), + make_output_param(out), + make_scalar_param(is_first), + make_scalar_param(is_last), + make_scalar_param(chunk_bc), + make_scalar_param(q_offset), + make_scalar_param(num_heads), + make_scalar_param(batch_start), + }; + pto2_rt_submit_task(rt, FUNC_ONLINE_UPDATE, PTO2_WORKER_VECTOR, params_up, 13); + } + } + } + } + + LOG_INFO(rt, "batch_paged_attention: %lu tasks (batch=%lu, max_bn=%lu, chunks=%lu, IN_CORE_BATCH=%lu)", + (unsigned long)(num_chunks * (1 + max_bn * 4)), + (unsigned long)batch, (unsigned long)max_bn, + (unsigned long)num_chunks, (unsigned long)IN_CORE_BATCH); +} + +} // extern "C"