Test mistral architecture adapter#1312
Conversation
* Fix type of HookedTransformerConfig.device This is typed as `Optional[str]` but sometimes returns `torch.device`. Updated the code to just return the `str` instead of wrapping with a device. I'm not confident that every function which takes a device will always be passed a string, so I didn't change functions like warn_if_mps. Found while working on TransformerLensOrg#1219 * more cleanup * 3.0 CI Bugs (TransformerLensOrg#1261) * Fixing `utils` imports * skip gated notebooks on PR from forks * Updating notebooks * Ensure LLaMA only runs when HF_TOKEN is available --------- Co-authored-by: jlarson4 <jonahalarson@comcast.net>
TransformerLens 3.1.0
|
Hi @indrayani21! As you continue to work on more of these, please start each new set of changes on Also, make sure to run |
| # Factory registration tests | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| class TestMistralFactory: |
There was a problem hiding this comment.
Can we add a factory dispatch test? Currently this only checks if "MistralForCausalLM" in SUPPORTED_ARCHITECTURES. It never calls ArchitectureAdapterFactory.select_architecture_adapter(cfg) to verify dispatch actually returns a MistralArchitectureAdapter. You should be able to match the same test you wrote in test_bloom_adapter.py
|
|
||
| def test_n_key_value_heads(self, adapter: MistralArchitectureAdapter) -> None: | ||
| assert adapter.cfg.n_key_value_heads == 4 | ||
| assert adapter.default_config["n_key_value_heads"] == 4 |
There was a problem hiding this comment.
The adapter unconditionally accesses cfg.n_key_value_heads when setting n_key_value_heads on default_config. This means a cfg without n_key_value_heads will raise an AttributeError at adapter construction. Can we add an additional test that catches this assertion?
Description
Adds focused unit tests for the Mistral architecture adapter in
tests/unit/model_bridge/supported_architectures/test_mistral_adapter.py.This PR improves test coverage for the Mistral architecture adapter by validating key architectural behaviors and configuration settings.
What this PR includes
Config validation for Mistral-specific settings:
n_key_value_heads) supportComponent mapping validation:
requires_attention_mask,requires_position_embeddings)Weight conversion validation:
Factory registration check:
MistralForCausalLMis correctly registered inSUPPORTED_ARCHITECTURESMotivation
Mistral is a key decoder-only architecture with grouped-query attention and RMSNorm-based design. This test suite ensures the architecture adapter correctly represents these structural properties and aligns with existing patterns used for Llama and CodeGen adapters.
Related to #1302
Type of change
Checklist