88from logbar import LogBar
99
1010from defuser import DEBUG_ON
11+ from defuser .modeling .runtime_defusion import (
12+ patch_dbrx_experts ,
13+ patch_longcat_flash_experts ,
14+ patch_parallel_experts ,
15+ patch_split_gate_up_mlp ,
16+ )
17+ from defuser .utils .common import is_within_max_layers
1118import torch
1219
1320logger = LogBar (__name__ )
@@ -68,7 +75,7 @@ def patched_init_weights(self, module):
6875
6976
7077@register_model_patch ("qwen3_omni_moe" )
71- def patch_qwen3_omni_text_runtime (model ) -> list [str ]:
78+ def patch_qwen3_omni_text_runtime (model , max_layers : int | None = None ) -> list [str ]:
7279 """Restore text-only ``forward`` and ``generate`` behavior after class swapping."""
7380 model_cls = type (model )
7481 if not getattr (model_cls , "__module__" , "" ).startswith ("transformers.models.qwen3_omni_moe." ):
@@ -98,6 +105,193 @@ def forward(self, *args, **kwargs):
98105 return applied
99106
100107
108+ def _patch_modules_by_class (
109+ model ,
110+ patchers : dict [str , Callable ],
111+ * ,
112+ max_layers : int | None = None ,
113+ ) -> list [str ]:
114+ applied = []
115+ for name , module in list (model .named_modules ()):
116+ if not is_within_max_layers (name , max_layers ):
117+ continue
118+ class_path = f"{ module .__class__ .__module__ } .{ module .__class__ .__name__ } "
119+ patcher = patchers .get (class_path )
120+ if patcher is None :
121+ continue
122+ if patcher (module ):
123+ applied .append (name )
124+ return applied
125+
126+
127+ def _patch_split_gate_up_mlps (
128+ model ,
129+ patchers : dict [str , str ],
130+ * ,
131+ max_layers : int | None = None ,
132+ ) -> list [str ]:
133+ return _patch_modules_by_class (
134+ model ,
135+ {
136+ class_path : (lambda module , variant = variant : patch_split_gate_up_mlp (module , variant = variant ))
137+ for class_path , variant in patchers .items ()
138+ },
139+ max_layers = max_layers ,
140+ )
141+
142+
143+ _STANDARD_SPLIT_GATE_UP_CLASSES = {
144+ "transformers.models.dia.modeling_dia.DiaMLP" : "standard" ,
145+ "transformers.models.glm.modeling_glm.GlmMLP" : "standard" ,
146+ "transformers.models.glm4.modeling_glm4.Glm4MLP" : "standard" ,
147+ "transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP" : "standard" ,
148+ "transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP" : "standard" ,
149+ "transformers.models.phi3.modeling_phi3.Phi3MLP" : "standard" ,
150+ "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalMLP" : "standard" ,
151+ "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalAudioMLP" : "phi4_audio" ,
152+ "transformers.models.zamba2.modeling_zamba2.Zamba2MLP" : "zamba2" ,
153+ }
154+
155+
156+ @register_model_patch ("dia" )
157+ def patch_dia_runtime (model , max_layers : int | None = None ) -> list [str ]:
158+ return _patch_split_gate_up_mlps (
159+ model ,
160+ {"transformers.models.dia.modeling_dia.DiaMLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.dia.modeling_dia.DiaMLP" ]},
161+ max_layers = max_layers ,
162+ )
163+
164+
165+ @register_model_patch ("glm" )
166+ def patch_glm_runtime (model , max_layers : int | None = None ) -> list [str ]:
167+ return _patch_split_gate_up_mlps (
168+ model ,
169+ {"transformers.models.glm.modeling_glm.GlmMLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.glm.modeling_glm.GlmMLP" ]},
170+ max_layers = max_layers ,
171+ )
172+
173+
174+ @register_model_patch ("glm4" )
175+ def patch_glm4_runtime (model , max_layers : int | None = None ) -> list [str ]:
176+ return _patch_split_gate_up_mlps (
177+ model ,
178+ {"transformers.models.glm4.modeling_glm4.Glm4MLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.glm4.modeling_glm4.Glm4MLP" ]},
179+ max_layers = max_layers ,
180+ )
181+
182+
183+ @register_model_patch ("glm_image" )
184+ def patch_glm_image_runtime (model , max_layers : int | None = None ) -> list [str ]:
185+ return _patch_split_gate_up_mlps (
186+ model ,
187+ {"transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP" ]},
188+ max_layers = max_layers ,
189+ )
190+
191+
192+ @register_model_patch ("glm_ocr" )
193+ def patch_glm_ocr_runtime (model , max_layers : int | None = None ) -> list [str ]:
194+ return _patch_split_gate_up_mlps (
195+ model ,
196+ {"transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP" ]},
197+ max_layers = max_layers ,
198+ )
199+
200+
201+ @register_model_patch ("phi3" )
202+ def patch_phi3_runtime (model , max_layers : int | None = None ) -> list [str ]:
203+ return _patch_split_gate_up_mlps (
204+ model ,
205+ {"transformers.models.phi3.modeling_phi3.Phi3MLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.phi3.modeling_phi3.Phi3MLP" ]},
206+ max_layers = max_layers ,
207+ )
208+
209+
210+ @register_model_patch ("phi4_multimodal" )
211+ def patch_phi4_multimodal_runtime (model , max_layers : int | None = None ) -> list [str ]:
212+ return _patch_split_gate_up_mlps (
213+ model ,
214+ {
215+ "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalMLP" :
216+ _STANDARD_SPLIT_GATE_UP_CLASSES [
217+ "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalMLP"
218+ ],
219+ "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalAudioMLP" :
220+ _STANDARD_SPLIT_GATE_UP_CLASSES [
221+ "transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalAudioMLP"
222+ ],
223+ },
224+ max_layers = max_layers ,
225+ )
226+
227+
228+ @register_model_patch ("zamba2" )
229+ def patch_zamba2_runtime (model , max_layers : int | None = None ) -> list [str ]:
230+ return _patch_split_gate_up_mlps (
231+ model ,
232+ {"transformers.models.zamba2.modeling_zamba2.Zamba2MLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.zamba2.modeling_zamba2.Zamba2MLP" ]},
233+ max_layers = max_layers ,
234+ )
235+
236+
237+ @register_model_patch ("dbrx" )
238+ def patch_dbrx_runtime (model , max_layers : int | None = None ) -> list [str ]:
239+ return _patch_modules_by_class (
240+ model ,
241+ {"transformers.models.dbrx.modeling_dbrx.DbrxExperts" : patch_dbrx_experts },
242+ max_layers = max_layers ,
243+ )
244+
245+
246+ def _patch_parallel_runtime (model , class_path : str , * , max_layers : int | None = None ) -> list [str ]:
247+ return _patch_modules_by_class (model , {class_path : patch_parallel_experts }, max_layers = max_layers )
248+
249+
250+ @register_model_patch ("granitemoe" )
251+ def patch_granitemoe_runtime (model , max_layers : int | None = None ) -> list [str ]:
252+ return _patch_parallel_runtime (
253+ model ,
254+ "transformers.models.granitemoe.modeling_granitemoe.GraniteMoeParallelExperts" ,
255+ max_layers = max_layers ,
256+ )
257+
258+
259+ @register_model_patch ("granitemoehybrid" )
260+ def patch_granitemoehybrid_runtime (model , max_layers : int | None = None ) -> list [str ]:
261+ return _patch_parallel_runtime (
262+ model ,
263+ "transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridParallelExperts" ,
264+ max_layers = max_layers ,
265+ )
266+
267+
268+ @register_model_patch ("granitemoeshared" )
269+ def patch_granitemoeshared_runtime (model , max_layers : int | None = None ) -> list [str ]:
270+ return _patch_parallel_runtime (
271+ model ,
272+ "transformers.models.granitemoeshared.modeling_granitemoeshared.GraniteMoeSharedParallelExperts" ,
273+ max_layers = max_layers ,
274+ )
275+
276+
277+ @register_model_patch ("jetmoe" )
278+ def patch_jetmoe_runtime (model , max_layers : int | None = None ) -> list [str ]:
279+ return _patch_parallel_runtime (
280+ model ,
281+ "transformers.models.jetmoe.modeling_jetmoe.JetMoeParallelExperts" ,
282+ max_layers = max_layers ,
283+ )
284+
285+
286+ @register_model_patch ("longcat_flash" )
287+ def patch_longcat_flash_runtime (model , max_layers : int | None = None ) -> list [str ]:
288+ return _patch_modules_by_class (
289+ model ,
290+ {"transformers.models.longcat_flash.modeling_longcat_flash.LongcatFlashExperts" : patch_longcat_flash_experts },
291+ max_layers = max_layers ,
292+ )
293+
294+
101295def apply_model_class_patches (model_type ) -> list [str ]:
102296 """Run any registered pre-construction patch for ``model_type``."""
103297 patch_model_class = _MODEL_CLASS_PATCH_REGISTRY .get (model_type )
@@ -110,15 +304,15 @@ def apply_model_class_patches(model_type) -> list[str]:
110304 return applied
111305
112306
113- def apply_model_patches (model ) -> list [str ]:
307+ def apply_model_patches (model , max_layers : int | None = None ) -> list [str ]:
114308 """Run any registered runtime patch for the instantiated ``model``."""
115309 config = getattr (model , "config" , None )
116310 model_type = getattr (config , "model_type" , None )
117311 patch = _MODEL_PATCH_REGISTRY .get (model_type )
118312 if patch is None :
119313 return []
120314
121- applied = patch (model )
315+ applied = patch (model , max_layers = max_layers )
122316 if applied and DEBUG_ON :
123317 logger .debug (f"Applied model patches for model_type={ model_type } : { ', ' .join (applied )} " )
124318 return applied
0 commit comments