diff --git a/tests/device_tests/tensormap_and_ringbuffer/paged_attention_taskring/golden.py b/tests/device_tests/tensormap_and_ringbuffer/paged_attention_taskring/golden.py new file mode 100644 index 000000000..f5b7730f1 --- /dev/null +++ b/tests/device_tests/tensormap_and_ringbuffer/paged_attention_taskring/golden.py @@ -0,0 +1,230 @@ +""" +Paged Attention Golden – Task Ring Stress Test + +Same algorithm as paged_attention/golden.py but with a larger case +(batch=64) to stress ring buffer wrapping with +small task_window / heap / dep_pool configured in kernel_config.py. +""" + +import ctypes +import struct +import torch + +__outputs__ = ["out"] + +RTOL = 1e-3 +ATOL = 1e-3 + + +ALL_CASES = { + "Case1": { + "batch": 64, + "num_heads": 16, + "kv_head_num": 1, + "head_dim": 128, + "block_size": 128, + "context_len": 1025, + "max_model_len": 32768, + }, +} + +DEFAULT_CASE = "Case1" + + +def generate_inputs(params: dict) -> list: + """Generate input tensors and zeroed output tensor.""" + batch = params["batch"] + num_heads = params["num_heads"] + kv_head_num = params["kv_head_num"] + head_dim = params["head_dim"] + block_size = params["block_size"] + context_len = params["context_len"] + max_model_len = params["max_model_len"] + + max_num_blocks_per_req = max_model_len // block_size + cur_valid_blocks = (context_len + block_size - 1) // block_size + total_blocks = batch * cur_valid_blocks + scale_value = 1.0 + scale_bits = struct.unpack('I', struct.pack('f', scale_value))[0] + + block_table = torch.randint( + 0, + max(total_blocks, 1), + size=(batch, max_num_blocks_per_req), + dtype=torch.int32, + ) + + context_lens = torch.full((batch,), context_len, dtype=torch.int32) + + config = torch.tensor( + [batch, num_heads, kv_head_num, head_dim, block_size, + max_num_blocks_per_req, scale_bits], + dtype=torch.int64, + ) + + query_bf16 = torch.empty(batch, 1, num_heads * head_dim).uniform_(-0.5, 0.5).to(torch.bfloat16) + query_bf16 = query_bf16.reshape(batch, num_heads, head_dim) + + key_bf16 = torch.empty(total_blocks, block_size, kv_head_num, head_dim).uniform_(-0.5, 0.5).to(torch.bfloat16) + value_bf16 = torch.empty(total_blocks, block_size, kv_head_num, head_dim).uniform_(-1, 1).to(torch.bfloat16) + + query = query_bf16.flatten() + key_cache = key_bf16.flatten() + value_cache = value_bf16.flatten() + block_table_flat = block_table.flatten() + out = torch.zeros(batch * num_heads * head_dim, dtype=torch.float32) + + return [ + ("query", query), + ("key_cache", key_cache), + ("value_cache", value_cache), + ("block_table", block_table_flat), + ("context_lens", context_lens), + ("out", out), + ("config", config), + ("size_query", ctypes.c_int64(query.nbytes)), + ("size_key_cache", ctypes.c_int64(key_cache.nbytes)), + ("size_value_cache", ctypes.c_int64(value_cache.nbytes)), + ] + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + num_heads: int, + scale_value: float, + block_table: torch.Tensor, + context_lens: torch.Tensor, +) -> torch.Tensor: + """ + Compute paged attention using online softmax with head tiling and GQA. + + Vectorized across the batch dimension for performance. + Supports different context_lens per batch via masking. + + Args: + query: (batch, num_heads, head_dim) bfloat16 + key_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16 + value_cache: (total_blocks, block_size, num_kv_heads, head_dim) bfloat16 + num_kv_heads: int + num_heads: int + scale_value: float + block_table: (batch, block_num) int32 + context_lens: (batch,) int32 + + Returns: + out: (batch * num_heads, head_dim) float32 + """ + assert num_kv_heads == 1 + batch, num_heads_dim, head_dim = query.shape + _, block_size, _, _ = key_cache.shape + + # Reshape for batched computation + key_cache_flat = key_cache.reshape(-1, block_size, head_dim) + value_cache_flat = value_cache.reshape(-1, block_size, head_dim) + + out = torch.zeros((batch, num_heads_dim, head_dim), dtype=torch.float32) + + q_tile = min(num_heads_dim, 128) + + # Max blocks across all batches (each batch may have different context_len) + max_bn = int(((context_lens.max().item()) + block_size - 1) // block_size) + + for q_offset in range(0, num_heads_dim, q_tile): + q_tile_size = min(q_tile, num_heads_dim - q_offset) + # qi: (batch, q_tile_size, head_dim) + qi = query[:, q_offset:q_offset + q_tile_size, :].to(torch.float32) + + oi = None # (batch, q_tile_size, head_dim) + li = None # (batch, q_tile_size, 1) + mi = None # (batch, q_tile_size, 1) + + for bn in range(max_bn): + # valid_len per batch for this block position + valid_lens = torch.clamp(context_lens - bn * block_size, min=0, max=block_size) + active_mask = valid_lens > 0 # (batch,) + + if not active_mask.any(): + break + + # Gather block indices for all batches + block_indices = block_table[:, bn] # (batch,) + + # Gather K and V: (batch, block_size, head_dim) + kj_all = key_cache_flat[block_indices].to(torch.float32) + vj_all = value_cache_flat[block_indices].to(torch.float32) + + # QK matmul: (batch, q_tile_size, block_size) + sij = torch.bmm(qi, kj_all.transpose(1, 2)) * scale_value + + # Mask out invalid positions (beyond valid_len per batch) + pos = torch.arange(block_size, device=sij.device).unsqueeze(0) # (1, block_size) + valid_mask = pos < valid_lens.unsqueeze(1) # (batch, block_size) + valid_mask = valid_mask.unsqueeze(1) # (batch, 1, block_size) + sij = sij.masked_fill(~valid_mask, float('-inf')) + + # Also mask inactive batches (no blocks at this position) + batch_mask = active_mask.view(-1, 1, 1) # (batch, 1, 1) + sij = sij.masked_fill(~batch_mask, float('-inf')) + + mij = sij.max(dim=-1, keepdim=True)[0] # (batch, q_tile_size, 1) + mij = mij.clamp(min=-1e30) + pij = torch.exp(sij - mij) + pij = pij.masked_fill(~valid_mask, 0.0) + pij = pij.masked_fill(~batch_mask, 0.0) + pij = pij.to(torch.bfloat16).to(torch.float32) + lij = pij.sum(dim=-1, keepdim=True) # (batch, q_tile_size, 1) + + # PV matmul: (batch, q_tile_size, head_dim) + oi_new = torch.bmm(pij, vj_all) + + if bn == 0: + oi = oi_new + li = lij + mi = mij + else: + mi_new = torch.maximum(mi, mij) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(mij - mi_new) + li = alpha * li + beta * lij + oi = alpha * oi + beta * oi_new + mi = mi_new + + # Final normalization + out[:, q_offset:q_offset + q_tile_size, :] = oi / li + + return out.reshape(-1, head_dim) + + +def compute_golden(tensors: dict, params: dict) -> None: + """Compute expected output in-place using online softmax paged attention.""" + batch = params["batch"] + num_heads = params["num_heads"] + kv_head_num = params["kv_head_num"] + head_dim = params["head_dim"] + block_size = params["block_size"] + max_model_len = params["max_model_len"] + + max_num_blocks_per_req = max_model_len // block_size + + # Reconstruct shaped tensors from flat tensors + query = tensors["query"].reshape(batch, num_heads, head_dim) + key_cache = tensors["key_cache"].reshape(-1, block_size, kv_head_num, head_dim) + value_cache = tensors["value_cache"].reshape(-1, block_size, kv_head_num, head_dim) + block_table = tensors["block_table"].reshape(batch, max_num_blocks_per_req) + context_lens = tensors["context_lens"] + + out = paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=kv_head_num, + num_heads=num_heads, + scale_value=1.0, + block_table=block_table, + context_lens=context_lens, + ) + + tensors["out"][:] = out.flatten() diff --git a/tests/device_tests/tensormap_and_ringbuffer/paged_attention_taskring/kernels/kernel_config.py b/tests/device_tests/tensormap_and_ringbuffer/paged_attention_taskring/kernels/kernel_config.py new file mode 100644 index 000000000..65af24eda --- /dev/null +++ b/tests/device_tests/tensormap_and_ringbuffer/paged_attention_taskring/kernels/kernel_config.py @@ -0,0 +1,46 @@ +""" +Paged Attention – Task Ring Stress Test + +Reuses the same kernels and orchestration as paged_attention but configures +small ring buffers via RUNTIME_ENV to exercise: + - CAS-based watermark advancement (slot reuse with task_window=16) + - Heap ring wrapping (total allocations > 1MB heap) + - Dependency pool wrapping (256 entries) +""" + +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent +_EXISTING_KERNELS = _KERNELS_ROOT.parent.parent / "paged_attention" / "kernels" + +# Orchestration config (reuse existing source) +ORCHESTRATION = { + "source": str(_EXISTING_KERNELS / "orchestration" / "paged_attention_orch.cpp"), + "function_name": "build_paged_attention_graph", +} + +# Kernel configs (reuse existing sources) +KERNELS = [ + # AIC kernels (matrix multiplication using Cube unit) + {"func_id": 0, "name": "QK", "source": str(_EXISTING_KERNELS / "aic" / "aic_qk_matmul.cpp"), "core_type": "aic"}, + {"func_id": 2, "name": "PV", "source": str(_EXISTING_KERNELS / "aic" / "aic_pv_matmul.cpp"), "core_type": "aic"}, + {"func_id": 4, "name": "AIC_HUB", "source": str(_EXISTING_KERNELS / "aic" / "aic_hub.cpp"), "core_type": "aic"}, + # AIV kernels (vector operations) + {"func_id": 1, "name": "SF", "source": str(_EXISTING_KERNELS / "aiv" / "aiv_softmax_prepare.cpp"), "core_type": "aiv"}, + {"func_id": 3, "name": "UP", "source": str(_EXISTING_KERNELS / "aiv" / "aiv_online_update.cpp"), "core_type": "aiv"}, + {"func_id": 5, "name": "AIV_HUB", "source": str(_EXISTING_KERNELS / "aiv" / "aiv_hub.cpp"), "core_type": "aiv"}, +] + +# Runtime configuration +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "block_dim": 24, +} + +# Small ring buffer sizes to stress task-ring slot reuse and heap wrapping +RUNTIME_ENV = { + "PTO2_RING_TASK_WINDOW": "16", # 16 slots (default 65536) - heavy slot reuse + "PTO2_RING_HEAP": "1048576", # 1MB (default 1GB) - heap ring wrapping + "PTO2_RING_DEP_POOL": "256", # 256 entries (default 65536) - dep pool wrapping +}