Skip to content

Add HunYuan Dense V1 (hunyuan_v1_dense) model support#2144

Open
anilmartha wants to merge 22 commits into
microsoft:mainfrom
anilmartha:tencent_hanyuan1
Open

Add HunYuan Dense V1 (hunyuan_v1_dense) model support#2144
anilmartha wants to merge 22 commits into
microsoft:mainfrom
anilmartha:tencent_hanyuan1

Conversation

@anilmartha
Copy link
Copy Markdown
Contributor

Summary

Adds model-builder and runtime support for Tencent's HY-MT series (architecture class: HunYuanDenseV1ForCausalLM), covering both the 1.8B and 7B parameter variants.

Changes

src/python/py/models/builders/hunyuan.py — New HunyuanDenseV1Model builder subclassing Model (base). Key overrides:

  • Overrides make_attention_qk_subgraph to apply QK norms (query/key LayerNorm) after RoPE — the correct order for this architecture (base class applies them before).
  • Bakes Hunyuan's Dynamic NTK-alpha RoPE scaling into a static rope_theta at export time:
    • effective_theta = rope_theta × α^(head_dim / (head_dim − 2))
    • then clears rope_scaling so the standard RoPE codepath is used.
  • Forces disable_qkv_fusion=True and use_rope_in_attn=False to create the separate Q/K paths required for post-RoPE QK norm insertion.

src/models/model_type.h — Registers "hunyuan_v1_dense" in the LLM model-type array (21 → 22 entries).

src/python/py/models/builder.py and src/python/py/models/builders/__init__.py — Wire HunyuanDenseV1Model under the hunyuan_v1_dense model-type key.

examples/python/test_hy_mt.py — Example inference script using the HF tokenizer for correct special-token handling.

Architecture Notes

HunYuan Dense V1 differs from Llama-style models in two ways:

  • Post-RoPE QK norm: Q and K tensors are normalized after rotary position embedding, not before.
  • NTK-alpha RoPE scaling: Uses a dynamic alpha-based formula resolved at export time to avoid runtime overhead.

All weight names are standard (no custom mapping needed).

Requirement: transformers >= 4.57 for HunYuanDenseV1ForCausalLM.

amdrajeevp1 and others added 12 commits March 20, 2026 19:46
Adds builder and runtime support for tencent/HY-MT series models
(HunYuanDenseV1ForCausalLM). Key implementation details:

- New HunyuanDenseV1Model builder (src/python/py/models/builders/hunyuan.py)
  that overrides make_attention_qk_subgraph to apply QK norms AFTER RoPE,
  matching the model's architecture (standard base class applies QK norm before).
- Dynamic NTK-alpha RoPE scaling baked into static rope_theta at export time:
  effective_theta = rope_theta * alpha^(head_dim/(head_dim-2))
- Forces disable_qkv_fusion and use_rope_in_attn=False to enable separate
  Q/K paths required for post-RoPE QK norm insertion.
- Registers "hunyuan_v1_dense" in the LLM array in model_type.h (size 21->22).
- Includes example inference script using HF tokenizer for correct special-token handling.

Supports 1.8B and 7B variants of tencent/HY-MT1.5 (same architecture class).
Requires transformers>=4.57 for native HunYuanDenseV1 support.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copilot AI review requested due to automatic review settings May 8, 2026 07:02
@anilmartha anilmartha requested a review from a team as a code owner May 8, 2026 07:02
@anilmartha
Copy link
Copy Markdown
Contributor Author

@microsoft-github-policy-service agree company="AMD"

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@kunal-vaishnavi
Copy link
Copy Markdown
Contributor

In the older PR, I added several review comments. Let's fix them in this PR.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.

Comment thread src/python/py/models/builders/hunyuan.py Outdated
Comment thread src/python/py/models/builders/hunyuan.py
Comment thread src/python/py/models/builder.py
Comment thread src/python/py/models/builders/hunyuan.py Fixed
Anil Kumar Martha and others added 3 commits May 18, 2026 04:41
…oft#2045)

Split the monolithic make_attention_qk_subgraph in base.py into three
composable methods that subclasses can override individually:

  make_attention_qk_norm(layer_id, attention)
      Makes Q/K SimplifiedLayerNorm nodes when q_norm/k_norm are set.

  make_attention_qk_rope(layer_id, **kwargs) -> (cos_cache, sin_cache)
      Makes RotaryEmbedding nodes (or caches for use_rope_in_attn).

  make_attention_qk_rope_and_norm(layer_id, attention, **kwargs)
      Calls norm then rope (base order); returns cache names.

make_attention_qk_subgraph now delegates to make_attention_qk_rope_and_norm
so the rest of the method (repeat_kv, sinks, attention op) is unchanged.

HunyuanDenseV1Model no longer overrides the full make_attention_qk_subgraph.
It only overrides make_attention_qk_rope_and_norm to reverse the order
(RoPE first, then QK norm) which is the Hunyuan-specific requirement.

Also explicitly sets attention_attrs["q_norm"] = True and ["k_norm"] = True
after super().__init__() so make_attention_qk_norm correctly emits the
QK norm nodes — these flags default to False in base.py and are never
auto-detected, every model using QK norm must set them explicitly.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…ft#2045)

Replace explicit disable_qkv_fusion extra option with make_attention_init
override that sets q_norm and k_norm before the base class packed matmul
check, matching the Qwen3Model pattern.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
@anilmartha
Copy link
Copy Markdown
Contributor Author

Hi @kunal-vaishnavi
I have incorporated the changes suggested in the PR. Could you please review it again?

Comment thread src/python/py/models/builders/hunyuan.py Fixed
Comment thread src/python/py/models/builders/hunyuan.py Outdated
Comment thread src/python/py/models/README.md
anilmartha and others added 2 commits May 20, 2026 06:35
Extract the fused RoPE support check into an overridable method on the base Model class, replacing the inline self.ep not in [dml] check in make_attention_init(). HunyuanDenseV1Model overrides it to return False since it needs explicit RotaryEmbedding nodes to insert QK norms between RoPE and attention.

Addresses PR microsoft#2144 review feedback.

Co-authored-by: Cursor <cursoragent@cursor.com>
@kunal-vaishnavi kunal-vaishnavi enabled auto-merge (squash) May 21, 2026 08:46
auto-merge was automatically disabled May 21, 2026 09:35

Head branch was pushed to by a user without write access

@kunal-vaishnavi kunal-vaishnavi enabled auto-merge (squash) May 21, 2026 09:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants