Skip to content

Enable OCP FP8 scaled-mm path on gfx12#32

Draft
TashaSkyUp wants to merge 1 commit into
amd:release/0.11from
TashaSkyUp:codex/gfx1201-ocp-fp8-scaled-mm
Draft

Enable OCP FP8 scaled-mm path on gfx12#32
TashaSkyUp wants to merge 1 commit into
amd:release/0.11from
TashaSkyUp:codex/gfx1201-ocp-fp8-scaled-mm

Conversation

@TashaSkyUp
Copy link
Copy Markdown

Summary

  • distinguish ROCm FP8 scaled-mm modes between OCP E4M3 and FNUZ
  • ignore generic rocminfo aliases such as gfx12-generic when checking concrete gfx targets
  • keep gfx950/gfx12 on the OCP E4M3 path instead of converting to FNUZ
  • add FP8 pertensor tests for gfx1201 detection and ROCm mode-specific dtype behavior

Root cause

ROCm 7.2 reports gfx1201 together with the generic alias gfx12-generic. The previous regex parsed both as numeric gfx IDs, including 12, so Quark disabled torch._scaled_mm because 12 < 940. Forcing the old HIP path also converted OCP float8_e4m3fn tensors to float8_e4m3fnuz, which is not supported on gfx1201 in local testing.

Validation

  • pytest -q test/test_for_torch/test_fp8_pertensor_kernel.py
  • Local gfx1201 runtime validation outside this PR: Quark QParamsLinear uses torch._scaled_mm with torch.float8_e4m3fn and returns finite BF16 output on AMD Radeon AI PRO R9700.
  • Local negative validation: forced hip_fnuz path passes float8_e4m3fnuz to torch._scaled_mm and fails with HIPBLAS_STATUS_NOT_SUPPORTED on gfx1201.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant