From fe3d49a07a6a80f27120c60c8adfca48c9250f9a Mon Sep 17 00:00:00 2001 From: Tasha Upchurch Date: Sun, 17 May 2026 06:17:10 +0000 Subject: [PATCH] Enable OCP FP8 scaled mm on gfx12 --- quark/torch/export/constants.py | 30 +++++++--- .../torch/export/nn/modules/qparamslinear.py | 4 +- .../test_fp8_pertensor_kernel.py | 55 ++++++++++++++++++- 3 files changed, 77 insertions(+), 12 deletions(-) diff --git a/quark/torch/export/constants.py b/quark/torch/export/constants.py index a482da3..114e864 100644 --- a/quark/torch/export/constants.py +++ b/quark/torch/export/constants.py @@ -93,7 +93,17 @@ def _check_scaled_mm_available_dev() -> str | None: """ - Determine if torch._scaled_mm is available, there are three return values, None, "hip", "cuda" + Determine if torch._scaled_mm is available. + + Return values: + - None: no supported scaled-mm path. + - "cuda": CUDA scaled-mm path. + - "hip_fnuz": ROCm scaled-mm path that requires E4M3 FNUZ conversion. + - "hip_ocp": ROCm scaled-mm path that uses OCP E4M3 directly. + + ROCm 7.2 on gfx12 reports both the concrete target, for example gfx1201, + and a generic alias, gfx12-generic. The generic alias is not a real ISA + version for this check and must not be interpreted as gfx12 < gfx940. """ scaled_mm_available_dev = None @@ -113,17 +123,19 @@ def _check_scaled_mm_available_dev() -> str | None: raise RuntimeError("The `rocminfo` command failed or was not found.") output = result.stdout.strip() - matches = re.findall(r"gfx(\d+)", output.lower()) + matches = [int(match) for match in re.findall(r"gfx(\d{3,4})", output.lower())] - scaled_mm_available_dev = "hip" if len(matches) > 0 else None - for match in matches: - version_number = int(match) - if version_number < 940: + if len(matches) > 0: + if any(version_number < 940 for version_number in matches): # In general, all video card models should be the same, - # All graphics cards must be eligible + # All graphics cards must be eligible. scaled_mm_available_dev = None - break - if scaled_mm_available_dev == "hip": + elif all(version_number >= 950 for version_number in matches): + scaled_mm_available_dev = "hip_ocp" + else: + scaled_mm_available_dev = "hip_fnuz" + + if scaled_mm_available_dev == "hip_fnuz": print( "[Warning] When the dtype of your model is float32 and custom_mode = 'fp8', a version of torch (rocm) lower than 2.4.0 will result in calculation errors of 'torch._scaled_mm', \n" "If you find that the ppl value is large, try to increase the version of torch. Besides, you should ensure your torch version matches your rocm to prevent errors." diff --git a/quark/torch/export/nn/modules/qparamslinear.py b/quark/torch/export/nn/modules/qparamslinear.py index 0a13395..4ae7f22 100644 --- a/quark/torch/export/nn/modules/qparamslinear.py +++ b/quark/torch/export/nn/modules/qparamslinear.py @@ -352,7 +352,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: output_shape = [*input.shape[:-1], weight.shape[1]] input_scale = self.input_quantizer.scale weight_scale = self.weight_quantizer.scale - if SCALED_MM_AVAILABLE_DEV == "hip": + if SCALED_MM_AVAILABLE_DEV == "hip_fnuz": weight, qinput, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, qinput=qinput, weight_scale=weight_scale, input_scale=input_scale ) @@ -405,7 +405,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: weight = weight.permute(1, 0) output_shape = [*input.shape[:-1], weight.shape[1]] - if SCALED_MM_AVAILABLE_DEV == "hip": + if SCALED_MM_AVAILABLE_DEV == "hip_fnuz": qinput, input_scale = e4m3fn_to_e4m3fnuz(tensor=qinput, tensor_scale=input_scale) output = torch._scaled_mm( diff --git a/test/test_for_torch/test_fp8_pertensor_kernel.py b/test/test_for_torch/test_fp8_pertensor_kernel.py index 9db23a8..713c591 100644 --- a/test/test_for_torch/test_fp8_pertensor_kernel.py +++ b/test/test_for_torch/test_fp8_pertensor_kernel.py @@ -30,6 +30,22 @@ def test_check_scaled_mm_available_dev(): mock_subprocess.return_value = Mock(returncode=0, stdout="gfx940") _ = _check_scaled_mm_available_dev() + with ( + patch("torch.cuda.is_available", return_value=True), + patch("torch.version.cuda", new=None), + patch("torch.version.hip", new="7.0.0"), + patch("subprocess.run") as mock_subprocess, + ): + mock_subprocess.return_value = Mock( + returncode=0, + stdout=( + "Name: gfx1201\n" + "Name: amdgcn-amd-amdhsa--gfx1201\n" + "Name: amdgcn-amd-amdhsa--gfx12-generic\n" + ), + ) + assert _check_scaled_mm_available_dev() == "hip_ocp" + with patch("torch.cuda.get_device_capability", return_value=(9, 0)), patch("torch.version.cuda", new=True): _ = _check_scaled_mm_available_dev() @@ -50,7 +66,7 @@ def test_check_scaled_mm_available_dev(): ) input = torch.randn([512, 512], dtype=dtype, device=device) - with PatchEverywhere("SCALED_MM_AVAILABLE_DEV", "hip", module_name_prefix="quark"): + with PatchEverywhere("SCALED_MM_AVAILABLE_DEV", "hip_fnuz", module_name_prefix="quark"): try: _ = qparam_linear(input) except ValueError as e: @@ -86,3 +102,40 @@ def test_check_scaled_mm_available_dev(): ) as mock_scaled_mm: output = qparam_linear(input) mock_scaled_mm.assert_not_called() + + +def test_fp8_scaled_mm_rocm_modes(): + FP8_PER_TENSOR_SPEC = QTensorConfig( + dtype=Dtype.fp8_e4m3, qscheme=QSchemeType.per_tensor, observer_cls=PerTensorMinMaxObserver, is_dynamic=False + ) + config = QLayerConfig(input_tensors=FP8_PER_TENSOR_SPEC, weight=FP8_PER_TENSOR_SPEC) + qparam_linear = QParamsLinear.from_module( + nn.Linear(in_features=512, out_features=512, bias=False, dtype=torch.float16), + custom_mode="fp8", + pack_method=None, + quant_config=config, + ) + qparam_linear.weight = nn.Parameter(torch.empty_like(qparam_linear.weight, dtype=torch.float8_e4m3fn)) + input_tensor = torch.randn([512, 512], dtype=torch.float16) + + with PatchEverywhere("SCALED_MM_AVAILABLE_DEV", "hip_ocp", module_name_prefix="quark"): + with patch( + "torch._scaled_mm", return_value=(torch.randn(512, 512, dtype=torch.float32), torch.tensor(1.0)) + ) as mock_scaled_mm: + output = qparam_linear(input_tensor) + mock_scaled_mm.assert_called_once() + scaled_mm_args = mock_scaled_mm.call_args.args + assert scaled_mm_args[0].dtype == torch.float8_e4m3fn + assert scaled_mm_args[1].dtype == torch.float8_e4m3fn + assert output is not None + + with PatchEverywhere("SCALED_MM_AVAILABLE_DEV", "hip_fnuz", module_name_prefix="quark"): + with patch( + "torch._scaled_mm", return_value=(torch.randn(512, 512, dtype=torch.float32), torch.tensor(1.0)) + ) as mock_scaled_mm: + output = qparam_linear(input_tensor) + mock_scaled_mm.assert_called_once() + scaled_mm_args = mock_scaled_mm.call_args.args + assert scaled_mm_args[0].dtype == torch.float8_e4m3fnuz + assert scaled_mm_args[1].dtype == torch.float8_e4m3fnuz + assert output is not None