Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions scripts/paddle_all_test_cases.sh
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,19 @@ python3 -m pytest tests/grouped_mm/ --tb=no -q
# All 690 passed tests cover test_dsv3_fused_routing.py and test_dsv3_router_gemm.py
# 4164 skips are environment-level (SM architecture/hardware constraints), not Paddle compat issues.
python3.12 -m pytest tests/model_optimizations/ --tb=no -q

# tests/comm: 29 PASS (2026-05-19)
# Only test_dcp_alltoall.py is adaptable as a single-GPU test.
# All multiprocessing/MPI/MNNVL/NVSHMEM tests skipped (too complex):
# - test_all_gather_matmul.py: SKIP - torch.distributed._symmetric_memory missing at module level (§23) + multiprocessing
# - test_allreduce_fusion_moe_unified_api.py: SKIP - multiprocessing
# - test_allreduce_unified_api.py: SKIP - multiprocessing
# - test_mixed_comm.py: SKIP - multiprocessing
# - test_allreduce_negative.py: SKIP - MPI-based (mpirun)
# - test_mnnvl_*.py: SKIP - MNNVL hardware required
# - test_nvshmem*.py: SKIP - NVSHMEM required
# - test_trtllm_allreduce_fusion.py, test_trtllm_allreduce.py, etc.: SKIP - multiprocessing
# - test_vllm_custom_allreduce.py: SKIP - multiprocessing + NCCL
# Fix: conftest.py §44-§48 + §52 monkey-patches (Paddle compat assert_close wraps ALL errors with
# "resulted in the unexpected exception above"; bfloat16/float16 isclose kernel missing)
python3 -m pytest tests/comm/test_dcp_alltoall.py --tb=no -q
98 changes: 98 additions & 0 deletions tests/comm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,101 @@ def node_id(request):
@pytest.fixture
def dist_init_method(request):
return request.config.getoption("--dist_init_method")


# ---------------------------------------------------------------------------
# Paddle compat monkey-patches (para44-para48, para52)
# ---------------------------------------------------------------------------
import functools

import torch

# para44/para45/para52: assert_close bfloat16/float16 fix + Paddle wraps all errors
_orig_assert_close = torch.testing.assert_close


def _is_paddle_isclose_dtype_error(exc):
seen = set()
cur = exc
while cur is not None and id(cur) not in seen:
seen.add(id(cur))
msg = str(cur)
# para52: Paddle wraps any assert_close internal error with this message
if "resulted in the unexpected exception above" in msg:
return True
if ("bfloat16" in msg or "float16" in msg) and (
"isclose" in msg or "NotFound" in msg
):
return True
cur = getattr(cur, "__cause__", None) or getattr(cur, "__context__", None)
return False


def _manual_allclose(actual, expected, rtol, atol):
a = actual.float().detach().cpu().numpy()
e = expected.float().detach().cpu().numpy()
diff = abs(a - e)
tol = atol + rtol * abs(e)
if not (diff <= tol).all():
max_diff = float(diff.max())
raise AssertionError(
f"Tensors are not close! Max diff: {max_diff:.6f}, rtol={rtol}, atol={atol}"
)


@functools.wraps(_orig_assert_close)
def _paddle_compat_assert_close(actual, expected, *args, **kwargs):
try:
_orig_assert_close(actual, expected, *args, **kwargs)
except RuntimeError as e:
if _is_paddle_isclose_dtype_error(e):
rtol = kwargs.get("rtol")
atol = kwargs.get("atol")
dt = actual.dtype if isinstance(actual, torch.Tensor) else torch.float32
if rtol is None:
rtol = (
0.016
if dt == torch.bfloat16
else (0.001 if dt == torch.float16 else 1.3e-6)
)
if atol is None:
atol = 1e-5
_manual_allclose(actual, expected, rtol=rtol, atol=atol)
else:
raise


torch.testing.assert_close = _paddle_compat_assert_close

# para46: torch.equal returns Tensor not bool in Paddle compat
_orig_equal = torch.equal


@functools.wraps(_orig_equal)
def _paddle_compat_equal(input, other):
if isinstance(input, torch.Tensor) and isinstance(other, torch.Tensor):
if input.shape != other.shape:
return False
result = _orig_equal(input, other)
if isinstance(result, torch.Tensor):
return bool(result.all().item()) if result.numel() > 1 else bool(result.item())
return bool(result)


torch.equal = _paddle_compat_equal

# para47: tensor.multiply(scalar) -- Paddle compat may not accept Python scalar
_orig_tensor_multiply = torch.Tensor.multiply


def _paddle_compat_tensor_multiply(self, other):
if isinstance(other, (int, float)):
other = torch.tensor(other, dtype=self.dtype, device=self.device)
return _orig_tensor_multiply(self, other)


torch.Tensor.multiply = _paddle_compat_tensor_multiply

# para48: clamp_min / clamp_max missing on Tensor in Paddle compat
torch.Tensor.clamp_min = lambda self, v: torch.clamp(self, min=v)
torch.Tensor.clamp_max = lambda self, v: torch.clamp(self, max=v)
Loading