Skip to content

Commit 4f5c909

Browse files
committed
add filter control
1 parent 4dbd864 commit 4f5c909

10 files changed

Lines changed: 380 additions & 31 deletions

File tree

README.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,17 @@ from defuser import convert_model, replace_fused_blocks
3333
```
3434

3535
- `replace_fused_blocks(model_type)` patches supported HF model classes before `from_pretrained()` or direct model construction.
36-
- `convert_model(model, cleanup_original=True, max_layers=None)` converts an already loaded model in place. This is the runtime defusion path for supported post-load expert and MLP conversions, including `qwen3_5_moe` style checkpoints.
36+
- `convert_model(model, cleanup_original=True, max_layers=None, filter=None)` converts an already loaded model in place. This is the runtime defusion path for supported post-load expert and MLP conversions, including `qwen3_5_moe` style checkpoints.
3737
- Defuser is designed and CI-tested for `transformers>=5.3.0`, and support is only offered for that version range. Older versions log a warning on these public APIs and are skipped as unsupported.
3838

39+
`filter` is an optional list of PCRE regex rules evaluated against full module paths such as `model.layers.0.mlp.experts`:
40+
41+
- `+:regex` explicitly includes matching candidate module paths
42+
- `-:regex` explicitly excludes matching candidate module paths
43+
- `regex` is shorthand for `+:regex`
44+
- negative rules take priority over positive rules
45+
- when `filter` is provided, a candidate module is defused only if it matches at least one positive rule and no negative rules
46+
3947
## Supported Models
4048

4149
Defuser currently supports the following `transformers==5.3.0` `model_type` values.
@@ -91,6 +99,20 @@ converted = convert_model(model)
9199
print(converted) # True when runtime defusion happened
92100
```
93101

102+
Use `filter` when only specific blocks should be defused:
103+
104+
```python
105+
from defuser import convert_model
106+
107+
convert_model(
108+
model,
109+
filter=[
110+
r"+:^model\.layers\.0\.mlp\.experts$",
111+
r"-:^model\.layers\.0\.mlp\.experts\.shared_",
112+
],
113+
)
114+
```
115+
94116
## Real Qwen3.5 MoE Example
95117

96118
The example below is written for the `transformers==5.3.0` public API surface and uses the real Hugging Face model `Qwen/Qwen3.5-35B-A3B-Instruct`. Defuser supports `transformers>=5.3.0`.

defuser/defuser.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def convert_model(
117117
model: nn.Module,
118118
cleanup_original: bool = False,
119119
max_layers: int | None = None,
120+
filter: list[str] | None = None,
120121
) -> bool:
121122
"""Convert one loaded model in place from fused experts to defused modules."""
122123
if warn_if_public_api_transformers_unsupported("convert_model()", logger):
@@ -200,7 +201,7 @@ def convert_model(
200201
if not check_model_compatibility(model):
201202
return False
202203

203-
apply_model_patches(model, max_layers=max_layers)
204+
apply_model_patches(model, max_layers=max_layers, filter_rules=filter)
204205

205206
# If fused blocks have already been structurally replaced at load model before,
206207
# there is no need to perform runtime defusing again
@@ -214,6 +215,7 @@ def convert_model(
214215
model,
215216
cleanup_original=cleanup_original,
216217
max_layers=max_layers,
218+
filter_rules=filter,
217219
)
218220

219221
return True

defuser/modeling/model_patches.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
patch_parallel_experts,
1515
patch_split_gate_up_mlp,
1616
)
17-
from defuser.utils.common import is_within_max_layers
17+
from defuser.utils.common import compile_module_name_filter, is_within_max_layers, matches_module_name_filter
1818
import torch
1919

2020
logger = LogBar(__name__)
@@ -87,7 +87,7 @@ def patched_init_weights(self, module):
8787

8888

8989
@register_model_patch("qwen3_omni_moe")
90-
def patch_qwen3_omni_text_runtime(model, max_layers: int | None = None) -> list[str]:
90+
def patch_qwen3_omni_text_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
9191
"""Restore text-only ``forward`` and ``generate`` behavior after class swapping."""
9292
model_cls = type(model)
9393
if not getattr(model_cls, "__module__", "").startswith("transformers.models.qwen3_omni_moe."):
@@ -122,11 +122,15 @@ def _patch_modules_by_class(
122122
patchers: dict[str, Callable],
123123
*,
124124
max_layers: int | None = None,
125+
filter_rules=None,
125126
) -> list[str]:
127+
module_name_filter = compile_module_name_filter(filter_rules)
126128
applied = []
127129
for name, module in list(model.named_modules()):
128130
if not is_within_max_layers(name, max_layers):
129131
continue
132+
if not matches_module_name_filter(name, module_name_filter):
133+
continue
130134
class_path = f"{module.__class__.__module__}.{module.__class__.__name__}"
131135
patcher = patchers.get(class_path)
132136
if patcher is None:
@@ -141,6 +145,7 @@ def _patch_split_gate_up_mlps(
141145
patchers: dict[str, str],
142146
*,
143147
max_layers: int | None = None,
148+
filter_rules=None,
144149
) -> list[str]:
145150
return _patch_modules_by_class(
146151
model,
@@ -149,6 +154,7 @@ def _patch_split_gate_up_mlps(
149154
for class_path, variant in patchers.items()
150155
},
151156
max_layers=max_layers,
157+
filter_rules=filter_rules,
152158
)
153159

154160

@@ -166,61 +172,67 @@ def _patch_split_gate_up_mlps(
166172

167173

168174
@register_model_patch("dia")
169-
def patch_dia_runtime(model, max_layers: int | None = None) -> list[str]:
175+
def patch_dia_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
170176
return _patch_split_gate_up_mlps(
171177
model,
172178
{"transformers.models.dia.modeling_dia.DiaMLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.dia.modeling_dia.DiaMLP"]},
173179
max_layers=max_layers,
180+
filter_rules=filter_rules,
174181
)
175182

176183

177184
@register_model_patch("glm")
178-
def patch_glm_runtime(model, max_layers: int | None = None) -> list[str]:
185+
def patch_glm_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
179186
return _patch_split_gate_up_mlps(
180187
model,
181188
{"transformers.models.glm.modeling_glm.GlmMLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.glm.modeling_glm.GlmMLP"]},
182189
max_layers=max_layers,
190+
filter_rules=filter_rules,
183191
)
184192

185193

186194
@register_model_patch("glm4")
187-
def patch_glm4_runtime(model, max_layers: int | None = None) -> list[str]:
195+
def patch_glm4_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
188196
return _patch_split_gate_up_mlps(
189197
model,
190198
{"transformers.models.glm4.modeling_glm4.Glm4MLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.glm4.modeling_glm4.Glm4MLP"]},
191199
max_layers=max_layers,
200+
filter_rules=filter_rules,
192201
)
193202

194203

195204
@register_model_patch("glm_image")
196-
def patch_glm_image_runtime(model, max_layers: int | None = None) -> list[str]:
205+
def patch_glm_image_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
197206
return _patch_split_gate_up_mlps(
198207
model,
199208
{"transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.glm_image.modeling_glm_image.GlmImageTextMLP"]},
200209
max_layers=max_layers,
210+
filter_rules=filter_rules,
201211
)
202212

203213

204214
@register_model_patch("glm_ocr")
205-
def patch_glm_ocr_runtime(model, max_layers: int | None = None) -> list[str]:
215+
def patch_glm_ocr_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
206216
return _patch_split_gate_up_mlps(
207217
model,
208218
{"transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.glm_ocr.modeling_glm_ocr.GlmOcrTextMLP"]},
209219
max_layers=max_layers,
220+
filter_rules=filter_rules,
210221
)
211222

212223

213224
@register_model_patch("phi3")
214-
def patch_phi3_runtime(model, max_layers: int | None = None) -> list[str]:
225+
def patch_phi3_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
215226
return _patch_split_gate_up_mlps(
216227
model,
217228
{"transformers.models.phi3.modeling_phi3.Phi3MLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.phi3.modeling_phi3.Phi3MLP"]},
218229
max_layers=max_layers,
230+
filter_rules=filter_rules,
219231
)
220232

221233

222234
@register_model_patch("phi4_multimodal")
223-
def patch_phi4_multimodal_runtime(model, max_layers: int | None = None) -> list[str]:
235+
def patch_phi4_multimodal_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
224236
return _patch_split_gate_up_mlps(
225237
model,
226238
{
@@ -234,73 +246,86 @@ def patch_phi4_multimodal_runtime(model, max_layers: int | None = None) -> list[
234246
],
235247
},
236248
max_layers=max_layers,
249+
filter_rules=filter_rules,
237250
)
238251

239252

240253
@register_model_patch("zamba2")
241-
def patch_zamba2_runtime(model, max_layers: int | None = None) -> list[str]:
254+
def patch_zamba2_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
242255
return _patch_split_gate_up_mlps(
243256
model,
244257
{"transformers.models.zamba2.modeling_zamba2.Zamba2MLP": _STANDARD_SPLIT_GATE_UP_CLASSES["transformers.models.zamba2.modeling_zamba2.Zamba2MLP"]},
245258
max_layers=max_layers,
259+
filter_rules=filter_rules,
246260
)
247261

248262

249263
@register_model_patch("dbrx")
250-
def patch_dbrx_runtime(model, max_layers: int | None = None) -> list[str]:
264+
def patch_dbrx_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
251265
return _patch_modules_by_class(
252266
model,
253267
{"transformers.models.dbrx.modeling_dbrx.DbrxExperts": patch_dbrx_experts},
254268
max_layers=max_layers,
269+
filter_rules=filter_rules,
255270
)
256271

257272

258-
def _patch_parallel_runtime(model, class_path: str, *, max_layers: int | None = None) -> list[str]:
259-
return _patch_modules_by_class(model, {class_path: patch_parallel_experts}, max_layers=max_layers)
273+
def _patch_parallel_runtime(model, class_path: str, *, max_layers: int | None = None, filter_rules=None) -> list[str]:
274+
return _patch_modules_by_class(
275+
model,
276+
{class_path: patch_parallel_experts},
277+
max_layers=max_layers,
278+
filter_rules=filter_rules,
279+
)
260280

261281

262282
@register_model_patch("granitemoe")
263-
def patch_granitemoe_runtime(model, max_layers: int | None = None) -> list[str]:
283+
def patch_granitemoe_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
264284
return _patch_parallel_runtime(
265285
model,
266286
"transformers.models.granitemoe.modeling_granitemoe.GraniteMoeParallelExperts",
267287
max_layers=max_layers,
288+
filter_rules=filter_rules,
268289
)
269290

270291

271292
@register_model_patch("granitemoehybrid")
272-
def patch_granitemoehybrid_runtime(model, max_layers: int | None = None) -> list[str]:
293+
def patch_granitemoehybrid_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
273294
return _patch_parallel_runtime(
274295
model,
275296
"transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridParallelExperts",
276297
max_layers=max_layers,
298+
filter_rules=filter_rules,
277299
)
278300

279301

280302
@register_model_patch("granitemoeshared")
281-
def patch_granitemoeshared_runtime(model, max_layers: int | None = None) -> list[str]:
303+
def patch_granitemoeshared_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
282304
return _patch_parallel_runtime(
283305
model,
284306
"transformers.models.granitemoeshared.modeling_granitemoeshared.GraniteMoeSharedParallelExperts",
285307
max_layers=max_layers,
308+
filter_rules=filter_rules,
286309
)
287310

288311

289312
@register_model_patch("jetmoe")
290-
def patch_jetmoe_runtime(model, max_layers: int | None = None) -> list[str]:
313+
def patch_jetmoe_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
291314
return _patch_parallel_runtime(
292315
model,
293316
"transformers.models.jetmoe.modeling_jetmoe.JetMoeParallelExperts",
294317
max_layers=max_layers,
318+
filter_rules=filter_rules,
295319
)
296320

297321

298322
@register_model_patch("longcat_flash")
299-
def patch_longcat_flash_runtime(model, max_layers: int | None = None) -> list[str]:
323+
def patch_longcat_flash_runtime(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
300324
return _patch_modules_by_class(
301325
model,
302326
{"transformers.models.longcat_flash.modeling_longcat_flash.LongcatFlashExperts": patch_longcat_flash_experts},
303327
max_layers=max_layers,
328+
filter_rules=filter_rules,
304329
)
305330

306331

@@ -316,15 +341,15 @@ def apply_model_class_patches(model_type) -> list[str]:
316341
return applied
317342

318343

319-
def apply_model_patches(model, max_layers: int | None = None) -> list[str]:
344+
def apply_model_patches(model, max_layers: int | None = None, filter_rules=None) -> list[str]:
320345
"""Run any registered runtime patch for the instantiated ``model``."""
321346
config = getattr(model, "config", None)
322347
model_type = getattr(config, "model_type", None)
323348
patch = _MODEL_PATCH_REGISTRY.get(model_type)
324349
if patch is None:
325350
return []
326351

327-
applied = patch(model, max_layers=max_layers)
352+
applied = patch(model, max_layers=max_layers, filter_rules=filter_rules)
328353
if applied and DEBUG_ON:
329354
logger.debug(f"Applied model patches for model_type={model_type}: {', '.join(applied)}")
330355
return applied

defuser/modeling/moe_experts_interface.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torch import nn
3333

3434
from defuser.model_registry import MODEL_CONFIG, PATCH
35+
from defuser.utils.common import compile_module_name_filter, matches_module_name_filter
3536
from defuser.utils.device import clear_memory, to_meta
3637

3738
from defuser import DEBUG_ON
@@ -693,7 +694,11 @@ def _unfuse_experts_weights_inplace(
693694
return True
694695

695696

696-
def prepare_model_for_moe_quantization(model: nn.Module, implementation: str = LINEAR_LOOP_IMPL) -> list[str]:
697+
def prepare_model_for_moe_quantization(
698+
model: nn.Module,
699+
implementation: str = LINEAR_LOOP_IMPL,
700+
filter_rules=None,
701+
) -> list[str]:
697702
"""Prepare a model for MOE quantization using transformers' experts interface.
698703
699704
This function:
@@ -722,7 +727,10 @@ def prepare_model_for_moe_quantization(model: nn.Module, implementation: str = L
722727
unfused_modules = []
723728
decorated_unfused_modules = []
724729
experts_defuse_specs = _model_experts_defuse_specs(model)
730+
module_name_filter = compile_module_name_filter(filter_rules)
725731
for name, module in model.named_modules():
732+
if not matches_module_name_filter(name, module_name_filter):
733+
continue
726734
spec = _matching_experts_defuse_spec(module, experts_defuse_specs)
727735
if spec is not None and _unfuse_experts_weights_inplace(
728736
module,

0 commit comments

Comments
 (0)