Skip to content

Commit dad5f7b

Browse files
committed
extend model support
1 parent c1f4bca commit dad5f7b

6 files changed

Lines changed: 1300 additions & 10 deletions

File tree

defuser/defuser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def convert_model(
200200
if not check_model_compatibility(model):
201201
return False
202202

203-
apply_model_patches(model)
203+
apply_model_patches(model, max_layers=max_layers)
204204

205205
# If fused blocks have already been structurally replaced at load model before,
206206
# there is no need to perform runtime defusing again

defuser/model_registry.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,39 @@ class PATCH(str, Enum):
1616

1717

1818
MODEL_CONFIG = {
19+
"dbrx": {
20+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
21+
},
22+
"deepseek_v2": {
23+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
24+
},
25+
"deepseek_v3": {
26+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
27+
},
28+
"dia": {
29+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
30+
},
31+
"dots1": {
32+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
33+
},
34+
"ernie4_5_moe": {
35+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
36+
},
37+
"ernie4_5_vl_moe": {
38+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
39+
},
40+
"exaone_moe": {
41+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
42+
},
43+
"flex_olmo": {
44+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
45+
},
46+
"glm": {
47+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
48+
},
49+
"glm4": {
50+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
51+
},
1952
"mixtral": {
2053
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
2154
PATCH.REPLACE_MODULE: [
@@ -96,6 +129,9 @@ class PATCH(str, Enum):
96129
)
97130
],
98131
},
132+
"glm4_moe_lite": {
133+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
134+
},
99135
"glm4v": {
100136
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
101137
PATCH.REPLACE_MODULE: [
@@ -116,9 +152,39 @@ class PATCH(str, Enum):
116152
),
117153
],
118154
},
155+
"glm4v_moe": {
156+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
157+
},
158+
"glm_image": {
159+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
160+
},
161+
"glm_moe_dsa": {
162+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
163+
},
164+
"glm_ocr": {
165+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
166+
},
119167
"gpt_oss": {
120168
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
121169
},
170+
"granitemoe": {
171+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
172+
},
173+
"granitemoehybrid": {
174+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
175+
},
176+
"granitemoeshared": {
177+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
178+
},
179+
"hunyuan_v1_moe": {
180+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
181+
},
182+
"jamba": {
183+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
184+
},
185+
"jetmoe": {
186+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
187+
},
122188
"llama4": {
123189
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
124190
PATCH.EXPERTS_DEFUSE: [
@@ -128,7 +194,40 @@ class PATCH(str, Enum):
128194
}
129195
],
130196
},
197+
"lfm2_moe": {
198+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
199+
},
200+
"longcat_flash": {
201+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
202+
},
203+
"minimax": {
204+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
205+
},
206+
"minimax_m2": {
207+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
208+
},
209+
"nemotron_h": {
210+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
211+
},
212+
"olmoe": {
213+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
214+
},
215+
"phi3": {
216+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
217+
},
218+
"phi4_multimodal": {
219+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
220+
},
131221
"phimoe": {
132222
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
133223
},
224+
"qwen3_vl_moe": {
225+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
226+
},
227+
"solar_open": {
228+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
229+
},
230+
"zamba2": {
231+
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
232+
},
134233
}

defuser/modeling/model_patches.py

Lines changed: 197 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
from logbar import LogBar
99

1010
from defuser import DEBUG_ON
11+
from defuser.modeling.runtime_defusion import (
12+
patch_dbrx_experts,
13+
patch_longcat_flash_experts,
14+
patch_parallel_experts,
15+
patch_split_gate_up_mlp,
16+
)
17+
from defuser.utils.common import is_within_max_layers
1118
import torch
1219

1320
logger = LogBar(__name__)
@@ -68,7 +75,7 @@ def patched_init_weights(self, module):
6875

6976

7077
@register_model_patch("qwen3_omni_moe")
71-
def patch_qwen3_omni_text_runtime(model) -> list[str]:
78+
def patch_qwen3_omni_text_runtime(model, max_layers: int | None = None) -> list[str]:
7279
"""Restore text-only ``forward`` and ``generate`` behavior after class swapping."""
7380
model_cls = type(model)
7481
if not getattr(model_cls, "__module__", "").startswith("transformers.models.qwen3_omni_moe."):
@@ -98,6 +105,193 @@ def forward(self, *args, **kwargs):
98105
return applied
99106

100107

108+
def _patch_modules_by_class(
109+
model,
110+
patchers: dict[str, Callable],
111+
*,
112+
max_layers: int | None = None,
113+
) -> list[str]:
114+
applied = []
115+
for name, module in list(model.named_modules()):
116+
if not is_within_max_layers(name, max_layers):
117+
continue
118+
class_path = f"{module.__class__.__module__}.{module.__class__.__name__}"
119+
patcher = patchers.get(class_path)
120+
if patcher is None:
121+
continue
122+
if patcher(module):
123+
applied.append(name)
124+
return applied
125+
126+
127+
def _patch_split_gate_up_mlps(
128+
model,
129+
patchers: dict[str, str],
130+
*,
131+
max_layers: int | None = None,
132+
) -> list[str]:
133+
return _patch_modules_by_class(
134+
model,
135+
{
136+
class_path: (lambda module, variant=variant: patch_split_gate_up_mlp(module, variant=variant))
137+
for class_path, variant in patchers.items()
138+
},
139+
max_layers=max_layers,
140+
)
141+
142+
143+
_STANDARD_SPLIT_GATE_UP_CLASSES = {
144+
"transformers.models.dia.modeling_dia.DiaMLP": "standard",
145+
"transformers.models.glm.modeling_glm.GlmMLP": "standard",
146+
"transformers.models.glm4.modeling_glm4.Glm4MLP": "standard",
147+
"transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP": "standard",
148+
"transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP": "standard",
149+
"transformers.models.phi3.modeling_phi3.Phi3MLP": "standard",
150+
"transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalMLP": "standard",
151+
"transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalAudioMLP": "phi4_audio",
152+
"transformers.models.zamba2.modeling_zamba2.Zamba2MLP": "zamba2",
153+
}
154+
155+
156+
@register_model_patch("dia")
157+
def patch_dia_runtime(model, max_layers: int | None = None) -> list[str]:
158+
return _patch_split_gate_up_mlps(
159+
model,
160+
{"transformers.models.dia.modeling_dia.DiaMLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.dia.modeling_dia.DiaMLP"]},
161+
max_layers=max_layers,
162+
)
163+
164+
165+
@register_model_patch("glm")
166+
def patch_glm_runtime(model, max_layers: int | None = None) -> list[str]:
167+
return _patch_split_gate_up_mlps(
168+
model,
169+
{"transformers.models.glm.modeling_glm.GlmMLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.glm.modeling_glm.GlmMLP"]},
170+
max_layers=max_layers,
171+
)
172+
173+
174+
@register_model_patch("glm4")
175+
def patch_glm4_runtime(model, max_layers: int | None = None) -> list[str]:
176+
return _patch_split_gate_up_mlps(
177+
model,
178+
{"transformers.models.glm4.modeling_glm4.Glm4MLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.glm4.modeling_glm4.Glm4MLP"]},
179+
max_layers=max_layers,
180+
)
181+
182+
183+
@register_model_patch("glm_image")
184+
def patch_glm_image_runtime(model, max_layers: int | None = None) -> list[str]:
185+
return _patch_split_gate_up_mlps(
186+
model,
187+
{"transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP"]},
188+
max_layers=max_layers,
189+
)
190+
191+
192+
@register_model_patch("glm_ocr")
193+
def patch_glm_ocr_runtime(model, max_layers: int | None = None) -> list[str]:
194+
return _patch_split_gate_up_mlps(
195+
model,
196+
{"transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP"]},
197+
max_layers=max_layers,
198+
)
199+
200+
201+
@register_model_patch("phi3")
202+
def patch_phi3_runtime(model, max_layers: int | None = None) -> list[str]:
203+
return _patch_split_gate_up_mlps(
204+
model,
205+
{"transformers.models.phi3.modeling_phi3.Phi3MLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.phi3.modeling_phi3.Phi3MLP"]},
206+
max_layers=max_layers,
207+
)
208+
209+
210+
@register_model_patch("phi4_multimodal")
211+
def patch_phi4_multimodal_runtime(model, max_layers: int | None = None) -> list[str]:
212+
return _patch_split_gate_up_mlps(
213+
model,
214+
{
215+
"transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalMLP":
216+
_STANDARD_SPLIT_GATE_UP_CLASSES[
217+
"transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalMLP"
218+
],
219+
"transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalAudioMLP":
220+
_STANDARD_SPLIT_GATE_UP_CLASSES[
221+
"transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalAudioMLP"
222+
],
223+
},
224+
max_layers=max_layers,
225+
)
226+
227+
228+
@register_model_patch("zamba2")
229+
def patch_zamba2_runtime(model, max_layers: int | None = None) -> list[str]:
230+
return _patch_split_gate_up_mlps(
231+
model,
232+
{"transformers.models.zamba2.modeling_zamba2.Zamba2MLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.zamba2.modeling_zamba2.Zamba2MLP"]},
233+
max_layers=max_layers,
234+
)
235+
236+
237+
@register_model_patch("dbrx")
238+
def patch_dbrx_runtime(model, max_layers: int | None = None) -> list[str]:
239+
return _patch_modules_by_class(
240+
model,
241+
{"transformers.models.dbrx.modeling_dbrx.DbrxExperts": patch_dbrx_experts},
242+
max_layers=max_layers,
243+
)
244+
245+
246+
def _patch_parallel_runtime(model, class_path: str, *, max_layers: int | None = None) -> list[str]:
247+
return _patch_modules_by_class(model, {class_path: patch_parallel_experts}, max_layers=max_layers)
248+
249+
250+
@register_model_patch("granitemoe")
251+
def patch_granitemoe_runtime(model, max_layers: int | None = None) -> list[str]:
252+
return _patch_parallel_runtime(
253+
model,
254+
"transformers.models.granitemoe.modeling_granitemoe.GraniteMoeParallelExperts",
255+
max_layers=max_layers,
256+
)
257+
258+
259+
@register_model_patch("granitemoehybrid")
260+
def patch_granitemoehybrid_runtime(model, max_layers: int | None = None) -> list[str]:
261+
return _patch_parallel_runtime(
262+
model,
263+
"transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridParallelExperts",
264+
max_layers=max_layers,
265+
)
266+
267+
268+
@register_model_patch("granitemoeshared")
269+
def patch_granitemoeshared_runtime(model, max_layers: int | None = None) -> list[str]:
270+
return _patch_parallel_runtime(
271+
model,
272+
"transformers.models.granitemoeshared.modeling_granitemoeshared.GraniteMoeSharedParallelExperts",
273+
max_layers=max_layers,
274+
)
275+
276+
277+
@register_model_patch("jetmoe")
278+
def patch_jetmoe_runtime(model, max_layers: int | None = None) -> list[str]:
279+
return _patch_parallel_runtime(
280+
model,
281+
"transformers.models.jetmoe.modeling_jetmoe.JetMoeParallelExperts",
282+
max_layers=max_layers,
283+
)
284+
285+
286+
@register_model_patch("longcat_flash")
287+
def patch_longcat_flash_runtime(model, max_layers: int | None = None) -> list[str]:
288+
return _patch_modules_by_class(
289+
model,
290+
{"transformers.models.longcat_flash.modeling_longcat_flash.LongcatFlashExperts": patch_longcat_flash_experts},
291+
max_layers=max_layers,
292+
)
293+
294+
101295
def apply_model_class_patches(model_type) -> list[str]:
102296
"""Run any registered pre-construction patch for ``model_type``."""
103297
patch_model_class = _MODEL_CLASS_PATCH_REGISTRY.get(model_type)
@@ -110,15 +304,15 @@ def apply_model_class_patches(model_type) -> list[str]:
110304
return applied
111305

112306

113-
def apply_model_patches(model) -> list[str]:
307+
def apply_model_patches(model, max_layers: int | None = None) -> list[str]:
114308
"""Run any registered runtime patch for the instantiated ``model``."""
115309
config = getattr(model, "config", None)
116310
model_type = getattr(config, "model_type", None)
117311
patch = _MODEL_PATCH_REGISTRY.get(model_type)
118312
if patch is None:
119313
return []
120314

121-
applied = patch(model)
315+
applied = patch(model, max_layers=max_layers)
122316
if applied and DEBUG_ON:
123317
logger.debug(f"Applied model patches for model_type={model_type}: {', '.join(applied)}")
124318
return applied

0 commit comments

Comments
 (0)