Skip to content

Commit cae3a65

Browse files
committed
test
1 parent e083eb3 commit cae3a65

1 file changed

Lines changed: 35 additions & 1 deletion

File tree

tests/test_convert_model.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from transformers.models.llama4.modeling_llama4 import Llama4Config, Llama4ForConditionalGeneration
3838
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
3939
Qwen3OmniMoeForConditionalGeneration,
40+
Qwen3OmniMoeTalkerTextSparseMoeBlock,
4041
Qwen3OmniMoeThinkerTextSparseMoeBlock,
4142
)
4243

@@ -51,7 +52,10 @@
5152
from defuser.modeling.unfused_moe.qwen2_moe import LinearQwen2MoeSparseMoeBlock
5253
from defuser.modeling.unfused_moe.qwen3_moe import LinearQwen3MoeSparseMoeBlock
5354
from defuser.modeling.unfused_moe.qwen3_next import LinearQwen3NextSparseMoeBlock
54-
from defuser.modeling.unfused_moe.qwen3_omni_moe import LinearQwen3OmniMoeThinkerTextSparseMoeBlock
55+
from defuser.modeling.unfused_moe.qwen3_omni_moe import (
56+
LinearQwen3OmniMoeTalkerTextSparseMoeBlock,
57+
LinearQwen3OmniMoeThinkerTextSparseMoeBlock,
58+
)
5559
from defuser.utils.common import MIN_SUPPORTED_TRANSFORMERS_VERSION
5660

5761

@@ -116,6 +120,25 @@ def _tiny_qwen3_omni_config():
116120
)
117121

118122

123+
def _tiny_qwen3_omni_talker_text_config():
124+
config = Qwen3OmniMoeConfig(enable_audio_output=True).talker_config.text_config
125+
config.hidden_size = 64
126+
config.intermediate_size = 128
127+
config.moe_intermediate_size = 32
128+
config.shared_expert_intermediate_size = 32
129+
config.num_hidden_layers = 1
130+
config.num_attention_heads = 4
131+
config.num_key_value_heads = 4
132+
config.head_dim = 16
133+
config.num_experts = 4
134+
config.num_experts_per_tok = 2
135+
config.vocab_size = 128
136+
config.pad_token_id = 0
137+
config.bos_token_id = 1
138+
config.eos_token_id = 2
139+
return config
140+
141+
119142
def _tiny_qwen3_5_moe_config():
120143
return Qwen3_5MoeConfig(
121144
text_config={
@@ -861,6 +884,17 @@ def test_qwen3_omni_defused_forward_matches_fused_math():
861884
)
862885

863886

887+
def test_qwen3_omni_talker_defused_forward_matches_fused_math():
888+
config = _tiny_qwen3_omni_talker_text_config()
889+
hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32)
890+
891+
_assert_sparse_moe_defused_matches_fused_math(
892+
Qwen3OmniMoeTalkerTextSparseMoeBlock(config),
893+
LinearQwen3OmniMoeTalkerTextSparseMoeBlock(config),
894+
hidden_states,
895+
)
896+
897+
864898
def test_glm4_moe_defused_forward_matches_fused_math():
865899
config = _tiny_glm4_moe_config()
866900
hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32)

0 commit comments

Comments
 (0)