1414 patch_parallel_experts ,
1515 patch_split_gate_up_mlp ,
1616)
17- from defuser .utils .common import is_within_max_layers
17+ from defuser .utils .common import compile_module_name_filter , is_within_max_layers , matches_module_name_filter
1818import torch
1919
2020logger = LogBar (__name__ )
@@ -87,7 +87,7 @@ def patched_init_weights(self, module):
8787
8888
8989@register_model_patch ("qwen3_omni_moe" )
90- def patch_qwen3_omni_text_runtime (model , max_layers : int | None = None ) -> list [str ]:
90+ def patch_qwen3_omni_text_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
9191 """Restore text-only ``forward`` and ``generate`` behavior after class swapping."""
9292 model_cls = type (model )
9393 if not getattr (model_cls , "__module__" , "" ).startswith ("transformers.models.qwen3_omni_moe." ):
@@ -122,11 +122,15 @@ def _patch_modules_by_class(
122122 patchers : dict [str , Callable ],
123123 * ,
124124 max_layers : int | None = None ,
125+ filter_rules = None ,
125126) -> list [str ]:
127+ module_name_filter = compile_module_name_filter (filter_rules )
126128 applied = []
127129 for name , module in list (model .named_modules ()):
128130 if not is_within_max_layers (name , max_layers ):
129131 continue
132+ if not matches_module_name_filter (name , module_name_filter ):
133+ continue
130134 class_path = f"{ module .__class__ .__module__ } .{ module .__class__ .__name__ } "
131135 patcher = patchers .get (class_path )
132136 if patcher is None :
@@ -141,6 +145,7 @@ def _patch_split_gate_up_mlps(
141145 patchers : dict [str , str ],
142146 * ,
143147 max_layers : int | None = None ,
148+ filter_rules = None ,
144149) -> list [str ]:
145150 return _patch_modules_by_class (
146151 model ,
@@ -149,6 +154,7 @@ def _patch_split_gate_up_mlps(
149154 for class_path , variant in patchers .items ()
150155 },
151156 max_layers = max_layers ,
157+ filter_rules = filter_rules ,
152158 )
153159
154160
@@ -166,61 +172,67 @@ def _patch_split_gate_up_mlps(
166172
167173
168174@register_model_patch ("dia" )
169- def patch_dia_runtime (model , max_layers : int | None = None ) -> list [str ]:
175+ def patch_dia_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
170176 return _patch_split_gate_up_mlps (
171177 model ,
172178 {"transformers.models.dia.modeling_dia.DiaMLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.dia.modeling_dia.DiaMLP" ]},
173179 max_layers = max_layers ,
180+ filter_rules = filter_rules ,
174181 )
175182
176183
177184@register_model_patch ("glm" )
178- def patch_glm_runtime (model , max_layers : int | None = None ) -> list [str ]:
185+ def patch_glm_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
179186 return _patch_split_gate_up_mlps (
180187 model ,
181188 {"transformers.models.glm.modeling_glm.GlmMLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.glm.modeling_glm.GlmMLP" ]},
182189 max_layers = max_layers ,
190+ filter_rules = filter_rules ,
183191 )
184192
185193
186194@register_model_patch ("glm4" )
187- def patch_glm4_runtime (model , max_layers : int | None = None ) -> list [str ]:
195+ def patch_glm4_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
188196 return _patch_split_gate_up_mlps (
189197 model ,
190198 {"transformers.models.glm4.modeling_glm4.Glm4MLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.glm4.modeling_glm4.Glm4MLP" ]},
191199 max_layers = max_layers ,
200+ filter_rules = filter_rules ,
192201 )
193202
194203
195204@register_model_patch ("glm_image" )
196- def patch_glm_image_runtime (model , max_layers : int | None = None ) -> list [str ]:
205+ def patch_glm_image_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
197206 return _patch_split_gate_up_mlps (
198207 model ,
199208 {"transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP" ]},
200209 max_layers = max_layers ,
210+ filter_rules = filter_rules ,
201211 )
202212
203213
204214@register_model_patch ("glm_ocr" )
205- def patch_glm_ocr_runtime (model , max_layers : int | None = None ) -> list [str ]:
215+ def patch_glm_ocr_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
206216 return _patch_split_gate_up_mlps (
207217 model ,
208218 {"transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP" ]},
209219 max_layers = max_layers ,
220+ filter_rules = filter_rules ,
210221 )
211222
212223
213224@register_model_patch ("phi3" )
214- def patch_phi3_runtime (model , max_layers : int | None = None ) -> list [str ]:
225+ def patch_phi3_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
215226 return _patch_split_gate_up_mlps (
216227 model ,
217228 {"transformers.models.phi3.modeling_phi3.Phi3MLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.phi3.modeling_phi3.Phi3MLP" ]},
218229 max_layers = max_layers ,
230+ filter_rules = filter_rules ,
219231 )
220232
221233
222234@register_model_patch ("phi4_multimodal" )
223- def patch_phi4_multimodal_runtime (model , max_layers : int | None = None ) -> list [str ]:
235+ def patch_phi4_multimodal_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
224236 return _patch_split_gate_up_mlps (
225237 model ,
226238 {
@@ -234,73 +246,86 @@ def patch_phi4_multimodal_runtime(model, max_layers: int | None = None) -> list[
234246 ],
235247 },
236248 max_layers = max_layers ,
249+ filter_rules = filter_rules ,
237250 )
238251
239252
240253@register_model_patch ("zamba2" )
241- def patch_zamba2_runtime (model , max_layers : int | None = None ) -> list [str ]:
254+ def patch_zamba2_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
242255 return _patch_split_gate_up_mlps (
243256 model ,
244257 {"transformers.models.zamba2.modeling_zamba2.Zamba2MLP" : _STANDARD_SPLIT_GATE_UP_CLASSES ["transformers.models.zamba2.modeling_zamba2.Zamba2MLP" ]},
245258 max_layers = max_layers ,
259+ filter_rules = filter_rules ,
246260 )
247261
248262
249263@register_model_patch ("dbrx" )
250- def patch_dbrx_runtime (model , max_layers : int | None = None ) -> list [str ]:
264+ def patch_dbrx_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
251265 return _patch_modules_by_class (
252266 model ,
253267 {"transformers.models.dbrx.modeling_dbrx.DbrxExperts" : patch_dbrx_experts },
254268 max_layers = max_layers ,
269+ filter_rules = filter_rules ,
255270 )
256271
257272
258- def _patch_parallel_runtime (model , class_path : str , * , max_layers : int | None = None ) -> list [str ]:
259- return _patch_modules_by_class (model , {class_path : patch_parallel_experts }, max_layers = max_layers )
273+ def _patch_parallel_runtime (model , class_path : str , * , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
274+ return _patch_modules_by_class (
275+ model ,
276+ {class_path : patch_parallel_experts },
277+ max_layers = max_layers ,
278+ filter_rules = filter_rules ,
279+ )
260280
261281
262282@register_model_patch ("granitemoe" )
263- def patch_granitemoe_runtime (model , max_layers : int | None = None ) -> list [str ]:
283+ def patch_granitemoe_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
264284 return _patch_parallel_runtime (
265285 model ,
266286 "transformers.models.granitemoe.modeling_granitemoe.GraniteMoeParallelExperts" ,
267287 max_layers = max_layers ,
288+ filter_rules = filter_rules ,
268289 )
269290
270291
271292@register_model_patch ("granitemoehybrid" )
272- def patch_granitemoehybrid_runtime (model , max_layers : int | None = None ) -> list [str ]:
293+ def patch_granitemoehybrid_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
273294 return _patch_parallel_runtime (
274295 model ,
275296 "transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridParallelExperts" ,
276297 max_layers = max_layers ,
298+ filter_rules = filter_rules ,
277299 )
278300
279301
280302@register_model_patch ("granitemoeshared" )
281- def patch_granitemoeshared_runtime (model , max_layers : int | None = None ) -> list [str ]:
303+ def patch_granitemoeshared_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
282304 return _patch_parallel_runtime (
283305 model ,
284306 "transformers.models.granitemoeshared.modeling_granitemoeshared.GraniteMoeSharedParallelExperts" ,
285307 max_layers = max_layers ,
308+ filter_rules = filter_rules ,
286309 )
287310
288311
289312@register_model_patch ("jetmoe" )
290- def patch_jetmoe_runtime (model , max_layers : int | None = None ) -> list [str ]:
313+ def patch_jetmoe_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
291314 return _patch_parallel_runtime (
292315 model ,
293316 "transformers.models.jetmoe.modeling_jetmoe.JetMoeParallelExperts" ,
294317 max_layers = max_layers ,
318+ filter_rules = filter_rules ,
295319 )
296320
297321
298322@register_model_patch ("longcat_flash" )
299- def patch_longcat_flash_runtime (model , max_layers : int | None = None ) -> list [str ]:
323+ def patch_longcat_flash_runtime (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
300324 return _patch_modules_by_class (
301325 model ,
302326 {"transformers.models.longcat_flash.modeling_longcat_flash.LongcatFlashExperts" : patch_longcat_flash_experts },
303327 max_layers = max_layers ,
328+ filter_rules = filter_rules ,
304329 )
305330
306331
@@ -316,15 +341,15 @@ def apply_model_class_patches(model_type) -> list[str]:
316341 return applied
317342
318343
319- def apply_model_patches (model , max_layers : int | None = None ) -> list [str ]:
344+ def apply_model_patches (model , max_layers : int | None = None , filter_rules = None ) -> list [str ]:
320345 """Run any registered runtime patch for the instantiated ``model``."""
321346 config = getattr (model , "config" , None )
322347 model_type = getattr (config , "model_type" , None )
323348 patch = _MODEL_PATCH_REGISTRY .get (model_type )
324349 if patch is None :
325350 return []
326351
327- applied = patch (model , max_layers = max_layers )
352+ applied = patch (model , max_layers = max_layers , filter_rules = filter_rules )
328353 if applied and DEBUG_ON :
329354 logger .debug (f"Applied model patches for model_type={ model_type } : { ', ' .join (applied )} " )
330355 return applied
0 commit comments