Skip to content

Commit d5315cc

Browse files
authored
preserve qwen shared expert execution order (#47)
* preserve qwen shared expert execution order * Update pyproject.toml
1 parent 9eae89e commit d5315cc

4 files changed

Lines changed: 95 additions & 3 deletions

File tree

defuser/modeling/unfused_moe/qwen2_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
3434
"""Route tokens exactly like HF Qwen2 MoE, then run explicit expert modules."""
3535
batch_size, sequence_length, hidden_dim = hidden_states.shape
3636
hidden_states = hidden_states.view(-1, hidden_dim)
37+
shared_expert_output = self.shared_expert(hidden_states)
3738
_, routing_weights, selected_experts = self.gate(hidden_states)
3839
routing_weights = routing_weights.to(hidden_states.dtype)
3940
final_hidden_states = run_routed_experts(
@@ -44,7 +45,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
4445
self.num_experts,
4546
)
4647

47-
shared_expert_output = self.shared_expert(hidden_states)
4848
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
4949

5050
final_hidden_states = final_hidden_states + shared_expert_output

defuser/modeling/unfused_moe/qwen3_next.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
3333
"""Route tokens exactly like HF Qwen3-Next MoE, then run explicit experts."""
3434
batch_size, sequence_length, hidden_dim = hidden_states.shape
3535
hidden_states = hidden_states.view(-1, hidden_dim)
36+
shared_expert_output = self.shared_expert(hidden_states)
3637
_, routing_weights, selected_experts = self.gate(hidden_states)
3738
routing_weights = routing_weights.to(hidden_states.dtype)
3839
final_hidden_states = run_routed_experts(
@@ -43,7 +44,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
4344
self.num_experts,
4445
)
4546

46-
shared_expert_output = self.shared_expert(hidden_states)
4747
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
4848

4949
final_hidden_states = final_hidden_states + shared_expert_output

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
99

1010
[project]
1111
name = "Defuser"
12-
version = "0.0.20"
12+
version = "0.0.21"
1313
description = "Model defuser helper for HF Transformers."
1414
readme = "README.md"
1515
requires-python = ">=3.9"

tests/test_convert_model.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,76 @@ def _assert_sparse_moe_defused_matches_fused_math(
402402
torch.testing.assert_close(actual, expected, **assert_close_kwargs)
403403

404404

405+
def _force_route_all_experts(block: nn.Module) -> None:
406+
"""Set MoE routers to select all experts so execution-order hooks always fire."""
407+
408+
router = getattr(block, "gate", None)
409+
num_experts = getattr(block, "num_experts", None)
410+
if router is None or num_experts is None:
411+
return
412+
413+
for name in ("top_k", "num_experts_per_tok"):
414+
if hasattr(router, name):
415+
setattr(router, name, num_experts)
416+
return
417+
418+
419+
def _semantic_sparse_moe_execution_order(block: nn.Module, hidden_states: torch.Tensor) -> list[str]:
420+
"""Record semantic MoE execution order for shared expert, router, routed experts, and shared gate."""
421+
422+
_force_route_all_experts(block)
423+
raw_events: list[str] = []
424+
handles = []
425+
426+
def _record(event_name: str):
427+
def _hook(_module, _inputs):
428+
raw_events.append(event_name)
429+
return _hook
430+
431+
if hasattr(block, "shared_expert"):
432+
handles.append(block.shared_expert.register_forward_pre_hook(_record("shared_expert")))
433+
if hasattr(block, "gate"):
434+
handles.append(block.gate.register_forward_pre_hook(_record("gate")))
435+
if hasattr(block, "shared_expert_gate"):
436+
handles.append(block.shared_expert_gate.register_forward_pre_hook(_record("shared_expert_gate")))
437+
438+
experts = getattr(block, "experts", None)
439+
if isinstance(experts, nn.ModuleList):
440+
for idx, expert in enumerate(experts):
441+
handles.append(expert.register_forward_pre_hook(_record(f"expert_{idx}")))
442+
elif isinstance(experts, nn.Module):
443+
handles.append(experts.register_forward_pre_hook(_record("experts")))
444+
445+
try:
446+
with torch.inference_mode():
447+
block.eval()(hidden_states)
448+
finally:
449+
for handle in handles:
450+
handle.remove()
451+
452+
semantic_events: list[str] = []
453+
for event in raw_events:
454+
normalized = "routed_experts" if event.startswith("expert_") or event == "experts" else event
455+
if not semantic_events or semantic_events[-1] != normalized:
456+
semantic_events.append(normalized)
457+
return semantic_events
458+
459+
460+
def _assert_sparse_moe_defused_matches_fused_execution_order(
461+
original_block: nn.Module,
462+
defused_block: nn.Module,
463+
hidden_states: torch.Tensor,
464+
) -> None:
465+
"""Defused blocks must preserve the same semantic execution order as fused HF blocks."""
466+
467+
_seed_floating_tensors(original_block)
468+
_copy_sparse_moe_weights(original_block, defused_block)
469+
470+
expected = _semantic_sparse_moe_execution_order(original_block, hidden_states)
471+
actual = _semantic_sparse_moe_execution_order(defused_block, hidden_states)
472+
assert actual == expected
473+
474+
405475
def test_qwen2_moe():
406476
model_type = "qwen2_moe"
407477
replace_fused_blocks(model_type)
@@ -858,6 +928,17 @@ def test_qwen2_moe_defused_forward_matches_fused_math():
858928
)
859929

860930

931+
def test_qwen2_moe_defused_forward_matches_fused_execution_order():
932+
config = _tiny_moe_config(Qwen2MoeConfig)
933+
hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32)
934+
935+
_assert_sparse_moe_defused_matches_fused_execution_order(
936+
Qwen2MoeSparseMoeBlock(config),
937+
LinearQwen2MoeSparseMoeBlock(config),
938+
hidden_states,
939+
)
940+
941+
861942
def test_qwen3_moe_defused_forward_matches_fused_math():
862943
config = _tiny_moe_config(Qwen3MoeConfig)
863944
hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32)
@@ -880,6 +961,17 @@ def test_qwen3_next_defused_forward_matches_fused_math():
880961
)
881962

882963

964+
def test_qwen3_next_defused_forward_matches_fused_execution_order():
965+
config = _tiny_moe_config(Qwen3NextConfig)
966+
hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32)
967+
968+
_assert_sparse_moe_defused_matches_fused_execution_order(
969+
Qwen3NextSparseMoeBlock(config),
970+
LinearQwen3NextSparseMoeBlock(config),
971+
hidden_states,
972+
)
973+
974+
883975
def test_qwen3_omni_defused_forward_matches_fused_math():
884976
config = _tiny_qwen3_omni_config().thinker_config.text_config
885977
hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32)

0 commit comments

Comments
 (0)