From 877d459f5f44abea03a36fb7537d8da1ee0528c9 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 20 Apr 2026 11:27:53 -0700 Subject: [PATCH] [PyTorch] Add fp32_partial_output support for CP P2P ring attention MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends context-parallel ring P2P attention (AttnFuncWithCPAndKVP2P) with a new `fp32_partial_output` flag that accumulates per-step partial attention outputs in float32 before LSE-correction merging, improving numerical stability across CP ranks for fp16/bf16 inputs. Implementation uses a software cast (out_per_step.to(float32)) after the cuDNN kernel call. The fake_dtype=torch.float32 path is wired and ready for when the cuDNN kernel gains native fp16/bf16→fp32 output support. Changes: - context_parallel.py: add fp32_partial_output param to cp_p2p_fwd_fused_attn and AttnFuncWithCPAndKVP2P.forward; fix THD out-buffer dtype; update backward return tuple for new forward input count - backends.py: thread fp32_partial_output through FusedAttentionBackend - dot_product_attention.py: expose fp32_partial_output in DotProductAttention.forward - tests: add test_cp_with_fused_attention_fp32_partial_output covering bf16/fp16, MHA/GQA, causal/non-causal, sbhd/thd, p2p/a2a+p2p Co-Authored-By: Claude Sonnet 4.6 --- .../attention/run_attention_with_cp.py | 3 ++ .../attention/test_attention_with_cp.py | 54 +++++++++++++++++++ .../dot_product_attention/backends.py | 2 + .../dot_product_attention/context_parallel.py | 31 +++++++---- .../dot_product_attention.py | 3 ++ 5 files changed, 82 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 0f36a8816d..f54219fdae 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -180,12 +180,14 @@ def run_dpa_with_cp( scaling_mode="delayed", f16_O="False", is_training="True", + fp32_partial_output="False", log_level=logging.WARNING, ): """Test DotProductAttention module with context parallelism""" logging.root.setLevel(log_level) # When is_training is False, gradient outputs are None. is_training = is_training == "True" + fp32_partial_output = fp32_partial_output == "True" # set up environment variables and config fp8_bwd = fp8_bwd == "True" and dtype == "fp8" @@ -469,6 +471,7 @@ def run_dpa_with_cp( cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, fp8_output=fp8_mha, + fp32_partial_output=fp32_partial_output, ) if config.return_max_logit: out_, max_logit_ = out_ diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 5aaf67061b..9dcc265587 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -384,3 +384,57 @@ def test_cp_with_fused_attention( log_level=pytest_logging_level, ), ) + + +model_configs_fp32_partial_output = { + "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=True), + "cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=True), + "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), +} + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") +@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") +@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) +@pytest.mark.parametrize("model", model_configs_fp32_partial_output.keys()) +@pytest.mark.parametrize("qkv_format", ["sbhd", "thd"]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "a2a+p2p"]) +def test_cp_with_fused_attention_fp32_partial_output(dtype, model, qkv_format, cp_comm_type): + """Test that FP32 partial outputs in fprop CP P2P produce numerically correct results.""" + num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 + if num_gpus > torch.cuda.device_count(): + pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + + if qkv_format == "thd" and get_device_compute_capability() < (9, 0): + pytest.skip("THD format is only supported on sm90+!") + if cp_comm_type == "a2a+p2p" and qkv_format == "thd": + pytest.skip( + "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format yet!" + ) + + config = model_configs_fp32_partial_output[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type + + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtypes[dtype], + qkv_layout="_".join([qkv_format] * 3), + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("No fused attention backend available.") + + run_distributed( + get_bash_arguments( + num_gpus_per_node=num_gpus, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FusedAttention", + cp_comm_type=cp_comm_type, + fp32_partial_output=True, + log_level=pytest_logging_level, + ), + ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index ecf3af2bf0..8efce94bb5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1867,6 +1867,7 @@ def forward( inference_params: Optional[InferenceParams] = None, softmax_offset: torch.Tensor = None, fp8_output: bool = False, + fp32_partial_output: bool = False, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -2034,6 +2035,7 @@ def forward( fp8_output=fp8_output, layer_number=self.layer_number, return_max_logit=self.return_max_logit, + fp32_partial_output=fp32_partial_output, ) else: with self.attention_dropout_ctx(): diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..d5c7c06758 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -794,6 +794,7 @@ def cp_p2p_fwd_fused_attn( cu_seqlens_q_per_step, cu_seqlens_kv_per_step, section, + fp32_partial_output=False, ): """Per-tile forward call of CP P2P with FusedAttention backend""" attn_bias_inputs = None @@ -884,7 +885,7 @@ def cp_p2p_fwd_fused_attn( q_part, k_part, v_part, - fake_dtype=fwd_nominal_dtype, + fake_dtype=torch.float32 if (fp32_partial_output and not fp8) else fwd_nominal_dtype, fused_attention_backend=fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -1287,6 +1288,7 @@ def forward( use_flash_attn_3, fp8_output, layer_number, + fp32_partial_output=False, ): # pylint: disable=missing-function-docstring @@ -1680,7 +1682,8 @@ def forward( attn_biases[i], max_logit_per_step[i], ) = cp_p2p_fwd_fused_attn( - *fused_attn_inputs, *prepare_outputs, section + *fused_attn_inputs, *prepare_outputs, section, + fp32_partial_output=fp32_partial_output, ) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( @@ -1707,7 +1710,8 @@ def forward( attn_biases[i], max_logit_per_step[i], ) = cp_p2p_fwd_fused_attn( - *fused_attn_inputs, *prepare_outputs, section + *fused_attn_inputs, *prepare_outputs, section, + fp32_partial_output=fp32_partial_output, ) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( @@ -1734,7 +1738,8 @@ def forward( attn_biases[i], max_logit_per_step[i], ) = cp_p2p_fwd_fused_attn( - *fused_attn_inputs, *prepare_outputs, section + *fused_attn_inputs, *prepare_outputs, section, + fp32_partial_output=fp32_partial_output, ) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( @@ -1761,7 +1766,10 @@ def forward( rng_states[i], attn_biases[i], max_logit_per_step[i], - ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) + ) = cp_p2p_fwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section, + fp32_partial_output=fp32_partial_output, + ) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn(*flash_attn_inputs, *prepare_outputs, section) @@ -1795,14 +1803,12 @@ def forward( softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": if enable_mla: - out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( - v_shape - ) + ref = out_per_step[0] if (fp8 or fp32_partial_output) else v + out = torch.zeros_like(ref).view(v_shape) else: # MHA or GQA - out = torch.zeros_like(q if not fp8 else out_per_step[0]).view( - q.shape - ) + ref = out_per_step[0] if (fp8 or fp32_partial_output) else q + out = torch.zeros_like(ref).view(q.shape) elif (i - 1) <= rank or not causal: flash_attn_fwd_softmax_lse_correction( softmax_lse, softmax_lse_per_step[i - 1] @@ -2761,6 +2767,7 @@ def backward(ctx, dout, *_args): None, None, None, + None, # fp32_partial_output ) @@ -3962,6 +3969,7 @@ def attn_forward_func_with_cp( fp8_output=False, layer_number=1, return_max_logit=False, + fp32_partial_output=False, ) -> torch.Tensor: """ Attention implementation with context parallelism (CP). CP partitions tensors along the sequence @@ -4126,6 +4134,7 @@ def attn_forward_func_with_cp( use_flash_attn_3, fp8_output, layer_number, + fp32_partial_output, ] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 588c708e10..4d3355b297 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -828,6 +828,7 @@ def forward( pad_between_seqs: Optional[bool] = None, fp8_output: Optional[bool] = False, num_splits: Optional[int] = 1, + fp32_partial_output: Optional[bool] = False, ) -> torch.Tensor: r""" Dot Product Attention Layer. @@ -1521,6 +1522,7 @@ def forward( inference_params=inference_params, softmax_offset=softmax_offset, fp8_output=fp8_output, + fp32_partial_output=fp32_partial_output, ) return self.fused_attention( query_layer, @@ -1552,6 +1554,7 @@ def forward( inference_params=inference_params, softmax_offset=softmax_offset, fp8_output=fp8_output, + fp32_partial_output=fp32_partial_output, ) if use_unfused_attention: