Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions scripts/paddle_all_test_cases.sh
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,6 @@ python3.12 -m pytest tests/model_optimizations/ --tb=no -q
# Fix: conftest.py §44-§48 + §52 monkey-patches (Paddle compat assert_close wraps ALL errors with
# "resulted in the unexpected exception above"; bfloat16/float16 isclose kernel missing)
python3 -m pytest tests/comm/test_dcp_alltoall.py --tb=no -q

# PASS (2026-05-19) §53+§54: CUDAGraphMoE ExternalStream + CUDA graph capture fix
python3.12 -m pytest -rs "tests/moe/test_trtllm_gen_fused_moe.py::test_renormalize_routing[BF16_logits-Swiglu-Shuffled_MajorK-Renorm-NvFP4xNvFP4-384-1024-8-RandomHiddenStates]"
80 changes: 67 additions & 13 deletions tests/moe/test_trtllm_gen_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(self, moe_impl, static_data, **config):
self.input_tensor = None
self.output_tensor = None
self.is_captured = False
self._static_quantized = (
None # §54: pre-allocated quantized input buffer for CUDA graph capture
)

def capture(self, hidden_states_sample, **runtime_args):
"""Capture CUDA graph with the given sample input."""
Expand All @@ -107,12 +110,18 @@ def capture(self, hidden_states_sample, **runtime_args):
)

# Create stream
err, self.stream = runtime.cudaStreamCreate()
check_cuda(err)

# Get the raw stream pointer for PyTorch
stream_ptr = int(self.stream)
torch_stream = torch.cuda.ExternalStream(stream_ptr)
if hasattr(torch.cuda, "ExternalStream"):
err, self.stream = runtime.cudaStreamCreate()
check_cuda(err)
stream_ptr = int(self.stream)
torch_stream = torch.cuda.ExternalStream(stream_ptr)
else:
# §53: Paddle compat - torch.cuda.ExternalStream not available.
# Use torch.cuda.Stream() (Paddle-managed) + wrap raw ptr as cudaStream_t.
_ts = torch.cuda.Stream()
self._torch_stream_ref = _ts # prevent GC
torch_stream = _ts
self.stream = runtime.cudaStream_t(_ts.stream_base.cuda_stream)

# Store input tensor reference (will be updated in place during launch)
self.input_tensor = hidden_states_sample.clone()
Expand All @@ -126,6 +135,22 @@ def capture(self, hidden_states_sample, **runtime_args):
err = runtime.cudaStreamSynchronize(self.stream)[0]
check_cuda(err)

# §54: Pre-allocate static quantized input buffers before CUDA graph capture.
# In Paddle compat, torch.empty() triggers cudaMemAlloc which is not allowed
# during stream capture (cudaErrorStreamCaptureUnsupported / error 900).
# We run quantize_inputs once here (outside capture) and store the result tensors
# as static buffers; during capture, _run_moe_computation reuses them via copy_().
if not hasattr(torch.cuda, "ExternalStream"):
_q = self.moe_impl.quantize_inputs(
self.input_tensor,
self.config["hidden_states_scale_global"],
is_swizzling=False,
)
self._static_quantized = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in _q.items()
}

# Begin capture
err, self.graph = runtime.cudaGraphCreate(0)
check_cuda(err)
Expand Down Expand Up @@ -155,6 +180,23 @@ def launch(self, hidden_states_new):
# Update input tensor in place
self.input_tensor.copy_(hidden_states_new)

# §54: Paddle compat - re-quantize inputs before graph replay.
# quantize_inputs is outside the captured graph (see _run_moe_computation),
# so we must update self._static_quantized with the new input here.
if self._static_quantized is not None:
_q = self.moe_impl.quantize_inputs(
self.input_tensor,
self.config["hidden_states_scale_global"],
is_swizzling=False,
)
for k, v in _q.items():
if isinstance(v, torch.Tensor) and isinstance(
self._static_quantized.get(k), torch.Tensor
):
self._static_quantized[k].copy_(v)
else:
self._static_quantized[k] = v

# Launch graph
err = runtime.cudaGraphLaunch(self.graph_exec, self.stream)[0]
check_cuda(err)
Expand All @@ -175,20 +217,32 @@ def cleanup(self):
check_cuda(err)
self.graph = None
if self.stream is not None:
err = runtime.cudaStreamDestroy(self.stream)[0]
check_cuda(err)
if not hasattr(self, "_torch_stream_ref"):
# Only destroy cudart-created streams, not Paddle-managed ones
err = runtime.cudaStreamDestroy(self.stream)[0]
check_cuda(err)
self.stream = None
self._torch_stream_ref = None # release Paddle stream ref
self.input_tensor = None
self.output_tensor = None
self.is_captured = False

def _run_moe_computation(self, runtime_args):
"""Run the MoE computation."""
input_quantized = self.moe_impl.quantize_inputs(
self.input_tensor,
self.config["hidden_states_scale_global"],
is_swizzling=False,
)
if self._static_quantized is not None:
# §54: Paddle compat CUDA graph capture fix.
# torch.empty() in Paddle compat calls cudaMemAlloc, which is forbidden
# during stream capture (cudaErrorStreamCaptureUnsupported / error 900).
# Solution: quantize_inputs runs OUTSIDE the capture window (pre-populated
# into self._static_quantized), and _run_moe_computation during capture
# uses those static buffers directly - no torch.empty() inside capture.
input_quantized = self._static_quantized
else:
input_quantized = self.moe_impl.quantize_inputs(
self.input_tensor,
self.config["hidden_states_scale_global"],
is_swizzling=False,
)

output = trtllm_fp4_block_scale_moe(
routing_logits=runtime_args["expert_logits"],
Expand Down
Loading