Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions quark/torch/export/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,17 @@

def _check_scaled_mm_available_dev() -> str | None:
"""
Determine if torch._scaled_mm is available, there are three return values, None, "hip", "cuda"
Determine if torch._scaled_mm is available.

Return values:
- None: no supported scaled-mm path.
- "cuda": CUDA scaled-mm path.
- "hip_fnuz": ROCm scaled-mm path that requires E4M3 FNUZ conversion.
- "hip_ocp": ROCm scaled-mm path that uses OCP E4M3 directly.

ROCm 7.2 on gfx12 reports both the concrete target, for example gfx1201,
and a generic alias, gfx12-generic. The generic alias is not a real ISA
version for this check and must not be interpreted as gfx12 < gfx940.
"""
scaled_mm_available_dev = None

Expand All @@ -113,17 +123,19 @@ def _check_scaled_mm_available_dev() -> str | None:
raise RuntimeError("The `rocminfo` command failed or was not found.")

output = result.stdout.strip()
matches = re.findall(r"gfx(\d+)", output.lower())
matches = [int(match) for match in re.findall(r"gfx(\d{3,4})", output.lower())]

scaled_mm_available_dev = "hip" if len(matches) > 0 else None
for match in matches:
version_number = int(match)
if version_number < 940:
if len(matches) > 0:
if any(version_number < 940 for version_number in matches):
# In general, all video card models should be the same,
# All graphics cards must be eligible
# All graphics cards must be eligible.
scaled_mm_available_dev = None
break
if scaled_mm_available_dev == "hip":
elif all(version_number >= 950 for version_number in matches):
scaled_mm_available_dev = "hip_ocp"
else:
scaled_mm_available_dev = "hip_fnuz"

if scaled_mm_available_dev == "hip_fnuz":
print(
"[Warning] When the dtype of your model is float32 and custom_mode = 'fp8', a version of torch (rocm) lower than 2.4.0 will result in calculation errors of 'torch._scaled_mm', \n"
"If you find that the ppl value is large, try to increase the version of torch. Besides, you should ensure your torch version matches your rocm to prevent errors."
Expand Down
4 changes: 2 additions & 2 deletions quark/torch/export/nn/modules/qparamslinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
output_shape = [*input.shape[:-1], weight.shape[1]]
input_scale = self.input_quantizer.scale
weight_scale = self.weight_quantizer.scale
if SCALED_MM_AVAILABLE_DEV == "hip":
if SCALED_MM_AVAILABLE_DEV == "hip_fnuz":
weight, qinput, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, qinput=qinput, weight_scale=weight_scale, input_scale=input_scale
)
Expand Down Expand Up @@ -405,7 +405,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
weight = weight.permute(1, 0)

output_shape = [*input.shape[:-1], weight.shape[1]]
if SCALED_MM_AVAILABLE_DEV == "hip":
if SCALED_MM_AVAILABLE_DEV == "hip_fnuz":
qinput, input_scale = e4m3fn_to_e4m3fnuz(tensor=qinput, tensor_scale=input_scale)

output = torch._scaled_mm(
Expand Down
55 changes: 54 additions & 1 deletion test/test_for_torch/test_fp8_pertensor_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ def test_check_scaled_mm_available_dev():
mock_subprocess.return_value = Mock(returncode=0, stdout="gfx940")
_ = _check_scaled_mm_available_dev()

with (
patch("torch.cuda.is_available", return_value=True),
patch("torch.version.cuda", new=None),
patch("torch.version.hip", new="7.0.0"),
patch("subprocess.run") as mock_subprocess,
):
mock_subprocess.return_value = Mock(
returncode=0,
stdout=(
"Name: gfx1201\n"
"Name: amdgcn-amd-amdhsa--gfx1201\n"
"Name: amdgcn-amd-amdhsa--gfx12-generic\n"
),
)
assert _check_scaled_mm_available_dev() == "hip_ocp"

with patch("torch.cuda.get_device_capability", return_value=(9, 0)), patch("torch.version.cuda", new=True):
_ = _check_scaled_mm_available_dev()

Expand All @@ -50,7 +66,7 @@ def test_check_scaled_mm_available_dev():
)
input = torch.randn([512, 512], dtype=dtype, device=device)

with PatchEverywhere("SCALED_MM_AVAILABLE_DEV", "hip", module_name_prefix="quark"):
with PatchEverywhere("SCALED_MM_AVAILABLE_DEV", "hip_fnuz", module_name_prefix="quark"):
try:
_ = qparam_linear(input)
except ValueError as e:
Expand Down Expand Up @@ -86,3 +102,40 @@ def test_check_scaled_mm_available_dev():
) as mock_scaled_mm:
output = qparam_linear(input)
mock_scaled_mm.assert_not_called()


def test_fp8_scaled_mm_rocm_modes():
FP8_PER_TENSOR_SPEC = QTensorConfig(
dtype=Dtype.fp8_e4m3, qscheme=QSchemeType.per_tensor, observer_cls=PerTensorMinMaxObserver, is_dynamic=False
)
config = QLayerConfig(input_tensors=FP8_PER_TENSOR_SPEC, weight=FP8_PER_TENSOR_SPEC)
qparam_linear = QParamsLinear.from_module(
nn.Linear(in_features=512, out_features=512, bias=False, dtype=torch.float16),
custom_mode="fp8",
pack_method=None,
quant_config=config,
)
qparam_linear.weight = nn.Parameter(torch.empty_like(qparam_linear.weight, dtype=torch.float8_e4m3fn))
input_tensor = torch.randn([512, 512], dtype=torch.float16)

with PatchEverywhere("SCALED_MM_AVAILABLE_DEV", "hip_ocp", module_name_prefix="quark"):
with patch(
"torch._scaled_mm", return_value=(torch.randn(512, 512, dtype=torch.float32), torch.tensor(1.0))
) as mock_scaled_mm:
output = qparam_linear(input_tensor)
mock_scaled_mm.assert_called_once()
scaled_mm_args = mock_scaled_mm.call_args.args
assert scaled_mm_args[0].dtype == torch.float8_e4m3fn
assert scaled_mm_args[1].dtype == torch.float8_e4m3fn
assert output is not None

with PatchEverywhere("SCALED_MM_AVAILABLE_DEV", "hip_fnuz", module_name_prefix="quark"):
with patch(
"torch._scaled_mm", return_value=(torch.randn(512, 512, dtype=torch.float32), torch.tensor(1.0))
) as mock_scaled_mm:
output = qparam_linear(input_tensor)
mock_scaled_mm.assert_called_once()
scaled_mm_args = mock_scaled_mm.call_args.args
assert scaled_mm_args[0].dtype == torch.float8_e4m3fnuz
assert scaled_mm_args[1].dtype == torch.float8_e4m3fnuz
assert output is not None