Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
f00c843
dsv4: wire sm120 backend instrumentation
bhaktatejas922 May 12, 2026
31767ad
dsv4: add graph phase trace diagnostics
bhaktatejas922 May 12, 2026
0d68570
dsv4: filter phase trace to target graph
bhaktatejas922 May 12, 2026
3cef7c5
dsv4: add capture-only phase tracing
bhaktatejas922 May 12, 2026
1af32ee
dsv4: add graph-only b12x mla gate
bhaktatejas922 May 12, 2026
6672912
dsv4: add b12x fp4 dense audit counters
bhaktatejas922 May 12, 2026
b2682dd
dsv4: expose b12x mla debug counters
bhaktatejas922 May 12, 2026
5e00840
dsv4: trace scheduler stream emission
bhaktatejas922 May 12, 2026
8318de9
dsv4: add scheduler timing trace
bhaktatejas922 May 12, 2026
a8a6eab
dsv4: trace eagle decode phases
bhaktatejas922 May 12, 2026
cc001db
dsv4: gate mla trace events
bhaktatejas922 May 12, 2026
e2419c0
dsv4: gate live b12x sparse mla rows
bhaktatejas922 May 13, 2026
68df8d6
dsv4: allow b12x mla draft extend gate
bhaktatejas922 May 13, 2026
7eb61de
dsv4: trace eagle draft graph replay
bhaktatejas922 May 13, 2026
84cc5d5
dsv4: throttle spec scheduler refills
bhaktatejas922 May 13, 2026
7c06e94
dsv4: implement HashTopK routed scaling fusion
bhaktatejas922 May 13, 2026
9287b36
dsv4: normalize indexer seq lens contract
bhaktatejas922 May 14, 2026
9b7b4b7
Fix DSv4 MTP graph metadata diagnostics
bhaktatejas922 May 14, 2026
c84f23e
Fix EAGLE draft graph cache loc padding
bhaktatejas922 May 14, 2026
3f5c8e0
Gate DSv4 sparse MLA draft decode
bhaktatejas922 May 14, 2026
40a050e
Default DSv4 MTP replay to proven metadata path
bhaktatejas922 May 14, 2026
e53dd58
Skip DSv4 sparse MLA shadow during graph capture
bhaktatejas922 May 14, 2026
673df0b
Use accurate B12X QK for DSv4 sparse draft paths
bhaktatejas922 May 14, 2026
f717121
Gate DSv4 sparse MLA BF16 QK path
bhaktatejas922 May 14, 2026
6577f99
Add DSv4 graph replay diagnostics
bhaktatejas922 May 14, 2026
6420f5c
Expose speculative simulation state
bhaktatejas922 May 14, 2026
08d2f6b
Handle DSv4 eager draft metadata fallback
bhaktatejas922 May 14, 2026
7c45012
Add DSv4 attention metadata diagnostics
bhaktatejas922 May 15, 2026
ed5b1b0
Add DSv4 row cluster diagnostics
bhaktatejas922 May 15, 2026
1ff911f
Add DSv4 MQA substage diagnostics
bhaktatejas922 May 15, 2026
42024ac
Trace DSv4 EAGLE draft row waste
bhaktatejas922 May 15, 2026
6f298ea
Fix DSv4 phase trace stale events
bhaktatejas922 May 15, 2026
de6dd22
Trace DSv4 sparse MLA QK proxy
bhaktatejas922 May 15, 2026
4ff297f
Skip DSv4 trace during graph capture
bhaktatejas922 May 15, 2026
c37982c
Add DSv4 EAGLE quality trace
bhaktatejas922 May 15, 2026
ad86c5a
Add DSv4 EAGLE handoff trace
bhaktatejas922 May 15, 2026
c36dd99
Gate DSv4 draft extend CUDA graph replay
bhaktatejas922 May 15, 2026
89cc571
Add DSv4 EAGLE recompute trace
bhaktatejas922 May 15, 2026
1478c47
Pass TP rank metadata to B12X MoE
bhaktatejas922 May 15, 2026
c1cb567
Expose B12X DSv4 MLA patch counters
bhaktatejas922 May 16, 2026
8c6d345
Add opt-in DSv4 rows4 prefill direct path
bhaktatejas922 May 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ struct MoEHashTopKParams {
uint32_t num_routed_experts;
uint32_t num_shared_experts;
float routed_scaling_factor;
bool apply_routed_scaling_factor_on_output;
};

template <auto Fn, bool kUsePDL>
__global__ void moe_hash_topk_fused(const MoEHashTopKParams __grid_constant__ params) {
using namespace device;
const auto& [
router_logits, input_id, tid2eid, topk_ids, topk_weights, // pointers
num_tokens, topk, num_routed_experts, num_shared_experts, routed_scaling_factor] =
num_tokens, topk, num_routed_experts, num_shared_experts, routed_scaling_factor,
apply_routed_scaling_factor_on_output] =
params;

const uint32_t topk_fused = topk + num_shared_experts;
Expand All @@ -60,8 +62,9 @@ __global__ void moe_hash_topk_fused(const MoEHashTopKParams __grid_constant__ pa
if (lane_id < topk_fused) {
const bool is_shared = lane_id >= topk;
const auto output_offset = warp_id * topk_fused + lane_id;
const auto scale = apply_routed_scaling_factor_on_output ? routed_scaling_factor : 1.0f;
topk_ids[output_offset] = is_shared ? num_routed_experts + lane_id - topk : expert_id;
topk_weights[output_offset] = is_shared ? 1.0f / routed_scaling_factor : routed_weight / routed_sum;
topk_weights[output_offset] = is_shared ? scale / routed_scaling_factor : (routed_weight / routed_sum) * scale;
}

PDLTriggerSecondary<kUsePDL>();
Expand Down Expand Up @@ -100,7 +103,8 @@ struct HashTopKKernel {
const tvm::ffi::TensorView tid2eid,
const tvm::ffi::TensorView topk_weights,
const tvm::ffi::TensorView topk_ids,
float routed_scaling_factor) {
float routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
using namespace host;

auto N = SymbolicSize{"num_tokens"};
Expand Down Expand Up @@ -148,6 +152,7 @@ struct HashTopKKernel {
.num_routed_experts = static_cast<uint32_t>(E.unwrap()),
.num_shared_experts = shared_experts,
.routed_scaling_factor = routed_scaling_factor,
.apply_routed_scaling_factor_on_output = apply_routed_scaling_factor_on_output,
};
const auto kBlockSize = 128u;
const auto kNumWarps = kBlockSize / device::kWarpThreads;
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/jit_kernel/deepseek_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def hash_topk(
tid2eid: torch.Tensor,
num_fused_shared_experts: int = 0,
routed_scaling_factor: float = 1.0,
apply_routed_scaling_factor_on_output: bool = False,
scoring_func: str = "sqrtsoftplus",
) -> Tuple[torch.Tensor, torch.Tensor]:
assert scoring_func == "sqrtsoftplus"
Expand All @@ -369,6 +370,7 @@ def hash_topk(
topk_weights,
topk_ids,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
)
return topk_weights, topk_ids

Expand Down
35 changes: 30 additions & 5 deletions python/sglang/srt/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import torch
import torch.distributed

from sglang.srt.models.dsv4_phase_trace import (
record_current_dsv4_allreduce_start,
record_dsv4_phase_end,
)

from .parallel_state import (
get_attn_tp_group,
get_moe_ep_group,
Expand All @@ -15,12 +20,20 @@

def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)
trace = record_current_dsv4_allreduce_start("tp_allreduce", input_)
try:
return get_tp_group().all_reduce(input_)
finally:
record_dsv4_phase_end(trace)


def tensor_model_parallel_quant_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().quant_all_reduce(input_)
trace = record_current_dsv4_allreduce_start("tp_quant_allreduce", input_)
try:
return get_tp_group().quant_all_reduce(input_)
finally:
record_dsv4_phase_end(trace)


def tensor_model_parallel_fused_allreduce_rmsnorm(
Expand Down Expand Up @@ -62,19 +75,31 @@ def broadcast_tensor_dict(

def attention_tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across attention parallel group."""
return get_attn_tp_group().all_reduce(input_)
trace = record_current_dsv4_allreduce_start("attn_tp_allreduce", input_)
try:
return get_attn_tp_group().all_reduce(input_)
finally:
record_dsv4_phase_end(trace)


def attention_tensor_model_parallel_quant_all_reduce(
input_: torch.Tensor,
) -> torch.Tensor:
"""All-reduce the input tensor across attention parallel group."""
return get_attn_tp_group().quant_all_reduce(input_)
trace = record_current_dsv4_allreduce_start("attn_tp_quant_allreduce", input_)
try:
return get_attn_tp_group().quant_all_reduce(input_)
finally:
record_dsv4_phase_end(trace)


def moe_tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across moe parallel group."""
return get_moe_tp_group().all_reduce(input_)
trace = record_current_dsv4_allreduce_start("moe_tp_allreduce", input_)
try:
return get_moe_tp_group().all_reduce(input_)
finally:
record_dsv4_phase_end(trace)


def moe_expert_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
Expand Down
138 changes: 136 additions & 2 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,24 +497,91 @@ def graph_capture(
graph_capture_context: Optional[GraphCaptureContext] = None,
stream: Optional[torch.cuda.Stream] = None,
):
trace_capture = os.environ.get("SGLANG_DSV4_GRAPH_CAPTURE_TRACE", "0") not in (
"",
"0",
"false",
"False",
)
trace_rank = (
torch.distributed.get_rank()
if torch.distributed.is_initialized()
else self.rank
)
trace_group = getattr(self, "unique_name", "unknown")

def _trace(phase: str, **details):
if not trace_capture:
return
detail_str = " ".join(f"{key}={value}" for key, value in details.items())
logger.info(
"[DSV4 graph trace] rank=%s group=%s phase=%s %s",
trace_rank,
trace_group,
phase,
detail_str,
)

_trace(
"group_capture_init",
has_context=graph_capture_context is not None,
has_stream=stream is not None,
world_size=getattr(self, "world_size", None),
)
if graph_capture_context is None:
if stream is None:
_trace("stream_create_start")
stream = self.device_module.Stream()
_trace("stream_create_done", stream=stream)
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
# We don't need the context of custom quick allreduce because the ipc access
# is already collected in init() and we can capture the quick allreduce directly.
ca_comm = self.ca_comm
_trace(
"ca_context_create_start",
present=ca_comm is not None,
disabled=getattr(ca_comm, "disabled", None),
comm_type=type(ca_comm).__name__ if ca_comm is not None else None,
)
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
_trace(
"ca_context_create_done",
present=ca_comm is not None,
disabled=getattr(ca_comm, "disabled", None),
)

# ensure all initialization operations complete before attempting to
# capture the graph on another stream
_trace("current_stream_get_start")
curr_stream = get_current_device_stream_fast()
_trace(
"current_stream_get_done",
same_stream=curr_stream == stream,
curr_stream=curr_stream,
capture_stream=stream,
)
if curr_stream != stream:
_trace("stream_wait_start")
stream.wait_stream(curr_stream)
_trace("stream_wait_done")

with self.device_module.stream(stream), maybe_ca_context:
with contextlib.ExitStack() as stack:
_trace("device_stream_enter_start")
stack.enter_context(self.device_module.stream(stream))
_trace("device_stream_enter_done")
_trace(
"ca_enter_start",
present=ca_comm is not None,
disabled=getattr(ca_comm, "disabled", None),
)
stack.enter_context(maybe_ca_context)
_trace(
"ca_enter_done",
present=ca_comm is not None,
disabled=getattr(ca_comm, "disabled", None),
)
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
Expand Down Expand Up @@ -543,16 +610,60 @@ def graph_capture(
if not pynccl_comm:
maybe_pynccl_context = nullcontext()
else:
_trace(
"pynccl_context_create_start",
disabled=getattr(pynccl_comm, "disabled", None),
comm_type=type(pynccl_comm).__name__,
)
maybe_pynccl_context = pynccl_comm.change_state(enable=True)
_trace(
"pynccl_context_create_done",
disabled=getattr(pynccl_comm, "disabled", None),
)

pymscclpp_comm = self.pymscclpp_comm
maybe_pymscclpp_context: Any
if not pymscclpp_comm:
maybe_pymscclpp_context = nullcontext()
else:
_trace(
"pymscclpp_context_create_start",
disabled=getattr(pymscclpp_comm, "disabled", None),
comm_type=type(pymscclpp_comm).__name__,
)
maybe_pymscclpp_context = pymscclpp_comm.change_state(enable=True)
with maybe_pynccl_context, maybe_pymscclpp_context:
_trace(
"pymscclpp_context_create_done",
disabled=getattr(pymscclpp_comm, "disabled", None),
)
_trace(
"pynccl_enter_start",
present=pynccl_comm is not None,
disabled=getattr(pynccl_comm, "disabled", None),
)
stack.enter_context(maybe_pynccl_context)
_trace(
"pynccl_enter_done",
present=pynccl_comm is not None,
disabled=getattr(pynccl_comm, "disabled", None),
)
_trace(
"pymscclpp_enter_start",
present=pymscclpp_comm is not None,
disabled=getattr(pymscclpp_comm, "disabled", None),
)
stack.enter_context(maybe_pymscclpp_context)
_trace(
"pymscclpp_enter_done",
present=pymscclpp_comm is not None,
disabled=getattr(pymscclpp_comm, "disabled", None),
)
try:
_trace("yield_start")
yield graph_capture_context
finally:
_trace("yield_exit_start")
_trace("contexts_exit_done")

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -1560,15 +1671,38 @@ def graph_capture(stream: Optional[torch.cuda.Stream] = None):
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
trace_capture = os.environ.get("SGLANG_DSV4_GRAPH_CAPTURE_TRACE", "0") not in (
"",
"0",
"false",
"False",
)
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else -1
if trace_capture:
logger.info("[DSV4 graph trace] rank=%s group=tp phase=enter_start", rank)
with get_tp_group().graph_capture(
stream=stream
) as context, get_pp_group().graph_capture(context):
if trace_capture:
logger.info("[DSV4 graph trace] rank=%s group=tp_pp phase=enter_done", rank)
with contextlib.ExitStack() as stack:
seen = {id(_TP)}
for group in (_MOE_EP, _MOE_TP):
if group is not None and id(group) not in seen:
seen.add(id(group))
if trace_capture:
logger.info(
"[DSV4 graph trace] rank=%s group=%s phase=enter_start",
rank,
getattr(group, "group_name", None),
)
stack.enter_context(group.graph_capture(context))
if trace_capture:
logger.info(
"[DSV4 graph trace] rank=%s group=%s phase=enter_done",
rank,
getattr(group, "group_name", None),
)
yield context


Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,8 @@ class Envs:

# GEMM / kernel fusion
SGLANG_OPT_FP8_WO_A_GEMM = EnvBool(False)
SGLANG_OPT_DSV4_PARALLEL_WO_A = EnvBool(False)
SGLANG_OPT_DSV4_DIRECT_WQ_B = EnvBool(False)
SGLANG_OPT_BF16_FP32_GEMM_ALGO = EnvStr("cublas")
SGLANG_OPT_USE_JIT_EP_ACTIVATION = EnvBool(True)
SGLANG_OPT_USE_JIT_NORM = EnvBool(False)
Expand Down
Loading
Loading