Skip to content

Commit 0cb2bb2

Browse files
junhaha666valarLip
andauthored
use hip smoothquant suppot moe ep (ROCm#1922)
* smoothquant suppot moe ep * fix * fix acc * add persistent mode for smoothquant * optimize kernel performance * default use persistent mode * fix test --------- Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com>
1 parent 22504cc commit 0cb2bb2

8 files changed

Lines changed: 189 additions & 83 deletions

File tree

aiter/fused_moe_bf16_asm.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# SPDX-License-Identifier: MIT
2-
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
2+
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
33

44
import torch
55
import torch.nn.functional as F
@@ -68,6 +68,7 @@ def asm_moe(
6868
block_shape=None,
6969
expert_mask=None,
7070
activation=ActivationType.Silu,
71+
local_expert_hash=None,
7172
):
7273
E, model_dim, inter_dim = w2.shape
7374
global_E = E
@@ -187,14 +188,25 @@ def asm_moe(
187188
a8_scale = torch.empty((topk * M), dtype=dtypes.fp32, device=device)
188189

189190
# moe_smoothquant_fwd need topk_ids which contains local_expert_id
190-
if expert_mask is not None:
191+
if expert_mask is not None and local_expert_hash is None:
191192
local_expert_hash = expert_mask.cumsum(0, dtype=dtypes.i32)
192193
local_expert_hash[local_expert_hash > 0] -= 1
193-
topk_ids = local_expert_hash[topk_ids]
194-
195-
aiter.moe_smoothquant_fwd(
196-
a8, hidden_states, fc1_smooth_scale, topk_ids, a8_scale
194+
local_expert_hash[expert_mask == 0] = -1
195+
# topk_ids = local_expert_hash[topk_ids]
196+
197+
# aiter.moe_smoothquant_fwd(
198+
# a8, hidden_states, fc1_smooth_scale, topk_ids, a8_scale
199+
# )
200+
aiter.smooth_per_token_scaled_quant(
201+
a8.view(topk, M, model_dim).transpose(0, 1),
202+
hidden_states.view(M, 1, model_dim).expand(-1, topk, -1),
203+
a8_scale,
204+
fc1_smooth_scale,
205+
topk_ids,
206+
smooth_scale_map_hash=local_expert_hash,
207+
enable_ps=True,
197208
)
209+
a8 = a8.view(-1, model_dim)
198210
else:
199211
if (
200212
w1.dtype == dtypes.fp8

aiter/jit/optCompilerConfig.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,8 @@
738738
],
739739
"extra_ldflags": "None",
740740
"extra_include": [
741-
"f'{AITER_CSRC_DIR}/include/ck_tile'"
741+
"f'{AITER_CSRC_DIR}/include/ck_tile'",
742+
"f'{AITER_CSRC_DIR}/include/opus'"
742743
],
743744
"verbose": "False",
744745
"blob_gen_cmd": "''"

aiter/ops/quant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,8 @@ def smooth_per_token_scaled_quant(
447447
shuffle_scale: bool = False,
448448
num_rows: Optional[torch.Tensor] = None,
449449
num_rows_factor: int = 1,
450+
smooth_scale_map_hash: Optional[torch.Tensor] = None,
451+
enable_ps: bool = True,
450452
) -> None: ...
451453

452454

csrc/include/quant.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
33
#pragma once
44

55
#include <torch/extension.h>
@@ -35,10 +35,12 @@ void smooth_per_token_scaled_quant(
3535
torch::Tensor const& input, // [..., d]
3636
torch::Tensor& scales,
3737
torch::Tensor const& smooth_scale,
38-
std::optional<torch::Tensor> const& smooth_scale_map = std::nullopt,
39-
bool shuffle_scale = false,
40-
std::optional<torch::Tensor> const& num_rows = std::nullopt,
41-
int num_rows_factor = 1);
38+
std::optional<torch::Tensor> const& smooth_scale_map = std::nullopt,
39+
bool shuffle_scale = false,
40+
std::optional<torch::Tensor> const& num_rows = std::nullopt,
41+
int num_rows_factor = 1,
42+
std::optional<torch::Tensor> const& smooth_scale_map_hash = std::nullopt,
43+
bool enable_ps = true);
4244

4345
void partial_transpose(torch::Tensor& out, // [rows, d]
4446
torch::Tensor const& input, // [rows, d]

csrc/include/rocm_ops.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,10 +1386,12 @@ namespace py = pybind11;
13861386
py::arg("input"), \
13871387
py::arg("scales"), \
13881388
py::arg("smooth_scale"), \
1389-
py::arg("smooth_scale_map") = std::nullopt, \
1390-
py::arg("shuffle_scale") = false, \
1391-
py::arg("num_rows") = std::nullopt, \
1392-
py::arg("num_rows_factor") = 1); \
1389+
py::arg("smooth_scale_map") = std::nullopt, \
1390+
py::arg("shuffle_scale") = false, \
1391+
py::arg("num_rows") = std::nullopt, \
1392+
py::arg("num_rows_factor") = 1, \
1393+
py::arg("smooth_scale_map_hash") = std::nullopt, \
1394+
py::arg("enable_ps") = true); \
13931395
m.def("partial_transpose", \
13941396
&aiter::partial_transpose, \
13951397
py::arg("out"), \

0 commit comments

Comments
 (0)