Skip to content

MXFP8 training bug fixes for quantized_model_init and Torch FSDP fp8 all gather#587

Open
sudhu2k wants to merge 2 commits into
devfrom
sudhu/mxfp8_bug_fixes
Open

MXFP8 training bug fixes for quantized_model_init and Torch FSDP fp8 all gather#587
sudhu2k wants to merge 2 commits into
devfrom
sudhu/mxfp8_bug_fixes

Conversation

@sudhu2k
Copy link
Copy Markdown
Contributor

@sudhu2k sudhu2k commented May 15, 2026

Description

Ensure keep_fp8_weight_transpose_cache flag is set to True not only for autocast but also for quantized_model_init.
Fix padding during fp8 all-gather

Fixes: #15425
#15420

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…el_init case and not just autocast case.

Fix padding during fp8 all-gather
@sudhu2k sudhu2k self-assigned this May 15, 2026
@sudhu2k sudhu2k added the ci-level 3 CI test level 3 label May 15, 2026
# NOTE: ROCm/HIP backend uses an unpadded scale-inv layout (see `MXFP8Quantizer.make_empty`),
# so applying the padding here would produce a per-shard scale-inv whose dim-0
# does not match the destination scale-inv allocated for the FSDP2 local shard.
padding_multiples = [128, 4] if not IS_HIP_EXTENSION else [1, 1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for gfx1250 we have some other padding requirements, this should be unified with #568

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. These changes should also be present in that PR accordingly. But I think for now, let's fix the issue on existing archs and make the appropriate changes along with the #568 PR.

Copy link
Copy Markdown
Contributor

@alextmagro alextmagro May 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, @matthiasdiener can you work with Sudharshan make sure this is in your PR one way or another?

Comment thread transformer_engine/pytorch/tensor/mxfp8_tensor.py Outdated
@sudhu2k sudhu2k requested a review from alextmagro May 19, 2026 20:10
@alextmagro
Copy link
Copy Markdown
Contributor

LGTM! Just sync with Matthias on that one padding thing please.

@sudhu2k sudhu2k requested a review from Micky774 May 26, 2026 23:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants