Skip to content

Add granitemode support#2124

Open
amdrajeevp1 wants to merge 3 commits into
microsoft:mainfrom
amdrajeevp1:add-granitemode-support
Open

Add granitemode support#2124
amdrajeevp1 wants to merge 3 commits into
microsoft:mainfrom
amdrajeevp1:add-granitemode-support

Conversation

@amdrajeevp1
Copy link
Copy Markdown
Contributor

Summary

Adds support for exporting ibm-granite/granite-4.0-1b (and compatible granite-4.0 variants) via the OGA model
builder.

granite-4.0-1b uses the GraniteMoeHybridForCausalLM architecture, which was previously unsupported — falling back
to GraniteForCausalLM caused incorrect weight mapping and significant accuracy loss. This PR adds a dedicated
GraniteMoeHybridModel builder that handles the structural differences while reusing existing granite infrastructure.

Key differences from GraniteForCausalLM:

  • MLP weights are in shared_mlp.input_linear (fused gate+up) and shared_mlp.output_linear (down) instead of
    mlp.gate_proj/up_proj/down_proj
  • MLP width comes from shared_intermediate_size rather than intermediate_size
  • rope_theta is stored under a rope_parameters dict in the config

For the 1B variant, all layers are attention-type and MoE routing is inactive (num_local_experts=0), so it is
computationally equivalent to GraniteForCausalLM with only the weight layout differing.

Changes

  • src/python/py/models/builders/granite.py — adds GraniteMoeHybridModel subclassing GraniteModel; overrides
    make_layer and make_mlp_unpacked_regular to handle the fused input_linear weight and correct intermediate size
  • src/python/py/models/builders/base.py — adds rope_parameters dict fallback in rope_theta resolution
  • src/python/py/models/builder.py — maps GraniteMoeHybridForCausalLMGraniteMoeHybridModel
  • src/python/py/models/builders/__init__.py — exports GraniteMoeHybridModel
  • src/models/model_type.h — registers granitemoehybrid in IsLLM()

Test plan

  • Export ibm-granite/granite-4.0-1b-base with python builder.py -m ibm-granite/granite-4.0-1b-base -o out/ -p fp32 -e cpu
  • Verify model loads and generates text without errors
  • Confirm output matches PyTorch reference
  • Confirm existing GraniteForCausalLM models (granite-3.x) are unaffected

amdrajeevp1 and others added 2 commits May 5, 2026 16:54
ibm-granite/granite-4.0-1b uses GraniteMoeHybridForCausalLM which was
previously unsupported, causing NotImplementedError or incorrect exports
with 20-40% accuracy loss when users forced it through GraniteForCausalLM.

For granite-4.0-1b all layers are attention-type and MoE routing is
disabled (num_local_experts=0), so computation is equivalent to
GraniteForCausalLM. The only structural difference is the MLP layout:
shared_mlp.input_linear (fused gate+up) + shared_mlp.output_linear (down)
instead of mlp.gate_proj/up_proj/down_proj, and the MLP width is read
from shared_intermediate_size rather than intermediate_size.

GraniteMoeHybridModel subclasses GraniteModel and adapts the module
layout in make_layer so the existing make_mlp_unpacked_regular and
make_mlp_proj infrastructure can be reused without modification.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add _mlp_intermediate_size attribute set before super().__init__ so
  self.intermediate_size gets the correct shared_intermediate_size value
- Override make_mlp_unpacked_regular to split input_linear at
  _mlp_intermediate_size explicitly, independent of self.intermediate_size
- Add rope_parameters dict fallback in base.py rope_theta resolution to
  support models that store rope config as a dict attribute

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@amdrajeevp1 amdrajeevp1 requested a review from a team as a code owner May 6, 2026 00:00
Copilot AI review requested due to automatic review settings May 6, 2026 00:00
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds ONNX Runtime GenAI model-builder support for Hugging Face Granite 4.0 “MoE-hybrid” configs (e.g., ibm-granite/granite-4.0-1b) by introducing a dedicated builder that correctly maps the fused MLP weight layout and config differences, and registers the resulting model type for runtime classification.

Changes:

  • Add GraniteMoeHybridModel builder to adapt shared_mlp / fused input_linear (gate+up) and shared_intermediate_size handling.
  • Extend RoPE theta resolution to support configs that store rope_theta under rope_parameters.
  • Wire the new architecture through the Python builder/export surface and register granitemoehybrid as an LLM model type in C++.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
src/python/py/models/builders/granite.py Introduces GraniteMoeHybridModel to remap Granite 4.0 hybrid MLP structure for export.
src/python/py/models/builders/base.py Adds rope_parameters fallback when resolving rope_theta.
src/python/py/models/builders/init.py Exports GraniteMoeHybridModel from the builders package.
src/python/py/models/builder.py Routes GraniteMoeHybridForCausalLM configs to the new builder class.
src/models/model_type.h Registers granitemoehybrid as an LLM type for runtime classification.

Comment thread src/python/py/models/builders/granite.py Outdated
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
def make_mlp_unpacked_regular(self, layer_id, mlp, gate_up_linear, root_input):
# Override to split input_linear at _mlp_intermediate_size explicitly,
# making the slicing independent of self.intermediate_size.
import torch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move all imports to be global?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants