Skip to content

Test mistral architecture adapter#1312

Open
indrayani21 wants to merge 8 commits into
TransformerLensOrg:devfrom
indrayani21:test-mistral-architecture-adapter
Open

Test mistral architecture adapter#1312
indrayani21 wants to merge 8 commits into
TransformerLensOrg:devfrom
indrayani21:test-mistral-architecture-adapter

Conversation

@indrayani21
Copy link
Copy Markdown

Description

Adds focused unit tests for the Mistral architecture adapter in tests/unit/model_bridge/supported_architectures/test_mistral_adapter.py.

This PR improves test coverage for the Mistral architecture adapter by validating key architectural behaviors and configuration settings.

What this PR includes

  • Config validation for Mistral-specific settings:

    • RMS normalization
    • Rotary positional embeddings
    • Gated MLP configuration
    • Final RMS disabled
    • Attention-only disabled
    • GQA (n_key_value_heads) support
  • Component mapping validation:

    • Embedding, attention, MLP, and normalization bridges
    • Correct HuggingFace module name mappings
    • Attention flags (requires_attention_mask, requires_position_embeddings)
  • Weight conversion validation:

    • Ensures Q/K/V/O projection keys exist in conversion mappings
  • Factory registration check:

    • Verifies MistralForCausalLM is correctly registered in SUPPORTED_ARCHITECTURES

Motivation

Mistral is a key decoder-only architecture with grouped-query attention and RMSNorm-based design. This test suite ensures the architecture adapter correctly represents these structural properties and aligns with existing patterns used for Llama and CodeGen adapters.

Related to #1302

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests affecting core interfaces or backward compatibility

brendanlong and others added 8 commits April 20, 2026 14:50
* Fix type of HookedTransformerConfig.device

This is typed as `Optional[str]` but sometimes returns `torch.device`.
Updated the code to just return the `str` instead of wrapping with a
device.

I'm not confident that every function which takes a device will
always be passed a string, so I didn't change functions like
warn_if_mps.

Found while working on TransformerLensOrg#1219

* more cleanup

* 3.0 CI Bugs (TransformerLensOrg#1261)

* Fixing `utils` imports

* skip gated notebooks on PR from forks

* Updating notebooks

* Ensure LLaMA only runs when HF_TOKEN is available

---------

Co-authored-by: jlarson4 <jonahalarson@comcast.net>
@jlarson4
Copy link
Copy Markdown
Collaborator

Hi @indrayani21! As you continue to work on more of these, please start each new set of changes on dev or main. When you stack your PRs like this, it makes it difficult to manage their integration into the project.

Also, make sure to run make check-format & resolve any format requests on your code for each PR

@jlarson4 jlarson4 changed the base branch from main to dev May 18, 2026 15:32
# Factory registration tests
# ---------------------------------------------------------------------------

class TestMistralFactory:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a factory dispatch test? Currently this only checks if "MistralForCausalLM" in SUPPORTED_ARCHITECTURES. It never calls ArchitectureAdapterFactory.select_architecture_adapter(cfg) to verify dispatch actually returns a MistralArchitectureAdapter. You should be able to match the same test you wrote in test_bloom_adapter.py


def test_n_key_value_heads(self, adapter: MistralArchitectureAdapter) -> None:
assert adapter.cfg.n_key_value_heads == 4
assert adapter.default_config["n_key_value_heads"] == 4
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The adapter unconditionally accesses cfg.n_key_value_heads when setting n_key_value_heads on default_config. This means a cfg without n_key_value_heads will raise an AttributeError at adapter construction. Can we add an additional test that catches this assertion?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants