|
37 | 37 | from transformers.models.llama4.modeling_llama4 import Llama4Config, Llama4ForConditionalGeneration |
38 | 38 | from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( |
39 | 39 | Qwen3OmniMoeForConditionalGeneration, |
| 40 | + Qwen3OmniMoeTalkerTextSparseMoeBlock, |
40 | 41 | Qwen3OmniMoeThinkerTextSparseMoeBlock, |
41 | 42 | ) |
42 | 43 |
|
|
51 | 52 | from defuser.modeling.unfused_moe.qwen2_moe import LinearQwen2MoeSparseMoeBlock |
52 | 53 | from defuser.modeling.unfused_moe.qwen3_moe import LinearQwen3MoeSparseMoeBlock |
53 | 54 | from defuser.modeling.unfused_moe.qwen3_next import LinearQwen3NextSparseMoeBlock |
54 | | -from defuser.modeling.unfused_moe.qwen3_omni_moe import LinearQwen3OmniMoeThinkerTextSparseMoeBlock |
| 55 | +from defuser.modeling.unfused_moe.qwen3_omni_moe import ( |
| 56 | + LinearQwen3OmniMoeTalkerTextSparseMoeBlock, |
| 57 | + LinearQwen3OmniMoeThinkerTextSparseMoeBlock, |
| 58 | +) |
55 | 59 | from defuser.utils.common import MIN_SUPPORTED_TRANSFORMERS_VERSION |
56 | 60 |
|
57 | 61 |
|
@@ -116,6 +120,25 @@ def _tiny_qwen3_omni_config(): |
116 | 120 | ) |
117 | 121 |
|
118 | 122 |
|
| 123 | +def _tiny_qwen3_omni_talker_text_config(): |
| 124 | + config = Qwen3OmniMoeConfig(enable_audio_output=True).talker_config.text_config |
| 125 | + config.hidden_size = 64 |
| 126 | + config.intermediate_size = 128 |
| 127 | + config.moe_intermediate_size = 32 |
| 128 | + config.shared_expert_intermediate_size = 32 |
| 129 | + config.num_hidden_layers = 1 |
| 130 | + config.num_attention_heads = 4 |
| 131 | + config.num_key_value_heads = 4 |
| 132 | + config.head_dim = 16 |
| 133 | + config.num_experts = 4 |
| 134 | + config.num_experts_per_tok = 2 |
| 135 | + config.vocab_size = 128 |
| 136 | + config.pad_token_id = 0 |
| 137 | + config.bos_token_id = 1 |
| 138 | + config.eos_token_id = 2 |
| 139 | + return config |
| 140 | + |
| 141 | + |
119 | 142 | def _tiny_qwen3_5_moe_config(): |
120 | 143 | return Qwen3_5MoeConfig( |
121 | 144 | text_config={ |
@@ -861,6 +884,17 @@ def test_qwen3_omni_defused_forward_matches_fused_math(): |
861 | 884 | ) |
862 | 885 |
|
863 | 886 |
|
| 887 | +def test_qwen3_omni_talker_defused_forward_matches_fused_math(): |
| 888 | + config = _tiny_qwen3_omni_talker_text_config() |
| 889 | + hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32) |
| 890 | + |
| 891 | + _assert_sparse_moe_defused_matches_fused_math( |
| 892 | + Qwen3OmniMoeTalkerTextSparseMoeBlock(config), |
| 893 | + LinearQwen3OmniMoeTalkerTextSparseMoeBlock(config), |
| 894 | + hidden_states, |
| 895 | + ) |
| 896 | + |
| 897 | + |
864 | 898 | def test_glm4_moe_defused_forward_matches_fused_math(): |
865 | 899 | config = _tiny_glm4_moe_config() |
866 | 900 | hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32) |
|
0 commit comments