@@ -402,6 +402,76 @@ def _assert_sparse_moe_defused_matches_fused_math(
402402 torch .testing .assert_close (actual , expected , ** assert_close_kwargs )
403403
404404
405+ def _force_route_all_experts (block : nn .Module ) -> None :
406+ """Set MoE routers to select all experts so execution-order hooks always fire."""
407+
408+ router = getattr (block , "gate" , None )
409+ num_experts = getattr (block , "num_experts" , None )
410+ if router is None or num_experts is None :
411+ return
412+
413+ for name in ("top_k" , "num_experts_per_tok" ):
414+ if hasattr (router , name ):
415+ setattr (router , name , num_experts )
416+ return
417+
418+
419+ def _semantic_sparse_moe_execution_order (block : nn .Module , hidden_states : torch .Tensor ) -> list [str ]:
420+ """Record semantic MoE execution order for shared expert, router, routed experts, and shared gate."""
421+
422+ _force_route_all_experts (block )
423+ raw_events : list [str ] = []
424+ handles = []
425+
426+ def _record (event_name : str ):
427+ def _hook (_module , _inputs ):
428+ raw_events .append (event_name )
429+ return _hook
430+
431+ if hasattr (block , "shared_expert" ):
432+ handles .append (block .shared_expert .register_forward_pre_hook (_record ("shared_expert" )))
433+ if hasattr (block , "gate" ):
434+ handles .append (block .gate .register_forward_pre_hook (_record ("gate" )))
435+ if hasattr (block , "shared_expert_gate" ):
436+ handles .append (block .shared_expert_gate .register_forward_pre_hook (_record ("shared_expert_gate" )))
437+
438+ experts = getattr (block , "experts" , None )
439+ if isinstance (experts , nn .ModuleList ):
440+ for idx , expert in enumerate (experts ):
441+ handles .append (expert .register_forward_pre_hook (_record (f"expert_{ idx } " )))
442+ elif isinstance (experts , nn .Module ):
443+ handles .append (experts .register_forward_pre_hook (_record ("experts" )))
444+
445+ try :
446+ with torch .inference_mode ():
447+ block .eval ()(hidden_states )
448+ finally :
449+ for handle in handles :
450+ handle .remove ()
451+
452+ semantic_events : list [str ] = []
453+ for event in raw_events :
454+ normalized = "routed_experts" if event .startswith ("expert_" ) or event == "experts" else event
455+ if not semantic_events or semantic_events [- 1 ] != normalized :
456+ semantic_events .append (normalized )
457+ return semantic_events
458+
459+
460+ def _assert_sparse_moe_defused_matches_fused_execution_order (
461+ original_block : nn .Module ,
462+ defused_block : nn .Module ,
463+ hidden_states : torch .Tensor ,
464+ ) -> None :
465+ """Defused blocks must preserve the same semantic execution order as fused HF blocks."""
466+
467+ _seed_floating_tensors (original_block )
468+ _copy_sparse_moe_weights (original_block , defused_block )
469+
470+ expected = _semantic_sparse_moe_execution_order (original_block , hidden_states )
471+ actual = _semantic_sparse_moe_execution_order (defused_block , hidden_states )
472+ assert actual == expected
473+
474+
405475def test_qwen2_moe ():
406476 model_type = "qwen2_moe"
407477 replace_fused_blocks (model_type )
@@ -858,6 +928,17 @@ def test_qwen2_moe_defused_forward_matches_fused_math():
858928 )
859929
860930
931+ def test_qwen2_moe_defused_forward_matches_fused_execution_order ():
932+ config = _tiny_moe_config (Qwen2MoeConfig )
933+ hidden_states = torch .randn (2 , 3 , config .hidden_size , dtype = torch .float32 )
934+
935+ _assert_sparse_moe_defused_matches_fused_execution_order (
936+ Qwen2MoeSparseMoeBlock (config ),
937+ LinearQwen2MoeSparseMoeBlock (config ),
938+ hidden_states ,
939+ )
940+
941+
861942def test_qwen3_moe_defused_forward_matches_fused_math ():
862943 config = _tiny_moe_config (Qwen3MoeConfig )
863944 hidden_states = torch .randn (2 , 3 , config .hidden_size , dtype = torch .float32 )
@@ -880,6 +961,17 @@ def test_qwen3_next_defused_forward_matches_fused_math():
880961 )
881962
882963
964+ def test_qwen3_next_defused_forward_matches_fused_execution_order ():
965+ config = _tiny_moe_config (Qwen3NextConfig )
966+ hidden_states = torch .randn (2 , 3 , config .hidden_size , dtype = torch .float32 )
967+
968+ _assert_sparse_moe_defused_matches_fused_execution_order (
969+ Qwen3NextSparseMoeBlock (config ),
970+ LinearQwen3NextSparseMoeBlock (config ),
971+ hidden_states ,
972+ )
973+
974+
883975def test_qwen3_omni_defused_forward_matches_fused_math ():
884976 config = _tiny_qwen3_omni_config ().thinker_config .text_config
885977 hidden_states = torch .randn (2 , 3 , config .hidden_size , dtype = torch .float32 )
0 commit comments