|
1 | 1 | # SPDX-License-Identifier: MIT |
2 | | -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. |
| 2 | +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. |
3 | 3 |
|
4 | 4 | import torch |
5 | 5 | import torch.nn.functional as F |
@@ -68,6 +68,7 @@ def asm_moe( |
68 | 68 | block_shape=None, |
69 | 69 | expert_mask=None, |
70 | 70 | activation=ActivationType.Silu, |
| 71 | + local_expert_hash=None, |
71 | 72 | ): |
72 | 73 | E, model_dim, inter_dim = w2.shape |
73 | 74 | global_E = E |
@@ -187,14 +188,25 @@ def asm_moe( |
187 | 188 | a8_scale = torch.empty((topk * M), dtype=dtypes.fp32, device=device) |
188 | 189 |
|
189 | 190 | # moe_smoothquant_fwd need topk_ids which contains local_expert_id |
190 | | - if expert_mask is not None: |
| 191 | + if expert_mask is not None and local_expert_hash is None: |
191 | 192 | local_expert_hash = expert_mask.cumsum(0, dtype=dtypes.i32) |
192 | 193 | local_expert_hash[local_expert_hash > 0] -= 1 |
193 | | - topk_ids = local_expert_hash[topk_ids] |
194 | | - |
195 | | - aiter.moe_smoothquant_fwd( |
196 | | - a8, hidden_states, fc1_smooth_scale, topk_ids, a8_scale |
| 194 | + local_expert_hash[expert_mask == 0] = -1 |
| 195 | + # topk_ids = local_expert_hash[topk_ids] |
| 196 | + |
| 197 | + # aiter.moe_smoothquant_fwd( |
| 198 | + # a8, hidden_states, fc1_smooth_scale, topk_ids, a8_scale |
| 199 | + # ) |
| 200 | + aiter.smooth_per_token_scaled_quant( |
| 201 | + a8.view(topk, M, model_dim).transpose(0, 1), |
| 202 | + hidden_states.view(M, 1, model_dim).expand(-1, topk, -1), |
| 203 | + a8_scale, |
| 204 | + fc1_smooth_scale, |
| 205 | + topk_ids, |
| 206 | + smooth_scale_map_hash=local_expert_hash, |
| 207 | + enable_ps=True, |
197 | 208 | ) |
| 209 | + a8 = a8.view(-1, model_dim) |
198 | 210 | else: |
199 | 211 | if ( |
200 | 212 | w1.dtype == dtypes.fp8 |
|
0 commit comments