@@ -45,15 +45,18 @@ def decorator(func: Callable):
4545def patch_qwen3_omni_text_class () -> list [str ]:
4646 """Teach HF init code how to initialize unfused qwen3-omni thinker experts."""
4747 from transformers .models .qwen3_omni_moe .modeling_qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration , Qwen3OmniMoePreTrainedModel
48- from defuser .modeling .unfused_moe .qwen3_omni_moe import LinearQwen3OmniMoeThinkerTextSparseMoeBlock
48+ from defuser .modeling .unfused_moe .qwen3_omni_moe import (
49+ LinearQwen3OmniMoeTalkerTextSparseMoeBlock ,
50+ LinearQwen3OmniMoeThinkerTextSparseMoeBlock ,
51+ )
4952 orig_init_weights = Qwen3OmniMoePreTrainedModel ._init_weights
5053
5154 def patched_init_weights (self , module ):
5255 try :
5356 orig_init_weights (self , module )
5457 except AttributeError as e :
5558 # fallback for unfused experts
56- if isinstance (module , LinearQwen3OmniMoeThinkerTextSparseMoeBlock ):
59+ if isinstance (module , ( LinearQwen3OmniMoeThinkerTextSparseMoeBlock , LinearQwen3OmniMoeTalkerTextSparseMoeBlock ) ):
5760 std = self .config .initializer_range
5861 experts = module .experts
5962
@@ -63,9 +66,18 @@ def patched_init_weights(self, module):
6366 torch .nn .init .normal_ (experts .up_proj .weight , 0.0 , std )
6467 if hasattr (experts , "down_proj" ):
6568 torch .nn .init .normal_ (experts .down_proj .weight , 0.0 , std )
69+ if isinstance (experts , torch .nn .ModuleList ):
70+ for expert in experts :
71+ torch .nn .init .normal_ (expert .gate_proj .weight , 0.0 , std )
72+ torch .nn .init .normal_ (expert .up_proj .weight , 0.0 , std )
73+ torch .nn .init .normal_ (expert .down_proj .weight , 0.0 , std )
6674
6775 if hasattr (module , "gate" ):
6876 torch .nn .init .normal_ (module .gate .weight , 0.0 , std )
77+ if hasattr (module , "shared_expert" ):
78+ module .shared_expert ._is_hf_initialized = True
79+ if hasattr (module , "shared_expert_gate" ):
80+ torch .nn .init .normal_ (module .shared_expert_gate .weight , 0.0 , std )
6981 else :
7082 raise e
7183
0 commit comments