Add granitemode support#2124
Open
amdrajeevp1 wants to merge 3 commits into
Open
Conversation
ibm-granite/granite-4.0-1b uses GraniteMoeHybridForCausalLM which was previously unsupported, causing NotImplementedError or incorrect exports with 20-40% accuracy loss when users forced it through GraniteForCausalLM. For granite-4.0-1b all layers are attention-type and MoE routing is disabled (num_local_experts=0), so computation is equivalent to GraniteForCausalLM. The only structural difference is the MLP layout: shared_mlp.input_linear (fused gate+up) + shared_mlp.output_linear (down) instead of mlp.gate_proj/up_proj/down_proj, and the MLP width is read from shared_intermediate_size rather than intermediate_size. GraniteMoeHybridModel subclasses GraniteModel and adapts the module layout in make_layer so the existing make_mlp_unpacked_regular and make_mlp_proj infrastructure can be reused without modification. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add _mlp_intermediate_size attribute set before super().__init__ so self.intermediate_size gets the correct shared_intermediate_size value - Override make_mlp_unpacked_regular to split input_linear at _mlp_intermediate_size explicitly, independent of self.intermediate_size - Add rope_parameters dict fallback in base.py rope_theta resolution to support models that store rope config as a dict attribute Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
Pull request overview
Adds ONNX Runtime GenAI model-builder support for Hugging Face Granite 4.0 “MoE-hybrid” configs (e.g., ibm-granite/granite-4.0-1b) by introducing a dedicated builder that correctly maps the fused MLP weight layout and config differences, and registers the resulting model type for runtime classification.
Changes:
- Add
GraniteMoeHybridModelbuilder to adaptshared_mlp/ fusedinput_linear(gate+up) andshared_intermediate_sizehandling. - Extend RoPE theta resolution to support configs that store
rope_thetaunderrope_parameters. - Wire the new architecture through the Python builder/export surface and register
granitemoehybridas an LLM model type in C++.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| src/python/py/models/builders/granite.py | Introduces GraniteMoeHybridModel to remap Granite 4.0 hybrid MLP structure for export. |
| src/python/py/models/builders/base.py | Adds rope_parameters fallback when resolving rope_theta. |
| src/python/py/models/builders/init.py | Exports GraniteMoeHybridModel from the builders package. |
| src/python/py/models/builder.py | Routes GraniteMoeHybridForCausalLM configs to the new builder class. |
| src/models/model_type.h | Registers granitemoehybrid as an LLM type for runtime classification. |
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
| def make_mlp_unpacked_regular(self, layer_id, mlp, gate_up_linear, root_input): | ||
| # Override to split input_linear at _mlp_intermediate_size explicitly, | ||
| # making the slicing independent of self.intermediate_size. | ||
| import torch |
Contributor
There was a problem hiding this comment.
Can we move all imports to be global?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds support for exporting
ibm-granite/granite-4.0-1b(and compatible granite-4.0 variants) via the OGA modelbuilder.
granite-4.0-1buses theGraniteMoeHybridForCausalLMarchitecture, which was previously unsupported — falling backto
GraniteForCausalLMcaused incorrect weight mapping and significant accuracy loss. This PR adds a dedicatedGraniteMoeHybridModelbuilder that handles the structural differences while reusing existing granite infrastructure.Key differences from GraniteForCausalLM:
shared_mlp.input_linear(fused gate+up) andshared_mlp.output_linear(down) instead ofmlp.gate_proj/up_proj/down_projshared_intermediate_sizerather thanintermediate_sizerope_thetais stored under arope_parametersdict in the configFor the 1B variant, all layers are attention-type and MoE routing is inactive (
num_local_experts=0), so it iscomputationally equivalent to GraniteForCausalLM with only the weight layout differing.
Changes
src/python/py/models/builders/granite.py— addsGraniteMoeHybridModelsubclassingGraniteModel; overridesmake_layerandmake_mlp_unpacked_regularto handle the fused input_linear weight and correct intermediate sizesrc/python/py/models/builders/base.py— addsrope_parametersdict fallback in rope_theta resolutionsrc/python/py/models/builder.py— mapsGraniteMoeHybridForCausalLM→GraniteMoeHybridModelsrc/python/py/models/builders/__init__.py— exportsGraniteMoeHybridModelsrc/models/model_type.h— registersgranitemoehybridinIsLLM()Test plan
ibm-granite/granite-4.0-1b-basewithpython builder.py -m ibm-granite/granite-4.0-1b-base -o out/ -p fp32 -e cpuGraniteForCausalLMmodels (granite-3.x) are unaffected