Skip to content

Commit e083eb3

Browse files
committed
extend model support
1 parent dad5f7b commit e083eb3

4 files changed

Lines changed: 825 additions & 2 deletions

File tree

defuser/model_registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ class PATCH(str, Enum):
117117
(
118118
"transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock",
119119
"defuser.modeling.unfused_moe.qwen3_omni_moe.LinearQwen3OmniMoeThinkerTextSparseMoeBlock",
120+
),
121+
(
122+
"transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeTalkerTextSparseMoeBlock",
123+
"defuser.modeling.unfused_moe.qwen3_omni_moe.LinearQwen3OmniMoeTalkerTextSparseMoeBlock",
120124
)
121125
],
122126
},

defuser/modeling/model_patches.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,18 @@ def decorator(func: Callable):
4545
def patch_qwen3_omni_text_class() -> list[str]:
4646
"""Teach HF init code how to initialize unfused qwen3-omni thinker experts."""
4747
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoePreTrainedModel
48-
from defuser.modeling.unfused_moe.qwen3_omni_moe import LinearQwen3OmniMoeThinkerTextSparseMoeBlock
48+
from defuser.modeling.unfused_moe.qwen3_omni_moe import (
49+
LinearQwen3OmniMoeTalkerTextSparseMoeBlock,
50+
LinearQwen3OmniMoeThinkerTextSparseMoeBlock,
51+
)
4952
orig_init_weights = Qwen3OmniMoePreTrainedModel._init_weights
5053

5154
def patched_init_weights(self, module):
5255
try:
5356
orig_init_weights(self, module)
5457
except AttributeError as e:
5558
# fallback for unfused experts
56-
if isinstance(module, LinearQwen3OmniMoeThinkerTextSparseMoeBlock):
59+
if isinstance(module, (LinearQwen3OmniMoeThinkerTextSparseMoeBlock, LinearQwen3OmniMoeTalkerTextSparseMoeBlock)):
5760
std = self.config.initializer_range
5861
experts = module.experts
5962

@@ -63,9 +66,18 @@ def patched_init_weights(self, module):
6366
torch.nn.init.normal_(experts.up_proj.weight, 0.0, std)
6467
if hasattr(experts, "down_proj"):
6568
torch.nn.init.normal_(experts.down_proj.weight, 0.0, std)
69+
if isinstance(experts, torch.nn.ModuleList):
70+
for expert in experts:
71+
torch.nn.init.normal_(expert.gate_proj.weight, 0.0, std)
72+
torch.nn.init.normal_(expert.up_proj.weight, 0.0, std)
73+
torch.nn.init.normal_(expert.down_proj.weight, 0.0, std)
6674

6775
if hasattr(module, "gate"):
6876
torch.nn.init.normal_(module.gate.weight, 0.0, std)
77+
if hasattr(module, "shared_expert"):
78+
module.shared_expert._is_hf_initialized = True
79+
if hasattr(module, "shared_expert_gate"):
80+
torch.nn.init.normal_(module.shared_expert_gate.weight, 0.0, std)
6981
else:
7082
raise e
7183

defuser/modeling/unfused_moe/qwen3_omni_moe.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,47 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
4646
)
4747
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
4848
return final_hidden_states
49+
50+
51+
class LinearQwen3OmniMoeTalkerTextSparseMoeBlock(nn.Module):
52+
"""Text talker MoE block for qwen3-omni with explicit per-expert modules."""
53+
54+
def __init__(self, config):
55+
super().__init__()
56+
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
57+
Qwen3OmniMoeTalkerTextMLP,
58+
Qwen3OmniMoeTalkerTextTopKRouter,
59+
)
60+
61+
self.num_experts = config.num_experts
62+
self.top_k = config.num_experts_per_tok
63+
self.norm_topk_prob = config.norm_topk_prob
64+
65+
self.gate = Qwen3OmniMoeTalkerTextTopKRouter(config)
66+
self.experts = nn.ModuleList(
67+
[
68+
Qwen3OmniMoeTalkerTextMLP(config, intermediate_size=config.moe_intermediate_size)
69+
for _ in range(self.num_experts)
70+
]
71+
)
72+
self.shared_expert = Qwen3OmniMoeTalkerTextMLP(
73+
config,
74+
intermediate_size=config.shared_expert_intermediate_size,
75+
)
76+
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
77+
78+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
79+
batch_size, sequence_length, hidden_dim = hidden_states.shape
80+
hidden_states = hidden_states.view(-1, hidden_dim)
81+
shared_expert_output = self.shared_expert(hidden_states)
82+
_, routing_weights, selected_experts = self.gate(hidden_states)
83+
final_hidden_states = run_routed_experts(
84+
self.experts,
85+
hidden_states,
86+
routing_weights.to(hidden_states.dtype),
87+
selected_experts,
88+
self.num_experts,
89+
)
90+
shared_expert_output = torch.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
91+
final_hidden_states = final_hidden_states + shared_expert_output
92+
return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)

0 commit comments

Comments
 (0)