diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh index 90dec3c11..e2562bedf 100644 --- a/python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh @@ -29,6 +29,7 @@ struct MoEHashTopKParams { uint32_t num_routed_experts; uint32_t num_shared_experts; float routed_scaling_factor; + bool apply_routed_scaling_factor_on_output; }; template @@ -36,7 +37,8 @@ __global__ void moe_hash_topk_fused(const MoEHashTopKParams __grid_constant__ pa using namespace device; const auto& [ router_logits, input_id, tid2eid, topk_ids, topk_weights, // pointers - num_tokens, topk, num_routed_experts, num_shared_experts, routed_scaling_factor] = + num_tokens, topk, num_routed_experts, num_shared_experts, routed_scaling_factor, + apply_routed_scaling_factor_on_output] = params; const uint32_t topk_fused = topk + num_shared_experts; @@ -60,8 +62,9 @@ __global__ void moe_hash_topk_fused(const MoEHashTopKParams __grid_constant__ pa if (lane_id < topk_fused) { const bool is_shared = lane_id >= topk; const auto output_offset = warp_id * topk_fused + lane_id; + const auto scale = apply_routed_scaling_factor_on_output ? routed_scaling_factor : 1.0f; topk_ids[output_offset] = is_shared ? num_routed_experts + lane_id - topk : expert_id; - topk_weights[output_offset] = is_shared ? 1.0f / routed_scaling_factor : routed_weight / routed_sum; + topk_weights[output_offset] = is_shared ? scale / routed_scaling_factor : (routed_weight / routed_sum) * scale; } PDLTriggerSecondary(); @@ -100,7 +103,8 @@ struct HashTopKKernel { const tvm::ffi::TensorView tid2eid, const tvm::ffi::TensorView topk_weights, const tvm::ffi::TensorView topk_ids, - float routed_scaling_factor) { + float routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { using namespace host; auto N = SymbolicSize{"num_tokens"}; @@ -148,6 +152,7 @@ struct HashTopKKernel { .num_routed_experts = static_cast(E.unwrap()), .num_shared_experts = shared_experts, .routed_scaling_factor = routed_scaling_factor, + .apply_routed_scaling_factor_on_output = apply_routed_scaling_factor_on_output, }; const auto kBlockSize = 128u; const auto kNumWarps = kBlockSize / device::kWarpThreads; diff --git a/python/sglang/jit_kernel/deepseek_v4.py b/python/sglang/jit_kernel/deepseek_v4.py index 72192b533..cf8f4d9a8 100644 --- a/python/sglang/jit_kernel/deepseek_v4.py +++ b/python/sglang/jit_kernel/deepseek_v4.py @@ -349,6 +349,7 @@ def hash_topk( tid2eid: torch.Tensor, num_fused_shared_experts: int = 0, routed_scaling_factor: float = 1.0, + apply_routed_scaling_factor_on_output: bool = False, scoring_func: str = "sqrtsoftplus", ) -> Tuple[torch.Tensor, torch.Tensor]: assert scoring_func == "sqrtsoftplus" @@ -369,6 +370,7 @@ def hash_topk( topk_weights, topk_ids, routed_scaling_factor, + apply_routed_scaling_factor_on_output, ) return topk_weights, topk_ids diff --git a/python/sglang/srt/distributed/communication_op.py b/python/sglang/srt/distributed/communication_op.py index de83c9c81..99fdf3b5b 100644 --- a/python/sglang/srt/distributed/communication_op.py +++ b/python/sglang/srt/distributed/communication_op.py @@ -5,6 +5,11 @@ import torch import torch.distributed +from sglang.srt.models.dsv4_phase_trace import ( + record_current_dsv4_allreduce_start, + record_dsv4_phase_end, +) + from .parallel_state import ( get_attn_tp_group, get_moe_ep_group, @@ -15,12 +20,20 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group.""" - return get_tp_group().all_reduce(input_) + trace = record_current_dsv4_allreduce_start("tp_allreduce", input_) + try: + return get_tp_group().all_reduce(input_) + finally: + record_dsv4_phase_end(trace) def tensor_model_parallel_quant_all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group.""" - return get_tp_group().quant_all_reduce(input_) + trace = record_current_dsv4_allreduce_start("tp_quant_allreduce", input_) + try: + return get_tp_group().quant_all_reduce(input_) + finally: + record_dsv4_phase_end(trace) def tensor_model_parallel_fused_allreduce_rmsnorm( @@ -62,19 +75,31 @@ def broadcast_tensor_dict( def attention_tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across attention parallel group.""" - return get_attn_tp_group().all_reduce(input_) + trace = record_current_dsv4_allreduce_start("attn_tp_allreduce", input_) + try: + return get_attn_tp_group().all_reduce(input_) + finally: + record_dsv4_phase_end(trace) def attention_tensor_model_parallel_quant_all_reduce( input_: torch.Tensor, ) -> torch.Tensor: """All-reduce the input tensor across attention parallel group.""" - return get_attn_tp_group().quant_all_reduce(input_) + trace = record_current_dsv4_allreduce_start("attn_tp_quant_allreduce", input_) + try: + return get_attn_tp_group().quant_all_reduce(input_) + finally: + record_dsv4_phase_end(trace) def moe_tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across moe parallel group.""" - return get_moe_tp_group().all_reduce(input_) + trace = record_current_dsv4_allreduce_start("moe_tp_allreduce", input_) + try: + return get_moe_tp_group().all_reduce(input_) + finally: + record_dsv4_phase_end(trace) def moe_expert_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 7ff6318dc..099a30abd 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -497,24 +497,91 @@ def graph_capture( graph_capture_context: Optional[GraphCaptureContext] = None, stream: Optional[torch.cuda.Stream] = None, ): + trace_capture = os.environ.get("SGLANG_DSV4_GRAPH_CAPTURE_TRACE", "0") not in ( + "", + "0", + "false", + "False", + ) + trace_rank = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() + else self.rank + ) + trace_group = getattr(self, "unique_name", "unknown") + + def _trace(phase: str, **details): + if not trace_capture: + return + detail_str = " ".join(f"{key}={value}" for key, value in details.items()) + logger.info( + "[DSV4 graph trace] rank=%s group=%s phase=%s %s", + trace_rank, + trace_group, + phase, + detail_str, + ) + + _trace( + "group_capture_init", + has_context=graph_capture_context is not None, + has_stream=stream is not None, + world_size=getattr(self, "world_size", None), + ) if graph_capture_context is None: if stream is None: + _trace("stream_create_start") stream = self.device_module.Stream() + _trace("stream_create_done", stream=stream) graph_capture_context = GraphCaptureContext(stream) else: stream = graph_capture_context.stream # We don't need the context of custom quick allreduce because the ipc access # is already collected in init() and we can capture the quick allreduce directly. ca_comm = self.ca_comm + _trace( + "ca_context_create_start", + present=ca_comm is not None, + disabled=getattr(ca_comm, "disabled", None), + comm_type=type(ca_comm).__name__ if ca_comm is not None else None, + ) maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() + _trace( + "ca_context_create_done", + present=ca_comm is not None, + disabled=getattr(ca_comm, "disabled", None), + ) # ensure all initialization operations complete before attempting to # capture the graph on another stream + _trace("current_stream_get_start") curr_stream = get_current_device_stream_fast() + _trace( + "current_stream_get_done", + same_stream=curr_stream == stream, + curr_stream=curr_stream, + capture_stream=stream, + ) if curr_stream != stream: + _trace("stream_wait_start") stream.wait_stream(curr_stream) + _trace("stream_wait_done") - with self.device_module.stream(stream), maybe_ca_context: + with contextlib.ExitStack() as stack: + _trace("device_stream_enter_start") + stack.enter_context(self.device_module.stream(stream)) + _trace("device_stream_enter_done") + _trace( + "ca_enter_start", + present=ca_comm is not None, + disabled=getattr(ca_comm, "disabled", None), + ) + stack.enter_context(maybe_ca_context) + _trace( + "ca_enter_done", + present=ca_comm is not None, + disabled=getattr(ca_comm, "disabled", None), + ) # In graph mode, we have to be very careful about the collective # operations. The current status is: # allreduce \ Mode | Eager | Graph | @@ -543,16 +610,60 @@ def graph_capture( if not pynccl_comm: maybe_pynccl_context = nullcontext() else: + _trace( + "pynccl_context_create_start", + disabled=getattr(pynccl_comm, "disabled", None), + comm_type=type(pynccl_comm).__name__, + ) maybe_pynccl_context = pynccl_comm.change_state(enable=True) + _trace( + "pynccl_context_create_done", + disabled=getattr(pynccl_comm, "disabled", None), + ) pymscclpp_comm = self.pymscclpp_comm maybe_pymscclpp_context: Any if not pymscclpp_comm: maybe_pymscclpp_context = nullcontext() else: + _trace( + "pymscclpp_context_create_start", + disabled=getattr(pymscclpp_comm, "disabled", None), + comm_type=type(pymscclpp_comm).__name__, + ) maybe_pymscclpp_context = pymscclpp_comm.change_state(enable=True) - with maybe_pynccl_context, maybe_pymscclpp_context: + _trace( + "pymscclpp_context_create_done", + disabled=getattr(pymscclpp_comm, "disabled", None), + ) + _trace( + "pynccl_enter_start", + present=pynccl_comm is not None, + disabled=getattr(pynccl_comm, "disabled", None), + ) + stack.enter_context(maybe_pynccl_context) + _trace( + "pynccl_enter_done", + present=pynccl_comm is not None, + disabled=getattr(pynccl_comm, "disabled", None), + ) + _trace( + "pymscclpp_enter_start", + present=pymscclpp_comm is not None, + disabled=getattr(pymscclpp_comm, "disabled", None), + ) + stack.enter_context(maybe_pymscclpp_context) + _trace( + "pymscclpp_enter_done", + present=pymscclpp_comm is not None, + disabled=getattr(pymscclpp_comm, "disabled", None), + ) + try: + _trace("yield_start") yield graph_capture_context + finally: + _trace("yield_exit_start") + _trace("contexts_exit_done") def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: """ @@ -1560,15 +1671,38 @@ def graph_capture(stream: Optional[torch.cuda.Stream] = None): in order to explicitly distinguish the kernels to capture from other kernels possibly launched on background in the default stream. """ + trace_capture = os.environ.get("SGLANG_DSV4_GRAPH_CAPTURE_TRACE", "0") not in ( + "", + "0", + "false", + "False", + ) + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else -1 + if trace_capture: + logger.info("[DSV4 graph trace] rank=%s group=tp phase=enter_start", rank) with get_tp_group().graph_capture( stream=stream ) as context, get_pp_group().graph_capture(context): + if trace_capture: + logger.info("[DSV4 graph trace] rank=%s group=tp_pp phase=enter_done", rank) with contextlib.ExitStack() as stack: seen = {id(_TP)} for group in (_MOE_EP, _MOE_TP): if group is not None and id(group) not in seen: seen.add(id(group)) + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s group=%s phase=enter_start", + rank, + getattr(group, "group_name", None), + ) stack.enter_context(group.graph_capture(context)) + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s group=%s phase=enter_done", + rank, + getattr(group, "group_name", None), + ) yield context diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 1d5fee38f..340905c9c 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -597,6 +597,8 @@ class Envs: # GEMM / kernel fusion SGLANG_OPT_FP8_WO_A_GEMM = EnvBool(False) + SGLANG_OPT_DSV4_PARALLEL_WO_A = EnvBool(False) + SGLANG_OPT_DSV4_DIRECT_WQ_B = EnvBool(False) SGLANG_OPT_BF16_FP32_GEMM_ALGO = EnvStr("cublas") SGLANG_OPT_USE_JIT_EP_ACTIVATION = EnvBool(True) SGLANG_OPT_USE_JIT_NORM = EnvBool(False) diff --git a/python/sglang/srt/layers/attention/deepseek_v4_backend.py b/python/sglang/srt/layers/attention/deepseek_v4_backend.py index 93e450765..be97d084a 100644 --- a/python/sglang/srt/layers/attention/deepseek_v4_backend.py +++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py @@ -1,10 +1,19 @@ from __future__ import annotations import enum +import atexit import functools +import hashlib +import json import logging +import os +from pathlib import Path +import signal +import threading +import time from dataclasses import dataclass, field from typing import ( + Any, TYPE_CHECKING, Dict, List, @@ -40,6 +49,8 @@ from sglang.srt.layers.dp_attention import ( get_attention_cp_rank, get_attention_cp_size, + get_attention_tp_rank, + get_attention_tp_size, ) from sglang.srt.mem_cache.deepseek_v4_memory_pool import DeepSeekV4TokenToKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -58,10 +69,1803 @@ C4_TOPK = 512 PAGE_INDEX_ALIGNED_SIZE = 64 +_B12X_DSV4_MLA_DECODE_HITS = 0 +_B12X_DSV4_MLA_DECODE_LOGS = 0 +_B12X_DSV4_MLA_TRACE_LOCK = threading.Lock() +_B12X_DSV4_MLA_SHADOW_COMPARE_COUNT = 0 +_B12X_DSV4_MLA_SHADOW_COMPARE_LOCK = threading.Lock() +_B12X_DSV4_MLA_TRACE_COUNTS: Dict[ + Tuple[str, Tuple[int, ...], str, Tuple[int, ...], str, str, int, bool, bool], + int, +] = {} +_B12X_DSV4_MLA_TRACE_REGISTERED = False +_B12X_DSV4_MLA_TRACE_WRITTEN = False +_DSV4_MTP_METADATA_MISMATCH_LOGS = 0 +_DSV4_MTP_METADATA_STEP_LOGS = 0 +_DSV4_ATTN_TRACE_RECORDS = 0 +_DSV4_STORE_TRACE_RECORDS = 0 +_DSV4_PAD_TRACE_RECORDS = 0 +_DSV4_EMPTY_EXTRA_TRACE_RECORDS = 0 +_B12X_DSV4_PREFILL_ROWS4_TRACE_LOCK = threading.Lock() +_B12X_DSV4_PREFILL_ROWS4_ROUTE_COUNTS: Dict[str, int] = {} +_B12X_DSV4_PREFILL_ROWS4_LN2 = 0.6931471805599453 + T = TypeVar("T", bound=Optional[torch.Tensor]) +def get_b12x_dsv4_mla_debug_counters() -> Dict[str, Any]: + with _B12X_DSV4_MLA_TRACE_LOCK: + trace_items = list(_B12X_DSV4_MLA_TRACE_COUNTS.items()) + with _B12X_DSV4_PREFILL_ROWS4_TRACE_LOCK: + rows4_routes = dict(_B12X_DSV4_PREFILL_ROWS4_ROUTE_COUNTS) + + reason_counts: Dict[str, int] = {} + mode_counts: Dict[str, int] = {} + capture_counts: Dict[str, int] = {} + top_shapes: List[Dict[str, Any]] = [] + + for ( + reason, + q_shape, + q_dtype, + k_shape, + k_dtype, + forward_mode, + compress_ratio, + is_draft_worker, + capturing, + ), count in trace_items: + count = int(count) + reason_counts[reason] = reason_counts.get(reason, 0) + count + mode_key = ( + f"{forward_mode}|compress_ratio={compress_ratio}|" + f"draft={int(is_draft_worker)}" + ) + mode_counts[mode_key] = mode_counts.get(mode_key, 0) + count + capture_key = f"{reason}|capturing={int(capturing)}" + capture_counts[capture_key] = capture_counts.get(capture_key, 0) + count + top_shapes.append( + { + "reason": reason, + "q_shape": [int(dim) for dim in q_shape], + "q_dtype": q_dtype, + "k_shape": [int(dim) for dim in k_shape], + "k_dtype": k_dtype, + "forward_mode": forward_mode, + "compress_ratio": int(compress_ratio), + "is_draft_worker": bool(is_draft_worker), + "capturing": bool(capturing), + "count": count, + } + ) + + top_shapes.sort(key=lambda item: int(item["count"]), reverse=True) + return { + "dsv4_mla_decode_hits": int(_B12X_DSV4_MLA_DECODE_HITS), + "dsv4_mla_trace_events": int(sum(count for _, count in trace_items)), + "reason_counts": dict(sorted(reason_counts.items())), + "mode_counts": dict(sorted(mode_counts.items())), + "capture_counts": dict(sorted(capture_counts.items())), + "top_shapes": top_shapes[:32], + "prefill_rows4_routes": dict(sorted(rows4_routes.items())), + } + + +def _env_true(name: str, default: str = "0") -> bool: + return os.environ.get(name, default).lower() in ("1", "true", "yes", "on") + + +def _env_int(name: str, default: int) -> int: + try: + return int(os.environ.get(name, str(default))) + except ValueError: + return default + + +def _is_rank0() -> bool: + try: + if int(get_attention_tp_rank()) != 0: + return False + except Exception: + pass + rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0" + return rank in ("0", "") + + +def _dsv4_attn_trace_enabled() -> bool: + return _env_true("SGLANG_DSV4_ATTN_TRACE", "0") + + +def _dsv4_attn_trace_decode_only() -> bool: + return _env_true("SGLANG_DSV4_ATTN_TRACE_DECODE_ONLY", "0") + + +def _dsv4_attn_trace_min_rows() -> int: + return max(1, _env_int("SGLANG_DSV4_ATTN_TRACE_MIN_ROWS", 1)) + + +def _dsv4_attn_trace_clusters_enabled() -> bool: + return _env_true("SGLANG_DSV4_ATTN_TRACE_CLUSTERS", "0") + + +def _dsv4_attn_trace_qk_proxy_enabled() -> bool: + return _env_true("SGLANG_DSV4_ATTN_TRACE_QK_PROXY", "0") + + +def _dsv4_attn_trace_qk_proxy_heads() -> int: + return max(1, _env_int("SGLANG_DSV4_ATTN_TRACE_QK_PROXY_HEADS", 16)) + + +def _dsv4_store_trace_enabled() -> bool: + return _env_true("SGLANG_DSV4_STORE_TRACE", "0") + + +def _dsv4_store_trace_decode_only() -> bool: + return _env_true("SGLANG_DSV4_STORE_TRACE_DECODE_ONLY", "0") + + +def _dsv4_store_trace_min_rows() -> int: + return max(1, _env_int("SGLANG_DSV4_STORE_TRACE_MIN_ROWS", 1)) + + +def _dsv4_pad_trace_enabled() -> bool: + return _env_true("SGLANG_DSV4_PAD_TRACE", "0") + + +def _dsv4_pad_trace_max_records() -> int: + return max(0, _env_int("SGLANG_DSV4_PAD_TRACE_MAX_RECORDS", 32)) + + +def _dsv4_zero_empty_extra_topk() -> bool: + return _env_true("SGLANG_DSV4_ZERO_EMPTY_EXTRA_TOPK", "0") + + +def _sample_int_tensor(tensor: Optional[torch.Tensor], rows: int = 8, cols: int = 8): + if tensor is None: + return None + try: + value = tensor.detach() + if value.ndim == 0: + return int(value.item()) + if value.ndim == 1: + return [int(x) for x in value[:rows].cpu().tolist()] + return value[:rows, :cols].cpu().tolist() + except Exception as exc: # noqa: BLE001 + return {"error": repr(exc)} + + +def _sample_list(value: Optional[List[Any]], rows: int = 8): + if value is None: + return None + return list(value[:rows]) + + +def _dsv4_row_clusters( + tensor: Optional[torch.Tensor], + *, + rows: int, + extra: Optional[torch.Tensor] = None, +) -> Optional[Dict[str, Any]]: + if ( + not _dsv4_attn_trace_clusters_enabled() + or tensor is None + or not isinstance(tensor, torch.Tensor) + or tensor.shape[0] == 0 + ): + return None + try: + row_count = min(int(tensor.shape[0]), rows) + if row_count <= 0: + return None + byte_rows = ( + tensor.detach()[:row_count] + .contiguous() + .view(torch.uint8) + .reshape(row_count, -1) + .cpu() + ) + extra_rows = None + if extra is not None: + extra_rows = ( + extra.detach()[:row_count] + .contiguous() + .view(torch.uint8) + .reshape(row_count, -1) + .cpu() + ) + byte_columns = int(byte_rows.shape[1]) + byte_limit = max(0, _env_int("SGLANG_DSV4_ATTN_TRACE_CLUSTER_BYTES", 0)) + hash_bytes = min(byte_limit or byte_columns, byte_columns) + groups: Dict[str, List[int]] = {} + row_hashes = [] + for row in range(row_count): + row_bytes = byte_rows[row, :hash_bytes].numpy().tobytes() + if extra_rows is not None: + row_bytes += b"|" + extra_rows[row].numpy().tobytes() + digest = hashlib.blake2s(row_bytes, digest_size=6).hexdigest() + row_hashes.append(digest) + groups.setdefault(digest, []).append(row) + clusters = [ + {"hash": digest, "rows": members} + for digest, members in sorted(groups.items(), key=lambda item: item[1][0]) + ] + return { + "rows": row_count, + "byte_columns": byte_columns, + "hash_bytes": hash_bytes, + "row_hashes": row_hashes, + "clusters": clusters, + "singleton_rows": [ + members[0] for members in groups.values() if len(members) == 1 + ], + } + except Exception as exc: # noqa: BLE001 + return {"error": repr(exc)} + + +def _trace_dsv4_mtp_padding( + *, + layer_id: int, + compress_ratio: int, + forward_batch: "ForwardBatch", + q_rows: int, + indices_before_rows: int, + topk_before_rows: int, + swa_page_indices: torch.Tensor, + swa_topk_lengths: torch.Tensor, +) -> None: + global _DSV4_PAD_TRACE_RECORDS + if not _dsv4_pad_trace_enabled() or not _is_rank0(): + return + max_records = _dsv4_pad_trace_max_records() + if _DSV4_PAD_TRACE_RECORDS >= max_records: + return + padded_from = min(indices_before_rows, topk_before_rows) + if q_rows <= padded_from: + return + sample_rows = max(1, _env_int("SGLANG_DSV4_PAD_TRACE_SAMPLE", 8)) + sample_cols = max(1, _env_int("SGLANG_DSV4_PAD_TRACE_INDEX_SAMPLE", 8)) + padded_indices = swa_page_indices[padded_from:q_rows] + padded_topk = swa_topk_lengths[padded_from:q_rows] + if padded_indices.ndim == 3: + padded_indices = padded_indices[:, 0, :] + active_rows = int((padded_topk > 0).sum().item()) + zero_index_rows = ( + int(((padded_indices == 0).any(dim=-1) & (padded_topk > 0)).sum().item()) + if padded_indices.ndim == 2 and padded_indices.shape[0] > 0 + else 0 + ) + payload = { + "layer": int(layer_id), + "record": int(_DSV4_PAD_TRACE_RECORDS + 1), + "mode": str(getattr(forward_batch, "forward_mode", None)), + "compress_ratio": int(compress_ratio), + "q_rows": int(q_rows), + "indices_before_rows": int(indices_before_rows), + "topk_before_rows": int(topk_before_rows), + "padded_rows": int(q_rows - padded_from), + "padded_active_rows": active_rows, + "padded_rows_with_index0": zero_index_rows, + "padded_indices": _sample_int_tensor(padded_indices, sample_rows, sample_cols), + "padded_topk_lengths": _sample_int_tensor(padded_topk, sample_rows), + "seq_lens": _sample_int_tensor(getattr(forward_batch, "seq_lens", None), sample_rows), + "out_cache_loc": _sample_int_tensor( + getattr(forward_batch, "out_cache_loc", None), sample_rows + ), + } + logger.warning("DSV4_PAD_TRACE %s", json.dumps(payload, sort_keys=True)) + _DSV4_PAD_TRACE_RECORDS += 1 + + +def _maybe_zero_empty_extra_topk( + *, + layer_id: int, + compress_ratio: int, + forward_batch: "ForwardBatch", + extra_indices: Optional[torch.Tensor], + extra_topk_lengths: Optional[torch.Tensor], +) -> Optional[torch.Tensor]: + global _DSV4_EMPTY_EXTRA_TRACE_RECORDS + if ( + not _dsv4_zero_empty_extra_topk() + or extra_indices is None + or extra_topk_lengths is None + ): + return extra_topk_lengths + idx = extra_indices + if idx.ndim == 3: + idx = idx[:, 0, :] + if idx.ndim != 2 or idx.shape[0] != extra_topk_lengths.shape[0]: + return extra_topk_lengths + empty_rows = ~(idx >= 0).any(dim=-1) + active_empty_rows = empty_rows & (extra_topk_lengths > 0) + if not bool(active_empty_rows.any().item()): + return extra_topk_lengths + + fixed_lengths = torch.where( + empty_rows, torch.zeros_like(extra_topk_lengths), extra_topk_lengths + ) + max_records = max(0, _env_int("SGLANG_DSV4_EMPTY_EXTRA_TRACE_MAX_RECORDS", 32)) + if _is_rank0() and _DSV4_EMPTY_EXTRA_TRACE_RECORDS < max_records: + sample_rows = max(1, _env_int("SGLANG_DSV4_EMPTY_EXTRA_TRACE_SAMPLE", 8)) + payload = { + "layer": int(layer_id), + "record": int(_DSV4_EMPTY_EXTRA_TRACE_RECORDS + 1), + "mode": str(getattr(forward_batch, "forward_mode", None)), + "compress_ratio": int(compress_ratio), + "rows": int(idx.shape[0]), + "active_empty_rows": int(active_empty_rows.sum().item()), + "topk_before": _sample_int_tensor(extra_topk_lengths, sample_rows), + "topk_after": _sample_int_tensor(fixed_lengths, sample_rows), + "empty_row_indices": _sample_int_tensor( + torch.nonzero(active_empty_rows, as_tuple=False).flatten(), + sample_rows, + ), + } + logger.warning("DSV4_EMPTY_EXTRA_TRACE %s", json.dumps(payload, sort_keys=True)) + _DSV4_EMPTY_EXTRA_TRACE_RECORDS += 1 + return fixed_lengths + + +def _dsv4_attn_row_delta(tensor: Optional[torch.Tensor]) -> Optional[Dict[str, Any]]: + if tensor is None or not isinstance(tensor, torch.Tensor) or tensor.shape[0] == 0: + return None + try: + flat = tensor.detach().reshape(tensor.shape[0], -1).to(torch.float32) + diff = (flat - flat[:1]).abs() + row_max = diff.max(dim=1).values + max_rows = max(0, _env_int("SGLANG_DSV4_ATTN_TRACE_SAMPLE", 16)) + return { + "shape": [int(dim) for dim in tensor.shape], + "dtype": str(tensor.dtype), + "row_max_abs": float(row_max.max().item()), + "nonzero_rows": [ + int(i) + for i in torch.nonzero(row_max > 0, as_tuple=False) + .flatten() + .detach() + .cpu() + .tolist() + ][:max_rows], + "firstN": [ + float(x) + for x in row_max[:max_rows].detach().cpu().tolist() + ], + "clusters": _dsv4_row_clusters(tensor, rows=max_rows), + } + except Exception as exc: # noqa: BLE001 + return {"error": repr(exc)} + + +def _dsv4_local_q_slice(q: torch.Tensor) -> torch.Tensor: + try: + tp_size = int(get_attention_tp_size()) + if tp_size <= 1 or q.ndim not in (3, 4): + return q + num_heads = int(q.shape[-2]) + local_heads = num_heads // tp_size + if local_heads <= 0: + return q + tp_rank = int(get_attention_tp_rank()) + head_slice = slice(tp_rank * local_heads, (tp_rank + 1) * local_heads) + if q.ndim == 4: + return q[:, :, head_slice, :] + return q[:, head_slice, :] + except Exception: + return q + + +def _dsv4_attn_indices_summary( + indices: Optional[torch.Tensor], + topk_length: Optional[torch.Tensor], +) -> Optional[Dict[str, Any]]: + if indices is None: + return None + try: + rows = max(0, _env_int("SGLANG_DSV4_ATTN_TRACE_SAMPLE", 16)) + cols = max(1, _env_int("SGLANG_DSV4_ATTN_TRACE_INDEX_SAMPLE", 8)) + idx = indices.detach() + if idx.ndim == 3: + idx = idx[:, 0, :] + length_head = None + if topk_length is not None: + lengths = topk_length.detach()[:rows] + length_head = [int(x) for x in lengths.cpu().tolist()] + else: + lengths = None + return { + "shape": [int(dim) for dim in indices.shape], + "dtype": str(indices.dtype), + "head": idx[:rows, :cols].detach().cpu().tolist(), + "topk_length": length_head, + "clusters": _dsv4_row_clusters(idx, rows=rows, extra=lengths), + } + except Exception as exc: # noqa: BLE001 + return {"error": repr(exc)} + + +def _dsv4_attn_cache_delta( + k_cache: Optional[torch.Tensor], + indices: Optional[torch.Tensor], + topk_length: Optional[torch.Tensor], +) -> Optional[Dict[str, Any]]: + if k_cache is None or indices is None: + return None + try: + idx = indices.detach() + if idx.ndim == 3: + idx = idx[:, 0, :] + if idx.shape[0] == 0: + return None + rows = min(idx.shape[0], max(0, _env_int("SGLANG_DSV4_ATTN_TRACE_SAMPLE", 16))) + tokens = max(1, _env_int("SGLANG_DSV4_ATTN_TRACE_TOKENS", 2)) + page_tokens = int(k_cache.shape[1]) if k_cache.ndim >= 2 else 0 + value_bytes = 576 + scale_bytes = 8 + token_bytes = value_bytes + scale_bytes + if page_tokens <= 0 or int(k_cache.shape[-1]) != token_bytes: + return { + "cache_shape": [int(dim) for dim in k_cache.shape], + "cache_dtype": str(k_cache.dtype), + "error": "unexpected_dsv4_cache_shape", + } + page_bytes = page_tokens * token_bytes + value_region_bytes = page_tokens * value_bytes + page_cache = k_cache.detach().reshape(k_cache.shape[0], page_bytes) + if page_cache.dtype != torch.uint8: + page_cache = page_cache.contiguous().view(torch.uint8).reshape( + k_cache.shape[0], page_bytes + ) + page_cache = page_cache.to(torch.int16) + + def token_bytes_at(flat_index: int) -> torch.Tensor: + page = flat_index // page_tokens + offset = flat_index % page_tokens + value_start = offset * value_bytes + scale_start = value_region_bytes + offset * scale_bytes + return torch.cat( + [ + page_cache[page, value_start : value_start + value_bytes], + page_cache[page, scale_start : scale_start + scale_bytes], + ] + ) + + if topk_length is None: + lengths = torch.full((idx.shape[0],), idx.shape[-1], device=idx.device) + else: + lengths = topk_length.detach().to(idx.device) + base_valid = idx[0, : min(tokens, int(lengths[0].item()))] + base_valid = base_valid[base_valid >= 0] + per_token = [] + row_max = [] + cache_vectors = [] + for row in range(rows): + row_valid = idx[row, : min(tokens, int(lengths[row].item()))] + row_valid = row_valid[row_valid >= 0] + if _dsv4_attn_trace_clusters_enabled(): + cache_vectors.append( + torch.cat( + [ + token_bytes_at(int(row_valid[j].item())) + for j in range(int(row_valid.numel())) + ] + ) + if row_valid.numel() > 0 + else torch.empty(0, dtype=torch.int16, device=page_cache.device) + ) + pairs = min(base_valid.numel(), row_valid.numel()) + if pairs == 0: + row_max.append(None) + continue + vals = [] + max_delta = 0 + for j in range(pairs): + base_pos = int(base_valid[j].item()) + row_pos = int(row_valid[j].item()) + delta = (token_bytes_at(row_pos) - token_bytes_at(base_pos)).abs().max().item() + delta_i = int(delta) + max_delta = max(max_delta, delta_i) + vals.append( + { + "token": int(j), + "base_idx": base_pos, + "row_idx": row_pos, + "byte_max_abs": delta_i, + } + ) + row_max.append(max_delta) + per_token.append({"row": int(row), "tokens": vals}) + token_byte_clusters = None + if cache_vectors: + max_width = max(int(vector.numel()) for vector in cache_vectors) + cache_rows = torch.full((rows, max_width), -1, dtype=torch.int16) + for row, vector in enumerate(cache_vectors): + width = int(vector.numel()) + if width > 0: + cache_rows[row, :width] = vector.detach().cpu() + token_byte_clusters = _dsv4_row_clusters(cache_rows, rows=rows) + return { + "cache_shape": [int(dim) for dim in k_cache.shape], + "cache_dtype": str(k_cache.dtype), + "row_byte_max_abs": row_max, + "token_deltas": per_token, + "token_byte_clusters": token_byte_clusters, + } + except Exception as exc: # noqa: BLE001 + return {"error": repr(exc)} + + +def _dsv4_decode_k_cache_token( + page_cache: torch.Tensor, + flat_index: int, + *, + page_tokens: int, + value_region_bytes: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + page = flat_index // page_tokens + offset = flat_index % page_tokens + value_start = offset * 576 + scale_start = value_region_bytes + offset * 8 + token = page_cache[page] + nope_bytes = token[value_start : value_start + 448] + rope_bytes = token[value_start + 448 : value_start + 576] + scale_u8 = token[scale_start : scale_start + 8] + + nope_fp8 = nope_bytes.contiguous().view(torch.float8_e4m3fn).to(torch.float32) + scales = torch.pow(2.0, scale_u8[:7].to(torch.float32) - 127.0) + nope = (nope_fp8.reshape(7, 64) * scales.unsqueeze(-1)).reshape(448) + rope = rope_bytes.contiguous().view(torch.bfloat16).to(torch.float32) + return nope, rope, scale_u8 + + +def _dsv4_attn_qk_proxy( + *, + q: torch.Tensor, + k_cache: Optional[torch.Tensor], + indices: Optional[torch.Tensor], + topk_length: Optional[torch.Tensor], + softmax_scale: float, + rows: int, + tokens: int, +) -> Optional[Dict[str, Any]]: + if not _dsv4_attn_trace_qk_proxy_enabled(): + return None + if k_cache is None or indices is None: + return None + try: + qv = q.detach() + if qv.ndim == 4: + qv = qv[:, 0] + if qv.ndim != 3 or qv.shape[-1] != 512: + return {"error": f"unexpected_q_shape:{tuple(q.shape)}"} + + idx = indices.detach() + if idx.ndim == 3: + idx = idx[:, 0, :] + if idx.ndim != 2 or idx.shape[0] == 0: + return {"error": f"unexpected_indices_shape:{tuple(indices.shape)}"} + + page_tokens = int(k_cache.shape[1]) if k_cache.ndim >= 2 else 0 + token_bytes = 584 + value_bytes = 576 + if page_tokens <= 0 or int(k_cache.shape[-1]) != token_bytes: + return { + "cache_shape": [int(dim) for dim in k_cache.shape], + "cache_dtype": str(k_cache.dtype), + "error": "unexpected_dsv4_cache_shape", + } + page_bytes = page_tokens * token_bytes + value_region_bytes = page_tokens * value_bytes + page_cache = k_cache.detach().reshape(k_cache.shape[0], page_bytes) + if page_cache.dtype is not torch.uint8: + page_cache = page_cache.contiguous().view(torch.uint8).reshape( + k_cache.shape[0], page_bytes + ) + + if topk_length is None: + lengths = torch.full((idx.shape[0],), idx.shape[-1], device=idx.device) + else: + lengths = topk_length.detach().to(idx.device) + + row_count = min(int(qv.shape[0]), int(idx.shape[0]), max(1, rows)) + head_count = min(int(qv.shape[1]), _dsv4_attn_trace_qk_proxy_heads()) + token_limit = max(1, tokens) + fp8_max = float(torch.finfo(torch.float8_e4m3fn).max) + + lse_abs_sum = 0.0 + lse_signed_sum = 0.0 + lse_max_abs = 0.0 + argmax_agree = 0 + topk_overlap_sum = 0.0 + score_abs_mean_sum = 0.0 + score_abs_max = 0.0 + total_heads = 0 + rows_with_tokens = 0 + selected_tokens = 0 + q_abs_max = 0.0 + k_nope_abs_max = 0.0 + k_rope_abs_max = 0.0 + k_scale_exp_min: Optional[int] = None + k_scale_exp_max: Optional[int] = None + + num_kv = int(k_cache.shape[0]) * page_tokens + for row in range(row_count): + row_len = min(token_limit, int(lengths[row].item())) + row_idx = idx[row, :row_len] + row_idx = row_idx[(row_idx >= 0) & (row_idx < num_kv)] + if row_idx.numel() == 0: + continue + + nope_parts = [] + rope_parts = [] + scale_parts = [] + for flat in row_idx.detach().cpu().tolist(): + nope, rope, scale_u8 = _dsv4_decode_k_cache_token( + page_cache, + int(flat), + page_tokens=page_tokens, + value_region_bytes=value_region_bytes, + ) + nope_parts.append(nope) + rope_parts.append(rope) + scale_parts.append(scale_u8[:7]) + k_nope = torch.stack(nope_parts, dim=0) + k_rope = torch.stack(rope_parts, dim=0) + k_scale = torch.stack(scale_parts, dim=0).to(torch.int16) - 127 + + q_row = qv[row, :head_count].to(torch.float32) + q_nope = q_row[:, :448] + q_rope = q_row[:, 448:] + q_nope_e4m3 = ( + q_nope.clamp(min=-fp8_max, max=fp8_max) + .to(torch.float8_e4m3fn) + .to(torch.float32) + ) + + ref_scores = ( + torch.matmul(q_nope, k_nope.transpose(0, 1)) + + torch.matmul(q_rope, k_rope.transpose(0, 1)) + ) * float(softmax_scale) + proxy_scores = ( + torch.matmul(q_nope_e4m3, k_nope.transpose(0, 1)) + + torch.matmul(q_rope, k_rope.transpose(0, 1)) + ) * float(softmax_scale) + diff = ref_scores - proxy_scores + lse_diff = torch.logsumexp(proxy_scores, dim=-1) - torch.logsumexp( + ref_scores, dim=-1 + ) + + local_heads = int(lse_diff.numel()) + total_heads += local_heads + rows_with_tokens += 1 + selected_tokens += int(row_idx.numel()) + lse_abs_sum += float(lse_diff.abs().sum().item()) + lse_signed_sum += float(lse_diff.sum().item()) + lse_max_abs = max(lse_max_abs, float(lse_diff.abs().max().item())) + score_abs = diff.abs() + score_abs_mean_sum += float(score_abs.mean().item()) + score_abs_max = max(score_abs_max, float(score_abs.max().item())) + + local_k = min(8, int(ref_scores.shape[-1])) + if local_k > 0: + ref_top = torch.topk(ref_scores, k=local_k, dim=-1).indices + proxy_top = torch.topk(proxy_scores, k=local_k, dim=-1).indices + argmax_agree += int((ref_top[:, 0] == proxy_top[:, 0]).sum().item()) + for h in range(int(ref_top.shape[0])): + ref_set = set(int(x) for x in ref_top[h].detach().cpu().tolist()) + proxy_set = set(int(x) for x in proxy_top[h].detach().cpu().tolist()) + topk_overlap_sum += float(len(ref_set & proxy_set)) / float(local_k) + + q_abs_max = max(q_abs_max, float(q_row.abs().max().item())) + k_nope_abs_max = max(k_nope_abs_max, float(k_nope.abs().max().item())) + k_rope_abs_max = max(k_rope_abs_max, float(k_rope.abs().max().item())) + row_scale_min = int(k_scale.min().item()) + row_scale_max = int(k_scale.max().item()) + k_scale_exp_min = ( + row_scale_min + if k_scale_exp_min is None + else min(k_scale_exp_min, row_scale_min) + ) + k_scale_exp_max = ( + row_scale_max + if k_scale_exp_max is None + else max(k_scale_exp_max, row_scale_max) + ) + + denom = float(max(total_heads, 1)) + return { + "sample_rows": int(row_count), + "rows_with_tokens": int(rows_with_tokens), + "sample_heads": int(head_count), + "sample_token_limit": int(token_limit), + "selected_tokens": int(selected_tokens), + "score_proxy_lse_mean_abs": lse_abs_sum / denom, + "score_proxy_lse_mean_signed": lse_signed_sum / denom, + "score_proxy_lse_max_abs": float(lse_max_abs), + "score_proxy_argmax_agree": float(argmax_agree) / denom, + "score_proxy_topk_overlap": topk_overlap_sum / denom, + "score_proxy_abs_mean": score_abs_mean_sum / float(max(rows_with_tokens, 1)), + "score_proxy_abs_max": float(score_abs_max), + "q_sample_abs_max": float(q_abs_max), + "k_nope_sample_abs_max": float(k_nope_abs_max), + "k_rope_sample_abs_max": float(k_rope_abs_max), + "k_scale_exp_min": k_scale_exp_min, + "k_scale_exp_max": k_scale_exp_max, + } + except Exception as exc: # noqa: BLE001 + return {"error": repr(exc)} + + +def _dsv4_sample_tensor(tensor: Optional[torch.Tensor]) -> Optional[List[Any]]: + if tensor is None or not isinstance(tensor, torch.Tensor): + return None + try: + sample = max(0, _env_int("SGLANG_DSV4_STORE_TRACE_SAMPLE", 16)) + return tensor.detach().flatten()[:sample].cpu().tolist() + except Exception as exc: # noqa: BLE001 + return [repr(exc)] + + +def _dsv4_byte_row_delta(tensor: Optional[Any]) -> Optional[Dict[str, Any]]: + if tensor is not None and not isinstance(tensor, torch.Tensor): + fields = {} + for name in ("k_nope_fp8", "k_rope_bf16", "scale_k_nope_ue8m0"): + value = getattr(tensor, name, None) + if value is not None: + fields[name] = _dsv4_byte_row_delta(value) + return fields or None + if tensor is None or tensor.shape[0] == 0: + return None + try: + rows = min( + tensor.shape[0], + max(0, _env_int("SGLANG_DSV4_STORE_TRACE_SAMPLE", 16)), + ) + byte_rows = ( + tensor.detach() + .contiguous() + .view(torch.uint8) + .reshape(tensor.shape[0], -1) + .to(torch.int16) + ) + diff = (byte_rows - byte_rows[:1]).abs().max(dim=1).values + return { + "shape": [int(dim) for dim in tensor.shape], + "dtype": str(tensor.dtype), + "row_byte_max_abs": [int(x) for x in diff[:rows].cpu().tolist()], + } + except Exception as exc: # noqa: BLE001 + return {"error": repr(exc)} + + +def _trace_dsv4_store_cache( + *, + layer_id: int, + forward_batch: ForwardBatch, + raw_loc: torch.Tensor, + swa_k: torch.Tensor, + packed: Optional[torch.Tensor] = None, +) -> None: + global _DSV4_STORE_TRACE_RECORDS + if not _dsv4_store_trace_enabled() or not _is_rank0(): + return + if swa_k is None or swa_k.shape[0] < _dsv4_store_trace_min_rows(): + return + mode = _b12x_dsv4_mla_forward_mode(forward_batch) + if _dsv4_store_trace_decode_only() and mode != "DECODE": + return + max_records = max(0, _env_int("SGLANG_DSV4_STORE_TRACE_MAX_RECORDS", 64)) + if _DSV4_STORE_TRACE_RECORDS >= max_records: + return + _DSV4_STORE_TRACE_RECORDS += 1 + try: + payload = { + "record": int(_DSV4_STORE_TRACE_RECORDS), + "layer": int(layer_id), + "mode": mode, + "raw_loc": _dsv4_sample_tensor(raw_loc), + "swa_k": _dsv4_attn_row_delta(swa_k), + "packed": _dsv4_byte_row_delta(packed), + } + logger.info("DSV4_STORE_TRACE %s", json.dumps(payload, sort_keys=True)) + except Exception as exc: # noqa: BLE001 + logger.warning("DSV4_STORE_TRACE failed: %s", exc) + + +def _trace_dsv4_attn_inputs( + *, + layer_id: int, + forward_batch: ForwardBatch, + compress_ratio: int, + softmax_scale: float, + q: torch.Tensor, + k_cache: torch.Tensor, + indices: torch.Tensor, + topk_length: torch.Tensor, + extra_k_cache: Optional[torch.Tensor], + extra_indices: Optional[torch.Tensor], + extra_topk_length: Optional[torch.Tensor], +) -> None: + global _DSV4_ATTN_TRACE_RECORDS + if not _dsv4_attn_trace_enabled() or not _is_rank0(): + return + if _b12x_dsv4_mla_capturing(): + return + if q.shape[0] < _dsv4_attn_trace_min_rows(): + return + mode = _b12x_dsv4_mla_forward_mode(forward_batch) + if _dsv4_attn_trace_decode_only() and mode != "DECODE": + return + max_records = max(0, _env_int("SGLANG_DSV4_ATTN_TRACE_MAX_RECORDS", 64)) + if _DSV4_ATTN_TRACE_RECORDS >= max_records: + return + _DSV4_ATTN_TRACE_RECORDS += 1 + try: + rows = max(0, _env_int("SGLANG_DSV4_ATTN_TRACE_SAMPLE", 16)) + proxy_tokens = max(1, _env_int("SGLANG_DSV4_ATTN_TRACE_TOKENS", 2)) + payload = { + "record": int(_DSV4_ATTN_TRACE_RECORDS), + "layer": int(layer_id), + "mode": mode, + "compress_ratio": int(compress_ratio), + "positions": _sample_int_tensor( + getattr(forward_batch, "positions", None), rows + ), + "seq_lens": _sample_int_tensor( + getattr(forward_batch, "seq_lens", None), rows + ), + "out_cache_loc": _sample_int_tensor( + getattr(forward_batch, "out_cache_loc", None), rows + ), + "input_ids": _sample_int_tensor( + getattr(forward_batch, "input_ids", None), rows + ), + "req_pool_indices": _sample_int_tensor( + getattr(forward_batch, "req_pool_indices", None), rows + ), + "rids": _sample_list(getattr(forward_batch, "rids", None), rows), + "q": _dsv4_attn_row_delta(q), + "q_local": _dsv4_attn_row_delta(_dsv4_local_q_slice(q)), + "indices": _dsv4_attn_indices_summary(indices, topk_length), + "k_cache": _dsv4_attn_cache_delta(k_cache, indices, topk_length), + "qk_proxy": _dsv4_attn_qk_proxy( + q=q, + k_cache=k_cache, + indices=indices, + topk_length=topk_length, + softmax_scale=softmax_scale, + rows=rows, + tokens=proxy_tokens, + ), + "extra_indices": _dsv4_attn_indices_summary( + extra_indices, extra_topk_length + ), + "extra_k_cache": _dsv4_attn_cache_delta( + extra_k_cache, extra_indices, extra_topk_length + ), + "extra_qk_proxy": _dsv4_attn_qk_proxy( + q=q, + k_cache=extra_k_cache, + indices=extra_indices, + topk_length=extra_topk_length, + softmax_scale=softmax_scale, + rows=rows, + tokens=proxy_tokens, + ), + } + logger.info("DSV4_ATTN_TRACE %s", json.dumps(payload, sort_keys=True)) + except Exception as exc: # noqa: BLE001 + logger.warning("DSV4_ATTN_TRACE failed: %s", exc) + + +def _b12x_dsv4_mla_decode_enabled() -> bool: + if not _env_true("B12X_MLA_DSV4_SGLANG_DIRECT", "1"): + return False + return _env_true("B12X_MLA_DSV4_CUTE") or _env_true("B12X_MLA_DECODE_CUTE") + + +def _b12x_dsv4_mla_verbose() -> bool: + return _env_true("B12X_MLA_DSV4_CUTE_VERBOSE") or _env_true( + "B12X_MLA_DECODE_CUTE_VERBOSE" + ) + + +def _b12x_dsv4_mla_target_only() -> bool: + return _env_true("B12X_MLA_DSV4_TARGET_ONLY", "1") + + +def _b12x_dsv4_mla_allow_target_verify() -> bool: + return _env_true("B12X_MLA_DSV4_ALLOW_TARGET_VERIFY", "0") + + +def _b12x_dsv4_mla_allow_draft_extend() -> bool: + return _env_true("B12X_MLA_DSV4_ALLOW_DRAFT_EXTEND", "0") + + +def _b12x_dsv4_mla_allow_draft_decode() -> bool: + return _env_true("B12X_MLA_DSV4_ALLOW_DRAFT_DECODE", "0") + + +def _b12x_dsv4_mla_bf16_qk_enabled() -> bool: + return _env_true("B12X_MLA_DEBUG_QK_BF16", "0") + + +def _b12x_dsv4_mla_needs_bf16_qk( + *, is_draft_worker: bool, compress_ratio: Literal[0, 4, 128] +) -> bool: + return bool(is_draft_worker or compress_ratio != 0) + + +def _b12x_dsv4_mla_allow_extra_tier() -> bool: + return _env_true("B12X_MLA_DECODE_CUTE_ALLOW_EXTRA", "0") + + +def _dsv4_mtp_step_metadata_enabled() -> bool: + return _env_true("SGLANG_DSV4_MTP_STEP_METADATA", "0") + + +def _b12x_dsv4_mla_graph_only() -> bool: + return _env_true("B12X_MLA_DSV4_GRAPH_ONLY", "0") + + +def _b12x_dsv4_mla_eager_only() -> bool: + return _env_true("B12X_MLA_DSV4_EAGER_ONLY", "0") + + +def _b12x_dsv4_prefill_rows4_enabled() -> bool: + return _env_true("B12X_MLA_DSV4_PREFILL_ROWS4_DIRECT", "0") + + +def _b12x_dsv4_mla_live_min_rows(compress_ratio: Literal[0, 4, 128]) -> int: + if compress_ratio == 0: + env_name = "B12X_MLA_DSV4_LIVE_MIN_ROWS_MAIN" + default = "32" + else: + env_name = "B12X_MLA_DSV4_LIVE_MIN_ROWS_EXTRA" + default = "1" + return _b12x_dsv4_mla_row_limit(env_name, default) + + +def _b12x_dsv4_mla_trace_path() -> str: + return os.environ.get("B12X_MLA_DSV4_TRACE_PATH", "").strip() + + +def _b12x_dsv4_mla_trace_events_enabled() -> bool: + return _env_true("B12X_MLA_DSV4_TRACE_EVENTS", "0") + + +def _b12x_dsv4_mla_shadow_compare_enabled() -> bool: + return _env_true("B12X_MLA_DSV4_SHADOW_COMPARE", "0") + + +def _b12x_dsv4_mla_shadow_compare_max() -> int: + try: + return max(0, int(os.environ.get("B12X_MLA_DSV4_SHADOW_COMPARE_MAX", "8"))) + except ValueError: + return 8 + + +def _b12x_dsv4_mla_shadow_compare_path() -> str: + return os.environ.get("B12X_MLA_DSV4_SHADOW_COMPARE_PATH", "").strip() + + +def _b12x_dsv4_mla_shadow_return_reference() -> bool: + return _env_true("B12X_MLA_DSV4_SHADOW_RETURN_REFERENCE", "0") + + +def _b12x_dsv4_mla_trace_output_path(raw: str) -> Path: + path = Path(raw) + rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "x" + name = f"b12x_mla_dsv4_trace.rank{rank}.pid{os.getpid()}.json" + if path.suffix == ".json": + return path.with_name(f"{path.stem}.rank{rank}.pid{os.getpid()}{path.suffix}") + return path / name + + +def _b12x_dsv4_mla_trace_event_path(raw: str) -> Path: + path = Path(raw) + rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "x" + name = f"b12x_mla_dsv4_trace.rank{rank}.pid{os.getpid()}.jsonl" + if path.suffix == ".jsonl": + return path.with_name(f"{path.stem}.rank{rank}.pid{os.getpid()}{path.suffix}") + if path.suffix == ".json": + return path.with_name(f"{path.stem}.rank{rank}.pid{os.getpid()}.events.jsonl") + return path / name + + +def _b12x_tensor_shape(tensor: Optional[torch.Tensor]) -> Optional[List[int]]: + if tensor is None: + return None + if not isinstance(tensor, torch.Tensor): + try: + return [len(tensor)] + except TypeError: + return None + return [int(dim) for dim in tensor.shape] + + +def _b12x_tensor_dtype(tensor: Optional[torch.Tensor]) -> Optional[str]: + if tensor is None: + return None + if not isinstance(tensor, torch.Tensor): + return type(tensor).__name__ + return str(tensor.dtype) + + +def _b12x_tensor_stride(tensor: Optional[torch.Tensor]) -> Optional[List[int]]: + if tensor is None: + return None + if not isinstance(tensor, torch.Tensor): + return None + return [int(dim) for dim in tensor.stride()] + + +def _b12x_dsv4_mla_forward_mode(forward_batch: Optional[ForwardBatch]) -> str: + if forward_batch is None: + return "unknown" + mode = getattr(forward_batch, "forward_mode", "unknown") + return str(getattr(mode, "name", mode)) + + +def _b12x_dsv4_mla_capturing() -> bool: + state = _b12x_dsv4_mla_capture_state() + return bool( + state.get("cuda_stream_capture", False) + or state.get("sglang_model_capture", False) + ) + + +def _b12x_dsv4_mla_capture_state() -> Dict[str, bool]: + state = { + "cuda_stream_capture": False, + "sglang_model_capture": False, + } + try: + state["cuda_stream_capture"] = bool(torch.cuda.is_current_stream_capturing()) + except Exception: # noqa: BLE001 + pass + try: + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + + state["sglang_model_capture"] = bool(get_is_capture_mode()) + except Exception: # noqa: BLE001 + pass + return state + + +def _flush_b12x_dsv4_mla_trace() -> None: + global _B12X_DSV4_MLA_TRACE_WRITTEN + raw = _b12x_dsv4_mla_trace_path() + if not raw or _B12X_DSV4_MLA_TRACE_WRITTEN: + return + with _B12X_DSV4_MLA_TRACE_LOCK: + if _B12X_DSV4_MLA_TRACE_WRITTEN: + return + rows = [ + { + "reason": reason, + "q_shape": list(q_shape), + "q_dtype": q_dtype, + "k_shape": list(k_shape), + "k_dtype": k_dtype, + "forward_mode": forward_mode, + "compress_ratio": compress_ratio, + "is_draft_worker": is_draft_worker, + "capturing": capturing, + "count": count, + } + for ( + reason, + q_shape, + q_dtype, + k_shape, + k_dtype, + forward_mode, + compress_ratio, + is_draft_worker, + capturing, + ), count in sorted(_B12X_DSV4_MLA_TRACE_COUNTS.items()) + ] + _B12X_DSV4_MLA_TRACE_WRITTEN = True + try: + output = _b12x_dsv4_mla_trace_output_path(raw) + output.parent.mkdir(parents=True, exist_ok=True) + payload = { + "pid": os.getpid(), + "rank": os.environ.get("RANK"), + "local_rank": os.environ.get("LOCAL_RANK"), + "time": time.time(), + "counts": rows, + } + output.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n") + except Exception as exc: # noqa: BLE001 + logger.warning("[b12x MLA] failed to write trace: %s", exc) + + +def _record_b12x_dsv4_mla_trace( + reason: str, + *, + q: torch.Tensor, + k_cache: torch.Tensor, + indices: Optional[torch.Tensor], + extra_k_cache: Optional[torch.Tensor], + extra_indices: Optional[torch.Tensor], + extra_topk_lengths: Optional[torch.Tensor], + head_dim_v: int, + is_fp8_kvcache: bool, + compress_ratio: Optional[int], + forward_batch: Optional[ForwardBatch], + is_draft_worker: bool, +) -> None: + global _B12X_DSV4_MLA_TRACE_REGISTERED + raw = _b12x_dsv4_mla_trace_path() + if not raw: + return + if not _B12X_DSV4_MLA_TRACE_REGISTERED: + atexit.register(_flush_b12x_dsv4_mla_trace) + if threading.current_thread() is threading.main_thread(): + for signum in (signal.SIGINT, signal.SIGTERM): + try: + previous = signal.getsignal(signum) + + def _handler(sig, frame, prev=previous): # noqa: ANN001 + _flush_b12x_dsv4_mla_trace() + if callable(prev): + prev(sig, frame) + raise SystemExit(128 + int(sig)) + + signal.signal(signum, _handler) + except Exception: # noqa: BLE001 + pass + _B12X_DSV4_MLA_TRACE_REGISTERED = True + + capturing = _b12x_dsv4_mla_capturing() + forward_mode = _b12x_dsv4_mla_forward_mode(forward_batch) + q_shape = tuple(int(dim) for dim in q.shape) + k_shape = tuple(int(dim) for dim in k_cache.shape) + key = ( + reason, + q_shape, + str(q.dtype), + k_shape, + str(k_cache.dtype), + forward_mode, + int(compress_ratio) if compress_ratio is not None else -1, + bool(is_draft_worker), + bool(capturing), + ) + with _B12X_DSV4_MLA_TRACE_LOCK: + _B12X_DSV4_MLA_TRACE_COUNTS[key] = ( + _B12X_DSV4_MLA_TRACE_COUNTS.get(key, 0) + 1 + ) + if not _b12x_dsv4_mla_trace_events_enabled(): + return + try: + output = _b12x_dsv4_mla_trace_event_path(raw) + output.parent.mkdir(parents=True, exist_ok=True) + payload = { + "pid": os.getpid(), + "rank": os.environ.get("RANK"), + "local_rank": os.environ.get("LOCAL_RANK"), + "time": time.time(), + "reason": reason, + "q_shape": list(q_shape), + "q_dtype": str(q.dtype), + "k_shape": list(k_shape), + "k_dtype": str(k_cache.dtype), + "indices_shape": _b12x_tensor_shape(indices), + "indices_dtype": _b12x_tensor_dtype(indices), + "extra_k_shape": _b12x_tensor_shape(extra_k_cache), + "extra_k_dtype": _b12x_tensor_dtype(extra_k_cache), + "extra_indices_shape": _b12x_tensor_shape(extra_indices), + "extra_indices_dtype": _b12x_tensor_dtype(extra_indices), + "extra_topk_lengths_shape": _b12x_tensor_shape(extra_topk_lengths), + "extra_topk_lengths_dtype": _b12x_tensor_dtype(extra_topk_lengths), + "head_dim_v": int(head_dim_v), + "is_fp8_kvcache": bool(is_fp8_kvcache), + "compress_ratio": int(compress_ratio) if compress_ratio is not None else None, + "forward_mode": forward_mode, + "is_draft_worker": bool(is_draft_worker), + "target_only": _b12x_dsv4_mla_target_only(), + "allow_target_verify": _b12x_dsv4_mla_allow_target_verify(), + "allow_draft_extend": _b12x_dsv4_mla_allow_draft_extend(), + "allow_draft_decode": _b12x_dsv4_mla_allow_draft_decode(), + "eager_only": _b12x_dsv4_mla_eager_only(), + "capturing": bool(capturing), + } + with output.open("a") as handle: + handle.write(json.dumps(payload, sort_keys=True) + "\n") + except Exception as exc: # noqa: BLE001 + logger.warning("[b12x MLA] failed to append trace event: %s", exc) + + +def _record_b12x_dsv4_prefill_rows4_route(reason: str) -> None: + with _B12X_DSV4_PREFILL_ROWS4_TRACE_LOCK: + _B12X_DSV4_PREFILL_ROWS4_ROUTE_COUNTS[reason] = ( + _B12X_DSV4_PREFILL_ROWS4_ROUTE_COUNTS.get(reason, 0) + 1 + ) + + +@dataclass +class _Rows4PrefillWorkspace: + rows_capacity: int + width: int + extra_width: int + q_ws: Optional[torch.Tensor] + page_ws: torch.Tensor + active_ws: torch.Tensor + union_indices: torch.Tensor + union_masks: torch.Tensor + union_counts: torch.Tensor + output: torch.Tensor + lse_log2: torch.Tensor + lse: torch.Tensor + sink_f32: torch.Tensor + sink_scratch: torch.Tensor + sink_scale: torch.Tensor + sink_scale_cast: torch.Tensor + extra_page_ws: Optional[torch.Tensor] = None + extra_active_ws: Optional[torch.Tensor] = None + extra_union_indices: Optional[torch.Tensor] = None + extra_union_masks: Optional[torch.Tensor] = None + extra_union_counts: Optional[torch.Tensor] = None + + +def _b12x_scalar_stats(tensor: Optional[torch.Tensor]) -> Optional[Dict[str, Any]]: + if tensor is None: + return None + if not isinstance(tensor, torch.Tensor): + try: + tensor = torch.as_tensor(tensor) + except Exception: # noqa: BLE001 + return None + flat = tensor.detach().reshape(-1) + if flat.numel() == 0: + return {"numel": 0} + vals = flat.to(torch.float32) + return { + "numel": int(flat.numel()), + "min": float(vals.min().item()), + "max": float(vals.max().item()), + "mean": float(vals.mean().item()), + } + + +def _b12x_output_compare_stats( + b12x_out: torch.Tensor, + ref_out: torch.Tensor, +) -> Dict[str, Any]: + a = b12x_out.detach().to(torch.float32).reshape(-1) + b = ref_out.detach().to(torch.float32).reshape(-1) + if a.numel() == 0 or b.numel() == 0: + return {"numel": int(min(a.numel(), b.numel()))} + diff = (a - b).abs() + denom = torch.linalg.vector_norm(a) * torch.linalg.vector_norm(b) + cosine = torch.dot(a, b) / torch.clamp(denom, min=1.0e-20) + return { + "numel": int(a.numel()), + "max_abs": float(diff.max().item()) if diff.numel() else 0.0, + "mean_abs": float(diff.mean().item()) if diff.numel() else 0.0, + "rms_abs": float(torch.sqrt(torch.mean(diff * diff)).item()) + if diff.numel() + else 0.0, + "cosine": float(cosine.item()), + } + + +def _b12x_output_compare_row_stats( + b12x_out: torch.Tensor, + ref_out: torch.Tensor, + *, + max_worst_rows: int = 4, +) -> Dict[str, Any]: + if b12x_out.shape[0] == 0 or ref_out.shape[0] == 0: + return {"rows": 0} + a = b12x_out.detach().to(torch.float32).reshape(b12x_out.shape[0], -1) + b = ref_out.detach().to(torch.float32).reshape(ref_out.shape[0], -1) + rows = min(a.shape[0], b.shape[0]) + a = a[:rows] + b = b[:rows] + diff = (a - b).abs() + row_max = diff.max(dim=1).values + row_mean = diff.mean(dim=1) + denom = torch.linalg.vector_norm(a, dim=1) * torch.linalg.vector_norm(b, dim=1) + row_cos = (a * b).sum(dim=1) / torch.clamp(denom, min=1.0e-20) + worst_n = min(max(0, int(max_worst_rows)), rows) + worst_rows = [] + if worst_n > 0: + worst = torch.topk(row_max, k=worst_n).indices.detach().cpu().tolist() + row_max_cpu = row_max.detach().cpu().tolist() + row_mean_cpu = row_mean.detach().cpu().tolist() + row_cos_cpu = row_cos.detach().cpu().tolist() + worst_rows = [ + { + "row": int(row), + "max_abs": float(row_max_cpu[row]), + "mean_abs": float(row_mean_cpu[row]), + "cosine": float(row_cos_cpu[row]), + } + for row in worst + ] + return { + "rows": int(rows), + "row_max_abs_max": float(row_max.max().item()), + "row_max_abs_p50": float(torch.quantile(row_max, 0.5).item()), + "row_max_abs_p90": float(torch.quantile(row_max, 0.9).item()), + "row_mean_abs_p50": float(torch.quantile(row_mean, 0.5).item()), + "row_cosine_min": float(row_cos.min().item()), + "row_cosine_p50": float(torch.quantile(row_cos, 0.5).item()), + "worst_rows_by_max_abs": worst_rows, + } + + +def _b12x_dsv4_shadow_worst_rows() -> int: + try: + return max(0, int(os.environ.get("B12X_MLA_DSV4_SHADOW_WORST_ROWS", "4"))) + except ValueError: + return 4 + + +def _b12x_tensor_head(tensor: Optional[torch.Tensor], n: int = 8) -> Optional[List[Any]]: + if tensor is None: + return None + if not isinstance(tensor, torch.Tensor): + return list(tensor[:n]) if hasattr(tensor, "__getitem__") else None + flat = tensor.detach().reshape(-1) + if flat.numel() == 0: + return [] + return flat[: min(n, flat.numel())].cpu().tolist() + + +def _b12x_shadow_forward_context( + forward_batch: Optional[ForwardBatch], + *, + q_rows: int, +) -> Dict[str, Any]: + if forward_batch is None: + return { + "batch_size": None, + "valid_q_rows_estimate": int(q_rows), + "padding_q_rows_estimate": 0, + } + mode = getattr(forward_batch, "forward_mode", None) + mode_name = str(getattr(mode, "name", mode)) + batch_size = int(getattr(forward_batch, "batch_size", q_rows)) + spec_info = getattr(forward_batch, "spec_info", None) + draft_token_num = getattr(spec_info, "draft_token_num", None) + num_tokens_per_req = getattr(spec_info, "num_tokens_per_req", None) + expected_rows = batch_size + if mode is not None and mode.is_target_verify(): + expected_rows = batch_size * int(draft_token_num or 1) + elif mode is not None and mode.is_draft_extend(include_v2=True): + expected_rows = batch_size * int(num_tokens_per_req or 1) + elif mode is not None and mode.is_prefill(include_draft_extend_v2=True): + expected_rows = int(getattr(forward_batch, "extend_num_tokens", q_rows) or q_rows) + valid_rows = min(int(q_rows), int(expected_rows)) + context = { + "batch_size": batch_size, + "forward_mode": mode_name, + "valid_q_rows_estimate": int(valid_rows), + "expected_q_rows": int(expected_rows), + "padding_q_rows_estimate": int(max(0, q_rows - valid_rows)), + "seq_lens_shape": _b12x_tensor_shape(getattr(forward_batch, "seq_lens", None)), + "seq_lens_stats": _b12x_scalar_stats(getattr(forward_batch, "seq_lens", None)), + "seq_lens_head": _b12x_tensor_head(getattr(forward_batch, "seq_lens", None)), + "seq_lens_cpu_shape": _b12x_tensor_shape( + getattr(forward_batch, "seq_lens_cpu", None) + ), + "seq_lens_cpu_head": _b12x_tensor_head( + getattr(forward_batch, "seq_lens_cpu", None) + ), + "req_pool_indices_shape": _b12x_tensor_shape( + getattr(forward_batch, "req_pool_indices", None) + ), + "req_pool_indices_head": _b12x_tensor_head( + getattr(forward_batch, "req_pool_indices", None) + ), + "out_cache_loc_shape": _b12x_tensor_shape( + getattr(forward_batch, "out_cache_loc", None) + ), + "out_cache_loc_head": _b12x_tensor_head( + getattr(forward_batch, "out_cache_loc", None) + ), + } + if spec_info is not None: + context["spec_info_type"] = type(spec_info).__name__ + context["spec_draft_token_num"] = ( + None if draft_token_num is None else int(draft_token_num) + ) + context["spec_num_tokens_per_req"] = ( + None if num_tokens_per_req is None else int(num_tokens_per_req) + ) + return context + + +def _append_b12x_shadow_compare_record(record: Dict[str, Any]) -> None: + raw = _b12x_dsv4_mla_shadow_compare_path() + if not raw: + logger.warning("[b12x MLA shadow] %s", json.dumps(record, sort_keys=True)) + return + try: + path = Path(raw) + if path.suffix != ".jsonl": + rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "x" + path = path / f"b12x_mla_dsv4_shadow.rank{rank}.pid{os.getpid()}.jsonl" + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a") as handle: + handle.write(json.dumps(record, sort_keys=True) + "\n") + except Exception as exc: # noqa: BLE001 + logger.warning("[b12x MLA shadow] failed to write record: %s", exc) + + +def _b12x_dsv4_mla_row_limit(name: str, default: str) -> int: + try: + value = int(os.environ.get(name, default)) + except ValueError: + return int(default) + return max(1, value) + + +def _b12x_dsv4_mla_decode_reject_reason( + *, + q: torch.Tensor, + k_cache: torch.Tensor, + indices: Optional[torch.Tensor], + extra_k_cache: Optional[torch.Tensor], + extra_indices: Optional[torch.Tensor], + extra_topk_lengths: Optional[torch.Tensor], + head_dim_v: int, + is_fp8_kvcache: bool, + compress_ratio: Literal[0, 4, 128], + forward_batch: ForwardBatch, + is_draft_worker: bool = False, +) -> Optional[str]: + if not _b12x_dsv4_mla_decode_enabled(): + return "reject_disabled" + capturing = _b12x_dsv4_mla_capturing() + if _b12x_dsv4_mla_graph_only() and not capturing: + return "reject_not_capturing_graph_only" + if _b12x_dsv4_mla_eager_only() and capturing: + return "reject_capturing_eager_only" + if _b12x_dsv4_mla_target_only() and is_draft_worker: + return "reject_draft_worker" + if _b12x_dsv4_mla_needs_bf16_qk( + is_draft_worker=is_draft_worker, compress_ratio=compress_ratio + ) and not _b12x_dsv4_mla_bf16_qk_enabled(): + return "reject_needs_bf16_qk" + if compress_ratio != 0 and not _b12x_dsv4_mla_allow_extra_tier(): + return "reject_compress_ratio" + mode = forward_batch.forward_mode + allow_target_verify = ( + _b12x_dsv4_mla_allow_target_verify() and mode.is_target_verify() + ) + allow_draft_extend = ( + _b12x_dsv4_mla_allow_draft_extend() + and is_draft_worker + and compress_ratio == 0 + and mode.is_draft_extend(include_v2=True) + ) + allow_decode = mode.is_decode() and ( + not is_draft_worker or _b12x_dsv4_mla_allow_draft_decode() + ) + if is_draft_worker and mode.is_decode() and not allow_decode: + return "reject_draft_decode" + if not (allow_decode or allow_target_verify or allow_draft_extend): + return "reject_forward_mode" + if indices is None: + return "reject_indices_missing" + if not is_fp8_kvcache: + return "reject_not_fp8_kvcache" + if q.dtype != torch.bfloat16: + return "reject_q_dtype" + if k_cache.dtype not in (torch.uint8, torch.float8_e4m3fn): + return "reject_k_dtype" + if q.ndim != 4 or q.shape[1] != 1 or q.shape[-1] != 512: + return "reject_q_shape" + if q.shape[2] not in (16, 64): + return "reject_heads" + if k_cache.ndim != 4 or k_cache.shape[-1] != 584: + return "reject_k_shape" + if int(head_dim_v) != 512: + return "reject_head_dim_v" + rows = int(q.shape[0]) + min_rows = _b12x_dsv4_mla_row_limit("B12X_MLA_DSV4_CUTE_MIN_ROWS", "1") + max_rows = _b12x_dsv4_mla_row_limit("B12X_MLA_DSV4_CUTE_MAX_ROWS", "65536") + if rows < min_rows: + return "reject_rows_lt_min" + if rows > max_rows: + return "reject_rows_gt_max" + if not capturing and rows < _b12x_dsv4_mla_live_min_rows(compress_ratio): + return "reject_live_rows_lt_min" + if compress_ratio == 0: + if ( + extra_k_cache is not None + or extra_indices is not None + or extra_topk_lengths is not None + ): + return "reject_extra_tier" + return None + if ( + extra_k_cache is None + or extra_indices is None + or extra_topk_lengths is None + ): + return "reject_extra_tier_missing" + if extra_k_cache.dtype not in (torch.uint8, torch.float8_e4m3fn): + return "reject_extra_k_dtype" + if extra_k_cache.ndim != 4 or extra_k_cache.shape[-1] != 584: + return "reject_extra_k_shape" + expected_extra_page = 64 if compress_ratio == 4 else 2 + expected_extra_width = 512 if compress_ratio == 4 else 320 + if int(extra_k_cache.shape[1]) != expected_extra_page: + return "reject_extra_page_shape" + if ( + int(extra_indices.shape[0]) != rows + or int(extra_indices.shape[-1]) != expected_extra_width + ): + return "reject_extra_indices_shape" + if int(extra_topk_lengths.shape[0]) != rows: + return "reject_extra_topk_shape" + return None + + +def _should_run_b12x_dsv4_mla_decode( + *, + q: torch.Tensor, + k_cache: torch.Tensor, + indices: Optional[torch.Tensor], + extra_k_cache: Optional[torch.Tensor], + extra_indices: Optional[torch.Tensor], + extra_topk_lengths: Optional[torch.Tensor], + head_dim_v: int, + is_fp8_kvcache: bool, + compress_ratio: Literal[0, 4, 128], + forward_batch: ForwardBatch, + is_draft_worker: bool = False, +) -> bool: + reason = _b12x_dsv4_mla_decode_reject_reason( + q=q, + k_cache=k_cache, + indices=indices, + extra_k_cache=extra_k_cache, + extra_indices=extra_indices, + extra_topk_lengths=extra_topk_lengths, + head_dim_v=head_dim_v, + is_fp8_kvcache=is_fp8_kvcache, + compress_ratio=compress_ratio, + forward_batch=forward_batch, + is_draft_worker=is_draft_worker, + ) + _record_b12x_dsv4_mla_trace( + "hit" if reason is None else reason, + q=q, + k_cache=k_cache, + indices=indices, + extra_k_cache=extra_k_cache, + extra_indices=extra_indices, + extra_topk_lengths=extra_topk_lengths, + head_dim_v=head_dim_v, + is_fp8_kvcache=is_fp8_kvcache, + compress_ratio=compress_ratio, + forward_batch=forward_batch, + is_draft_worker=is_draft_worker, + ) + return reason is None + + +def _run_b12x_dsv4_mla_decode( + *, + q: torch.Tensor, + k_cache: torch.Tensor, + indices: torch.Tensor, + topk_length: Optional[torch.Tensor], + attn_sink: torch.Tensor, + softmax_scale: float, + head_dim_v: int, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices: Optional[torch.Tensor] = None, + extra_topk_lengths: Optional[torch.Tensor] = None, + flashmla_metadata: Optional[Any] = None, + forward_batch: Optional[ForwardBatch] = None, + compress_ratio: Literal[0, 4, 128] = 0, + is_draft_worker: bool = False, +) -> torch.Tensor: + global _B12X_DSV4_MLA_DECODE_HITS + global _B12X_DSV4_MLA_DECODE_LOGS + global _B12X_DSV4_MLA_SHADOW_COMPARE_COUNT + + from b12x.cute.sm120.sparse_mla_decode_dsv4 import run_dsv4_sparse_mla_decode + + ref_k_cache = k_cache + ref_extra_k_cache = extra_k_cache + if k_cache.dtype == torch.float8_e4m3fn: + k_cache = k_cache.view(torch.uint8) + if extra_k_cache is not None and extra_k_cache.dtype == torch.float8_e4m3fn: + extra_k_cache = extra_k_cache.view(torch.uint8) + out, b12x_lse = run_dsv4_sparse_mla_decode( + q=q, + k_cache=k_cache, + indices=indices, + topk_length=topk_length, + attn_sink=attn_sink, + sm_scale=float(softmax_scale), + v_head_dim=int(head_dim_v), + extra_k_cache=extra_k_cache, + extra_indices=extra_indices, + extra_topk_length=extra_topk_lengths, + qk_bf16=( + _b12x_dsv4_mla_bf16_qk_enabled() + and _b12x_dsv4_mla_needs_bf16_qk( + is_draft_worker=is_draft_worker, compress_ratio=compress_ratio + ) + ), + ) + _B12X_DSV4_MLA_DECODE_HITS += 1 + ref_out = None + if ( + _b12x_dsv4_mla_shadow_compare_enabled() + and flashmla_metadata is not None + and not _b12x_dsv4_mla_capturing() + ): + with _B12X_DSV4_MLA_SHADOW_COMPARE_LOCK: + should_compare = ( + _B12X_DSV4_MLA_SHADOW_COMPARE_COUNT + < _b12x_dsv4_mla_shadow_compare_max() + ) + if should_compare: + _B12X_DSV4_MLA_SHADOW_COMPARE_COUNT += 1 + compare_id = _B12X_DSV4_MLA_SHADOW_COMPARE_COUNT + else: + compare_id = -1 + if should_compare: + try: + import flash_mla + + ref_tuple = flash_mla.flash_mla_with_kvcache( + q=q, + k_cache=ref_k_cache, + head_dim_v=int(head_dim_v), + block_table=None, + cache_seqlens=None, + tile_scheduler_metadata=flashmla_metadata, + softmax_scale=float(softmax_scale), + is_fp8_kvcache=True, + indices=indices, + topk_length=topk_length, + attn_sink=attn_sink, + extra_k_cache=ref_extra_k_cache, + extra_indices_in_kvcache=extra_indices, + extra_topk_length=extra_topk_lengths, + ) + ref_out = ref_tuple[0] + b12x_cmp = out.squeeze(1) + ref_cmp = ref_out.squeeze(1) + forward_context = _b12x_shadow_forward_context( + forward_batch, q_rows=int(b12x_cmp.shape[0]) + ) + valid_rows = int(forward_context["valid_q_rows_estimate"]) + worst_rows = _b12x_dsv4_shadow_worst_rows() + record = { + "pid": os.getpid(), + "rank": os.environ.get("RANK"), + "local_rank": os.environ.get("LOCAL_RANK"), + "time": time.time(), + "compare_id": int(compare_id), + "forward_mode": _b12x_dsv4_mla_forward_mode(forward_batch), + "compress_ratio": int(compress_ratio), + "is_draft_worker": bool(is_draft_worker), + "capturing": False, + "capture_state": _b12x_dsv4_mla_capture_state(), + "forward_context": forward_context, + "q_shape": _b12x_tensor_shape(q), + "q_stride": _b12x_tensor_stride(q), + "q_dtype": _b12x_tensor_dtype(q), + "q_stats": _b12x_scalar_stats(q), + "q_valid_stats": _b12x_scalar_stats(q[:valid_rows]), + "q_padding_stats": ( + _b12x_scalar_stats(q[valid_rows:]) + if valid_rows < int(q.shape[0]) + else None + ), + "k_shape": _b12x_tensor_shape(ref_k_cache), + "k_stride": _b12x_tensor_stride(ref_k_cache), + "k_dtype": _b12x_tensor_dtype(ref_k_cache), + "indices_shape": _b12x_tensor_shape(indices), + "indices_stride": _b12x_tensor_stride(indices), + "indices_stats": _b12x_scalar_stats(indices), + "topk_length_shape": _b12x_tensor_shape(topk_length), + "topk_length_stats": _b12x_scalar_stats(topk_length), + "extra_k_shape": _b12x_tensor_shape(ref_extra_k_cache), + "extra_k_stride": _b12x_tensor_stride(ref_extra_k_cache), + "extra_indices_shape": _b12x_tensor_shape(extra_indices), + "extra_indices_stats": _b12x_scalar_stats(extra_indices), + "extra_topk_length_shape": _b12x_tensor_shape(extra_topk_lengths), + "extra_topk_length_stats": _b12x_scalar_stats(extra_topk_lengths), + "output_shape": _b12x_tensor_shape(b12x_cmp), + "output_compare": _b12x_output_compare_stats(b12x_cmp, ref_cmp), + "output_compare_rows": _b12x_output_compare_row_stats( + b12x_cmp, ref_cmp, max_worst_rows=worst_rows + ), + "output_compare_valid": _b12x_output_compare_stats( + b12x_cmp[:valid_rows], ref_cmp[:valid_rows] + ), + "output_compare_valid_rows": _b12x_output_compare_row_stats( + b12x_cmp[:valid_rows], + ref_cmp[:valid_rows], + max_worst_rows=worst_rows, + ), + } + if valid_rows < int(b12x_cmp.shape[0]): + record["output_compare_padding"] = _b12x_output_compare_stats( + b12x_cmp[valid_rows:], ref_cmp[valid_rows:] + ) + record["output_compare_padding_rows"] = ( + _b12x_output_compare_row_stats( + b12x_cmp[valid_rows:], + ref_cmp[valid_rows:], + max_worst_rows=worst_rows, + ) + ) + if len(ref_tuple) > 1 and ref_tuple[1] is not None: + try: + ref_lse = ref_tuple[1] + record["lse_shape"] = _b12x_tensor_shape(b12x_lse) + record["ref_lse_shape"] = _b12x_tensor_shape(ref_lse) + record["lse_compare"] = _b12x_output_compare_stats( + b12x_lse.reshape(-1), + ref_lse.reshape(-1), + ) + record["lse_compare_valid"] = _b12x_output_compare_stats( + b12x_lse[:valid_rows].reshape(-1), + ref_lse[:valid_rows].reshape(-1), + ) + record["lse_compare_valid_rows"] = ( + _b12x_output_compare_row_stats( + b12x_lse[:valid_rows], + ref_lse[:valid_rows], + max_worst_rows=worst_rows, + ) + ) + if valid_rows < int(b12x_lse.shape[0]): + record["lse_compare_padding"] = _b12x_output_compare_stats( + b12x_lse[valid_rows:].reshape(-1), + ref_lse[valid_rows:].reshape(-1), + ) + except Exception as exc: # noqa: BLE001 + record["lse_compare_error"] = repr(exc) + _append_b12x_shadow_compare_record(record) + except Exception as exc: # noqa: BLE001 + _append_b12x_shadow_compare_record( + { + "pid": os.getpid(), + "rank": os.environ.get("RANK"), + "local_rank": os.environ.get("LOCAL_RANK"), + "time": time.time(), + "compare_id": int(compare_id), + "error": repr(exc), + "forward_mode": _b12x_dsv4_mla_forward_mode(forward_batch), + "compress_ratio": int(compress_ratio), + "is_draft_worker": bool(is_draft_worker), + "q_shape": _b12x_tensor_shape(q), + "k_shape": _b12x_tensor_shape(ref_k_cache), + "indices_shape": _b12x_tensor_shape(indices), + } + ) + if _b12x_dsv4_mla_verbose() and _B12X_DSV4_MLA_DECODE_LOGS < 16: + _B12X_DSV4_MLA_DECODE_LOGS += 1 + logger.info( + "b12x DSv4 MLA decode direct: rows=%s heads=%s indices=%s extra=%s", + q.shape[0], + q.shape[2], + tuple(indices.shape), + None if extra_indices is None else tuple(extra_indices.shape), + ) + if ref_out is not None and _b12x_dsv4_mla_shadow_return_reference(): + return ref_out.squeeze(1) + return out.squeeze(1) + + def _pad_last_dim(x: T, multiples_of: int = PAGE_INDEX_ALIGNED_SIZE) -> T: if x is None: return None @@ -361,12 +2165,14 @@ def __init__( model_runner.server_args.speculative_num_draft_tokens ) self.speculative_step_id = speculative_step_id + self.is_draft_worker = bool(getattr(model_runner, "is_draft_worker", False)) self.forward_metadata: Union[ DSV4Metadata, DSV4RawVerifyMetadata, DSV4RawDecodeMetadata, ] = None self._replay_forward_batch: Optional[ForwardBatch] = None # FIXME: out-of-band + self._rows4_prefill_workspace: Dict[Any, _Rows4PrefillWorkspace] = {} def _move_to_device(self, x: List[int]) -> torch.Tensor: pin_tensor = torch.tensor(x, dtype=torch.int32, pin_memory=True) @@ -386,9 +2192,43 @@ def init_forward_metadata_decode( seq_lens: torch.Tensor, out_cache_loc: torch.Tensor, ) -> Union[DSV4Metadata, DSV4RawDecodeMetadata]: + global _DSV4_MTP_METADATA_MISMATCH_LOGS + if ( + ( + req_pool_indices.shape[0] != seq_lens.shape[0] + or req_pool_indices.shape[0] != out_cache_loc.shape[0] + ) + and _DSV4_MTP_METADATA_MISMATCH_LOGS < 16 + ): + _DSV4_MTP_METADATA_MISMATCH_LOGS += 1 + logger.warning( + "DSv4 decode metadata shape mismatch: " + "req_pool_indices=%s seq_lens=%s out_cache_loc=%s " + "is_draft_worker=%s speculative_step_id=%s " + "speculative_num_steps=%s topk=%s seq_lens_head=%s " + "out_cache_loc_head=%s", + tuple(req_pool_indices.shape), + tuple(seq_lens.shape), + tuple(out_cache_loc.shape), + self.is_draft_worker, + self.speculative_step_id, + self.speculative_num_steps, + self.topk, + seq_lens[: min(8, seq_lens.numel())].detach().cpu().tolist(), + out_cache_loc[: min(16, out_cache_loc.numel())] + .detach() + .cpu() + .tolist(), + ) assert ( req_pool_indices.shape[0] == seq_lens.shape[0] == out_cache_loc.shape[0] - ), f"{req_pool_indices.shape=} {seq_lens.shape=} {out_cache_loc.shape=}" + ), ( + f"{req_pool_indices.shape=} {seq_lens.shape=} {out_cache_loc.shape=}" + f" is_draft_worker={self.is_draft_worker}" + f" speculative_step_id={self.speculative_step_id}" + f" speculative_num_steps={self.speculative_num_steps}" + f" topk={self.topk}" + ) if envs.SGLANG_PREP_IN_CUDA_GRAPH.get(): return DSV4RawDecodeMetadata( @@ -714,6 +2554,7 @@ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int) -> None: self.draft_extend_num_tokens_per_bs = ( max_num_tokens // max_bs if max_bs > 0 else 1 ) + self._rows4_prefill_workspace = {} def init_forward_metadata_capture_cuda_graph( self, @@ -843,12 +2684,21 @@ def init_forward_metadata_replay_cuda_graph( ) elif bucket == _GraphBucket.DRAFT_EXTEND: num_tokens_per_bs = self.draft_extend_num_tokens_per_bs + assert out_cache_loc is not None + num_tokens = num_tokens_per_bs * bs + out_cache_loc_padded = torch.nn.functional.pad( + out_cache_loc, + pad=(0, num_tokens - len(out_cache_loc)), + mode="constant", + value=0, + ) temp_metadata = self.init_forward_metadata_draft_extend( max_seq_len=chosen_max_seq_len, req_pool_indices=req_pool_indices, seq_lens=seq_lens, seq_lens_cpu=seq_lens_cpu.tolist(), num_tokens_per_bs=num_tokens_per_bs, + out_cache_loc=out_cache_loc_padded, use_prefill_cuda_graph=True, ) else: @@ -896,6 +2746,12 @@ def store_cache( ) -> None: raw_loc = forward_batch.out_cache_loc if envs.SGLANG_OPT_USE_FUSED_STORE_CACHE.get(): + _trace_dsv4_store_cache( + layer_id=layer_id, + forward_batch=forward_batch, + raw_loc=raw_loc, + swa_k=swa_k, + ) self.token_to_kv_pool.set_swa_key_buffer_radix_fused( layer_id=layer_id, raw_loc=raw_loc, @@ -903,6 +2759,13 @@ def store_cache( ) else: swa_k_pack = quant_to_nope_fp8_rope_bf16_pack_triton(swa_k) + _trace_dsv4_store_cache( + layer_id=layer_id, + forward_batch=forward_batch, + raw_loc=raw_loc, + swa_k=swa_k, + packed=swa_k_pack, + ) self.token_to_kv_pool.set_swa_key_buffer_radix( layer_id=layer_id, raw_loc=raw_loc, @@ -927,6 +2790,349 @@ def _maybe_upgrade_forward_metadata(self) -> None: raw_metadata=self.forward_metadata, ) + def _get_rows4_prefill_workspace( + self, + *, + key: Any, + rows_capacity: int, + width: int, + extra_width: int, + q: torch.Tensor, + capturing: bool, + needs_q_ws: bool, + ) -> Optional[_Rows4PrefillWorkspace]: + ws = self._rows4_prefill_workspace.get(key) + if ( + ws is not None + and ws.rows_capacity >= rows_capacity + and ws.width == width + and ws.extra_width == extra_width + and ((ws.q_ws is not None) == needs_q_ws) + and ws.output.dtype == q.dtype + and ws.output.device == q.device + ): + return ws + if capturing: + return None + groups = rows_capacity // 4 + head_dim_v = int(q.shape[-1]) + heads = int(q.shape[2]) + q_elem = q.element_size() + workspace_bytes = ( + (2 if needs_q_ws else 1) * rows_capacity * heads * head_dim_v * q_elem + + rows_capacity * width * 4 + + rows_capacity * 4 + + 2 * groups * 4 * width * 4 + + groups * 4 + + 4 * rows_capacity * heads * 4 + + rows_capacity * heads * q_elem + + 1 * heads * 4 + ) + if extra_width > 0: + workspace_bytes += ( + rows_capacity * extra_width * 4 + + rows_capacity * 4 + + 2 * groups * 4 * extra_width * 4 + + groups * 4 + ) + try: + free_bytes, _ = torch.cuda.mem_get_info(q.device) + except Exception: # noqa: BLE001 + free_bytes = workspace_bytes * 2 + headroom = max(256 * 1024 * 1024, workspace_bytes // 8) + if free_bytes < workspace_bytes + headroom: + _record_b12x_dsv4_prefill_rows4_route("fallback_workspace_low_mem") + return None + try: + q_ws = ( + torch.zeros( + (rows_capacity, heads, head_dim_v), dtype=q.dtype, device=q.device + ) + if needs_q_ws + else None + ) + page_ws = torch.zeros((rows_capacity, width), dtype=torch.int32, device=q.device) + active_ws = torch.zeros((rows_capacity,), dtype=torch.int32, device=q.device) + union_indices = torch.empty( + (groups, 4 * width), dtype=torch.int32, device=q.device + ) + union_masks = torch.empty((groups, 4 * width), dtype=torch.int32, device=q.device) + union_counts = torch.empty((groups,), dtype=torch.int32, device=q.device) + output = torch.empty( + (rows_capacity, heads, head_dim_v), dtype=q.dtype, device=q.device + ) + lse_log2 = torch.empty( + (rows_capacity, heads), dtype=torch.float32, device=q.device + ) + lse = torch.empty((rows_capacity, heads), dtype=torch.float32, device=q.device) + sink_f32 = torch.empty((1, heads), dtype=torch.float32, device=q.device) + sink_scratch = torch.empty( + (rows_capacity, heads), dtype=torch.float32, device=q.device + ) + sink_scale = torch.empty( + (rows_capacity, heads), dtype=torch.float32, device=q.device + ) + sink_scale_cast = torch.empty( + (rows_capacity, heads), dtype=q.dtype, device=q.device + ) + ws = _Rows4PrefillWorkspace( + rows_capacity=rows_capacity, + width=width, + extra_width=extra_width, + q_ws=q_ws, + page_ws=page_ws, + active_ws=active_ws, + union_indices=union_indices, + union_masks=union_masks, + union_counts=union_counts, + output=output, + lse_log2=lse_log2, + lse=lse, + sink_f32=sink_f32, + sink_scratch=sink_scratch, + sink_scale=sink_scale, + sink_scale_cast=sink_scale_cast, + ) + if extra_width > 0: + ws.extra_page_ws = torch.zeros( + (rows_capacity, extra_width), dtype=torch.int32, device=q.device + ) + ws.extra_active_ws = torch.zeros( + (rows_capacity,), dtype=torch.int32, device=q.device + ) + ws.extra_union_indices = torch.empty( + (groups, 4 * extra_width), dtype=torch.int32, device=q.device + ) + ws.extra_union_masks = torch.empty( + (groups, 4 * extra_width), dtype=torch.int32, device=q.device + ) + ws.extra_union_counts = torch.empty( + (groups,), dtype=torch.int32, device=q.device + ) + except torch.OutOfMemoryError: + _record_b12x_dsv4_prefill_rows4_route("fallback_workspace_oom") + torch.cuda.empty_cache() + return None + self._rows4_prefill_workspace[key] = ws + return ws + + def _try_run_b12x_rows4_prefill_direct( + self, + *, + q: torch.Tensor, + swa_k_cache: torch.Tensor, + swa_page_indices: torch.Tensor, + swa_topk_lengths: torch.Tensor, + attn_sink: Optional[torch.Tensor], + compress_ratio: Literal[0, 4, 128], + forward_batch: ForwardBatch, + extra_k_cache: Optional[torch.Tensor], + extra_indices: Optional[torch.Tensor], + extra_topk_lengths: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + if not _b12x_dsv4_prefill_rows4_enabled(): + return None + if not forward_batch.forward_mode.is_extend_without_speculative(): + return None + if ( + q.ndim != 4 + or q.shape[1] != 1 + or q.shape[2] != 64 + or q.shape[3] != 512 + or attn_sink is None + ): + _record_b12x_dsv4_prefill_rows4_route("fallback_shape") + return None + + if swa_page_indices.dtype != torch.int32 or not swa_page_indices.is_cuda: + _record_b12x_dsv4_prefill_rows4_route("fallback_main_indices_dtype") + return None + if swa_topk_lengths.dtype != torch.int32 or not swa_topk_lengths.is_cuda: + _record_b12x_dsv4_prefill_rows4_route("fallback_main_topk_dtype") + return None + if swa_page_indices.ndim != 3 or swa_page_indices.shape[1] != 1: + _record_b12x_dsv4_prefill_rows4_route("fallback_main_indices_shape") + return None + + raw_rows = int(q.shape[0]) + if swa_topk_lengths.ndim != 1 or swa_topk_lengths.shape[0] != raw_rows: + _record_b12x_dsv4_prefill_rows4_route("fallback_bad_active_shape") + return None + + page_2d = swa_page_indices[:, 0] + active_1d = swa_topk_lengths + if not page_2d.is_contiguous() or not active_1d.is_contiguous(): + _record_b12x_dsv4_prefill_rows4_route("fallback_main_noncontiguous") + return None + width = int(page_2d.shape[1]) + if width <= 0 or width > 512: + _record_b12x_dsv4_prefill_rows4_route("fallback_main_width") + return None + + extra_width = 0 + extra_page_2d = None + extra_active_1d = None + if compress_ratio != 0: + if not _b12x_dsv4_mla_allow_extra_tier(): + _record_b12x_dsv4_prefill_rows4_route("fallback_extra_disabled") + return None + if extra_k_cache is None or extra_indices is None or extra_topk_lengths is None: + _record_b12x_dsv4_prefill_rows4_route("fallback_extra_missing") + return None + if extra_indices.dtype != torch.int32 or not extra_indices.is_cuda: + _record_b12x_dsv4_prefill_rows4_route("fallback_extra_indices_dtype") + return None + if extra_topk_lengths.dtype != torch.int32 or not extra_topk_lengths.is_cuda: + _record_b12x_dsv4_prefill_rows4_route("fallback_extra_topk_dtype") + return None + if extra_indices.ndim != 3 or extra_indices.shape[1] != 1: + _record_b12x_dsv4_prefill_rows4_route("fallback_extra_indices_shape") + return None + if extra_topk_lengths.ndim != 1 or extra_topk_lengths.shape[0] != raw_rows: + _record_b12x_dsv4_prefill_rows4_route("fallback_extra_topk_shape") + return None + extra_page_2d = extra_indices[:, 0] + extra_active_1d = extra_topk_lengths + extra_width = int(extra_page_2d.shape[1]) + if extra_width <= 0 or extra_width > 512: + _record_b12x_dsv4_prefill_rows4_route("fallback_extra_width") + return None + if not extra_page_2d.is_contiguous() or not extra_active_1d.is_contiguous(): + _record_b12x_dsv4_prefill_rows4_route("fallback_extra_noncontiguous") + return None + + try: + from b12x.attention.mla.kernel_dsv4_prefill import ( + run_sparse_mla_prefill_rows4_heads16_two_tier_union_kernel, + run_sparse_mla_prefill_rows4_heads16_union_kernel, + ) + from b12x.attention.mla.prefill_union import ( + build_sparse_mla_prefill_rows4_union_tables_cuda_into, + ) + except Exception as exc: # noqa: BLE001 + _record_b12x_dsv4_prefill_rows4_route( + f"fallback_import_{type(exc).__name__}" + ) + return None + + capturing = _b12x_dsv4_mla_capturing() + live_rows = int( + forward_batch.extend_num_tokens + if forward_batch.extend_num_tokens is not None + else raw_rows + ) + live_rows = max(0, min(live_rows, raw_rows)) + padded_rows = ceil_align(raw_rows, 4) + q_direct = q[:, 0] + use_direct_q = ( + live_rows == raw_rows + and padded_rows == raw_rows + and q_direct.is_contiguous() + ) + ws = self._get_rows4_prefill_workspace( + key=( + "rows4_prefill", + forward_batch.forward_mode, + compress_ratio, + width, + extra_width, + padded_rows, + "direct_q" if use_direct_q else "q_ws", + ), + rows_capacity=padded_rows, + width=width, + extra_width=extra_width, + q=q, + capturing=capturing, + needs_q_ws=not use_direct_q, + ) + if ws is None: + _record_b12x_dsv4_prefill_rows4_route( + "fallback_workspace_capture_miss" + if capturing + else "fallback_workspace_unavailable" + ) + return None + + try: + ws.active_ws.zero_() + if use_direct_q: + q_all = q_direct + else: + assert ws.q_ws is not None + ws.q_ws[:live_rows].copy_(q[:live_rows, 0]) + q_all = ws.q_ws + ws.page_ws[:live_rows].copy_(page_2d[:live_rows]) + ws.active_ws[:live_rows].copy_(active_1d[:live_rows]) + build_sparse_mla_prefill_rows4_union_tables_cuda_into( + page_table_1=ws.page_ws, + active_token_counts=ws.active_ws, + union_indices=ws.union_indices, + union_masks=ws.union_masks, + union_counts=ws.union_counts, + ) + if compress_ratio == 0: + run_sparse_mla_prefill_rows4_heads16_union_kernel( + q_all=q_all, + kv_cache=swa_k_cache, + union_indices=ws.union_indices, + union_masks=ws.union_masks, + union_counts=ws.union_counts, + sm_scale=float(self.softmax_scale), + output=ws.output, + lse_output=ws.lse_log2, + ) + else: + assert extra_page_2d is not None + assert extra_active_1d is not None + assert extra_k_cache is not None + assert ws.extra_page_ws is not None + assert ws.extra_active_ws is not None + assert ws.extra_union_indices is not None + assert ws.extra_union_masks is not None + assert ws.extra_union_counts is not None + ws.extra_active_ws.zero_() + ws.extra_page_ws[:live_rows].copy_(extra_page_2d[:live_rows]) + ws.extra_active_ws[:live_rows].copy_(extra_active_1d[:live_rows]) + build_sparse_mla_prefill_rows4_union_tables_cuda_into( + page_table_1=ws.extra_page_ws, + active_token_counts=ws.extra_active_ws, + union_indices=ws.extra_union_indices, + union_masks=ws.extra_union_masks, + union_counts=ws.extra_union_counts, + ) + run_sparse_mla_prefill_rows4_heads16_two_tier_union_kernel( + q_all=q_all, + kv_cache=swa_k_cache, + union_indices=ws.union_indices, + union_masks=ws.union_masks, + union_counts=ws.union_counts, + extra_kv_cache=extra_k_cache, + extra_union_indices=ws.extra_union_indices, + extra_union_masks=ws.extra_union_masks, + extra_union_counts=ws.extra_union_counts, + sm_scale=float(self.softmax_scale), + output=ws.output, + lse_output=ws.lse_log2, + ) + ws.sink_f32.copy_(attn_sink[: q.shape[2]].reshape(1, q.shape[2])) + torch.mul(ws.lse_log2, _B12X_DSV4_PREFILL_ROWS4_LN2, out=ws.lse) + torch.sub(ws.lse, ws.sink_f32, out=ws.sink_scratch) + torch.sigmoid(ws.sink_scratch, out=ws.sink_scale) + ws.sink_scale_cast.copy_(ws.sink_scale) + ws.output.mul_(ws.sink_scale_cast.unsqueeze(-1)) + torch.logaddexp(ws.lse, ws.sink_f32, out=ws.lse) + except Exception as exc: # noqa: BLE001 + _record_b12x_dsv4_prefill_rows4_route( + f"fallback_exception_{type(exc).__name__}" + ) + return None + + _record_b12x_dsv4_prefill_rows4_route(f"rows4_kernel_cr{compress_ratio}") + if use_direct_q: + _record_b12x_dsv4_prefill_rows4_route(f"rows4_direct_q_cr{compress_ratio}") + return ws.output[:raw_rows] + def forward( self, q: torch.Tensor, @@ -954,7 +3160,9 @@ def forward( assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) if isinstance(core_attn_metadata, DSV4AttnMetadata): - if save_kv_cache: + if save_kv_cache and not getattr( + forward_batch, "_dsv4_eagle_recompute_no_kv_store", False + ): self.store_cache(layer_id, swa_k, forward_batch) swa_k_cache = token_to_kv_pool.get_swa_key_buffer_radix(layer_id) @@ -992,6 +3200,8 @@ def forward( swa_topk_lengths = core_attn_metadata.swa_topk_lengths if self.mtp_enabled: + indices_before_rows = swa_page_indices.shape[0] + topk_before_rows = swa_topk_lengths.shape[0] if swa_page_indices.shape[0] != q.shape[0]: swa_page_indices = _pad_tensor_to_size( swa_page_indices, q.shape[0], value=0 @@ -1001,6 +3211,16 @@ def forward( swa_topk_lengths = _pad_tensor_to_size( swa_topk_lengths, q.shape[0], value=1 ) + _trace_dsv4_mtp_padding( + layer_id=layer_id, + compress_ratio=compress_ratio, + forward_batch=forward_batch, + q_rows=q.shape[0], + indices_before_rows=indices_before_rows, + topk_before_rows=topk_before_rows, + swa_page_indices=swa_page_indices, + swa_topk_lengths=swa_topk_lengths, + ) if q.ndim == 3: q = q.unsqueeze(1) @@ -1008,6 +3228,13 @@ def forward( swa_page_indices = swa_page_indices.unsqueeze(1) if extra_indices is not None and extra_indices.ndim == 2: extra_indices = extra_indices.unsqueeze(1) + extra_topk_lengths = _maybe_zero_empty_extra_topk( + layer_id=layer_id, + compress_ratio=compress_ratio, + forward_batch=forward_batch, + extra_indices=extra_indices, + extra_topk_lengths=extra_topk_lengths, + ) assert attn_sink is not None @@ -1021,6 +3248,69 @@ def forward( extra_indices.shape[-1] % 64 == 0 ), f"{extra_indices.shape=}'s last dimension is not aligned to 64" + _trace_dsv4_attn_inputs( + layer_id=layer_id, + forward_batch=forward_batch, + compress_ratio=compress_ratio, + softmax_scale=self.softmax_scale, + q=q, + k_cache=swa_k_cache, + indices=swa_page_indices, + topk_length=swa_topk_lengths, + extra_k_cache=extra_k_cache, + extra_indices=extra_indices, + extra_topk_length=extra_topk_lengths, + ) + + if _should_run_b12x_dsv4_mla_decode( + q=q, + k_cache=swa_k_cache, + indices=swa_page_indices, + extra_k_cache=extra_k_cache, + extra_indices=extra_indices, + extra_topk_lengths=extra_topk_lengths, + head_dim_v=self.head_dim_v, + is_fp8_kvcache=True, + compress_ratio=compress_ratio, + forward_batch=forward_batch, + is_draft_worker=self.is_draft_worker, + ): + return _run_b12x_dsv4_mla_decode( + q=q, + k_cache=swa_k_cache, + indices=swa_page_indices, + topk_length=swa_topk_lengths, + attn_sink=attn_sink, + softmax_scale=self.softmax_scale, + head_dim_v=self.head_dim_v, + extra_k_cache=extra_k_cache, + extra_indices=extra_indices, + extra_topk_lengths=extra_topk_lengths, + flashmla_metadata=flashmla_metadata, + forward_batch=forward_batch, + compress_ratio=compress_ratio, + is_draft_worker=self.is_draft_worker, + ) + + rows4_out = self._try_run_b12x_rows4_prefill_direct( + q=q, + swa_k_cache=swa_k_cache, + swa_page_indices=swa_page_indices, + swa_topk_lengths=swa_topk_lengths, + attn_sink=attn_sink, + compress_ratio=compress_ratio, + forward_batch=forward_batch, + extra_k_cache=extra_k_cache, + extra_indices=extra_indices, + extra_topk_lengths=extra_topk_lengths, + ) + if rows4_out is not None: + return rows4_out + + if _b12x_dsv4_prefill_rows4_enabled() and forward_batch.forward_mode.is_prefill( + include_draft_extend_v2=True + ): + _record_b12x_dsv4_prefill_rows4_route("fallback_flashmla") import flash_mla o = flash_mla.flash_mla_with_kvcache( @@ -1177,7 +3467,7 @@ def __init__( self.topk = topk self.speculative_num_steps = speculative_num_steps self.attn_backends: List[DeepseekV4AttnBackend] = [] - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends.append( DeepseekV4AttnBackend( model_runner, @@ -1187,21 +3477,145 @@ def __init__( ) ) + @staticmethod + def _step_seq_lens(seq_lens: torch.Tensor, step_id: int) -> torch.Tensor: + return seq_lens + (step_id + 1) + + @staticmethod + def _step_seq_lens_cpu(seq_lens_cpu: object, step_id: int) -> object: + delta = step_id + 1 + if seq_lens_cpu is None: + return None + if isinstance(seq_lens_cpu, torch.Tensor): + return seq_lens_cpu + delta + return [int(x) + delta for x in seq_lens_cpu] + + def _step_out_cache_locs(self, forward_batch: ForwardBatch) -> Optional[torch.Tensor]: + out_cache_loc = forward_batch.out_cache_loc + if out_cache_loc is None: + return None + batch_size = int(forward_batch.batch_size) + expected = batch_size * self.topk * self.speculative_num_steps + if out_cache_loc.numel() < expected: + raise ValueError( + "DSv4 MTP draft metadata expected at least " + f"{expected} out_cache_loc entries for batch_size={batch_size}, " + f"topk={self.topk}, speculative_num_steps={self.speculative_num_steps}, " + f"got {out_cache_loc.numel()}" + ) + return ( + out_cache_loc[:expected] + .reshape(batch_size, self.topk, self.speculative_num_steps) + .permute(2, 0, 1) + .reshape(self.speculative_num_steps, batch_size * self.topk) + .contiguous() + ) + + def _needs_step_metadata_split(self, forward_batch: ForwardBatch) -> bool: + out_cache_loc = forward_batch.out_cache_loc + if out_cache_loc is None: + return False + batch_size = int(forward_batch.batch_size) + if batch_size <= 0: + return False + expected = batch_size * self.topk * self.speculative_num_steps + return ( + int(out_cache_loc.numel()) >= expected + and int(out_cache_loc.numel()) != batch_size + ) + + def _log_step_metadata( + self, + *, + forward_batch: ForwardBatch, + step_out_cache_locs: Optional[torch.Tensor], + ) -> None: + global _DSV4_MTP_METADATA_STEP_LOGS + if _DSV4_MTP_METADATA_STEP_LOGS >= 8: + return + _DSV4_MTP_METADATA_STEP_LOGS += 1 + out_shape = ( + None if forward_batch.out_cache_loc is None else tuple(forward_batch.out_cache_loc.shape) + ) + step_shape = None if step_out_cache_locs is None else tuple(step_out_cache_locs.shape) + logger.info( + "DSv4 MTP draft metadata split: batch_size=%s topk=%s " + "speculative_num_steps=%s original_out_cache_loc=%s step_out_cache_locs=%s " + "seq_lens_head=%s first_step_out_cache_loc_head=%s", + int(forward_batch.batch_size), + self.topk, + self.speculative_num_steps, + out_shape, + step_shape, + forward_batch.seq_lens[: min(8, forward_batch.seq_lens.numel())] + .detach() + .cpu() + .tolist(), + ( + None + if step_out_cache_locs is None + else step_out_cache_locs[0, : min(16, step_out_cache_locs.shape[1])] + .detach() + .cpu() + .tolist() + ), + ) + def init_forward_metadata(self, forward_batch: ForwardBatch): - for i in range(self.speculative_num_steps - 1): - self.attn_backends[i].init_forward_metadata(forward_batch) + if ( + not _dsv4_mtp_step_metadata_enabled() + and not self._needs_step_metadata_split(forward_batch) + ): + for i in range(self.speculative_num_steps - 1): + self.attn_backends[i].init_forward_metadata(forward_batch) + return + + step_out_cache_locs = self._step_out_cache_locs(forward_batch) + self._log_step_metadata( + forward_batch=forward_batch, + step_out_cache_locs=step_out_cache_locs, + ) + original_out_cache_loc = forward_batch.out_cache_loc + original_seq_lens = forward_batch.seq_lens + original_seq_lens_cpu = forward_batch.seq_lens_cpu + try: + for i in range(self.speculative_num_steps - 1): + if step_out_cache_locs is not None: + forward_batch.out_cache_loc = step_out_cache_locs[i] + forward_batch.seq_lens = self._step_seq_lens(original_seq_lens, i) + forward_batch.seq_lens_cpu = self._step_seq_lens_cpu( + original_seq_lens_cpu, i + ) + self.attn_backends[i].init_forward_metadata(forward_batch) + finally: + forward_batch.out_cache_loc = original_out_cache_loc + forward_batch.seq_lens = original_seq_lens + forward_batch.seq_lens_cpu = original_seq_lens_cpu def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens) def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): - for i in range(self.speculative_num_steps): + if not _dsv4_mtp_step_metadata_enabled(): + for backend in self.attn_backends: + backend.init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + return + + for i in range(self.speculative_num_steps - 1): self.attn_backends[i].init_forward_metadata_capture_cuda_graph( forward_batch.batch_size, forward_batch.batch_size * self.topk, forward_batch.req_pool_indices, - forward_batch.seq_lens, + self._step_seq_lens(forward_batch.seq_lens, i), encoder_lens=None, forward_mode=ForwardMode.DECODE, spec_info=forward_batch.spec_info, @@ -1217,26 +3631,56 @@ def init_forward_metadata_replay_cuda_graph( if self.speculative_num_steps == 1: return - self.attn_backends[0]._replay_forward_batch = forward_batch - self.attn_backends[0].init_forward_metadata_replay_cuda_graph( - bs=bs, - req_pool_indices=forward_batch.req_pool_indices, - seq_lens=forward_batch.seq_lens, - seq_lens_sum=forward_batch.seq_lens_sum, - encoder_lens=None, - forward_mode=ForwardMode.DECODE, - spec_info=forward_batch.spec_info, - seq_lens_cpu=forward_batch.seq_lens_cpu, - ) - self.attn_backends[0]._replay_forward_batch = None - temp_metadata = self.attn_backends[0].forward_metadata - - for i in range(1, self.speculative_num_steps - 1): - self.attn_backends[i].replay_cuda_graph_metadata_from( + if not _dsv4_mtp_step_metadata_enabled(): + self.attn_backends[0]._replay_forward_batch = forward_batch + self.attn_backends[0].init_forward_metadata_replay_cuda_graph( bs=bs, - temp_metadata=temp_metadata, - bucket=_GraphBucket.DECODE_OR_IDLE, + req_pool_indices=forward_batch.req_pool_indices, + seq_lens=forward_batch.seq_lens, + seq_lens_sum=forward_batch.seq_lens_sum, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + seq_lens_cpu=forward_batch.seq_lens_cpu, ) + self.attn_backends[0]._replay_forward_batch = None + temp_metadata = self.attn_backends[0].forward_metadata + for i in range(1, self.speculative_num_steps - 1): + self.attn_backends[i].replay_cuda_graph_metadata_from( + bs=bs, + temp_metadata=temp_metadata, + bucket=_GraphBucket.DECODE_OR_IDLE, + ) + return + + step_out_cache_locs = self._step_out_cache_locs(forward_batch) + self._log_step_metadata( + forward_batch=forward_batch, + step_out_cache_locs=step_out_cache_locs, + ) + original_out_cache_loc = forward_batch.out_cache_loc + try: + for i in range(self.speculative_num_steps - 1): + if step_out_cache_locs is not None: + forward_batch.out_cache_loc = step_out_cache_locs[i] + step_seq_lens = self._step_seq_lens(forward_batch.seq_lens, i) + step_seq_lens_cpu = self._step_seq_lens_cpu(forward_batch.seq_lens_cpu, i) + self.attn_backends[i]._replay_forward_batch = forward_batch + self.attn_backends[i].init_forward_metadata_replay_cuda_graph( + bs=bs, + req_pool_indices=forward_batch.req_pool_indices, + seq_lens=step_seq_lens, + seq_lens_sum=forward_batch.seq_lens_sum, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + seq_lens_cpu=step_seq_lens_cpu, + ) + self.attn_backends[i]._replay_forward_batch = None + finally: + forward_batch.out_cache_loc = original_out_cache_loc + for backend in self.attn_backends: + backend._replay_forward_batch = None def _pad_tensor_to_size(tensor: torch.Tensor, size: int, *, value: int = 0): diff --git a/python/sglang/srt/layers/attention/dsv4/indexer.py b/python/sglang/srt/layers/attention/dsv4/indexer.py index 3bc982446..bab5ac05f 100644 --- a/python/sglang/srt/layers/attention/dsv4/indexer.py +++ b/python/sglang/srt/layers/attention/dsv4/indexer.py @@ -93,6 +93,30 @@ def fp8_paged_mqa_logits_torch( return logits +def _normalize_indexer_seq_lens( + c4_seq_lens: torch.Tensor, + *, + expected_rows: int, +) -> torch.Tensor: + if c4_seq_lens.dim() == 2 and c4_seq_lens.shape[-1] == 1: + c4_seq_lens = c4_seq_lens.squeeze(-1) + if c4_seq_lens.dim() != 1: + raise ValueError( + "DSv4 C4 indexer seq_lens must be rank-1 or trailing-1 rank-2, " + f"got shape={tuple(c4_seq_lens.shape)}" + ) + if c4_seq_lens.shape[0] != expected_rows: + raise ValueError( + "DSv4 C4 indexer row mismatch: " + f"seq_lens_rows={c4_seq_lens.shape[0]} q_rows={expected_rows}" + ) + if c4_seq_lens.dtype != torch.int32: + c4_seq_lens = c4_seq_lens.to(torch.int32) + if not c4_seq_lens.is_contiguous(): + c4_seq_lens = c4_seq_lens.contiguous() + return c4_seq_lens + + def topk_transform_512_pytorch_vectorized( scores: torch.Tensor, seq_lens: torch.Tensor, @@ -377,6 +401,7 @@ def forward_c4_indexer( ) assert len(weights.shape) == 3 weights = weights.squeeze(2) + deep_gemm_logits_backend = False if envs.SGLANG_OPT_USE_TILELANG_INDEXER.get(): from sglang.srt.layers.attention.dsv4.tilelang_kernel import ( tilelang_fp8_paged_mqa_logits as fn, @@ -386,14 +411,32 @@ def forward_c4_indexer( else: from deep_gemm import fp8_paged_mqa_logits as fn - _c4sl = indexer_metadata.c4_seq_lens - if _c4sl.dim() == 1: - _c4sl = _c4sl.unsqueeze(-1) + deep_gemm_logits_backend = True + + num_indexer_rows = q_fp8.shape[0] + if weights.shape != (num_indexer_rows, q_fp8.shape[2]): + raise ValueError( + "DSv4 C4 indexer weight row mismatch: " + f"weights_shape={tuple(weights.shape)} q_shape={tuple(q_fp8.shape)}" + ) + if indexer_metadata.page_table.shape[0] != num_indexer_rows: + raise ValueError( + "DSv4 C4 indexer page-table row mismatch: " + f"page_table_rows={indexer_metadata.page_table.shape[0]} " + f"q_rows={num_indexer_rows}" + ) + c4_seq_lens = _normalize_indexer_seq_lens( + indexer_metadata.c4_seq_lens, + expected_rows=num_indexer_rows, + ) + c4_seq_lens_for_logits = ( + c4_seq_lens.unsqueeze(-1) if deep_gemm_logits_backend else c4_seq_lens + ) logits = fn( q_fp8, c4_indexer_kv_cache, weights, - _c4sl, + c4_seq_lens_for_logits, indexer_metadata.page_table, indexer_metadata.deep_gemm_metadata, indexer_metadata.max_c4_seq_len, @@ -423,7 +466,7 @@ def forward_c4_indexer( if envs.SGLANG_TOPK_TRANSFORM_512_TORCH.get(): topk_transform_512_pytorch_vectorized( logits, - indexer_metadata.c4_seq_lens, + c4_seq_lens, core_metadata.page_table, core_metadata.c4_sparse_page_indices, indexer_metadata.c4_page_size, @@ -432,7 +475,7 @@ def forward_c4_indexer( elif envs.SGLANG_OPT_USE_TOPK_V2.get() and raw_indices is None: topk_transform_512_v2( logits, - indexer_metadata.c4_seq_lens, + c4_seq_lens, core_metadata.page_table, core_metadata.c4_sparse_page_indices, indexer_metadata.c4_page_size, @@ -441,7 +484,7 @@ def forward_c4_indexer( else: topk_transform_512( logits, - indexer_metadata.c4_seq_lens, + c4_seq_lens, core_metadata.page_table, core_metadata.c4_sparse_page_indices, indexer_metadata.c4_page_size, @@ -455,7 +498,7 @@ def forward_c4_indexer( core_metadata.c4_sparse_page_indices = ( hisparse_coordinator.swap_in_selected_pages( req_pool_indices=forward_batch.req_pool_indices, - compressed_seq_lens=indexer_metadata.c4_seq_lens, + compressed_seq_lens=c4_seq_lens, top_k_result=raw_indices, layer_id=compress_layer_id, ) diff --git a/python/sglang/srt/layers/moe/hash_topk.py b/python/sglang/srt/layers/moe/hash_topk.py index 6b63b286a..0fa3ff661 100644 --- a/python/sglang/srt/layers/moe/hash_topk.py +++ b/python/sglang/srt/layers/moe/hash_topk.py @@ -35,6 +35,9 @@ def __init__( self.num_experts = num_experts self.topk = topk self.routed_scaling_factor = routed_scaling_factor + self.apply_routed_scaling_factor_on_output = ( + apply_routed_scaling_factor_on_output + ) self.num_fused_shared_experts = num_fused_shared_experts self.score_func = scoring_func self.tid2eid = nn.Parameter( @@ -42,8 +45,6 @@ def __init__( requires_grad=False, ) - assert not apply_routed_scaling_factor_on_output, "not implemented" - def empty_topk_output(self, device: torch.device): topk = self.topk - self.num_fused_shared_experts topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device) @@ -94,6 +95,9 @@ def _forward_torch( if self.score_func != "softmax": topk_weights[:, :] /= topk_weights[:, :].sum(dim=-1, keepdim=True) + if self.apply_routed_scaling_factor_on_output: + topk_weights *= self.routed_scaling_factor + return topk_weights, topk_ids def forward( @@ -117,6 +121,9 @@ def forward( tid2eid=self.tid2eid, num_fused_shared_experts=self.num_fused_shared_experts, routed_scaling_factor=self.routed_scaling_factor, + apply_routed_scaling_factor_on_output=( + self.apply_routed_scaling_factor_on_output + ), scoring_func=self.score_func, ) else: diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index b5663e44b..a4850dd2a 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -1312,7 +1312,6 @@ def select_experts( scoring_func=scoring_func, ) elif custom_routing_function is None: - assert not apply_routed_scaling_factor_on_output, "Not implemented" if scoring_func == "sqrtsoftplus": _biased_topk = ( biased_topk_jit_kernel_impl @@ -1346,6 +1345,7 @@ def select_experts( renormalize=renormalize, ) else: + assert not apply_routed_scaling_factor_on_output, "Not implemented" # Qwen3MOE uses fused_topk topk_weights, topk_ids = fused_topk( hidden_states=hidden_states, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 7b04407c6..3ce7cbe8d 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +import os from enum import IntEnum from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -117,6 +118,138 @@ class ActivationType(IntEnum): logger = logging.getLogger(__name__) _B12X_MOE_WORKSPACE_POOLS: dict[int, Any] = {} +_B12X_FP4_GEMM_CALLS = 0 +_B12X_FP4_GEMM_SHAPES: dict[tuple[int, int, int], int] = {} +_B12X_FP4_GEMM_LOGGED_SHAPES: set[tuple[int, int, int]] = set() +_B12X_FP4_LINEAR_AUDIT: dict[str, dict[str, Any]] = {} +_B12X_FP4_GEMM_MODULE_CALLS: dict[str, dict[str, Any]] = {} +_B12X_FP4_GEMM_RECENT_CALLS: list[dict[str, Any]] = [] +_B12X_FP4_GEMM_RECENT_CALL_LIMIT = 32 + + +def _shape_list(tensor: torch.Tensor | None) -> list[int] | None: + if tensor is None: + return None + return [int(dim) for dim in tensor.shape] + + +def _is_current_stream_capturing() -> bool: + try: + return bool(torch.cuda.is_current_stream_capturing()) + except Exception: + return False + + +def _record_fp4_linear_audit( + prefix: str, + *, + layer: torch.nn.Module, + phase: str, + extra: dict[str, Any] | None = None, +) -> None: + if os.environ.get("B12X_FP4_GEMM_AUDIT", "1") == "0": + return + record = _B12X_FP4_LINEAR_AUDIT.setdefault( + prefix, + { + "prefix": prefix, + "quant_method": "ModelOptFp4LinearMethod", + }, + ) + record.update( + { + "phase": phase, + "input_size_per_partition": int( + getattr(layer, "input_size_per_partition", 0) + ), + "output_size_per_partition": int( + getattr(layer, "output_size_per_partition", 0) + ), + "weight_shape": _shape_list(getattr(layer, "weight", None)), + "weight_scale_shape": _shape_list(getattr(layer, "weight_scale", None)), + "weight_scale_interleaved_shape": _shape_list( + getattr(layer, "weight_scale_interleaved", None) + ), + "weights_padding_cols": int( + getattr(layer, "weights_padding_cols", 0) + ), + } + ) + if extra: + record.update(extra) + + +def _record_b12x_fp4_linear_call( + prefix: str, + *, + input: torch.Tensor, + weight: torch.Tensor, + input_sf: torch.Tensor, + weight_sf: torch.Tensor, + output_dtype: torch.dtype, + output_features: int, +) -> None: + if os.environ.get("B12X_FP4_GEMM_AUDIT", "1") == "0": + return + m = int(input.shape[0]) + k = int(input.shape[1]) * 2 + padded_n = int(weight.shape[0]) + original_n = int(output_features) + padded_m = ((m + 127) // 128) * 128 + shape_key = f"M={m},N={original_n},K={k}" + module_record = _B12X_FP4_GEMM_MODULE_CALLS.setdefault( + prefix, + { + "prefix": prefix, + "calls": 0, + "shapes": {}, + }, + ) + module_record["calls"] = int(module_record["calls"]) + 1 + module_record["shapes"][shape_key] = int( + module_record["shapes"].get(shape_key, 0) + ) + 1 + call = { + "prefix": prefix, + "shape": shape_key, + "padded_m": padded_m, + "padded_n": ((padded_n + 127) // 128) * 128, + "weight_rows": padded_n, + "output_features": original_n, + "input_sf_shape": _shape_list(input_sf), + "weight_sf_shape": _shape_list(weight_sf), + "input_dtype": str(input.dtype), + "weight_dtype": str(weight.dtype), + "output_dtype": str(output_dtype), + "capturing": _is_current_stream_capturing(), + } + module_record["last_call"] = call + _B12X_FP4_GEMM_RECENT_CALLS.append(call) + if len(_B12X_FP4_GEMM_RECENT_CALLS) > _B12X_FP4_GEMM_RECENT_CALL_LIMIT: + del _B12X_FP4_GEMM_RECENT_CALLS[ + : len(_B12X_FP4_GEMM_RECENT_CALLS) - _B12X_FP4_GEMM_RECENT_CALL_LIMIT + ] + + +def get_b12x_fp4_gemm_debug_counters() -> dict[str, Any]: + return { + "calls": int(_B12X_FP4_GEMM_CALLS), + "shapes": { + f"M={m},N={n},K={k}": int(count) + for (m, n, k), count in sorted(_B12X_FP4_GEMM_SHAPES.items()) + }, + "linear_audit": { + key: _B12X_FP4_LINEAR_AUDIT[key] + for key in sorted(_B12X_FP4_LINEAR_AUDIT) + }, + "module_calls": { + key: _B12X_FP4_GEMM_MODULE_CALLS[key] + for key in sorted(_B12X_FP4_GEMM_MODULE_CALLS) + }, + "recent_calls": list(_B12X_FP4_GEMM_RECENT_CALLS), + } + + def _get_b12x_workspace_pool(device: torch.device): """Return the process-local b12x workspace pool for this device. @@ -259,12 +392,30 @@ def _b12x_fp4_gemm( We infer the padded 2D shape from the total element count and reshape into the 6D MMA view b12x expects. """ + global _B12X_FP4_GEMM_CALLS + from b12x.gemm.dense import dense_gemm from b12x.quant.expert_fp4 import _as_grouped_scale_view M_orig = input.shape[0] K = input.shape[1] * 2 # FP4 packed N_orig = weight.shape[0] + shape_key = (int(M_orig), int(N_orig), int(K)) + _B12X_FP4_GEMM_CALLS += 1 + _B12X_FP4_GEMM_SHAPES[shape_key] = _B12X_FP4_GEMM_SHAPES.get(shape_key, 0) + 1 + if ( + os.environ.get("B12X_FP4_GEMM_TRACE", "0") != "0" + and shape_key not in _B12X_FP4_GEMM_LOGGED_SHAPES + ): + logger.warning( + "b12x dense FP4 GEMM active: M=%d N=%d K=%d input_sf=%s weight_sf=%s", + M_orig, + N_orig, + K, + tuple(input_sf.shape), + tuple(weight_sf.shape), + ) + _B12X_FP4_GEMM_LOGGED_SHAPES.add(shape_key) # b12x requires M and N divisible by 128 def _pad128(x, dim): @@ -1414,12 +1565,22 @@ def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config: ) def get_quant_method(self, layer: torch.nn.Module, prefix: str): - return self._get_quant_method( - layer, - prefix, - Linear=ModelOptFp4LinearMethod, - Moe=ModelOptNvFp4FusedMoEMethod, - ) + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + + if isinstance(layer, LinearBase): + if is_layer_skipped( + prefix, self.exclude_modules, self.packed_modules_mapping + ) or self.is_layer_excluded(prefix): + return UnquantizedLinearMethod() + return ModelOptFp4LinearMethod(self, prefix=prefix) + if self.kv_cache_quant_algo and isinstance(layer, RadixAttention): + return ModelOptFp8KVCacheMethod(self) + if isinstance(layer, FusedMoE): + if self.is_layer_excluded(prefix): + return None + return ModelOptNvFp4FusedMoEMethod(self) + return None class ModelOptFp4LinearMethod(LinearMethodBase): @@ -1437,8 +1598,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase): Args: quant_config: The ModelOpt quantization config. """ - def __init__(self, quant_config: ModelOptFp4Config): + def __init__(self, quant_config: ModelOptFp4Config, prefix: str | None = None): self.quant_config = quant_config + self.prefix = prefix or "" def create_weights( self, @@ -1515,6 +1677,14 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + _record_fp4_linear_audit( + self.prefix, + layer=layer, + phase="preprocess", + extra={ + "runner_backend": str(get_fp4_gemm_runner_backend()), + }, + ) input_scale_2 = layer.input_scale.max().to(torch.float32) weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) @@ -1578,6 +1748,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: copy_or_rebind_param(layer, "weight_scale_interleaved", scale) copy_or_rebind_param(layer, "weight", weight) layer.weights_padding_cols = weights_padding_cols + _record_fp4_linear_audit( + self.prefix, + layer=layer, + phase="postprocess_flashinfer_trtllm", + extra={ + "runner_backend": str(get_fp4_gemm_runner_backend()), + }, + ) return # Pad weights for CUTLASS/FlashInfer kernel alignment (K and N divisible by 32) @@ -1608,6 +1786,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else padded_scales.reshape(B, M_padded, K_padded) ) copy_or_rebind_param(layer, "weight_scale_interleaved", padded_scales) + _record_fp4_linear_audit( + self.prefix, + layer=layer, + phase="postprocess_cutlass_b12x_layout", + extra={ + "runner_backend": str(get_fp4_gemm_runner_backend()), + }, + ) def apply( self, @@ -1643,6 +1829,16 @@ def apply( ): w = layer.weight.T w_scale_interleaved = layer.weight_scale_interleaved.T + if fp4_backend.is_b12x(): + _record_b12x_fp4_linear_call( + self.prefix, + input=x_fp4, + weight=w, + input_sf=x_scale_interleaved, + weight_sf=w_scale_interleaved, + output_dtype=output_dtype, + output_features=output_size, + ) out = fp4_gemm( x_fp4, @@ -2294,6 +2490,10 @@ def apply( output=symm_output, input_scales_are_reciprocal=True, input_scales_static=True, + tp_size=layer.moe_tp_size, + tp_rank=layer.moe_tp_rank, + ep_size=layer.moe_ep_size, + ep_rank=layer.moe_ep_rank, ).to(x.dtype) from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/managers/prefill_delayer.py b/python/sglang/srt/managers/prefill_delayer.py index 1eb21ae64..268e96ece 100644 --- a/python/sglang/srt/managers/prefill_delayer.py +++ b/python/sglang/srt/managers/prefill_delayer.py @@ -197,15 +197,18 @@ def _negotiate_should_allow_prefill_pure( global_waiting_queue_max = int(global_waiting_queue_len.max().item()) # Queue-based trigger: delay prefill until the waiting queue - # reaches queue_min = min(running_req * ratio, max_prefill_bs), + # reaches queue_min = min(running_req * ratio, prefill_cap), # capped by a wall-clock timeout to bound worst-case TTFT. # Targets workloads where decode requests finish one-at-a-time # and fragment prefill into many tiny batches. queue_condition = False if self._queue_trigger_enabled and global_running_batch_max > 0: + queue_prefill_cap = global_max_prefill_bs_max + if queue_prefill_cap <= 0: + queue_prefill_cap = max_running_requests queue_min_effective = min( int(global_running_batch_max * self._queue_min_ratio), - global_max_prefill_bs_max, + queue_prefill_cap, ) queue_condition = ( queue_min_effective > 0 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8f4dfe996..950aa0b32 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -14,6 +14,7 @@ """A scheduler that manages a tensor parallel GPU worker.""" import faulthandler +import json import logging import os import signal @@ -23,6 +24,7 @@ from contextlib import nullcontext from dataclasses import dataclass from http import HTTPStatus +from pathlib import Path from typing import Any, Deque, Dict, List, Optional, Tuple, Union from sglang.srt.utils.common import suppress_noisy_warnings @@ -258,6 +260,164 @@ class SchedulerMlxOverlapMixin: _is_npu = is_npu() +_DSV4_SCHED_TRACE_COUNT = 0 + + +def _dsv4_scheduler_trace_max_records() -> int: + raw = os.environ.get("SGLANG_DSV4_SCHED_TRACE_MAX_RECORDS", "0") or "0" + try: + return int(raw) + except ValueError: + logger.warning("Ignoring invalid SGLANG_DSV4_SCHED_TRACE_MAX_RECORDS=%r", raw) + return 0 + + +def _dsv4_scheduler_trace_path(scheduler: "Scheduler") -> Optional[Path]: + raw = os.environ.get("SGLANG_DSV4_SCHED_TRACE_PATH", "").strip() + if not raw: + return None + rank = ( + os.environ.get("RANK") + or os.environ.get("LOCAL_RANK") + or str(getattr(scheduler, "tp_rank", "x")) + ) + pid = os.getpid() + if "{rank}" in raw or "{pid}" in raw: + return Path(raw.format(rank=rank, pid=pid)) + path = Path(raw) + if path.suffix == ".jsonl": + return path.with_name(f"{path.stem}.rank{rank}.pid{pid}{path.suffix}") + return path / f"dsv4_scheduler_trace.rank{rank}.pid{pid}.jsonl" + + +def _dsv4_forward_mode_name(batch: Optional["ScheduleBatch"]) -> Optional[str]: + if batch is None: + return None + mode = getattr(batch, "forward_mode", None) + return str(getattr(mode, "name", mode)) + + +def _dsv4_tensor_shape(value: Any) -> Optional[List[int]]: + if isinstance(value, torch.Tensor): + return [int(dim) for dim in value.shape] + return None + + +def _dsv4_batch_trace_payload(batch: Optional["ScheduleBatch"]) -> Dict[str, Any]: + if batch is None: + return { + "has_batch": False, + "forward_mode": None, + "batch_size": 0, + } + + reqs = getattr(batch, "reqs", None) or [] + seq_lens_cpu = getattr(batch, "seq_lens_cpu", None) + seq_lens_summary: Dict[str, Any] = {} + if isinstance(seq_lens_cpu, torch.Tensor) and seq_lens_cpu.device.type == "cpu": + if seq_lens_cpu.numel() > 0: + seq_lens_summary = { + "seq_lens_min": int(seq_lens_cpu.min().item()), + "seq_lens_max": int(seq_lens_cpu.max().item()), + "seq_lens_sum_cpu": int(seq_lens_cpu.sum().item()), + } + else: + seq_lens_summary = { + "seq_lens_min": None, + "seq_lens_max": None, + "seq_lens_sum_cpu": 0, + } + + try: + is_spec_v2 = bool(getattr(batch, "is_spec_v2", False)) + except Exception: # noqa: BLE001 + is_spec_v2 = None + + return { + "has_batch": True, + "forward_mode": _dsv4_forward_mode_name(batch), + "forward_iter": getattr(batch, "forward_iter", None), + "batch_size": len(reqs), + "seq_lens_sum": getattr(batch, "seq_lens_sum", None), + "is_spec_v2": is_spec_v2, + "has_stream": bool(getattr(batch, "has_stream", False)), + "return_logprob": bool(getattr(batch, "return_logprob", False)), + "req_output_lens_sample": [ + len(getattr(req, "output_ids", [])) for req in reqs[:8] + ], + **seq_lens_summary, + } + + +def _dsv4_result_trace_payload(result: Any) -> Dict[str, Any]: + if result is None: + return {"has_result": False} + next_token_ids = getattr(result, "next_token_ids", None) + next_draft_input = getattr(result, "next_draft_input", None) + return { + "has_result": True, + "result_type": type(result).__name__, + "next_token_ids_shape": _dsv4_tensor_shape(next_token_ids), + "has_delay_sample_func": getattr(result, "delay_sample_func", None) + is not None, + "has_next_draft_input": next_draft_input is not None, + "copy_done_type": ( + type(getattr(result, "copy_done", None)).__name__ + if getattr(result, "copy_done", None) is not None + else None + ), + } + + +def _trace_dsv4_scheduler_event( + scheduler: "Scheduler", + event: str, + *, + batch: Optional["ScheduleBatch"] = None, + result: Any = None, + elapsed_ms: Optional[float] = None, + **extra: Any, +) -> None: + path = _dsv4_scheduler_trace_path(scheduler) + if path is None: + return + + global _DSV4_SCHED_TRACE_COUNT + max_records = _dsv4_scheduler_trace_max_records() + if max_records > 0 and _DSV4_SCHED_TRACE_COUNT >= max_records: + return + record_index = _DSV4_SCHED_TRACE_COUNT + _DSV4_SCHED_TRACE_COUNT += 1 + + payload = { + "event": event, + "pid": os.getpid(), + "rank": os.environ.get("RANK") or getattr(scheduler, "tp_rank", None), + "local_rank": os.environ.get("LOCAL_RANK"), + "record_index": record_index, + "time": time.time(), + "monotonic_ns": time.monotonic_ns(), + "tp_rank": getattr(scheduler, "tp_rank", None), + "tp_size": getattr(scheduler, "tp_size", None), + "dp_rank": getattr(scheduler, "dp_rank", None), + "attn_tp_rank": getattr(scheduler, "attn_tp_rank", None), + "attn_tp_size": getattr(scheduler, "attn_tp_size", None), + "enable_overlap": bool(getattr(scheduler, "enable_overlap", False)), + "spec_algorithm": str(getattr(scheduler, "spec_algorithm", None)), + } + if elapsed_ms is not None: + payload["elapsed_ms"] = float(elapsed_ms) + payload.update(_dsv4_batch_trace_payload(batch)) + payload.update(_dsv4_result_trace_payload(result)) + payload.update(extra) + + try: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload, sort_keys=True) + "\n") + except Exception: # noqa: BLE001 + logger.exception("Failed to write DSv4 scheduler trace record") + @dataclass class EmbeddingBatchResult: @@ -2971,8 +3131,10 @@ def run_batch( pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Union[GenerationBatchResult, EmbeddingBatchResult]: """Run a batch.""" + run_start_ns = time.perf_counter_ns() self.forward_ct += 1 batch.forward_iter = self.forward_ct + _trace_dsv4_scheduler_event(self, "run_batch_start", batch=batch) # Whether to run the profiler self._profile_batch_predicate(batch) @@ -2982,7 +3144,16 @@ def run_batch( # Place holder handling for pd-disagg decode event loop if batch.forward_mode.is_prebuilt(): - return self._run_batch_prebuilt(batch) + ret = self._run_batch_prebuilt(batch) + _trace_dsv4_scheduler_event( + self, + "run_batch_end", + batch=batch, + result=ret, + elapsed_ms=(time.perf_counter_ns() - run_start_ns) / 1e6, + branch="prebuilt", + ) + return ret # Run forward if self.is_generation: @@ -3007,16 +3178,55 @@ def run_batch( with self.forward_stream_ctx: self.forward_stream.wait_stream(self.schedule_stream) + resolve_start_ns = time.perf_counter_ns() self.future_map.resolve_future(model_worker_batch) + _trace_dsv4_scheduler_event( + self, + "future_resolve_end", + batch=batch, + elapsed_ms=(time.perf_counter_ns() - resolve_start_ns) / 1e6, + branch="generation_overlap", + ) + forward_start_ns = time.perf_counter_ns() + _trace_dsv4_scheduler_event( + self, + "model_forward_start", + batch=batch, + branch="generation_overlap", + ) batch_result = self.model_worker.forward_batch_generation( model_worker_batch # here pp is not compatible with overlap ) + _trace_dsv4_scheduler_event( + self, + "model_forward_end", + batch=batch, + result=batch_result, + elapsed_ms=(time.perf_counter_ns() - forward_start_ns) / 1e6, + branch="generation_overlap", + ) # FIXME(lsyin): maybe move this to forward_batch_generation batch_result.copy_done = self.device_module.Event() if batch_result.delay_sample_func is None: self.future_map.store_to_map(future_indices, batch_result) + copy_start_ns = time.perf_counter_ns() + _trace_dsv4_scheduler_event( + self, + "copy_to_cpu_start", + batch=batch, + result=batch_result, + branch="generation_overlap", + ) batch_result.copy_to_cpu(return_logprob=batch.return_logprob) + _trace_dsv4_scheduler_event( + self, + "copy_to_cpu_end", + batch=batch, + result=batch_result, + elapsed_ms=(time.perf_counter_ns() - copy_start_ns) / 1e6, + branch="generation_overlap", + ) else: batch_result.future_indices = future_indices @@ -3034,7 +3244,22 @@ def run_batch( # Current implementation strictly synchronizes the seq_lens batch.seq_lens = batch_result.next_draft_input.new_seq_lens elif self.enable_pdmux and batch.forward_mode.is_split_prefill(): + forward_start_ns = time.perf_counter_ns() + _trace_dsv4_scheduler_event( + self, + "model_forward_start", + batch=batch, + branch="split_prefill", + ) batch_result = self.tp_worker.forward_batch_split_prefill(batch) + _trace_dsv4_scheduler_event( + self, + "model_forward_end", + batch=batch, + result=batch_result, + elapsed_ms=(time.perf_counter_ns() - forward_start_ns) / 1e6, + branch="split_prefill", + ) future_indices_or_next_token_ids = batch_result.next_token_ids else: kwargs = ( @@ -3042,9 +3267,24 @@ def run_batch( if self.spec_algorithm.is_none() else {} ) + forward_start_ns = time.perf_counter_ns() + _trace_dsv4_scheduler_event( + self, + "model_forward_start", + batch=batch, + branch="generation", + ) batch_result = self.model_worker.forward_batch_generation( worker_batch_or_batch, **kwargs ) + _trace_dsv4_scheduler_event( + self, + "model_forward_end", + batch=batch, + result=batch_result, + elapsed_ms=(time.perf_counter_ns() - forward_start_ns) / 1e6, + branch="generation", + ) future_indices_or_next_token_ids = batch_result.next_token_ids self.update_cache_from_scheduler(batch, batch_result) @@ -3076,18 +3316,46 @@ def run_batch( self.record_batch_in_overlap(model_worker_batch) with self.forward_stream_ctx: self.forward_stream.wait_stream(self.schedule_stream) + forward_start_ns = time.perf_counter_ns() + _trace_dsv4_scheduler_event( + self, + "model_forward_start", + batch=batch, + branch="embedding_overlap", + ) pooler_output = self.tp_worker.forward_batch_embedding( model_worker_batch ) + _trace_dsv4_scheduler_event( + self, + "model_forward_end", + batch=batch, + elapsed_ms=(time.perf_counter_ns() - forward_start_ns) / 1e6, + branch="embedding_overlap", + ) ret = EmbeddingBatchResult( embeddings=pooler_output.embeddings, pooled_hidden_states=pooler_output.pooled_hidden_states, ) ret.copy_to_cpu() else: + forward_start_ns = time.perf_counter_ns() + _trace_dsv4_scheduler_event( + self, + "model_forward_start", + batch=batch, + branch="embedding", + ) pooler_output = self.tp_worker.forward_batch_embedding( model_worker_batch ) + _trace_dsv4_scheduler_event( + self, + "model_forward_end", + batch=batch, + elapsed_ms=(time.perf_counter_ns() - forward_start_ns) / 1e6, + branch="embedding", + ) ret = EmbeddingBatchResult( embeddings=pooler_output.embeddings, pooled_hidden_states=pooler_output.pooled_hidden_states, @@ -3106,6 +3374,13 @@ def run_batch( ActiveRanksOutput(status=dp_active_ranks.tolist()) ) + _trace_dsv4_scheduler_event( + self, + "run_batch_end", + batch=batch, + result=ret, + elapsed_ms=(time.perf_counter_ns() - run_start_ns) / 1e6, + ) return ret def launch_batch_sample_if_needed( @@ -3116,12 +3391,26 @@ def launch_batch_sample_if_needed( if batch_result is None or batch_result.delay_sample_func is None: return + sample_start_ns = time.perf_counter_ns() + _trace_dsv4_scheduler_event( + self, + "delay_sample_start", + batch=self.cur_batch, + result=batch_result, + ) with self.forward_stream_ctx: self.forward_stream.wait_stream(self.schedule_stream) _batch_result = batch_result.delay_sample_func() assert _batch_result is batch_result self.future_map.store_to_map(batch_result.future_indices, batch_result) batch_result.copy_to_cpu(return_logprob=self.cur_batch.return_logprob) + _trace_dsv4_scheduler_event( + self, + "delay_sample_end", + batch=self.cur_batch, + result=batch_result, + elapsed_ms=(time.perf_counter_ns() - sample_start_ns) / 1e6, + ) # Release the closure and large GPU tensors that are no longer needed. # The delay_sample_func closure captures forward_batch (which holds @@ -3138,6 +3427,13 @@ def process_batch_result( batch: ScheduleBatch, result: Union[GenerationBatchResult, EmbeddingBatchResult], ): + process_start_ns = time.perf_counter_ns() + _trace_dsv4_scheduler_event( + self, + "process_batch_result_start", + batch=batch, + result=result, + ) if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result) elif batch.forward_mode.is_extend(): @@ -3156,6 +3452,13 @@ def process_batch_result( self._maybe_clear_mm_inputs(batch) self.maybe_send_health_check_signal() self.update_device_timer() + _trace_dsv4_scheduler_event( + self, + "process_batch_result_end", + batch=batch, + result=result, + elapsed_ms=(time.perf_counter_ns() - process_start_ns) / 1e6, + ) def maybe_send_health_check_signal(self): if self.return_health_check_ipcs: @@ -3439,6 +3742,8 @@ def get_internal_state(self, recv_req: GetInternalStateReq): "graph": round(self.tp_worker.model_runner.graph_mem_usage, 2), } ret["effective_max_running_requests_per_dp"] = self.max_running_requests + ret["simulate_acc_len"] = envs.SGLANG_SIMULATE_ACC_LEN.get() + ret["simulate_acc_method"] = envs.SGLANG_SIMULATE_ACC_METHOD.get() if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0: ret["avg_spec_accept_length"] = ( @@ -3448,6 +3753,44 @@ def get_internal_state(self, recv_req: GetInternalStateReq): if RECORD_STEP_TIME: ret["step_time_dict"] = self.step_time_dict + if os.environ.get("B12X_DEBUG_COUNTERS_IN_SERVER_INFO", "1") != "0": + try: + from b12x.integration.tp_moe import get_tp_moe_debug_counters + + ret["b12x_tp_moe_debug_counters"] = get_tp_moe_debug_counters() + except Exception as exc: # noqa: BLE001 + ret["b12x_tp_moe_debug_counters_error"] = repr(exc) + try: + from sglang.srt.layers.quantization.modelopt_quant import ( + get_b12x_fp4_gemm_debug_counters, + ) + + ret["b12x_fp4_gemm_debug_counters"] = ( + get_b12x_fp4_gemm_debug_counters() + ) + except Exception as exc: # noqa: BLE001 + ret["b12x_fp4_gemm_debug_counters_error"] = repr(exc) + try: + from sglang.srt.layers.attention.deepseek_v4_backend import ( + get_b12x_dsv4_mla_debug_counters, + ) + + ret["b12x_dsv4_mla_debug_counters"] = ( + get_b12x_dsv4_mla_debug_counters() + ) + except Exception as exc: # noqa: BLE001 + ret["b12x_dsv4_mla_debug_counters_error"] = repr(exc) + try: + from b12x.integration.dsv4_sparse_mla_decode_patch import ( + get_dsv4_sparse_mla_patch_counters, + ) + + ret["b12x_dsv4_mla_patch_counters"] = ( + get_dsv4_sparse_mla_patch_counters() + ) + except Exception as exc: # noqa: BLE001 + ret["b12x_dsv4_mla_patch_counters_error"] = repr(exc) + # This field is not serializable. ret.pop("model_config", None) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 87981bf3e..eeadb6d4a 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -1,6 +1,10 @@ from __future__ import annotations +import json import logging +import os +from pathlib import Path +import time from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch @@ -39,6 +43,108 @@ # How often (in decoded tokens) the scheduler force-flushes an intermediate # output batch for non-streaming requests. DEFAULT_FORCE_STREAM_INTERVAL = envs.SGLANG_FORCE_STREAM_INTERVAL.get() +_DSV4_STREAM_TRACE_COUNT = 0 + + +def _dsv4_stream_trace_rank(scheduler: Optional["Scheduler"] = None) -> str: + return ( + os.environ.get("RANK") + or os.environ.get("LOCAL_RANK") + or str(getattr(scheduler, "tp_rank", "x")) + ) + + +def _dsv4_stream_trace_path(scheduler: Optional["Scheduler"] = None) -> Optional[Path]: + raw = os.environ.get("SGLANG_DSV4_STREAM_TRACE_PATH", "").strip() + if not raw: + return None + rank = _dsv4_stream_trace_rank(scheduler) + pid = os.getpid() + if "{rank}" in raw or "{pid}" in raw: + return Path(raw.format(rank=rank, pid=pid)) + path = Path(raw) + if path.suffix == ".jsonl": + return path.with_name(f"{path.stem}.rank{rank}.pid{pid}{path.suffix}") + return path / f"dsv4_stream_trace.rank{rank}.pid{pid}.jsonl" + + +def _dsv4_stream_trace_max_records() -> int: + raw = os.environ.get("SGLANG_DSV4_STREAM_TRACE_MAX_RECORDS", "0") or "0" + try: + return int(raw) + except ValueError: + logger.warning("Ignoring invalid SGLANG_DSV4_STREAM_TRACE_MAX_RECORDS=%r", raw) + return 0 + + +def _trace_dsv4_stream_output( + scheduler: "Scheduler", + *, + input_req_count: int, + is_idle_batch: bool, + rids: List[str], + output_ids: List[List[int]], + completion_tokens: List[int], + finished_reasons: List[Optional[dict]], + spec_verify_ct: List[int], + spec_accepted_drafts: List[int], +) -> None: + path = _dsv4_stream_trace_path(scheduler) + if path is None: + return + + global _DSV4_STREAM_TRACE_COUNT + max_records = _dsv4_stream_trace_max_records() + if max_records > 0 and _DSV4_STREAM_TRACE_COUNT >= max_records: + return + record_index = _DSV4_STREAM_TRACE_COUNT + _DSV4_STREAM_TRACE_COUNT += 1 + + emitted = [] + for idx, rid in enumerate(rids): + emitted.append( + { + "rid": str(rid), + "delta_tokens": len(output_ids[idx]) if idx < len(output_ids) else 0, + "completion_tokens": ( + int(completion_tokens[idx]) if idx < len(completion_tokens) else None + ), + "finished_reason": ( + finished_reasons[idx] if idx < len(finished_reasons) else None + ), + "spec_verify_ct": ( + int(spec_verify_ct[idx]) if idx < len(spec_verify_ct) else None + ), + "spec_accepted_drafts": ( + int(spec_accepted_drafts[idx]) + if idx < len(spec_accepted_drafts) + else None + ), + } + ) + + payload = { + "event": "dsv4_stream_output", + "pid": os.getpid(), + "rank": os.environ.get("RANK") or getattr(scheduler, "tp_rank", None), + "local_rank": os.environ.get("LOCAL_RANK"), + "record_index": record_index, + "time": time.time(), + "tp_rank": getattr(scheduler, "tp_rank", None), + "tp_size": getattr(scheduler, "tp_size", None), + "dp_rank": getattr(scheduler, "dp_rank", None), + "attn_tp_rank": getattr(scheduler, "attn_tp_rank", None), + "input_req_count": int(input_req_count), + "is_idle_batch": bool(is_idle_batch), + "emitted_count": len(emitted), + "emitted": emitted, + } + try: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload, sort_keys=True) + "\n") + except Exception: # noqa: BLE001 + logger.exception("Failed to write DSv4 stream trace record") class SchedulerOutputProcessorMixin: @@ -1302,6 +1408,17 @@ def stream_output_generation( req.log_time_stats() dp_ranks = [self.dp_rank] * len(rids) if rids else None + _trace_dsv4_stream_output( + self, + input_req_count=len(reqs), + is_idle_batch=is_idle_batch, + rids=rids, + output_ids=output_ids, + completion_tokens=completion_tokens, + finished_reasons=finished_reasons, + spec_verify_ct=spec_verify_ct, + spec_accepted_drafts=spec_accepted_drafts, + ) # Send to detokenizer if reqs or is_idle_batch: diff --git a/python/sglang/srt/managers/scheduler_recv_skipper.py b/python/sglang/srt/managers/scheduler_recv_skipper.py index 69c3e19a5..7f13a2f21 100644 --- a/python/sglang/srt/managers/scheduler_recv_skipper.py +++ b/python/sglang/srt/managers/scheduler_recv_skipper.py @@ -20,6 +20,8 @@ def __init__(self, server_args: ServerArgs): self._weight_of_forward_mode = { ForwardMode.DECODE: envs.SGLANG_SCHEDULER_RECV_SKIPPER_WEIGHT_DECODE.get(), ForwardMode.TARGET_VERIFY: envs.SGLANG_SCHEDULER_RECV_SKIPPER_WEIGHT_TARGET_VERIFY.get(), + ForwardMode.DRAFT_EXTEND: envs.SGLANG_SCHEDULER_RECV_SKIPPER_WEIGHT_TARGET_VERIFY.get(), + ForwardMode.DRAFT_EXTEND_V2: envs.SGLANG_SCHEDULER_RECV_SKIPPER_WEIGHT_TARGET_VERIFY.get(), None: envs.SGLANG_SCHEDULER_RECV_SKIPPER_WEIGHT_NONE.get(), } diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index 75b59d0c9..157f4538d 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -443,6 +443,7 @@ def alloc_extend( ) else: self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices + self._allocated_mask[alloc_full_indices] = True return alloc_full_indices @@ -471,6 +472,7 @@ def alloc_decode( ) else: self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices + self._allocated_mask[alloc_full_indices] = True return alloc_full_indices @@ -511,6 +513,7 @@ def set_full_to_swa_mapping( ) else: self.full_to_swa_index_mapping[full_indices] = swa_indices + self._allocated_mask[full_indices] = True def free_swa(self, free_index: torch.Tensor): swa_indices = self.full_to_swa_index_mapping[free_index] @@ -525,15 +528,20 @@ def backup_state(self): self.full_attn_allocator.backup_state(), self.swa_attn_allocator.backup_state(), self.full_to_swa_index_mapping.clone(), + self._allocated_mask.clone(), ] def restore_state(self, state): - assert len(state) == 3 + assert len(state) in (3, 4) self.full_attn_allocator.restore_state(state[0]) self.swa_attn_allocator.restore_state(state[1]) # Restore in-place to preserve shared references (attention backends # hold a reference to full_to_swa_index_mapping via self._kvcache). self.full_to_swa_index_mapping.copy_(state[2]) + if len(state) == 4: + self._allocated_mask.copy_(state[3]) + else: + self._allocated_mask.fill_(False) def clear(self): self.swa_attn_allocator.clear() diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index e5e6b02bd..dabe4096a 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -19,8 +19,10 @@ import contextlib import gc import inspect +import json import logging import os +import time from contextlib import contextmanager from dataclasses import dataclass from functools import partial @@ -103,6 +105,48 @@ logger = logging.getLogger(__name__) +_REPLAY_TRACE_COUNT = 0 + + +def _graph_debug_enabled() -> bool: + return get_bool_env_var("SGLANG_DSV4_GRAPH_DEBUG") or _graph_debug_sync_enabled() + + +def _graph_debug_sync_enabled() -> bool: + return get_bool_env_var("SGLANG_DSV4_GRAPH_DEBUG_SYNC") + + +def _write_replay_trace(record: dict) -> None: + path_template = os.environ.get("SGLANG_DSV4_REPLAY_TRACE_PATH", "") + if not path_template: + return + max_records = int( + os.environ.get("SGLANG_DSV4_REPLAY_TRACE_MAX_RECORDS", "0") or "0" + ) + global _REPLAY_TRACE_COUNT + if max_records > 0 and _REPLAY_TRACE_COUNT >= max_records: + return + _REPLAY_TRACE_COUNT += 1 + try: + rank = get_tensor_model_parallel_rank() + except Exception: # noqa: BLE001 + rank = int(os.environ.get("RANK", "0")) + record = { + "pid": os.getpid(), + "rank": rank, + **record, + } + if "{rank}" in path_template or "{pid}" in path_template: + path = path_template.format(rank=rank, pid=os.getpid()) + else: + path = f"{path_template}.rank{rank}.pid{os.getpid()}.jsonl" + try: + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + with open(path, "a", encoding="utf-8") as handle: + handle.write(json.dumps(record, sort_keys=True) + "\n") + except Exception: # noqa: BLE001 + logger.exception("Failed to write CUDA graph replay trace record") + if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner @@ -851,6 +895,7 @@ def capture(self) -> None: profile_context = empty_context() if self.enable_profile_cuda_graph: profile_context = self._init_profile_context_and_memory_record() + trace_capture = get_bool_env_var("SGLANG_DSV4_GRAPH_CAPTURE_TRACE") def _capture_one_stream(stream_idx: Optional[int] = None): avail_mem = get_available_gpu_memory( @@ -890,10 +935,26 @@ def _capture_one_stream(stream_idx: Optional[int] = None): num_tokens=bs * self.num_tokens_per_bs, tp_group=self.model_runner.tp_group, ) as forward: + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s stream=%s bs=%s variant=%s phase=capture_one_start", + get_tensor_model_parallel_rank(), + stream_idx, + bs, + variant_label, + ) ( graph, output_buffers, ) = self.capture_one_batch_size(bs, forward, stream_idx) + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s stream=%s bs=%s variant=%s phase=capture_one_done", + get_tensor_model_parallel_rank(), + stream_idx, + bs, + variant_label, + ) key = _default_make_graph_key(bs, stream_idx, variant_label) self.graphs[key] = graph self.output_buffers[key] = output_buffers @@ -903,16 +964,38 @@ def _capture_one_stream(stream_idx: Optional[int] = None): # can reuse the memory pool allocated for the large shapes. with freeze_gc(self.model_runner.server_args.enable_cudagraph_gc): if not self.enable_pdmux: + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s phase=graph_context_enter_start", + get_tensor_model_parallel_rank(), + ) with graph_capture() as graph_capture_context, profile_context as prof: + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s phase=graph_context_enter_done", + get_tensor_model_parallel_rank(), + ) self.stream = graph_capture_context.stream _capture_one_stream() else: set_pdmux_status(False) for i, sg in enumerate(self.stream_groups): + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s stream=%s phase=graph_context_enter_start", + get_tensor_model_parallel_rank(), + i, + ) with ( graph_capture(stream=sg[1]) as graph_capture_context, profile_context as prof, ): + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s stream=%s phase=graph_context_enter_done", + get_tensor_model_parallel_rank(), + i, + ) self.stream = graph_capture_context.stream _capture_one_stream(i) @@ -964,6 +1047,15 @@ def _create_device_graph(self): def capture_one_batch_size( self, bs: int, forward: Callable, stream_idx: Optional[int] = None ): + trace_capture = get_bool_env_var("SGLANG_DSV4_GRAPH_CAPTURE_TRACE") + rank = get_tensor_model_parallel_rank() + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s stream=%s bs=%s phase=setup_start", + rank, + stream_idx, + bs, + ) buffers: DecodeInputBuffers = self.buffers graph = self._create_device_graph() stream = self.stream @@ -1116,6 +1208,14 @@ def capture_one_batch_size( self.model_runner.lora_manager.prepare_lora_batch(forward_batch) # Attention backend + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s stream=%s bs=%s num_tokens=%s phase=metadata_start", + rank, + stream_idx, + bs, + num_tokens, + ) attn_backend.init_forward_metadata_capture_cuda_graph( bs, num_tokens, @@ -1125,6 +1225,14 @@ def capture_one_batch_size( forward_batch.forward_mode, forward_batch.spec_info, ) + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s stream=%s bs=%s num_tokens=%s phase=metadata_done", + rank, + stream_idx, + bs, + num_tokens, + ) # Run and capture def run_once(): @@ -1171,20 +1279,54 @@ def run_once(): self.buffers.out_cache_loc_swa[:num_tokens] ) - for _ in range(2): + for warmup_idx in range(2): + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s stream=%s bs=%s num_tokens=%s warmup=%s phase=warmup_start", + rank, + stream_idx, + bs, + num_tokens, + warmup_idx, + ) self.device_module.synchronize() self.model_runner.tp_group.barrier() run_once() attn_backend.on_after_cuda_graph_warmup() + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s stream=%s bs=%s num_tokens=%s warmup=%s phase=warmup_done", + rank, + stream_idx, + bs, + num_tokens, + warmup_idx, + ) if get_global_graph_memory_pool() is None: set_global_graph_memory_pool(self.device_module.graph_pool_handle()) # Set graph pool id globally to be able to use symmetric memory set_graph_pool_id(get_global_graph_memory_pool()) + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s stream=%s bs=%s num_tokens=%s phase=capture_graph_start", + rank, + stream_idx, + bs, + num_tokens, + ) out = self._capture_graph( graph, get_global_graph_memory_pool(), stream, run_once ) + if trace_capture: + logger.info( + "[DSV4 graph trace] rank=%s stream=%s bs=%s num_tokens=%s phase=capture_graph_done", + rank, + stream_idx, + bs, + num_tokens, + ) return graph, out @@ -1310,6 +1452,15 @@ def replay( skip_attn_backend_init: bool = False, pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + trace_enabled = bool(os.environ.get("SGLANG_DSV4_REPLAY_TRACE_PATH", "")) + trace_sync_enabled = trace_enabled and bool( + int(os.environ.get("SGLANG_DSV4_REPLAY_TRACE_SYNC", "0") or "0") + ) + phase_trace_enabled = bool(os.environ.get("SGLANG_DSV4_PHASE_TRACE_PATH", "")) + phase_trace_sync_enabled = phase_trace_enabled and bool( + int(os.environ.get("SGLANG_DSV4_PHASE_TRACE_SYNC", "0") or "0") + ) + trace_t0 = time.perf_counter_ns() if trace_enabled else 0 self.deepep_adapter.replay() if not skip_attn_backend_init: @@ -1326,11 +1477,14 @@ def replay( self.buffers.input_embeds[: self.raw_num_token].copy_( forward_batch.input_embeds ) + trace_t1 = time.perf_counter_ns() if trace_enabled else 0 # Replay variant_label = self._resolve_lora_variant(forward_batch) stream_idx = get_current_stream_idx() if self.enable_pdmux else None graph_key = self._make_graph_key(self.bs, stream_idx, variant_label) + debug_enabled = _graph_debug_enabled() + debug_sync_enabled = _graph_debug_sync_enabled() ctx = ( self.model_runner.device_timer.wrap( metadata={ @@ -1340,8 +1494,115 @@ def replay( if self.model_runner.device_timer else contextlib.nullcontext() ) + trace_t2 = time.perf_counter_ns() if trace_enabled else 0 + replay_start = replay_end = None + if trace_enabled: + _write_replay_trace( + { + "event": "cuda_graph_replay_pre", + "forward_mode": forward_batch.forward_mode.name, + "capture_forward_mode": self.capture_forward_mode.name, + "capture_hidden_mode": self.capture_hidden_mode.name, + "skip_attn_backend_init": skip_attn_backend_init, + "raw_bs": int(self.raw_bs), + "bs": int(self.bs), + "raw_num_token": int(self.raw_num_token), + "num_tokens_per_bs": int(self.num_tokens_per_bs), + "graph_key": str(graph_key), + "stream_idx": stream_idx, + "variant_label": variant_label, + "sync_debug": debug_sync_enabled, + } + ) + if debug_enabled: + logger.warning( + "[DSV4 graph debug] target replay pre mode=%s capture_mode=%s " + "raw_bs=%s bs=%s raw_num_token=%s graph_key=%s skip_attn=%s", + forward_batch.forward_mode.name, + self.capture_forward_mode.name, + self.raw_bs, + self.bs, + self.raw_num_token, + graph_key, + skip_attn_backend_init, + ) + if debug_sync_enabled: + torch.cuda.synchronize() + logger.warning( + "[DSV4 graph debug] target replay pre-sync done graph_key=%s", + graph_key, + ) + if trace_sync_enabled or phase_trace_sync_enabled: + replay_start = torch.cuda.Event(enable_timing=True) + replay_end = torch.cuda.Event(enable_timing=True) + replay_start.record() with ctx: self.graphs[graph_key].replay() + if debug_sync_enabled: + torch.cuda.synchronize() + logger.warning( + "[DSV4 graph debug] target replay post-sync done graph_key=%s", + graph_key, + ) + if trace_sync_enabled or phase_trace_sync_enabled: + assert replay_start is not None and replay_end is not None + replay_end.record() + replay_end.synchronize() + replay_device_us = float(replay_start.elapsed_time(replay_end) * 1000.0) + else: + replay_device_us = None + trace_t3 = time.perf_counter_ns() if trace_enabled else 0 + if trace_enabled: + _write_replay_trace( + { + "event": "cuda_graph_replay", + "forward_mode": forward_batch.forward_mode.name, + "capture_forward_mode": self.capture_forward_mode.name, + "capture_hidden_mode": self.capture_hidden_mode.name, + "skip_attn_backend_init": skip_attn_backend_init, + "raw_bs": int(self.raw_bs), + "bs": int(self.bs), + "raw_num_token": int(self.raw_num_token), + "num_tokens_per_bs": int(self.num_tokens_per_bs), + "graph_key": str(graph_key), + "stream_idx": stream_idx, + "variant_label": variant_label, + "entry_ns": trace_t0, + "after_prepare_ns": trace_t1, + "before_launch_ns": trace_t2, + "after_launch_ns": trace_t3, + "prepare_us": (trace_t1 - trace_t0) / 1000.0, + "launch_us": (trace_t3 - trace_t2) / 1000.0, + "entry_to_launch_us": (trace_t2 - trace_t0) / 1000.0, + "replay_device_us": replay_device_us, + "sync_trace": trace_sync_enabled, + } + ) + if phase_trace_sync_enabled: + try: + from sglang.srt.models.dsv4_phase_trace import ( + write_dsv4_phase_trace, + ) + + write_dsv4_phase_trace( + { + "forward_mode": forward_batch.forward_mode.name, + "capture_forward_mode": self.capture_forward_mode.name, + "capture_hidden_mode": self.capture_hidden_mode.name, + "skip_attn_backend_init": skip_attn_backend_init, + "raw_bs": int(self.raw_bs), + "bs": int(self.bs), + "raw_num_token": int(self.raw_num_token), + "num_tokens_per_bs": int(self.num_tokens_per_bs), + "graph_key": str(graph_key), + "stream_idx": stream_idx, + "variant_label": variant_label, + "replay_device_us": replay_device_us, + }, + role_filter={"target"}, + ) + except Exception: # noqa: BLE001 + logger.exception("Failed to write DSv4 CUDA graph phase trace") output = self.output_buffers[graph_key] diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py index 4461a0a27..9f10c1bd5 100644 --- a/python/sglang/srt/models/deepseek_v4.py +++ b/python/sglang/srt/models/deepseek_v4.py @@ -1,7 +1,10 @@ from __future__ import annotations import concurrent.futures +import hashlib import logging +import os +from contextlib import nullcontext from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Set, Tuple import torch @@ -16,6 +19,7 @@ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size from sglang.srt.environ import envs from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.attention.dsv4.compressor import Compressor from sglang.srt.layers.attention.dsv4.indexer import C4Indexer from sglang.srt.layers.attention.nsa.utils import ( @@ -63,10 +67,17 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.dbrx import ReplicatedLinear from sglang.srt.models.deepseek_v2 import ParallelLMHead, _is_cuda, _is_hip, _is_npu +from sglang.srt.models.dsv4_phase_trace import ( + dsv4_phase_trace_context, + is_dsv4_phase_trace_enabled, + record_dsv4_phase_end, + record_dsv4_phase_start, +) from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( LazyValue, add_prefix, + get_bool_env_var, log_info_on_rank0, make_layers, ) @@ -75,6 +86,143 @@ logger = logging.getLogger(__name__) _FP8_WO_A_GEMM = envs.SGLANG_OPT_FP8_WO_A_GEMM.get() +_PARALLEL_WO_A = get_bool_env_var("SGLANG_OPT_DSV4_PARALLEL_WO_A") +_DIRECT_WQ_B = get_bool_env_var("SGLANG_OPT_DSV4_DIRECT_WQ_B") +_DSV4_PHASE_TRACE = is_dsv4_phase_trace_enabled() +_DSV4_ROW_TRACE = get_bool_env_var("SGLANG_DSV4_ROW_TRACE") +_DSV4_ROW_TRACE_DECODE_ONLY = get_bool_env_var("SGLANG_DSV4_ROW_TRACE_DECODE_ONLY", "true") +_DSV4_ROW_TRACE_CLUSTERS = get_bool_env_var("SGLANG_DSV4_ROW_TRACE_CLUSTERS") + + +def _int_env(name: str, default: int) -> int: + try: + return int(os.environ.get(name, str(default))) + except ValueError: + return default + + +_DSV4_ROW_TRACE_MAX_RECORDS = max(0, _int_env("SGLANG_DSV4_ROW_TRACE_MAX_RECORDS", 256)) +_DSV4_ROW_TRACE_MIN_ROWS = max(1, _int_env("SGLANG_DSV4_ROW_TRACE_MIN_ROWS", 2)) +_DSV4_ROW_TRACE_SAMPLE = max(1, _int_env("SGLANG_DSV4_ROW_TRACE_SAMPLE", 16)) +_DSV4_ROW_TRACE_CLUSTER_BYTES = max( + 0, _int_env("SGLANG_DSV4_ROW_TRACE_CLUSTER_BYTES", 0) +) +_DSV4_ROW_TRACE_REMAINING = _DSV4_ROW_TRACE_MAX_RECORDS + + +def _trace_tensor_sample(tensor: Optional[torch.Tensor]) -> Optional[list]: + if tensor is None: + return None + return tensor.detach().flatten()[:_DSV4_ROW_TRACE_SAMPLE].cpu().tolist() + + +def _trace_list_sample(value: Optional[list]) -> Optional[list]: + if value is None: + return None + return list(value[:_DSV4_ROW_TRACE_SAMPLE]) + + +def _trace_row_clusters(tensor: torch.Tensor) -> Optional[dict]: + if not _DSV4_ROW_TRACE_CLUSTERS: + return None + rows = min(int(tensor.shape[0]), _DSV4_ROW_TRACE_SAMPLE) + if rows <= 0: + return None + try: + byte_rows = ( + tensor.detach()[:rows] + .contiguous() + .view(torch.uint8) + .reshape(rows, -1) + .cpu() + ) + byte_columns = int(byte_rows.shape[1]) + hash_bytes = _DSV4_ROW_TRACE_CLUSTER_BYTES or byte_columns + hash_bytes = min(hash_bytes, byte_columns) + groups = {} + row_hashes = [] + for row in range(rows): + row_bytes = byte_rows[row, :hash_bytes].numpy().tobytes() + digest = hashlib.blake2s(row_bytes, digest_size=6).hexdigest() + row_hashes.append(digest) + groups.setdefault(digest, []).append(row) + clusters = [ + {"hash": digest, "rows": members} + for digest, members in sorted(groups.items(), key=lambda item: item[1][0]) + ] + return { + "rows": rows, + "byte_columns": byte_columns, + "hash_bytes": hash_bytes, + "row_hashes": row_hashes, + "clusters": clusters, + "singleton_rows": [ + members[0] for members in groups.values() if len(members) == 1 + ], + } + except Exception as exc: + return {"error": repr(exc)} + + +def _trace_row_delta( + tag: str, + layer_id: int, + tensor: torch.Tensor, + forward_batch: "ForwardBatch", +) -> None: + global _DSV4_ROW_TRACE_REMAINING + if ( + not (_DSV4_ROW_TRACE or _DSV4_ROW_TRACE_CLUSTERS) + or _DSV4_ROW_TRACE_REMAINING <= 0 + or tensor is None + or tensor.shape[0] < _DSV4_ROW_TRACE_MIN_ROWS + ): + return + try: + mode = getattr(forward_batch, "forward_mode", None) + is_decode = mode is not None and hasattr(mode, "is_decode") and mode.is_decode() + if _DSV4_ROW_TRACE_DECODE_ONLY and not is_decode: + return + if get_attention_tp_rank() != 0: + return + flat = tensor.detach().reshape(tensor.shape[0], -1).float() + row_delta = (flat - flat[0:1]).abs().amax(dim=1) + max_abs = float(row_delta.max().item()) + clusters = _trace_row_clusters(tensor) + if max_abs == 0.0 and clusters is None: + return + nonzero_rows = int((row_delta > 0).sum().item()) + nonzero_indices = ( + torch.nonzero(row_delta > 0, as_tuple=False) + .flatten()[:_DSV4_ROW_TRACE_SAMPLE] + .cpu() + .tolist() + ) + logger.warning( + "DSV4_ROW_TRACE layer=%s tag=%s mode=%s shape=%s max_abs=%g " + "nonzero_rows=%d nonzero_indices=%s firstN=%s positions=%s " + "seq_lens=%s out_cache_loc=%s input_ids=%s req_pool_indices=%s " + "rids=%s clusters=%s", + layer_id, + tag, + mode, + tuple(tensor.shape), + max_abs, + nonzero_rows, + nonzero_indices, + [float(x) for x in row_delta[:_DSV4_ROW_TRACE_SAMPLE].tolist()], + _trace_tensor_sample(getattr(forward_batch, "positions", None)), + _trace_tensor_sample(getattr(forward_batch, "seq_lens", None)), + _trace_tensor_sample(getattr(forward_batch, "out_cache_loc", None)), + _trace_tensor_sample(getattr(forward_batch, "input_ids", None)), + _trace_tensor_sample(getattr(forward_batch, "req_pool_indices", None)), + _trace_list_sample(getattr(forward_batch, "rids", None)), + clusters, + ) + _DSV4_ROW_TRACE_REMAINING -= 1 + except Exception as exc: + logger.warning("DSV4_ROW_TRACE failed at layer=%s tag=%s: %s", layer_id, tag, exc) + _DSV4_ROW_TRACE_REMAINING = 0 if TYPE_CHECKING: @@ -139,6 +287,47 @@ def rms_normalize_triton( return x +@triton.jit +def _rms_normalize_local_heads_kernel( + x_ptr, + eps, + token_stride: tl.constexpr, + local_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + token = pid // local_heads + head = pid - token * local_heads + offs = tl.arange(0, BLOCK_SIZE) + mask = offs < head_dim + base = token * token_stride + head * head_dim + x = tl.load(x_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) + mean_sq = tl.sum(x * x, axis=0) / head_dim + rms_inv = tl.rsqrt(mean_sq + eps) + tl.store(x_ptr + base + offs, x * rms_inv, mask=mask) + + +def rms_normalize_local_heads_triton( + x: torch.Tensor, + eps: float, + *, + local_heads: int, + head_dim: int, +) -> torch.Tensor: + rows = x.shape[0] * local_heads + _rms_normalize_local_heads_kernel[(rows,)]( + x, + eps, + x.stride(0), + local_heads, + head_dim, + BLOCK_SIZE=triton.next_power_of_2(head_dim), + num_warps=8, + ) + return x + + class MQALayer(nn.Module): def __init__( self, @@ -317,6 +506,18 @@ def __init__( self.overlap_store_cache = envs.SGLANG_OPT_USE_OVERLAP_STORE_CACHE.get() self.use_jit_norm = envs.SGLANG_OPT_USE_JIT_NORM.get() + self.parallel_wo_a = ( + _PARALLEL_WO_A + and not _FP8_WO_A_GEMM + and self.n_local_groups == 2 + and self.o_lora_rank == 1024 + ) + self.direct_wq_b = ( + _DIRECT_WQ_B + and self.tp_size > 1 + and self.n_local_heads * self.head_dim == self.wq_b.weight.shape[0] + and self.wq_b.weight.dtype == torch.bfloat16 + ) def _compute_q_a( self, @@ -335,7 +536,36 @@ def _compute_q_b( self, q: torch.Tensor, positions: Optional[torch.Tensor] = None, + q_out: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if self.direct_wq_b and q_out is not None: + q_flat = torch.as_strided( + q_out, + (q.shape[0], self.n_local_heads * self.head_dim), + (q_out.stride(0), 1), + ) + torch.mm(q, self.wq_b.weight.t(), out=q_flat) + q = q_out + rms_normalize_local_heads_triton( + q, + self.eps, + local_heads=self.n_local_heads, + head_dim=self.head_dim, + ) + if positions is not None: + fused_rope( + q[..., -self.qk_rope_head_dim :], + None, + self.freqs_cis, + positions=positions, + ) + else: + apply_rotary_emb_triton( + q[..., -self.qk_rope_head_dim :], + self.freqs_cis, + ) + return q + q, _ = self.wq_b(q) q = q.view(-1, self.n_local_heads, self.head_dim) if self.use_jit_norm: @@ -375,6 +605,26 @@ def _compute_kv( apply_rotary_emb_triton(kv[..., -self.qk_rope_head_dim :], self.freqs_cis) return kv + def _compute_wo_a_bf16_flat(self, o: torch.Tensor) -> torch.Tensor: + G = self.n_local_groups + R = self.o_lora_rank + wo_a = self.wo_a.weight.view(G, R, -1) + out = torch.empty(o.shape[0], G * R, dtype=o.dtype, device=o.device) + if G == 2 and self.alt_streams is not None and len(self.alt_streams) >= 1: + current_stream = torch.cuda.current_stream() + aux_stream = self.alt_streams[0] + aux_stream.wait_stream(current_stream) + torch.mm(o[:, 0, :], wo_a[0].t(), out=out[:, :R]) + with torch.cuda.stream(aux_stream): + torch.mm(o[:, 1, :], wo_a[1].t(), out=out[:, R : 2 * R]) + current_stream.wait_stream(aux_stream) + else: + for group in range(G): + start = group * R + stop = start + R + torch.mm(o[:, group, :], wo_a[group].t(), out=out[:, start:stop]) + return out + def _forward_prepare_multi_stream( self, x: torch.Tensor, @@ -433,8 +683,8 @@ def _forward_prepare_multi_stream( x, forward_batch, self.layer_id, self.compressor ) - q = self._compute_q_b(q_lora, positions) - if q_out is not None: + q = self._compute_q_b(q_lora, positions, q_out=q_out) + if q_out is not None and q is not q_out: q_out.copy_(q) current_stream.wait_stream(stream_kv) @@ -461,14 +711,18 @@ def _forward_prepare( q, _ = self.wq_a(x) q = self.q_norm(q) q_lora = q + _trace_row_delta("mqa_q_a_norm", self.layer_id, q, forward_batch) q, _ = self.wq_b(q) q = q.view(-1, self.n_local_heads, self.head_dim) + _trace_row_delta("mqa_q_b", self.layer_id, q, forward_batch) if self.use_jit_norm: q = rmsnorm_self(q, self.eps) else: q = rms_normalize_triton(q, self.eps) + _trace_row_delta("mqa_q_b_norm", self.layer_id, q, forward_batch) kv = self.kv_norm(kv) + _trace_row_delta("mqa_kv_norm", self.layer_id, kv, forward_batch) fused_rope( q[..., -self.qk_rope_head_dim :], @@ -476,6 +730,8 @@ def _forward_prepare( self.freqs_cis, positions=positions, ) + _trace_row_delta("mqa_q_rope", self.layer_id, q, forward_batch) + _trace_row_delta("mqa_kv_rope", self.layer_id, kv, forward_batch) if self.nsa_enable_prefill_cp and nsa_use_prefill_cp(forward_batch): kv = cp_all_gather_rerange_output( @@ -545,6 +801,8 @@ def forward( q, kv = self._forward_prepare( x, positions, forward_batch, attn_backend, q_out ) + _trace_row_delta("mqa_q_pre_backend", self.layer_id, q, forward_batch) + _trace_row_delta("mqa_kv_pre_backend", self.layer_id, kv, forward_batch) o = attn_backend.forward( q=q_padded if q_padded is not None else q, @@ -557,6 +815,7 @@ def forward( save_kv_cache=not self.overlap_store_cache, ) o = o[:, tp_slice, :] + _trace_row_delta("mqa_o_backend", self.layer_id, o, forward_batch) fused_rope( o[..., -self.qk_rope_head_dim :], None, @@ -564,6 +823,7 @@ def forward( positions=positions, inverse=True, ) + _trace_row_delta("mqa_o_inverse_rope", self.layer_id, o, forward_batch) o = o.view(o.shape[0], self.n_local_groups, -1) @@ -586,10 +846,15 @@ def forward( ) o = output else: - wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) - o = torch.einsum("tgd,grd->tgr", o, wo_a) + if self.parallel_wo_a: + o = self._compute_wo_a_bf16_flat(o) + else: + wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) + o = torch.einsum("tgd,grd->tgr", o, wo_a).flatten(1) + _trace_row_delta("mqa_wo_a", self.layer_id, o, forward_batch) - o, _ = self.wo_b(o.flatten(1)) + o, _ = self.wo_b(o.flatten(1) if o.ndim == 3 else o) + _trace_row_delta("mqa_wo_b", self.layer_id, o, forward_batch) return o @@ -646,6 +911,24 @@ def __init__( self.hc_ffn_scale = nn.Parameter(torch.empty(3, dtype=torch.float32)) self.rms_norm_eps = config.rms_norm_eps self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + self.is_nextn = bool(is_nextn) + + def _trace_role(self) -> str: + return "draft" if self.is_nextn else "target" + + def _trace_phase_start(self, phase: str): + if not _DSV4_PHASE_TRACE: + return None + return record_dsv4_phase_start(self._trace_role(), self.layer_id, phase) + + def _trace_phase_end(self, handle) -> None: + if _DSV4_PHASE_TRACE: + record_dsv4_phase_end(handle) + + def _trace_context(self, phase: str): + if not _DSV4_PHASE_TRACE: + return nullcontext() + return dsv4_phase_trace_context(self._trace_role(), self.layer_id, phase) def hc_pre( self, @@ -735,6 +1018,8 @@ def hc_post( residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor, + *, + allreduce_before_post: bool = False, ): if x.shape[0] == 0: @@ -742,6 +1027,29 @@ def hc_post( (0, self.hc_mult, x.shape[-1]), dtype=x.dtype, device=x.device ) + if allreduce_before_post: + if get_bool_env_var("B12X_USE_PCIE_AR_MHC_POST"): + try: + from b12x.integration.dsv4_allreduce_patch import ( + b12x_all_reduce_mhc_post, + ) + + out = b12x_all_reduce_mhc_post(x, residual, post, comb) + if out is not None: + return out + except Exception as exc: + logger.warning( + "B12X fused allreduce+MHC-post failed; falling back to " + "unfused path: %s", + exc, + ) + + from sglang.srt.distributed.communication_op import ( + tensor_model_parallel_all_reduce, + ) + + x = tensor_model_parallel_all_reduce(x) + if envs.SGLANG_OPT_USE_TILELANG_MHC_POST.get(): from sglang.srt.layers.mhc import mhc_post @@ -768,7 +1076,9 @@ def forward( forward_batch: ForwardBatch, input_ids_global: torch.Tensor, ) -> torch.Tensor: + _trace_row_delta("layer_input", self.layer_id, hidden_states, forward_batch) residual = hidden_states + trace = self._trace_phase_start("attn_hc_pre") hidden_states, post, comb, norm_fused = self.hc_pre( hidden_states, self.hc_attn_fn, @@ -776,17 +1086,29 @@ def forward( self.hc_attn_base, norm=self.input_layernorm, ) + self._trace_phase_end(trace) if not norm_fused: + trace = self._trace_phase_start("attn_norm") hidden_states = self.input_layernorm(hidden_states) + self._trace_phase_end(trace) + _trace_row_delta("attention_input", self.layer_id, hidden_states, forward_batch) - hidden_states = self.self_attn( - x=hidden_states, - positions=positions, - forward_batch=forward_batch, - ) + trace = self._trace_phase_start("attention") + with self._trace_context("attention"): + hidden_states = self.self_attn( + x=hidden_states, + positions=positions, + forward_batch=forward_batch, + ) + self._trace_phase_end(trace) + _trace_row_delta("attention_out", self.layer_id, hidden_states, forward_batch) + trace = self._trace_phase_start("attn_hc_post") hidden_states = self.hc_post(hidden_states, residual, post, comb) + self._trace_phase_end(trace) + _trace_row_delta("attention_hc_post", self.layer_id, hidden_states, forward_batch) residual = hidden_states + trace = self._trace_phase_start("ffn_hc_pre") hidden_states, post, comb, norm_fused = self.hc_pre( hidden_states, self.hc_ffn_fn, @@ -794,8 +1116,12 @@ def forward( self.hc_ffn_base, norm=self.post_attention_layernorm, ) + self._trace_phase_end(trace) if not norm_fused: + trace = self._trace_phase_start("ffn_norm") hidden_states = self.post_attention_layernorm(hidden_states) + self._trace_phase_end(trace) + _trace_row_delta("moe_input", self.layer_id, hidden_states, forward_batch) _use_cp = self.nsa_enable_prefill_cp and nsa_use_prefill_cp(forward_batch) _use_tp_moe_gather = ( @@ -828,12 +1154,19 @@ def forward( hidden_states = _a2a_scatter_chunks[r].contiguous() input_ids = input_ids.tensor_split(s)[r].contiguous() input_ids_global = input_ids_global.tensor_split(s)[r].contiguous() - hidden_states = self.mlp( - hidden_states, - forward_batch, - input_ids=input_ids, - input_ids_global=input_ids_global, - ) + fuse_ar_mhc_post = get_bool_env_var("B12X_USE_PCIE_AR_MHC_POST") + trace = self._trace_phase_start("moe") + with get_global_expert_distribution_recorder().with_current_layer(self.layer_id): + with self._trace_context("moe"): + hidden_states = self.mlp( + hidden_states, + forward_batch, + should_allreduce_fusion=fuse_ar_mhc_post, + input_ids=input_ids, + input_ids_global=input_ids_global, + ) + self._trace_phase_end(trace) + _trace_row_delta("moe_out", self.layer_id, hidden_states, forward_batch) if _use_tp_moe_gather: hidden_states, global_hidden_states = get_local_dp_buffer(), hidden_states dp_scatter(hidden_states, global_hidden_states, forward_batch) @@ -843,7 +1176,17 @@ def forward( attn_tp_all_gather(gathered, hidden_states.contiguous()) hidden_states = torch.cat(gathered) - hidden_states = self.hc_post(hidden_states, residual, post, comb) + trace = self._trace_phase_start("ffn_hc_post") + with self._trace_context("ffn_hc_post"): + hidden_states = self.hc_post( + hidden_states, + residual, + post, + comb, + allreduce_before_post=fuse_ar_mhc_post, + ) + self._trace_phase_end(trace) + _trace_row_delta("layer_out", self.layer_id, hidden_states, forward_batch) return hidden_states @@ -932,6 +1275,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) hidden_states = hidden_states.unsqueeze(1).repeat(1, self.hc_mult, 1) + _trace_row_delta("model_embed", -1, hidden_states, forward_batch) if get_attention_dp_size() > 1 and get_moe_a2a_backend().is_none(): input_ids_global = torch.empty( @@ -967,11 +1311,13 @@ def forward( ) pre_hc_head = hidden_states.flatten(1) + _trace_row_delta("model_pre_hc_head", -1, pre_hc_head, forward_batch) hidden_states = self.hc_head( hidden_states, self.hc_head_fn, self.hc_head_scale, self.hc_head_base ) hidden_states = self.norm(hidden_states) + _trace_row_delta("model_norm", -1, hidden_states, forward_batch) return hidden_states, pre_hc_head diff --git a/python/sglang/srt/models/deepseek_v4_nextn.py b/python/sglang/srt/models/deepseek_v4_nextn.py index 9b220b184..47dea5ec7 100644 --- a/python/sglang/srt/models/deepseek_v4_nextn.py +++ b/python/sglang/srt/models/deepseek_v4_nextn.py @@ -32,6 +32,18 @@ COMPRESS_RATIO_NEXTN_LAYER = 0 +def _split_nextn_quant_config( + quant_config: Optional[QuantizationConfig], +) -> Tuple[Optional[QuantizationConfig], Optional[QuantizationConfig]]: + if quant_config is not None and quant_config.get_name() == "modelopt_fp4": + logger.warning( + "DeepseekV4 NextN keeps routed experts on modelopt_fp4 while " + "forcing dense draft modules to bf16." + ) + return None, quant_config + return quant_config, None + + class DeepseekV4ModelNextN(nn.Module): def __init__( self, @@ -41,6 +53,10 @@ def __init__( ) -> None: super().__init__() self.config = config + dense_quant_config, moe_quant_config_override = _split_nextn_quant_config( + quant_config + ) + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -67,14 +83,14 @@ def __init__( config.hidden_size, config.hidden_size, bias=False, - quant_config=quant_config, + quant_config=dense_quant_config, prefix=add_prefix("e_proj", prefix), ) self.h_proj = ReplicatedLinear( config.hidden_size, config.hidden_size, bias=False, - quant_config=quant_config, + quant_config=dense_quant_config, prefix=add_prefix("h_proj", prefix), ) @@ -83,7 +99,8 @@ def __init__( self.decoder = DeepseekV4DecoderLayer( config, layer_id=0, - quant_config=quant_config, + quant_config=dense_quant_config, + moe_quant_config_override=moe_quant_config_override, is_nextn=True, prefix=add_prefix(layer_name, prefix), alt_streams=None, @@ -177,6 +194,7 @@ def __init__( self.pp_group = get_pp_group() self.quant_config = quant_config self.determine_num_fused_shared_experts() + dense_quant_config, _ = _split_nextn_quant_config(quant_config) self.model = DeepseekV4ModelNextN( config, quant_config, prefix=add_prefix("model", prefix) @@ -184,7 +202,7 @@ def __init__( self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, - quant_config=quant_config, + quant_config=dense_quant_config, prefix=add_prefix("model.shared_head.head", prefix), use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) diff --git a/python/sglang/srt/models/dsv4_phase_trace.py b/python/sglang/srt/models/dsv4_phase_trace.py new file mode 100644 index 000000000..e8d1be0ad --- /dev/null +++ b/python/sglang/srt/models/dsv4_phase_trace.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +import contextlib +import contextvars +import json +import logging +import os +import time +from typing import Any, Dict, Iterator, Optional, Set, Tuple + +import torch + +logger = logging.getLogger(__name__) + +_PHASE_TRACE_PATH = os.environ.get("SGLANG_DSV4_PHASE_TRACE_PATH", "").strip() +_PHASE_TRACE_COUNT = 0 +_PHASE_EVENTS: Dict[Tuple[str, int, str], Tuple[torch.cuda.Event, torch.cuda.Event]] = {} +_PHASE_CONTEXT: contextvars.ContextVar[Optional[Tuple[str, int, str]]] = ( + contextvars.ContextVar("dsv4_phase_trace_context", default=None) +) + + +def is_dsv4_phase_trace_enabled() -> bool: + return bool(_PHASE_TRACE_PATH) + + +def _phase_trace_max_records() -> int: + return int(os.environ.get("SGLANG_DSV4_PHASE_TRACE_MAX_RECORDS", "0") or "0") + + +def _phase_trace_capture_only() -> bool: + return bool(int(os.environ.get("SGLANG_DSV4_PHASE_TRACE_CAPTURE_ONLY", "0") or "0")) + + +def _phase_trace_filter_stale() -> bool: + return bool(int(os.environ.get("SGLANG_DSV4_PHASE_TRACE_FILTER_STALE", "1") or "1")) + + +def _phase_trace_stale_threshold_us(record: Dict[str, Any]) -> Optional[float]: + if not _phase_trace_filter_stale(): + return None + replay_device_us = record.get("replay_device_us") + if replay_device_us is None: + return None + try: + replay_us = float(replay_device_us) + except (TypeError, ValueError): + return None + if replay_us <= 0: + return None + try: + ratio = float(os.environ.get("SGLANG_DSV4_PHASE_TRACE_STALE_RATIO", "1.10")) + except ValueError: + ratio = 1.10 + try: + slack_us = float(os.environ.get("SGLANG_DSV4_PHASE_TRACE_STALE_SLACK_US", "50.0")) + except ValueError: + slack_us = 50.0 + return replay_us * max(ratio, 1.0) + max(slack_us, 0.0) + + +def _phase_trace_path(rank: int) -> str: + if "{rank}" in _PHASE_TRACE_PATH or "{pid}" in _PHASE_TRACE_PATH: + return _PHASE_TRACE_PATH.format(rank=rank, pid=os.getpid()) + return f"{_PHASE_TRACE_PATH}.rank{rank}.pid{os.getpid()}.jsonl" + + +def _rank() -> int: + try: + from sglang.srt.distributed import get_tensor_model_parallel_rank + + return get_tensor_model_parallel_rank() + except Exception: # noqa: BLE001 + return int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", "0")) or "0") + + +def _get_or_create_events( + role: str, layer_id: int, phase: str +) -> Optional[Tuple[torch.cuda.Event, torch.cuda.Event]]: + key = (role, int(layer_id), phase) + events = _PHASE_EVENTS.get(key) + if events is not None: + return events + + if torch.cuda.is_current_stream_capturing(): + return None + + events = ( + torch.cuda.Event(enable_timing=True, external=True), + torch.cuda.Event(enable_timing=True, external=True), + ) + _PHASE_EVENTS[key] = events + return events + + +def record_dsv4_phase_start( + role: str, layer_id: int, phase: str +) -> Optional[Tuple[str, int, str]]: + if not _PHASE_TRACE_PATH: + return None + events = _get_or_create_events(role, layer_id, phase) + if events is None: + return None + if _phase_trace_capture_only() and not torch.cuda.is_current_stream_capturing(): + return None + events[0].record() + return role, int(layer_id), phase + + +def record_dsv4_phase_end(handle: Optional[Tuple[str, int, str]]) -> None: + if handle is None: + return + events = _PHASE_EVENTS.get(handle) + if events is None: + return + events[1].record() + + +@contextlib.contextmanager +def dsv4_phase_trace_context( + role: str, layer_id: int, phase: str +) -> Iterator[None]: + if not _PHASE_TRACE_PATH: + yield + return + token = _PHASE_CONTEXT.set((role, int(layer_id), phase)) + try: + yield + finally: + _PHASE_CONTEXT.reset(token) + + +def record_current_dsv4_allreduce_start( + kind: str, input_: torch.Tensor +) -> Optional[Tuple[str, int, str]]: + if not _PHASE_TRACE_PATH: + return None + context = _PHASE_CONTEXT.get() + if context is None: + return None + role, layer_id, phase = context + shape = "x".join(str(int(dim)) for dim in input_.shape) + dtype = str(input_.dtype).replace("torch.", "") + label = f"{phase}.{kind}.{shape}.{dtype}" + return record_dsv4_phase_start(role, layer_id, label) + + +def write_dsv4_phase_trace( + record: Dict[str, Any], role_filter: Optional[Set[str]] = None +) -> None: + if not _PHASE_TRACE_PATH: + return + + global _PHASE_TRACE_COUNT + max_records = _phase_trace_max_records() + if max_records > 0 and _PHASE_TRACE_COUNT >= max_records: + return + _PHASE_TRACE_COUNT += 1 + + phases = [] + dropped_stale_phases = 0 + stale_threshold_us = _phase_trace_stale_threshold_us(record) + for (role, layer_id, phase), (start, end) in sorted(_PHASE_EVENTS.items()): + if role_filter is not None and role not in role_filter: + continue + try: + elapsed_us = float(start.elapsed_time(end) * 1000.0) + except Exception as exc: # noqa: BLE001 + elapsed_us = None + error = repr(exc) + else: + error = None + if ( + stale_threshold_us is not None + and elapsed_us is not None + and elapsed_us > stale_threshold_us + ): + dropped_stale_phases += 1 + continue + phase_record = { + "role": role, + "layer": int(layer_id), + "phase": phase, + "elapsed_us": elapsed_us, + } + if error is not None: + phase_record["error"] = error + phases.append(phase_record) + + rank = _rank() + payload = { + "event": "dsv4_phase_trace", + "pid": os.getpid(), + "rank": rank, + "record_index": _PHASE_TRACE_COUNT - 1, + "time": time.time(), + **record, + "phases": phases, + "dropped_stale_phases": dropped_stale_phases, + "stale_phase_threshold_us": stale_threshold_us, + } + path = _phase_trace_path(rank) + try: + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + with open(path, "a", encoding="utf-8") as handle: + handle.write(json.dumps(payload, sort_keys=True) + "\n") + except Exception: # noqa: BLE001 + logger.exception("Failed to write DSv4 phase trace record") diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 804d421b1..bb54f815b 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -2,6 +2,8 @@ import bisect import contextlib +import logging +import os from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Optional @@ -37,6 +39,27 @@ if TYPE_CHECKING: from sglang.srt.speculative.eagle_worker import EAGLEWorker +logger = logging.getLogger(__name__) +_DSV4_DRAFT_GRAPH_PADDING_LOGS = 0 + + +def _dsv4_graph_debug_enabled() -> bool: + return os.environ.get("SGLANG_DSV4_GRAPH_DEBUG", "0") not in ( + "", + "0", + "false", + "False", + ) or _dsv4_graph_debug_sync_enabled() + + +def _dsv4_graph_debug_sync_enabled() -> bool: + return os.environ.get("SGLANG_DSV4_GRAPH_DEBUG_SYNC", "0") not in ( + "", + "0", + "false", + "False", + ) + @dataclass class EagleDraftInputBuffers(ForwardInputBuffers): @@ -225,6 +248,28 @@ def _capture_graph(self, graph, pool, stream, run_once_fn): return out def _replay(self, forward_batch: ForwardBatch): + debug_enabled = _dsv4_graph_debug_enabled() + debug_sync_enabled = _dsv4_graph_debug_sync_enabled() + if debug_enabled: + logger.warning( + "[DSV4 graph debug] eagle draft replay pre raw_bs=%s bs=%s " + "num_tokens_per_bs=%s out_cache_loc=%s positions=%s", + getattr(self, "raw_bs", None), + getattr(self, "bs", None), + self.num_tokens_per_bs, + tuple(forward_batch.out_cache_loc.shape) + if forward_batch.out_cache_loc is not None + else None, + tuple(forward_batch.positions.shape) + if forward_batch.positions is not None + else None, + ) + if debug_sync_enabled: + torch.cuda.synchronize() + logger.warning( + "[DSV4 graph debug] eagle draft replay pre-sync done bs=%s", + getattr(self, "bs", None), + ) ctx = ( self.model_runner.device_timer.wrap(metadata={"category": "eagle_draft"}) if self.model_runner.device_timer @@ -232,6 +277,12 @@ def _replay(self, forward_batch: ForwardBatch): ) with ctx: self.graphs[self.bs].replay() + if debug_sync_enabled: + torch.cuda.synchronize() + logger.warning( + "[DSV4 graph debug] eagle draft replay post-sync done bs=%s", + getattr(self, "bs", None), + ) def capture(self): CudaGraphRunner.capture(self) @@ -434,11 +485,16 @@ def replay(self, forward_batch: ForwardBatch): buffers.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) # Attention backend + original_out_cache_loc = forward_batch.out_cache_loc + original_seq_lens_sum = forward_batch.seq_lens_sum if bs != raw_bs: forward_batch.batch_size = bs forward_batch.seq_lens = buffers.seq_lens[:bs] forward_batch.req_pool_indices = buffers.req_pool_indices[:bs] forward_batch.positions = buffers.positions[:num_tokens] + forward_batch.out_cache_loc = buffers.out_cache_loc[ + : num_tokens * self.speculative_num_steps + ] if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: @@ -446,12 +502,29 @@ def replay(self, forward_batch: ForwardBatch): buffers.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) forward_batch.seq_lens_cpu = buffers.seq_lens_cpu[:bs] + global _DSV4_DRAFT_GRAPH_PADDING_LOGS + if bs != raw_bs and _DSV4_DRAFT_GRAPH_PADDING_LOGS < 8: + _DSV4_DRAFT_GRAPH_PADDING_LOGS += 1 + logger.info( + "EAGLE draft cuda graph replay padding: raw_bs=%s graph_bs=%s " + "raw_out_cache_loc=%s padded_out_cache_loc=%s steps=%s topk=%s", + raw_bs, + bs, + raw_num_token * self.speculative_num_steps, + num_tokens * self.speculative_num_steps, + self.speculative_num_steps, + self.topk, + ) + + padded_seq_lens_sum = ( + original_seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value + ) + forward_batch.seq_lens_sum = padded_seq_lens_sum self.draft_attn_backend.init_forward_metadata_replay_cuda_graph( forward_batch, bs ) self.raw_bs = raw_bs self.bs = bs - # TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph # Replay self._replay(forward_batch) @@ -463,7 +536,9 @@ def replay(self, forward_batch: ForwardBatch): forward_batch.positions = buffers.positions[:raw_num_token] forward_batch.seq_lens = buffers.seq_lens[:raw_bs] forward_batch.req_pool_indices = buffers.req_pool_indices[:raw_bs] + forward_batch.out_cache_loc = original_out_cache_loc if forward_batch.seq_lens_cpu is not None: forward_batch.seq_lens_cpu = buffers.seq_lens_cpu[:raw_bs] + forward_batch.seq_lens_sum = original_seq_lens_sum return out diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index e8f6f5bfa..26a729bbf 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -2,6 +2,9 @@ import bisect import contextlib +import logging +import os +import time from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Optional @@ -19,6 +22,7 @@ set_global_graph_memory_pool, set_is_extend_in_batch, set_torch_compile_config, + _write_replay_trace, ) from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, @@ -39,6 +43,27 @@ from sglang.srt.speculative.eagle_worker import EAGLEWorker +logger = logging.getLogger(__name__) + + +def _dsv4_graph_debug_enabled() -> bool: + return os.environ.get("SGLANG_DSV4_GRAPH_DEBUG", "0") not in ( + "", + "0", + "false", + "False", + ) or _dsv4_graph_debug_sync_enabled() + + +def _dsv4_graph_debug_sync_enabled() -> bool: + return os.environ.get("SGLANG_DSV4_GRAPH_DEBUG_SYNC", "0") not in ( + "", + "0", + "false", + "False", + ) + + @dataclass class EagleDraftExtendInputBuffers(ForwardInputBuffers): input_ids: torch.Tensor @@ -285,6 +310,12 @@ def _capture_graph(self, graph, pool, stream, run_once_fn): return out def _replay(self, forward_batch: ForwardBatch): + trace_enabled = bool(os.environ.get("SGLANG_DSV4_REPLAY_TRACE_PATH", "")) + trace_sync_enabled = trace_enabled and bool( + int(os.environ.get("SGLANG_DSV4_REPLAY_TRACE_SYNC", "0") or "0") + ) + debug_enabled = _dsv4_graph_debug_enabled() + debug_sync_enabled = _dsv4_graph_debug_sync_enabled() ctx = ( self.model_runner.device_timer.wrap( metadata={"category": "eagle_draft_extend"} @@ -292,8 +323,86 @@ def _replay(self, forward_batch: ForwardBatch): if self.model_runner.device_timer else contextlib.nullcontext() ) + trace_t0 = time.perf_counter_ns() if trace_enabled else 0 + replay_start = replay_end = None + if trace_enabled: + _write_replay_trace( + { + "event": "eagle_draft_extend_cuda_graph_replay_pre", + "forward_mode": self.forward_mode.name, + "capture_forward_mode": self.forward_mode.name, + "capture_hidden_mode": "LAST", + "raw_bs": int(getattr(self, "raw_bs", forward_batch.batch_size)), + "bs": int(self.bs), + "raw_num_token": int(forward_batch.input_ids.shape[0]), + "num_tokens_per_bs": int(self.num_tokens_per_bs), + "graph_key": str(self.bs), + "stream_idx": None, + "variant_label": None, + "sync_debug": debug_sync_enabled, + } + ) + if debug_enabled: + logger.warning( + "[DSV4 graph debug] eagle draft extend replay pre mode=%s " + "raw_bs=%s bs=%s raw_num_token=%s graph_key=%s", + self.forward_mode.name, + getattr(self, "raw_bs", None), + getattr(self, "bs", None), + int(forward_batch.input_ids.shape[0]), + getattr(self, "bs", None), + ) + if debug_sync_enabled: + torch.cuda.synchronize() + logger.warning( + "[DSV4 graph debug] eagle draft extend replay pre-sync done bs=%s", + getattr(self, "bs", None), + ) + if trace_sync_enabled: + replay_start = torch.cuda.Event(enable_timing=True) + replay_end = torch.cuda.Event(enable_timing=True) + replay_start.record() with ctx: self.graphs[self.bs].replay() + if debug_sync_enabled: + torch.cuda.synchronize() + logger.warning( + "[DSV4 graph debug] eagle draft extend replay post-sync done bs=%s", + getattr(self, "bs", None), + ) + if trace_sync_enabled: + assert replay_start is not None and replay_end is not None + replay_end.record() + replay_end.synchronize() + replay_device_us = float(replay_start.elapsed_time(replay_end) * 1000.0) + else: + replay_device_us = None + trace_t1 = time.perf_counter_ns() if trace_enabled else 0 + if trace_enabled: + _write_replay_trace( + { + "event": "eagle_draft_extend_cuda_graph_replay", + "forward_mode": self.forward_mode.name, + "capture_forward_mode": self.forward_mode.name, + "capture_hidden_mode": "LAST", + "raw_bs": int(getattr(self, "raw_bs", forward_batch.batch_size)), + "bs": int(self.bs), + "raw_num_token": int(forward_batch.input_ids.shape[0]), + "num_tokens_per_bs": int(self.num_tokens_per_bs), + "graph_key": str(self.bs), + "stream_idx": None, + "variant_label": None, + "entry_ns": trace_t0, + "after_prepare_ns": trace_t0, + "before_launch_ns": trace_t0, + "after_launch_ns": trace_t1, + "prepare_us": 0.0, + "launch_us": (trace_t1 - trace_t0) / 1000.0, + "entry_to_launch_us": 0.0, + "replay_device_us": replay_device_us, + "sync_trace": trace_sync_enabled, + } + ) def capture(self): CudaGraphRunner.capture(self) @@ -537,17 +646,21 @@ def replay(self, forward_batch: ForwardBatch): :bs ] - self.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph( - bs=bs, - req_pool_indices=buffers.req_pool_indices, - seq_lens=buffers.seq_lens, - seq_lens_sum=forward_batch.seq_lens_sum - + (bs - raw_bs) * self.seq_len_fill_value, - encoder_lens=None, - forward_mode=self.forward_mode, - spec_info=forward_batch.spec_info, - seq_lens_cpu=buffers.seq_lens_cpu, - ) + self.draft_extend_attn_backend._replay_forward_batch = forward_batch + try: + self.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph( + bs=bs, + req_pool_indices=buffers.req_pool_indices, + seq_lens=buffers.seq_lens, + seq_lens_sum=forward_batch.seq_lens_sum + + (bs - raw_bs) * self.seq_len_fill_value, + encoder_lens=None, + forward_mode=self.forward_mode, + spec_info=forward_batch.spec_info, + seq_lens_cpu=buffers.seq_lens_cpu, + ) + finally: + self.draft_extend_attn_backend._replay_forward_batch = None # Replay self.raw_bs = raw_bs diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 2d2ac2612..0cd28b696 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -1,7 +1,12 @@ import contextlib +import copy +import dataclasses +import json import logging +import os import time -from typing import List, Optional, Tuple +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple import torch @@ -81,6 +86,173 @@ _is_hip = is_hip() logger = logging.getLogger(__name__) +_DSV4_EAGLE_TRACE_COUNT = 0 + + +def _dsv4_eagle_trace_max_records() -> int: + raw = os.environ.get("SGLANG_DSV4_EAGLE_TRACE_MAX_RECORDS", "0") or "0" + try: + return int(raw) + except ValueError: + logger.warning("Ignoring invalid SGLANG_DSV4_EAGLE_TRACE_MAX_RECORDS=%r", raw) + return 0 + + +def _dsv4_eagle_rank() -> str: + if os.environ.get("RANK") or os.environ.get("LOCAL_RANK"): + return os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "x" + try: + from sglang.srt.distributed import get_tensor_model_parallel_rank + + return str(get_tensor_model_parallel_rank()) + except Exception: # noqa: BLE001 + return "x" + + +def _dsv4_eagle_trace_path() -> Optional[Path]: + raw = os.environ.get("SGLANG_DSV4_EAGLE_TRACE_PATH", "").strip() + if not raw: + return None + rank = _dsv4_eagle_rank() + pid = os.getpid() + if "{rank}" in raw or "{pid}" in raw: + return Path(raw.format(rank=rank, pid=pid)) + path = Path(raw) + if path.suffix == ".jsonl": + return path.with_name(f"{path.stem}.rank{rank}.pid{pid}{path.suffix}") + return path / f"dsv4_eagle_trace.rank{rank}.pid{pid}.jsonl" + + +def _dsv4_eagle_batch_payload(batch: Optional[ModelWorkerBatch]) -> Dict[str, Any]: + if batch is None: + return {"has_batch": False} + seq_lens = getattr(batch, "seq_lens", None) + forward_mode = getattr(batch, "forward_mode", None) + spec_info = getattr(batch, "spec_info", None) + return { + "has_batch": True, + "forward_mode": str(getattr(forward_mode, "name", forward_mode)), + "batch_size": len(seq_lens) if seq_lens is not None else None, + "is_extend_in_batch": bool(getattr(batch, "is_extend_in_batch", False)), + "has_grammar": bool(getattr(batch, "has_grammar", False)), + "return_logprob": bool(getattr(batch, "return_logprob", False)), + "spec_info_type": type(spec_info).__name__ if spec_info is not None else None, + } + + +def _dsv4_eagle_generation_payload( + result: Optional[GenerationBatchResult], + draft_tokens_per_req: Optional[int], +) -> Dict[str, Any]: + if result is None: + return {} + + payload: Dict[str, Any] = { + "num_accepted_drafts": int(getattr(result, "num_accepted_drafts", 0) or 0), + } + accept_lens = getattr(result, "accept_lens", None) + if accept_lens is None: + per_req = getattr(result, "num_accepted_drafts_per_req_cpu", None) + if per_req is None: + return payload + accept_lens_list = [int(item) + 1 for item in per_req] + elif torch.is_tensor(accept_lens): + accept_lens_list = [int(item) for item in accept_lens.detach().cpu().tolist()] + else: + accept_lens_list = [int(item) for item in accept_lens] + + if not accept_lens_list: + return payload + + batch_size = len(accept_lens_list) + accepted_token_rows = sum(accept_lens_list) + accepted_draft_rows = sum(max(item - 1, 0) for item in accept_lens_list) + payload.update( + { + "accept_lens_sum": accepted_token_rows, + "accept_lens_min": min(accept_lens_list), + "accept_lens_max": max(accept_lens_list), + "accept_lens_mean": accepted_token_rows / batch_size, + "accept_lens_sample": accept_lens_list[:64], + "accepted_draft_rows": accepted_draft_rows, + } + ) + + if draft_tokens_per_req is not None: + planned_rows = batch_size * int(draft_tokens_per_req) + selected_state_rows = batch_size + payload.update( + { + "draft_extend_planned_rows": planned_rows, + "draft_extend_accepted_token_rows": accepted_token_rows, + "draft_extend_selected_state_rows": selected_state_rows, + "draft_extend_wasted_rows_vs_accept": max( + planned_rows - accepted_token_rows, 0 + ), + "draft_extend_wasted_rows_vs_selected": max( + planned_rows - selected_state_rows, 0 + ), + "draft_extend_useful_fraction_by_accept": ( + accepted_token_rows / planned_rows if planned_rows else None + ), + "draft_extend_selected_fraction": ( + selected_state_rows / planned_rows if planned_rows else None + ), + } + ) + return payload + + +def _trace_dsv4_eagle_event( + worker: "EagleWorker", + event: str, + *, + batch: Optional[ModelWorkerBatch] = None, + batch_result: Optional[GenerationBatchResult] = None, + elapsed_ms: Optional[float] = None, + **extra: Any, +) -> None: + path = _dsv4_eagle_trace_path() + if path is None: + return + + global _DSV4_EAGLE_TRACE_COUNT + max_records = _dsv4_eagle_trace_max_records() + if max_records > 0 and _DSV4_EAGLE_TRACE_COUNT >= max_records: + return + record_index = _DSV4_EAGLE_TRACE_COUNT + _DSV4_EAGLE_TRACE_COUNT += 1 + + payload = { + "event": event, + "pid": os.getpid(), + "rank": _dsv4_eagle_rank(), + "record_index": record_index, + "time": time.time(), + "monotonic_ns": time.monotonic_ns(), + "speculative_num_steps": getattr(worker, "speculative_num_steps", None), + "speculative_num_draft_tokens": getattr( + worker, "speculative_num_draft_tokens", None + ), + "topk": getattr(worker, "topk", None), + "plan_stream": bool(getattr(worker, "plan_stream", None)), + } + if elapsed_ms is not None: + payload["elapsed_ms"] = float(elapsed_ms) + payload.update(_dsv4_eagle_batch_payload(batch)) + payload.update( + _dsv4_eagle_generation_payload( + batch_result, payload.get("speculative_num_draft_tokens") + ) + ) + payload.update(extra) + + try: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload, sort_keys=True) + "\n") + except Exception: # noqa: BLE001 + logger.exception("Failed to write DSv4 EAGLE trace record") def _get_spec_debug_token_ids( @@ -107,6 +279,644 @@ def _get_spec_debug_token_ids( return draft_token_ids, predict_token_ids, accept_lens +def _env_flag(name: str) -> bool: + return (os.environ.get(name, "0") or "0").strip().lower() in { + "1", + "true", + "yes", + "y", + } + + +def _env_int(name: str, default: int) -> int: + raw = os.environ.get(name, str(default)) or str(default) + try: + return int(raw) + except ValueError: + logger.warning("Ignoring invalid %s=%r", name, raw) + return default + + +def _dsv4_eagle_quality_payload( + verify_input: EagleVerifyInput, + logits_output: Any, + predict: torch.Tensor, + accept_lens: torch.Tensor, + bs: int, +) -> Dict[str, Any]: + if not _env_flag("SGLANG_DSV4_EAGLE_QUALITY_TRACE"): + return {} + if bs <= 0 or verify_input.draft_token_num <= 1: + return {} + + max_reqs = max( + 0, min(bs, _env_int("SGLANG_DSV4_EAGLE_QUALITY_TRACE_MAX_REQS", 8)) + ) + if max_reqs == 0: + return {} + + try: + draft_num = int(verify_input.draft_token_num) + logits = logits_output.next_token_logits + if logits is None or logits.numel() == 0: + return {} + + candidates = verify_input.draft_token.reshape(bs, draft_num) + logits_view = logits.reshape(bs, draft_num, -1) + target_predict = torch.argmax(logits, dim=-1).to(torch.int64).reshape( + bs, draft_num + ) + + # In EAGLE v2 top-k=1, column 0 is the current/bonus token. Draft + # acceptance compares target predictions at columns 0..N-2 against + # draft candidates at columns 1..N-1. + shifted_matches = target_predict[:, :-1] == candidates[:, 1:].to(torch.int64) + shifted_matches_cpu = shifted_matches[:max_reqs].detach().cpu().tolist() + accept_lens_cpu = accept_lens[:max_reqs].detach().cpu().tolist() + + cand_tokens = candidates[:, 1:].to(torch.long) + cand_logits = torch.gather( + logits_view[:, :-1, :], + dim=-1, + index=cand_tokens.unsqueeze(-1), + ).squeeze(-1) + target_logits = torch.gather( + logits_view[:, :-1, :], + dim=-1, + index=target_predict[:, :-1].unsqueeze(-1), + ).squeeze(-1) + margins = (target_logits - cand_logits)[:max_reqs].detach().float().cpu() + + rows = [] + prefix_hist: Dict[str, int] = {} + first_mismatch_margins = [] + target_predict_cpu = target_predict[:max_reqs].detach().cpu().tolist() + candidates_cpu = candidates[:max_reqs].detach().cpu().tolist() + predict_cpu = predict.reshape(bs, draft_num)[:max_reqs].detach().cpu().tolist() + + for row_idx, matches in enumerate(shifted_matches_cpu): + prefix = 0 + for matched in matches: + if not matched: + break + prefix += 1 + prefix_hist[str(prefix)] = prefix_hist.get(str(prefix), 0) + 1 + + first_mismatch = None + if prefix < draft_num - 1: + margin = float(margins[row_idx, prefix].item()) + first_mismatch_margins.append(margin) + first_mismatch = { + "step": prefix, + "candidate_token": int(candidates_cpu[row_idx][prefix + 1]), + "target_token": int(target_predict_cpu[row_idx][prefix]), + "target_minus_candidate_logit": margin, + } + + rows.append( + { + "row": row_idx, + "accept_len": int(accept_lens_cpu[row_idx]), + "accepted_draft_tokens": int(accept_lens_cpu[row_idx]) - 1, + "shifted_prefix_matches": prefix, + "draft_tokens": [int(x) for x in candidates_cpu[row_idx]], + "target_predict_tokens": [ + int(x) for x in target_predict_cpu[row_idx] + ], + "predict_tokens": [int(x) for x in predict_cpu[row_idx]], + "first_mismatch": first_mismatch, + } + ) + + payload: Dict[str, Any] = { + "quality_trace_enabled": True, + "quality_trace_rows": rows, + "quality_shifted_prefix_hist": prefix_hist, + } + if first_mismatch_margins: + payload["quality_first_mismatch_margin_mean"] = sum( + first_mismatch_margins + ) / len(first_mismatch_margins) + payload["quality_first_mismatch_margin_max"] = max( + first_mismatch_margins + ) + return payload + except Exception as exc: # noqa: BLE001 + logger.warning("Failed to build DSv4 EAGLE quality trace: %s", exc) + return {"quality_trace_error": str(exc)} + + +def _tensor_row_stats_payload( + name: str, + tensor: Optional[torch.Tensor], + max_rows: int, +) -> Dict[str, Any]: + if tensor is None: + return {f"{name}_present": False} + try: + rows = min(max_rows, int(tensor.shape[0])) if tensor.ndim > 0 else 0 + if rows <= 0: + return { + f"{name}_present": True, + f"{name}_shape": list(tensor.shape), + f"{name}_dtype": str(tensor.dtype), + f"{name}_rows_sampled": 0, + } + sample = tensor[:rows].detach().float() + flat = sample.reshape(rows, -1) + row_norm = torch.linalg.vector_norm(flat, dim=1) + row_sum = flat.sum(dim=1) + return { + f"{name}_present": True, + f"{name}_shape": list(tensor.shape), + f"{name}_dtype": str(tensor.dtype), + f"{name}_rows_sampled": rows, + f"{name}_finite": bool(torch.isfinite(flat).all().item()), + f"{name}_abs_max": float(flat.abs().max().item()), + f"{name}_mean": float(flat.mean().item()), + f"{name}_row_norm_sample": [ + float(item) for item in row_norm.detach().cpu().tolist() + ], + f"{name}_row_sum_sample": [ + float(item) for item in row_sum.detach().cpu().tolist() + ], + } + except Exception as exc: # noqa: BLE001 + logger.warning("Failed to build DSv4 EAGLE %s tensor stats: %s", name, exc) + return {f"{name}_trace_error": str(exc)} + + +def _slice_optional_tensor( + value: Optional[torch.Tensor], + start: int, + stop: int, +) -> Optional[torch.Tensor]: + if value is None: + return None + if not torch.is_tensor(value) or value.ndim == 0: + return value + return value[start:stop] + + +def _slice_optional_list( + value: Optional[List[Any]], + start: int, + stop: int, +) -> Optional[List[Any]]: + if value is None: + return None + return list(value[start:stop]) + + +def _single_global_token_counts( + value: Optional[List[int]], + tokens_per_req: int, +) -> Optional[List[int]]: + if value is None: + return None + count_len = max(1, len(value)) + return [int(tokens_per_req)] * count_len + + +def _single_global_token_counts_tensor( + value: Optional[torch.Tensor], + tokens_per_req: int, +) -> Optional[torch.Tensor]: + if value is None: + return None + if not torch.is_tensor(value) or value.ndim == 0: + return value + return torch.full_like(value, int(tokens_per_req)) + + +def _slice_token_or_row_tensor( + value: Optional[torch.Tensor], + *, + row: int, + token_start: int, + token_stop: int, + batch_size: int, +) -> Optional[torch.Tensor]: + if value is None: + return None + if not torch.is_tensor(value) or value.ndim == 0: + return value + if int(value.shape[0]) == int(batch_size): + return value[row : row + 1] + return value[token_start:token_stop] + + +def _slice_optional_mrope_positions( + value: Optional[torch.Tensor], + start: int, + stop: int, +) -> Optional[torch.Tensor]: + if value is None: + return None + if not torch.is_tensor(value) or value.ndim < 2: + return value + return value[:, start:stop] + + +def _single_extend_start_loc( + value: Optional[torch.Tensor], + row: int, +) -> Optional[torch.Tensor]: + if value is None: + return None + if not torch.is_tensor(value) or value.ndim == 0: + return value + return torch.zeros_like(value[row : row + 1]) + + +def _logit_pair_payload( + name: str, + lhs: torch.Tensor, + rhs: torch.Tensor, + *, + topk: int, +) -> Dict[str, Any]: + """Summarize row-wise logit agreement without dumping full vocab tensors.""" + rows = min(int(lhs.shape[0]), int(rhs.shape[0])) + if rows <= 0: + return {f"{name}_rows": 0} + + k = max(1, min(int(topk), int(lhs.shape[-1]), int(rhs.shape[-1]))) + lhs_f = lhs[:rows].detach().float() + rhs_f = rhs[:rows].detach().float() + lhs_vals, lhs_idx = torch.topk(lhs_f, k=k, dim=-1) + rhs_vals, rhs_idx = torch.topk(rhs_f, k=k, dim=-1) + diff = (lhs_f - rhs_f).abs() + + topk_overlap = [] + for row in range(rows): + lhs_set = set(int(item) for item in lhs_idx[row].detach().cpu().tolist()) + rhs_set = set(int(item) for item in rhs_idx[row].detach().cpu().tolist()) + topk_overlap.append(len(lhs_set & rhs_set)) + + return { + f"{name}_rows": rows, + f"{name}_top1_equal_sample": [ + bool(item) + for item in (lhs_idx[:, 0] == rhs_idx[:, 0]).detach().cpu().tolist() + ], + f"{name}_topk_overlap_sample": topk_overlap, + f"{name}_lhs_topk_token_sample": [ + [int(item) for item in row] for row in lhs_idx.detach().cpu().tolist() + ], + f"{name}_rhs_topk_token_sample": [ + [int(item) for item in row] for row in rhs_idx.detach().cpu().tolist() + ], + f"{name}_lhs_topk_logit_sample": [ + [float(item) for item in row] for row in lhs_vals.detach().cpu().tolist() + ], + f"{name}_rhs_topk_logit_sample": [ + [float(item) for item in row] for row in rhs_vals.detach().cpu().tolist() + ], + f"{name}_max_abs_sample": [ + float(item) for item in diff.amax(dim=-1).detach().cpu().tolist() + ], + f"{name}_mean_abs_sample": [ + float(item) for item in diff.mean(dim=-1).detach().cpu().tolist() + ], + } + + +def _hidden_pair_payload( + name: str, + lhs: torch.Tensor, + rhs: torch.Tensor, +) -> Dict[str, Any]: + rows = min(int(lhs.shape[0]), int(rhs.shape[0])) + if rows <= 0: + return {f"{name}_rows": 0} + lhs_f = lhs[:rows].detach().float().reshape(rows, -1) + rhs_f = rhs[:rows].detach().float().reshape(rows, -1) + diff = (lhs_f - rhs_f).abs() + dot = (lhs_f * rhs_f).sum(dim=1) + denom = torch.linalg.vector_norm(lhs_f, dim=1) * torch.linalg.vector_norm( + rhs_f, dim=1 + ) + cosine = dot / torch.clamp(denom, min=1.0e-30) + return { + f"{name}_rows": rows, + f"{name}_cosine_sample": [ + float(item) for item in cosine.detach().cpu().tolist() + ], + f"{name}_max_abs_sample": [ + float(item) for item in diff.amax(dim=1).detach().cpu().tolist() + ], + f"{name}_mean_abs_sample": [ + float(item) for item in diff.mean(dim=1).detach().cpu().tolist() + ], + } + + +def _build_single_row_forward_batch( + forward_batch: ForwardBatch, + *, + row: int, + tokens_per_req: int, +) -> ForwardBatch: + token_start = int(row) * int(tokens_per_req) + token_stop = token_start + int(tokens_per_req) + + spec_info = copy.copy(forward_batch.spec_info) + spec_info.hidden_states = _slice_optional_tensor( + forward_batch.spec_info.hidden_states, token_start, token_stop + ) + spec_info.num_accepted_drafts = _slice_optional_tensor( + getattr(forward_batch.spec_info, "num_accepted_drafts", None), row, row + 1 + ) + spec_info.num_accepted_tokens = _slice_optional_tensor( + getattr(forward_batch.spec_info, "num_accepted_tokens", None), row, row + 1 + ) + + seq_lens_cpu = getattr(forward_batch, "seq_lens_cpu", None) + if torch.is_tensor(seq_lens_cpu): + seq_lens_cpu = seq_lens_cpu[row : row + 1] + elif seq_lens_cpu is not None: + seq_lens_cpu = list(seq_lens_cpu[row : row + 1]) + + return dataclasses.replace( + forward_batch, + batch_size=1, + input_ids=_slice_optional_tensor( + forward_batch.input_ids, token_start, token_stop + ), + req_pool_indices=_slice_optional_tensor( + forward_batch.req_pool_indices, row, row + 1 + ), + seq_lens=_slice_optional_tensor(forward_batch.seq_lens, row, row + 1), + out_cache_loc=_slice_optional_tensor( + forward_batch.out_cache_loc, token_start, token_stop + ), + out_cache_loc_swa=_slice_optional_tensor( + forward_batch.out_cache_loc_swa, token_start, token_stop + ), + seq_lens_sum=int(forward_batch.seq_lens[row].detach().cpu().item()) + if torch.is_tensor(forward_batch.seq_lens) + else int(forward_batch.seq_lens_sum), + seq_lens_cpu=seq_lens_cpu, + positions=_slice_optional_tensor( + forward_batch.positions, token_start, token_stop + ), + mrope_positions=_slice_optional_mrope_positions( + forward_batch.mrope_positions, token_start, token_stop + ), + next_token_logits_buffer=_slice_token_or_row_tensor( + forward_batch.next_token_logits_buffer, + row=row, + token_start=token_start, + token_stop=token_stop, + batch_size=forward_batch.batch_size, + ), + extend_num_tokens=int(tokens_per_req), + extend_seq_lens=_slice_optional_tensor( + forward_batch.extend_seq_lens, row, row + 1 + ), + extend_prefix_lens=_slice_optional_tensor( + forward_batch.extend_prefix_lens, row, row + 1 + ), + extend_start_loc=_single_extend_start_loc(forward_batch.extend_start_loc, row), + extend_prefix_lens_cpu=_slice_optional_list( + forward_batch.extend_prefix_lens_cpu, row, row + 1 + ), + extend_seq_lens_cpu=_slice_optional_list( + forward_batch.extend_seq_lens_cpu, row, row + 1 + ), + original_global_num_tokens_cpu=_single_global_token_counts( + forward_batch.original_global_num_tokens_cpu, tokens_per_req + ), + global_num_tokens_cpu=_single_global_token_counts( + forward_batch.global_num_tokens_cpu, tokens_per_req + ), + global_num_tokens_for_logprob_cpu=_single_global_token_counts( + forward_batch.global_num_tokens_for_logprob_cpu, tokens_per_req + ), + global_num_tokens_gpu=_single_global_token_counts_tensor( + forward_batch.global_num_tokens_gpu, tokens_per_req + ), + global_num_tokens_for_logprob_gpu=_single_global_token_counts_tensor( + forward_batch.global_num_tokens_for_logprob_gpu, tokens_per_req + ), + num_token_non_padded=( + torch.tensor( + int(tokens_per_req), + dtype=forward_batch.num_token_non_padded.dtype, + device=forward_batch.num_token_non_padded.device, + ) + if torch.is_tensor(forward_batch.num_token_non_padded) + else forward_batch.num_token_non_padded + ), + num_token_non_padded_cpu=int(tokens_per_req), + rids=_slice_optional_list(forward_batch.rids, row, row + 1), + global_dp_buffer_len=( + int(tokens_per_req) * len(forward_batch.global_num_tokens_cpu) + if forward_batch.global_num_tokens_cpu is not None + else forward_batch.global_dp_buffer_len + ), + dp_local_start_pos=None, + dp_local_num_tokens=None, + spec_info=spec_info, + ) + + +def _dsv4_eagle_recompute_payload( + worker: "EagleDraftWorker", + forward_batch: ForwardBatch, + draft_logits_output: Any, + select_index: torch.Tensor, + *, + rows: int, +) -> Dict[str, Any]: + if not _env_flag("SGLANG_DSV4_EAGLE_RECOMPUTE_TRACE"): + return {} + if rows <= 0: + return {} + + topk = max(1, _env_int("SGLANG_DSV4_EAGLE_RECOMPUTE_TOPK", 5)) + max_rows = max(1, _env_int("SGLANG_DSV4_EAGLE_RECOMPUTE_MAX_ROWS", 1)) + rows = min(rows, max_rows) + tokens_per_req = int(worker.speculative_num_draft_tokens) + + payload: Dict[str, Any] = { + "recompute_trace_enabled": True, + "recompute_rows_requested": rows, + } + live_logits_all = getattr(draft_logits_output, "next_token_logits", None) + live_hidden_all = getattr(draft_logits_output, "hidden_states", None) + if live_logits_all is None or live_hidden_all is None: + payload["recompute_trace_error"] = "missing_live_logits_or_hidden" + return payload + + row_payloads = [] + for row in range(rows): + try: + single = _build_single_row_forward_batch( + forward_batch, row=row, tokens_per_req=tokens_per_req + ) + attn_backend = worker.draft_runner.attn_backend + previous_metadata = getattr(attn_backend, "forward_metadata", None) + single.attn_backend = attn_backend + # Rebuild metadata for the sliced request. This is trace-only and + # intentionally eager so it can compare against the live batched + # draft-extend output without relying on graph replay. The hidden + # flag makes DSv4 attention skip KV stores, so this read-only replay + # cannot poison later speculative steps. + setattr(single, "_dsv4_eagle_recompute_no_kv_store", True) + try: + attn_backend.init_forward_metadata(single) + replay = worker.draft_runner.forward( + single, skip_attn_backend_init=True + ).logits_output + finally: + if hasattr(attn_backend, "forward_metadata"): + attn_backend.forward_metadata = previous_metadata + + local_index = ( + int(select_index[row].detach().cpu().item()) - row * tokens_per_req + ) + if local_index < 0 or local_index >= tokens_per_req: + row_payloads.append( + { + "row": row, + "error": f"local_index_out_of_range:{local_index}", + } + ) + continue + + live_index = int(select_index[row].detach().cpu().item()) + row_payload: Dict[str, Any] = { + "row": row, + "live_index": live_index, + "local_index": local_index, + } + row_payload.update( + _logit_pair_payload( + "selected_logit", + live_logits_all[live_index : live_index + 1], + replay.next_token_logits[local_index : local_index + 1], + topk=topk, + ) + ) + row_payload.update( + _hidden_pair_payload( + "selected_hidden", + live_hidden_all[live_index : live_index + 1], + replay.hidden_states[local_index : local_index + 1], + ) + ) + row_payloads.append(row_payload) + except Exception as exc: # noqa: BLE001 + logger.warning("DSv4 EAGLE recompute trace failed for row %s: %s", row, exc) + row_payloads.append({"row": row, "error": repr(exc)}) + + payload["recompute_rows"] = row_payloads + return payload + + +def _dsv4_eagle_handoff_payload( + *, + batch: ModelWorkerBatch, + batch_result: GenerationBatchResult, + draft_logits_output: Any, + select_index: torch.Tensor, + seq_lens_before_extend: Optional[torch.Tensor], + can_cuda_graph: bool, + max_reqs: int, +) -> Dict[str, Any]: + if not _env_flag("SGLANG_DSV4_EAGLE_HANDOFF_TRACE"): + return {} + if max_reqs <= 0: + return {} + + try: + bs = int(select_index.numel()) + rows = min(max_reqs, bs) + payload: Dict[str, Any] = { + "handoff_trace_enabled": True, + "handoff_can_cuda_graph": bool(can_cuda_graph), + "handoff_rows_sampled": rows, + } + if seq_lens_before_extend is not None: + payload["handoff_seq_lens_before_sample"] = [ + int(item) for item in seq_lens_before_extend[:rows].detach().cpu().tolist() + ] + payload["handoff_seq_lens_after_sample"] = [ + int(item) for item in batch.seq_lens[:rows].detach().cpu().tolist() + ] + payload["handoff_accept_lens_sample"] = [ + int(item) for item in batch_result.accept_lens[:rows].detach().cpu().tolist() + ] + payload["handoff_select_index_sample"] = [ + int(item) for item in select_index[:rows].detach().cpu().tolist() + ] + + draft_tokens = getattr(batch_result, "next_token_ids", None) + if draft_tokens is not None and draft_tokens.numel() >= bs: + token_grid = draft_tokens.reshape(bs, -1) + payload["handoff_next_token_grid_shape"] = list(token_grid.shape) + payload["handoff_next_token_grid_sample"] = [ + [int(tok) for tok in row] + for row in token_grid[:rows].detach().cpu().tolist() + ] + selected_pos = (batch_result.accept_lens[:rows] - 1).to(torch.long) + payload["handoff_selected_token_sample"] = [ + int(item) + for item in token_grid[:rows] + .gather(1, selected_pos.unsqueeze(1)) + .squeeze(1) + .detach() + .cpu() + .tolist() + ] + + target_hidden = getattr(batch_result.logits_output, "hidden_states", None) + draft_hidden_all = getattr(draft_logits_output, "hidden_states", None) + selected_target_hidden = ( + target_hidden[select_index] + if target_hidden is not None + and target_hidden.numel() > 0 + and int(target_hidden.shape[0]) > int(select_index.max().item()) + else None + ) + selected_hidden = ( + draft_hidden_all[select_index] + if draft_hidden_all is not None and draft_hidden_all.numel() > 0 + else None + ) + payload.update(_tensor_row_stats_payload("handoff_target_hidden", target_hidden, rows)) + payload.update( + _tensor_row_stats_payload( + "handoff_selected_target_hidden", selected_target_hidden, rows + ) + ) + payload.update( + _tensor_row_stats_payload("handoff_selected_draft_hidden", selected_hidden, rows) + ) + if ( + selected_target_hidden is not None + and selected_hidden is not None + and selected_target_hidden.ndim == selected_hidden.ndim + and selected_target_hidden.shape[-1] == selected_hidden.shape[-1] + ): + lhs = selected_target_hidden[:rows].detach().float().reshape(rows, -1) + rhs = selected_hidden[:rows].detach().float().reshape(rows, -1) + dot = (lhs * rhs).sum(dim=1) + denom = torch.linalg.vector_norm(lhs, dim=1) * torch.linalg.vector_norm( + rhs, dim=1 + ) + cosine = dot / torch.clamp(denom, min=1.0e-30) + payload["handoff_selected_target_draft_hidden_cos_sample"] = [ + float(item) for item in cosine.detach().cpu().tolist() + ] + return payload + except Exception as exc: # noqa: BLE001 + logger.warning("Failed to build DSv4 EAGLE handoff trace: %s", exc) + return {"handoff_trace_error": str(exc)} + + def _get_plan_stream( device: str, ) -> Tuple[any, contextlib.AbstractContextManager]: @@ -336,6 +1146,18 @@ def init_cuda_graphs(self): isinstance(self.draft_extend_attn_backend, TritonAttnBackend) or isinstance(self.draft_extend_attn_backend, TRTLLMMLABackend) ) + if ( + (_is_cuda or _is_musa) + and not supports_cuda_draft_extend_graph + and _env_flag("SGLANG_DSV4_ENABLE_DRAFT_EXTEND_CUDA_GRAPH") + ): + from sglang.srt.layers.attention.deepseek_v4_backend import ( + DeepseekV4AttnBackend, + ) + + supports_cuda_draft_extend_graph = isinstance( + self.draft_extend_attn_backend, DeepseekV4AttnBackend + ) # Capture extend # TODO: support draft extend cuda graph for more attention backends if self.draft_extend_attn_backend and ( @@ -590,6 +1412,11 @@ def _draft_extend_for_decode( self, batch: ModelWorkerBatch, batch_result: GenerationBatchResult ): # Batch 2: Draft extend + seq_lens_before_extend = ( + batch.seq_lens.detach().clone() + if _env_flag("SGLANG_DSV4_EAGLE_HANDOFF_TRACE") + else None + ) draft_input = EagleDraftInput( hidden_states=batch_result.logits_output.hidden_states, num_tokens_per_req=self.speculative_num_steps + 1, @@ -641,6 +1468,33 @@ def _draft_extend_for_decode( draft_logits_output.next_token_logits, f"draft_extend_for_decode (cuda_graph={can_cuda_graph})", ) + handoff_payload = _dsv4_eagle_handoff_payload( + batch=batch, + batch_result=batch_result, + draft_logits_output=draft_logits_output, + select_index=select_index, + seq_lens_before_extend=seq_lens_before_extend, + can_cuda_graph=bool(can_cuda_graph), + max_reqs=_env_int("SGLANG_DSV4_EAGLE_HANDOFF_TRACE_MAX_REQS", 8), + ) + if handoff_payload: + handoff_payload.update( + _dsv4_eagle_recompute_payload( + self, + forward_batch, + draft_logits_output, + select_index, + rows=int(handoff_payload.get("handoff_rows_sampled", 0) or 0), + ) + ) + if handoff_payload: + _trace_dsv4_eagle_event( + self, + "draft_extend_handoff", + batch=batch, + batch_result=batch_result, + **handoff_payload, + ) # Reorganize the spec info for the next batch draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[ @@ -759,21 +1613,46 @@ def clear_cache_pool(self): pass def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): + forward_start_ns = time.perf_counter_ns() + _trace_dsv4_eagle_event( + self, + "eagle_forward_start", + batch=model_worker_batch, + ) if ( model_worker_batch.forward_mode.is_extend() or model_worker_batch.is_extend_in_batch ): # Target prefill model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + target_start_ns = time.perf_counter_ns() + _trace_dsv4_eagle_event( + self, + "target_prefill_start", + batch=model_worker_batch, + ) batch_output = self.target_worker.forward_batch_generation( model_worker_batch ) + _trace_dsv4_eagle_event( + self, + "target_prefill_end", + batch=model_worker_batch, + elapsed_ms=(time.perf_counter_ns() - target_start_ns) / 1e6, + can_run_cuda_graph=getattr(batch_output, "can_run_cuda_graph", None), + ) # Draft prefill model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST with self.draft_worker.draft_tp_context( self.draft_worker.draft_runner.tp_group ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + draft_extend_start_ns = time.perf_counter_ns() + _trace_dsv4_eagle_event( + self, + "draft_extend_prefill_start", + batch=model_worker_batch, + ) batch_output.next_draft_input = ( self.draft_worker._draft_extend_for_prefill( model_worker_batch, @@ -782,6 +1661,19 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): batch_output.logits_output.mm_input_embeds, ) ) + _trace_dsv4_eagle_event( + self, + "draft_extend_prefill_end", + batch=model_worker_batch, + elapsed_ms=(time.perf_counter_ns() - draft_extend_start_ns) / 1e6, + ) + _trace_dsv4_eagle_event( + self, + "eagle_forward_end", + batch=model_worker_batch, + elapsed_ms=(time.perf_counter_ns() - forward_start_ns) / 1e6, + path="prefill", + ) return batch_output else: if model_worker_batch.spec_info is None: @@ -795,9 +1687,21 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): with self.draft_worker.draft_tp_context( self.draft_worker.draft_runner.tp_group ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + draft_start_ns = time.perf_counter_ns() + _trace_dsv4_eagle_event( + self, + "draft_start", + batch=model_worker_batch, + ) verify_input: EagleVerifyInput = self.draft_worker.draft( model_worker_batch ) + _trace_dsv4_eagle_event( + self, + "draft_end", + batch=model_worker_batch, + elapsed_ms=(time.perf_counter_ns() - draft_start_ns) / 1e6, + ) assert verify_input.is_verify_input() # Record a CUDA event after draft() GPU work is dispatched. # This event will be waited on by plan_stream in verify() @@ -807,14 +1711,48 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): self._draft_done_event = torch.get_device_module(self.device).Event() self._draft_done_event.record() model_worker_batch.spec_info = verify_input + verify_start_ns = time.perf_counter_ns() + _trace_dsv4_eagle_event( + self, + "verify_start", + batch=model_worker_batch, + ) batch_output = self.verify(model_worker_batch) + _trace_dsv4_eagle_event( + self, + "verify_end", + batch=model_worker_batch, + batch_result=batch_output, + elapsed_ms=(time.perf_counter_ns() - verify_start_ns) / 1e6, + can_run_cuda_graph=getattr(batch_output, "can_run_cuda_graph", None), + ) with self.draft_worker.draft_tp_context( self.draft_worker.draft_runner.tp_group ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + draft_extend_start_ns = time.perf_counter_ns() + _trace_dsv4_eagle_event( + self, + "draft_extend_decode_start", + batch=model_worker_batch, + ) self.draft_worker._draft_extend_for_decode( model_worker_batch, batch_output ) + _trace_dsv4_eagle_event( + self, + "draft_extend_decode_end", + batch=model_worker_batch, + batch_result=batch_output, + elapsed_ms=(time.perf_counter_ns() - draft_extend_start_ns) / 1e6, + ) + _trace_dsv4_eagle_event( + self, + "eagle_forward_end", + batch=model_worker_batch, + elapsed_ms=(time.perf_counter_ns() - forward_start_ns) / 1e6, + path="decode", + ) return batch_output def on_verify_complete_cpu(self, accepted_draft_tokens: list[int]) -> None: @@ -959,6 +1897,8 @@ def _override_worker_state( ) = backup def verify(self, batch: ModelWorkerBatch): + verify_total_start_ns = time.perf_counter_ns() + _trace_dsv4_eagle_event(self, "verify_detail_start", batch=batch) # Since batch.seq_lens is allocated in another stream, we need # record_stream() to prevent pytorch gc and reuse the gpu memory # while forward_stream is still running. @@ -973,6 +1913,8 @@ def verify(self, batch: ModelWorkerBatch): # Batch 1: Target verify # Prepare for target verify in a separate stream + prepare_start_ns = time.perf_counter_ns() + _trace_dsv4_eagle_event(self, "verify_prepare_start", batch=batch) with self.plan_stream_ctx: # Wait for the draft CUDA graph to finish before plan_stream # begins its work. Using an event is more targeted than @@ -987,9 +1929,18 @@ def verify(self, batch: ModelWorkerBatch): self.target_worker, ) ) + _trace_dsv4_eagle_event( + self, + "verify_prepare_end", + batch=batch, + elapsed_ms=(time.perf_counter_ns() - prepare_start_ns) / 1e6, + can_run_cuda_graph=bool(can_run_cuda_graph), + ) # Correct some buffers due to the overlap plan if self.plan_stream: + plan_post_start_ns = time.perf_counter_ns() + _trace_dsv4_eagle_event(self, "verify_plan_post_start", batch=batch) torch.get_device_module(self.device).current_stream().wait_stream( self.plan_stream ) @@ -1017,6 +1968,13 @@ def verify(self, batch: ModelWorkerBatch): else None ), ) + _trace_dsv4_eagle_event( + self, + "verify_plan_post_end", + batch=batch, + elapsed_ms=(time.perf_counter_ns() - plan_post_start_ns) / 1e6, + can_run_cuda_graph=bool(can_run_cuda_graph), + ) # Prepare grammar data on CPU if needed if batch.has_grammar: @@ -1027,12 +1985,29 @@ def verify(self, batch: ModelWorkerBatch): ).cpu() # Run target verify batch in the main compute stream (GPU compute) + target_verify_start_ns = time.perf_counter_ns() + _trace_dsv4_eagle_event( + self, + "target_verify_forward_start", + batch=batch, + can_run_cuda_graph=bool(can_run_cuda_graph), + ) forward_batch_output = self.target_worker.forward_batch_generation( model_worker_batch=None, forward_batch=verify_forward_batch, is_verify=True, skip_attn_backend_init=True, ) + _trace_dsv4_eagle_event( + self, + "target_verify_forward_end", + batch=batch, + elapsed_ms=(time.perf_counter_ns() - target_verify_start_ns) / 1e6, + can_run_cuda_graph=bool(can_run_cuda_graph), + target_output_can_run_cuda_graph=getattr( + forward_batch_output, "can_run_cuda_graph", None + ), + ) logits_output = forward_batch_output.logits_output # Generate vocab mask for constrained decoding @@ -1057,11 +2032,26 @@ def verify(self, batch: ModelWorkerBatch): # Sample maybe_detect_nan(logits_output.next_token_logits, "verify: target model logits") + sample_start_ns = time.perf_counter_ns() + _trace_dsv4_eagle_event(self, "verify_sample_start", batch=batch) ( predict, accept_lens, accept_index, ) = verify_input.sample(batch, logits_output, vocab_mask) + _trace_dsv4_eagle_event( + self, + "verify_sample_end", + batch=batch, + elapsed_ms=(time.perf_counter_ns() - sample_start_ns) / 1e6, + **_dsv4_eagle_quality_payload( + verify_input, + logits_output, + predict, + accept_lens, + bs, + ), + ) new_seq_lens = batch.seq_lens + accept_lens # Update mamba state for hybrid GDN models after verification @@ -1069,14 +2059,24 @@ def verify(self, batch: ModelWorkerBatch): self.target_worker.model_runner.hybrid_gdn_config is not None or self.target_worker.model_runner.mamba2_config is not None ): + mamba_start_ns = time.perf_counter_ns() + _trace_dsv4_eagle_event(self, "mamba_verify_update_start", batch=batch) self._mamba_verify_update( batch, verify_input, accept_lens, accept_index, bs ) + _trace_dsv4_eagle_event( + self, + "mamba_verify_update_end", + batch=batch, + elapsed_ms=(time.perf_counter_ns() - mamba_start_ns) / 1e6, + ) verify_done = torch.get_device_module(self.device).Event() verify_done.record() if not batch.forward_mode.is_idle(): + bonus_start_ns = time.perf_counter_ns() + _trace_dsv4_eagle_event(self, "verify_bonus_tokens_start", batch=batch) accept_tokens = predict[accept_index] bonus_tokens = torch.empty_like(accept_lens, dtype=torch.int32) fill_bonus_tokens[(bs,)]( @@ -1085,6 +2085,12 @@ def verify(self, batch: ModelWorkerBatch): bonus_tokens, self.speculative_num_draft_tokens, ) + _trace_dsv4_eagle_event( + self, + "verify_bonus_tokens_end", + batch=batch, + elapsed_ms=(time.perf_counter_ns() - bonus_start_ns) / 1e6, + ) else: bonus_tokens = torch.empty((0,), device=self.device, dtype=torch.int32) @@ -1116,6 +2122,13 @@ def verify(self, batch: ModelWorkerBatch): verify_done=verify_done, ) + _trace_dsv4_eagle_event( + self, + "verify_detail_end", + batch=batch, + elapsed_ms=(time.perf_counter_ns() - verify_total_start_ns) / 1e6, + can_run_cuda_graph=bool(can_run_cuda_graph), + ) return GenerationBatchResult( logits_output=logits_output, next_token_ids=predict, diff --git a/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py index 348f9e834..a248fe3d4 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py @@ -16,6 +16,7 @@ import bisect import logging +import os import time from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, List, Optional @@ -34,6 +35,7 @@ set_global_graph_memory_pool, set_is_extend_in_batch, set_torch_compile_config, + _write_replay_trace, ) from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, @@ -61,6 +63,24 @@ logger = logging.getLogger(__name__) +def _dsv4_graph_debug_enabled() -> bool: + return os.environ.get("SGLANG_DSV4_GRAPH_DEBUG", "0") not in ( + "", + "0", + "false", + "False", + ) or _dsv4_graph_debug_sync_enabled() + + +def _dsv4_graph_debug_sync_enabled() -> bool: + return os.environ.get("SGLANG_DSV4_GRAPH_DEBUG_SYNC", "0") not in ( + "", + "0", + "false", + "False", + ) + + @dataclass class MultiLayerEagleDraftExtendInputBuffers(ForwardInputBuffers): # Sliced from shared parent buffers @@ -287,7 +307,96 @@ def _capture_graph(self, graph, pool, stream, run_once_fn): return out def _replay(self, forward_batch: ForwardBatch): + trace_enabled = bool(os.environ.get("SGLANG_DSV4_REPLAY_TRACE_PATH", "")) + trace_sync_enabled = trace_enabled and bool( + int(os.environ.get("SGLANG_DSV4_REPLAY_TRACE_SYNC", "0") or "0") + ) + debug_enabled = _dsv4_graph_debug_enabled() + debug_sync_enabled = _dsv4_graph_debug_sync_enabled() + trace_t0 = time.perf_counter_ns() if trace_enabled else 0 + replay_start = replay_end = None + if trace_enabled: + _write_replay_trace( + { + "event": "multi_layer_eagle_draft_extend_cuda_graph_replay_pre", + "forward_mode": self.forward_mode.name, + "capture_forward_mode": self.forward_mode.name, + "capture_hidden_mode": "LAST", + "raw_bs": int(getattr(self, "raw_bs", forward_batch.batch_size)), + "bs": int(self.bs), + "raw_num_token": int(forward_batch.input_ids.shape[0]), + "num_tokens_per_bs": int(self.num_tokens_per_bs), + "step": int(self.step), + "graph_key": f"{self.step}:{self.bs}", + "stream_idx": None, + "variant_label": None, + "sync_debug": debug_sync_enabled, + } + ) + if debug_enabled: + logger.warning( + "[DSV4 graph debug] multi-layer eagle draft extend replay pre " + "step=%s mode=%s raw_bs=%s bs=%s raw_num_token=%s graph_key=%s", + self.step, + self.forward_mode.name, + getattr(self, "raw_bs", None), + getattr(self, "bs", None), + int(forward_batch.input_ids.shape[0]), + f"{self.step}:{getattr(self, 'bs', None)}", + ) + if debug_sync_enabled: + torch.cuda.synchronize() + logger.warning( + "[DSV4 graph debug] multi-layer eagle draft extend replay pre-sync done step=%s bs=%s", + self.step, + getattr(self, "bs", None), + ) + if trace_sync_enabled: + replay_start = torch.cuda.Event(enable_timing=True) + replay_end = torch.cuda.Event(enable_timing=True) + replay_start.record() self.graphs[self.bs].replay() + if debug_sync_enabled: + torch.cuda.synchronize() + logger.warning( + "[DSV4 graph debug] multi-layer eagle draft extend replay post-sync done step=%s bs=%s", + self.step, + getattr(self, "bs", None), + ) + if trace_sync_enabled: + assert replay_start is not None and replay_end is not None + replay_end.record() + replay_end.synchronize() + replay_device_us = float(replay_start.elapsed_time(replay_end) * 1000.0) + else: + replay_device_us = None + trace_t1 = time.perf_counter_ns() if trace_enabled else 0 + if trace_enabled: + _write_replay_trace( + { + "event": "multi_layer_eagle_draft_extend_cuda_graph_replay", + "forward_mode": self.forward_mode.name, + "capture_forward_mode": self.forward_mode.name, + "capture_hidden_mode": "LAST", + "raw_bs": int(getattr(self, "raw_bs", forward_batch.batch_size)), + "bs": int(self.bs), + "raw_num_token": int(forward_batch.input_ids.shape[0]), + "num_tokens_per_bs": int(self.num_tokens_per_bs), + "step": int(self.step), + "graph_key": f"{self.step}:{self.bs}", + "stream_idx": None, + "variant_label": None, + "entry_ns": trace_t0, + "after_prepare_ns": trace_t0, + "before_launch_ns": trace_t0, + "after_launch_ns": trace_t1, + "prepare_us": 0.0, + "launch_us": (trace_t1 - trace_t0) / 1000.0, + "entry_to_launch_us": 0.0, + "replay_device_us": replay_device_us, + "sync_trace": trace_sync_enabled, + } + ) def capture(self): CudaGraphRunner.capture(self) diff --git a/test/registered/dsv4/test_dsv4_indexer_seq_lens_contract.py b/test/registered/dsv4/test_dsv4_indexer_seq_lens_contract.py new file mode 100644 index 000000000..48749f766 --- /dev/null +++ b/test/registered/dsv4/test_dsv4_indexer_seq_lens_contract.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import pytest +import torch + +from sglang.srt.layers.attention.dsv4.indexer import _normalize_indexer_seq_lens + + +@pytest.mark.parametrize( + "batch_size", + [1, 2, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 24, 32], +) +@pytest.mark.parametrize("query_rows_per_request", [1, 4]) +def test_indexer_seq_lens_accepts_cuda_graph_batch_rows( + batch_size: int, + query_rows_per_request: int, +) -> None: + rows = batch_size * query_rows_per_request + seq_lens = torch.arange(1, rows + 1, dtype=torch.int64) + + normalized = _normalize_indexer_seq_lens(seq_lens, expected_rows=rows) + + assert normalized.shape == (rows,) + assert normalized.dtype == torch.int32 + assert normalized.is_contiguous() + torch.testing.assert_close(normalized, seq_lens.to(torch.int32)) + + +def test_indexer_seq_lens_accepts_deepgemm_trailing_dim_metadata() -> None: + seq_lens = torch.arange(1, 33, dtype=torch.int32).reshape(32, 1) + + normalized = _normalize_indexer_seq_lens(seq_lens, expected_rows=32) + + assert normalized.shape == (32,) + torch.testing.assert_close(normalized, seq_lens.squeeze(-1)) + + +@pytest.mark.parametrize( + "seq_lens", + [ + torch.ones((8, 2), dtype=torch.int32), + torch.ones((2, 4, 1), dtype=torch.int32), + ], +) +def test_indexer_seq_lens_rejects_non_row_vector_shapes(seq_lens: torch.Tensor) -> None: + with pytest.raises(ValueError, match="rank-1 or trailing-1 rank-2"): + _normalize_indexer_seq_lens(seq_lens, expected_rows=8) + + +def test_indexer_seq_lens_rejects_row_count_mismatch() -> None: + seq_lens = torch.ones((31,), dtype=torch.int32) + + with pytest.raises(ValueError, match="row mismatch"): + _normalize_indexer_seq_lens(seq_lens, expected_rows=32) diff --git a/test/registered/dsv4/test_hash_topk_scaling.py b/test/registered/dsv4/test_hash_topk_scaling.py new file mode 100644 index 000000000..272ffbc27 --- /dev/null +++ b/test/registered/dsv4/test_hash_topk_scaling.py @@ -0,0 +1,65 @@ +import unittest + +import torch + +from sglang.srt.environ import envs +from sglang.srt.layers.moe.hash_topk import HashTopK + + +class TestHashTopKScaling(unittest.TestCase): + def test_apply_routed_scaling_factor_scales_fused_output_weights(self): + vocab_size = 8 + num_experts = 16 + routed_topk = 3 + num_shared = 1 + scale = 1.5 + + router_logits = torch.linspace( + -0.25, 0.75, steps=2 * num_experts, dtype=torch.float32 + ).reshape(2, num_experts) + input_ids = torch.tensor([1, 4], dtype=torch.int64) + hidden_states = torch.zeros((2, 4), dtype=torch.float32) + tid2eid = torch.tensor( + [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11], + [12, 13, 14], + [1, 3, 5], + [7, 9, 11], + [0, 2, 4], + ], + dtype=torch.int32, + ) + + def make_topk(apply_scaling: bool) -> HashTopK: + topk = HashTopK( + topk=routed_topk + num_shared, + num_experts=num_experts, + num_fused_shared_experts=num_shared, + vocab_size=vocab_size, + routed_scaling_factor=scale, + apply_routed_scaling_factor_on_output=apply_scaling, + ) + topk.tid2eid.data.copy_(tid2eid) + return topk + + with envs.SGLANG_OPT_USE_FUSED_HASH_TOPK.override(False): + base = make_topk(False)( + hidden_states, router_logits, input_ids=input_ids + ) + scaled = make_topk(True)( + hidden_states, router_logits, input_ids=input_ids + ) + + torch.testing.assert_close(scaled.topk_ids, base.topk_ids) + torch.testing.assert_close(scaled.topk_weights, base.topk_weights * scale) + torch.testing.assert_close( + scaled.topk_weights[:, -1], + torch.ones_like(scaled.topk_weights[:, -1]), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/scheduler/test_prefill_delayer.py b/test/registered/scheduler/test_prefill_delayer.py index 0bd8b7997..438feaea3 100644 --- a/test/registered/scheduler/test_prefill_delayer.py +++ b/test/registered/scheduler/test_prefill_delayer.py @@ -262,6 +262,38 @@ def _run_negotiate_test(rank, test_cases): expected_allow=False, expected_reason="delay", ), + # Before any prefill batch is recorded, max_prefill_bs can still be zero. + # Queue gating should still batch tiny refills by falling back to the + # configured max_running_requests cap instead of silently disabling itself. + NegotiateTestCase( + name="queue_trigger_uses_max_running_before_prefill_bs_is_learned", + max_delay_passes=100, + token_usage_low_watermark=0.8, + queue_min_ratio=0.5, + max_delay_ms=5000, + calls=[ + NegotiateCall( + prefillable=[True, True, True, True], + token_usage=[0.9, 0.9, 0.9, 0.9], + running_batch=[30, 30, 30, 30], + max_prefill_bs=[0, 0, 0, 0], + waiting_queue_len=[3, 3, 3, 3], + max_running_requests=32, + ), + # First queue delay is skipped for compatibility with the + # slot-based path; the second call must delay. + NegotiateCall( + prefillable=[True, True, True, True], + token_usage=[0.9, 0.9, 0.9, 0.9], + running_batch=[30, 30, 30, 30], + max_prefill_bs=[0, 0, 0, 0], + waiting_queue_len=[3, 3, 3, 3], + max_running_requests=32, + ), + ], + expected_allow=False, + expected_reason="delay", + ), # Waiting queue at or above queue_min: queue trigger must not fire. NegotiateTestCase( name="queue_trigger_above_threshold", diff --git a/test/registered/unit/mem_cache/test_swa_unittest.py b/test/registered/unit/mem_cache/test_swa_unittest.py index b6ec6538c..da4f7774a 100644 --- a/test/registered/unit/mem_cache/test_swa_unittest.py +++ b/test/registered/unit/mem_cache/test_swa_unittest.py @@ -676,6 +676,93 @@ def test_swa_backup_restore_eagle3(self): self.assertEqual(alloc.full_available_size(), size) self.assertEqual(alloc.swa_available_size(), size_swa) + def test_paged_swa_allocator_frees_full_and_swa_pages(self): + """Paged SWA allocators must mark full indices as allocated. + + The mapping cannot use a negative free sentinel because CUDA graphs read + it during replay, so a separate allocation mask guards double frees. + Paged extend/decode must update that mask or free() drops the full page + indices before reaching the backing allocators. + """ + size = 1024 + size_swa = 1024 + page_size = 256 + dtype = torch.bfloat16 + device = get_device() + pool = SWAKVPool( + size=size, + size_swa=size_swa, + page_size=page_size, + dtype=dtype, + head_num=1, + head_dim=1, + swa_attention_layer_ids=[1], + full_attention_layer_ids=[0], + enable_kvcache_transpose=False, + device=device, + ) + alloc = SWATokenToKVPoolAllocator( + size=size, + size_swa=size_swa, + page_size=page_size, + dtype=dtype, + device=device, + kvcache=pool, + need_sort=False, + ) + + prefix_lens = torch.zeros(2, dtype=torch.int64, device=device) + prefix_lens_cpu = torch.zeros(2, dtype=torch.int64) + seq_lens = torch.full((2,), page_size, dtype=torch.int64, device=device) + seq_lens_cpu = torch.full((2,), page_size, dtype=torch.int64) + last_loc = torch.full((2,), -1, dtype=torch.int64, device=device) + + extend_indices = alloc.alloc_extend( + prefix_lens, + prefix_lens_cpu, + seq_lens, + seq_lens_cpu, + last_loc, + 2 * page_size, + ) + self.assertIsNotNone(extend_indices) + self.assertEqual(extend_indices.numel(), 2 * page_size) + self.assertEqual(alloc.full_available_size(), size - 2 * page_size) + self.assertEqual(alloc.swa_available_size(), size_swa - 2 * page_size) + self.assertTrue(torch.all(alloc._allocated_mask[extend_indices])) + + alloc.free(extend_indices) + self.assertEqual(alloc.full_available_size(), size) + self.assertEqual(alloc.swa_available_size(), size_swa) + self.assertFalse(torch.any(alloc._allocated_mask[extend_indices])) + + decode_last_loc = torch.tensor( + [page_size - 1, 2 * page_size - 1], dtype=torch.int64, device=device + ) + decode_seq_lens = torch.tensor( + [page_size + 1, 2 * page_size + 1], dtype=torch.int64, device=device + ) + decode_seq_lens_cpu = torch.tensor( + [page_size + 1, 2 * page_size + 1], dtype=torch.int64 + ) + decode_indices = alloc.alloc_decode( + decode_seq_lens, + decode_seq_lens_cpu, + decode_last_loc, + ) + self.assertIsNotNone(decode_indices) + self.assertEqual(decode_indices.numel(), 2) + self.assertEqual(alloc.full_available_size(), size - 2 * page_size) + self.assertEqual(alloc.swa_available_size(), size_swa - 2 * page_size) + self.assertTrue(torch.all(alloc._allocated_mask[decode_indices])) + + saved_state = alloc.backup_state() + self.assertEqual(len(saved_state), 4) + alloc.free(decode_indices) + self.assertEqual(alloc.full_available_size(), size) + self.assertEqual(alloc.swa_available_size(), size_swa) + self.assertFalse(torch.any(alloc._allocated_mask[decode_indices])) + # Optimization: SGLANG_OPT_SWA_SPLIT_LEAF_ON_INSERT. # Splits a freshly-inserted leaf at the (page-aligned) sliding-window