|
33 | 33 | ) |
34 | 34 | from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import Qwen3OmniMoeConfig |
35 | 35 | from transformers.models.gpt_oss.modeling_gpt_oss import GptOssConfig, GptOssForCausalLM |
| 36 | +from transformers.models.phimoe.modeling_phimoe import PhimoeConfig, PhimoeForCausalLM |
36 | 37 | from transformers.models.llama4.modeling_llama4 import Llama4Config, Llama4ForConditionalGeneration |
37 | 38 | from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( |
38 | 39 | Qwen3OmniMoeForConditionalGeneration, |
@@ -252,6 +253,22 @@ def _tiny_llama4_config(): |
252 | 253 | ) |
253 | 254 |
|
254 | 255 |
|
| 256 | +def _tiny_phimoe_config(): |
| 257 | + return PhimoeConfig( |
| 258 | + vocab_size=128, |
| 259 | + hidden_size=64, |
| 260 | + intermediate_size=128, |
| 261 | + num_hidden_layers=1, |
| 262 | + num_attention_heads=4, |
| 263 | + num_key_value_heads=4, |
| 264 | + num_local_experts=4, |
| 265 | + num_experts_per_tok=2, |
| 266 | + pad_token_id=0, |
| 267 | + bos_token_id=1, |
| 268 | + eos_token_id=2, |
| 269 | + ) |
| 270 | + |
| 271 | + |
255 | 272 | def _write_single_safetensors_checkpoint(path, state_dict: dict[str, torch.Tensor], config) -> None: |
256 | 273 | config.save_pretrained(path) |
257 | 274 | save_file({name: tensor.detach().cpu().contiguous() for name, tensor in state_dict.items()}, str(path / "model.safetensors")) |
@@ -1076,3 +1093,61 @@ def test_llama4_split_forward_matches_fused_math(): |
1076 | 1093 |
|
1077 | 1094 | # The split module should exactly reproduce the original fused MLP math. |
1078 | 1095 | torch.testing.assert_close(mlp(hidden_states), expected) |
| 1096 | + |
| 1097 | + |
| 1098 | +def test_phimoe(): |
| 1099 | + from transformers.models.phimoe.modeling_phimoe import PhimoeSparseMoeBlock |
| 1100 | + |
| 1101 | + model = PhimoeForCausalLM(_tiny_phimoe_config()) |
| 1102 | + assert model.config.model_type == "phimoe" |
| 1103 | + |
| 1104 | + original_moe_block = model.model.layers[0].mlp |
| 1105 | + assert isinstance(original_moe_block, PhimoeSparseMoeBlock) |
| 1106 | + |
| 1107 | + hidden_dim = original_moe_block.experts.gate_up_proj.shape[-1] |
| 1108 | + intermediate_dim = original_moe_block.experts.gate_up_proj.shape[1] // 2 |
| 1109 | + |
| 1110 | + expected_gate = original_moe_block.experts.gate_up_proj[0, :intermediate_dim, :hidden_dim].contiguous().clone() |
| 1111 | + expected_up = original_moe_block.experts.gate_up_proj[0, intermediate_dim:, :hidden_dim].contiguous().clone() |
| 1112 | + expected_down = original_moe_block.experts.down_proj[0, :hidden_dim, :intermediate_dim].contiguous().clone() |
| 1113 | + |
| 1114 | + converted = convert_model(model, cleanup_original=False, max_layers=1) |
| 1115 | + assert converted |
| 1116 | + |
| 1117 | + moe_block = model.model.layers[0].mlp |
| 1118 | + experts = moe_block.experts |
| 1119 | + |
| 1120 | + _assert_unfused_expert_module(experts) |
| 1121 | + expert0 = getattr(experts, "0") |
| 1122 | + |
| 1123 | + materialize_model(model.model.layers[0]) |
| 1124 | + |
| 1125 | + torch.testing.assert_close(expert0.gate_proj.weight, expected_gate) |
| 1126 | + torch.testing.assert_close(expert0.up_proj.weight, expected_up) |
| 1127 | + torch.testing.assert_close(expert0.down_proj.weight, expected_down) |
| 1128 | + |
| 1129 | +def test_phimoe_split_forward_matches_fused_math(): |
| 1130 | + from transformers.models.phimoe.modeling_phimoe import PhimoeExperts |
| 1131 | + |
| 1132 | + model = PhimoeForCausalLM(_tiny_phimoe_config()) |
| 1133 | + fused_experts = model.model.layers[0].mlp.experts |
| 1134 | + assert isinstance(fused_experts, PhimoeExperts) |
| 1135 | + |
| 1136 | + hidden_states = torch.randn(5, model.config.hidden_size, dtype=torch.float32) |
| 1137 | + top_k_index = torch.zeros((hidden_states.size(0), 1), dtype=torch.long) |
| 1138 | + top_k_weights = torch.ones((hidden_states.size(0), 1), dtype=hidden_states.dtype) |
| 1139 | + |
| 1140 | + with torch.no_grad(): |
| 1141 | + expected = fused_experts(hidden_states, top_k_index, top_k_weights) |
| 1142 | + |
| 1143 | + converted = convert_model(model, cleanup_original=False, max_layers=1) |
| 1144 | + assert converted |
| 1145 | + |
| 1146 | + split_experts = model.model.layers[0].mlp.experts |
| 1147 | + _assert_unfused_expert_module(split_experts) |
| 1148 | + materialize_model(model.model.layers[0]) |
| 1149 | + with torch.no_grad(): |
| 1150 | + actual = split_experts(hidden_states, top_k_index, top_k_weights) |
| 1151 | + |
| 1152 | + # The split experts path should exactly reproduce the original fused experts math. |
| 1153 | + torch.testing.assert_close(actual, expected) |
0 commit comments