-
Notifications
You must be signed in to change notification settings - Fork 572
Adding adapter tests for Qwen2 #1309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Rishik00
wants to merge
6
commits into
TransformerLensOrg:dev
Choose a base branch
from
Rishik00:qwen-adapter-test
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+256
−0
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
58b007f
Fix type of HookedTransformerConfig.device (#1230)
brendanlong 6f56518
Merge pull request #1277 from TransformerLensOrg/dev
jlarson4 31d4f6a
Merge pull request #1294 from TransformerLensOrg/dev
jlarson4 5f7b02e
Merge pull request #1295 from TransformerLensOrg/dev
jlarson4 4570fe6
qwen2 adapter tests
Rishik00 2f8a436
qwen 2 adapter tests
Rishik00 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
256 changes: 256 additions & 0 deletions
256
tests/unit/model_bridge/supported_architectures/test_qwen2_adapter.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,256 @@ | ||
| """Unit tests for Qwen2ArchitectureAdapter. | ||
|
|
||
| Tests cover: | ||
| - Config attributes set by the adapter | ||
| - Component mapping structure and HF module paths | ||
| - Standard Q/K/V/O weight conversion rules, including GQA K/V head counts | ||
| - Narrow hook-shape coverage for Qwen2-style GQA attention with fake modules | ||
| - Factory registration | ||
| """ | ||
|
|
||
| from typing import Any | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from transformer_lens.config import TransformerBridgeConfig | ||
| from transformer_lens.conversion_utils.conversion_steps.rearrange_tensor_conversion import ( | ||
| RearrangeTensorConversion, | ||
| ) | ||
| from transformer_lens.conversion_utils.param_processing_conversion import ( | ||
| ParamProcessingConversion, | ||
| ) | ||
| from transformer_lens.factories.architecture_adapter_factory import ( | ||
| SUPPORTED_ARCHITECTURES, | ||
| ArchitectureAdapterFactory, | ||
| ) | ||
| from transformer_lens.model_bridge.generalized_components import ( | ||
| BlockBridge, | ||
| EmbeddingBridge, | ||
| LinearBridge, | ||
| MLPBridge, | ||
| PositionEmbeddingsAttentionBridge, | ||
| RMSNormalizationBridge, | ||
| RotaryEmbeddingBridge, | ||
| UnembeddingBridge, | ||
| ) | ||
| from transformer_lens.model_bridge.supported_architectures.qwen2 import ( | ||
| Qwen2ArchitectureAdapter, | ||
| ) | ||
|
|
||
|
|
||
| def _make_cfg( | ||
| n_heads: int = 4, | ||
| n_key_value_heads: int = 2, | ||
| d_model: int = 64, | ||
| n_layers: int = 2, | ||
| d_mlp: int = 256, | ||
| d_vocab: int = 100, | ||
| n_ctx: int = 64, | ||
| ) -> TransformerBridgeConfig: | ||
| # Keep dimensions tiny so adapter tests do not need HF downloads or real checkpoints. | ||
| return TransformerBridgeConfig( | ||
| d_model=d_model, | ||
| d_head=d_model // n_heads, | ||
| n_layers=n_layers, | ||
| n_ctx=n_ctx, | ||
| n_heads=n_heads, | ||
| n_key_value_heads=n_key_value_heads, | ||
| d_vocab=d_vocab, | ||
| d_mlp=d_mlp, | ||
| default_prepend_bos=False, | ||
| architecture="Qwen2ForCausalLM", | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def cfg() -> TransformerBridgeConfig: | ||
| return _make_cfg() | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def adapter(cfg: TransformerBridgeConfig) -> Qwen2ArchitectureAdapter: | ||
| return Qwen2ArchitectureAdapter(cfg) | ||
|
|
||
|
|
||
| class FakeQwen2Attention(nn.Module): | ||
| """Minimal Qwen2-style attention module for adapter hook-shape tests.""" | ||
|
|
||
| def __init__(self, cfg: TransformerBridgeConfig) -> None: | ||
| super().__init__() | ||
| # PositionEmbeddingsAttentionBridge reads these HF-style attributes during forward. | ||
| self.head_dim = cfg.d_head | ||
| self.num_key_value_groups = cfg.n_heads // ( | ||
| cfg.n_key_value_heads or cfg.n_heads | ||
| ) | ||
| self.scaling = cfg.d_head**-0.5 | ||
| self.attention_dropout = 0.0 | ||
|
|
||
| # Qwen2 uses GQA: Q has n_heads, while K/V have n_key_value_heads. | ||
| kv_width = (cfg.n_key_value_heads or cfg.n_heads) * cfg.d_head | ||
| self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * cfg.d_head, bias=False) | ||
| self.k_proj = nn.Linear(cfg.d_model, kv_width, bias=False) | ||
| self.v_proj = nn.Linear(cfg.d_model, kv_width, bias=False) | ||
| self.o_proj = nn.Linear(cfg.n_heads * cfg.d_head, cfg.d_model, bias=False) | ||
|
|
||
|
|
||
| class TestQwen2AdapterConfig: | ||
| """Adapter-owned config defaults that downstream bridge code relies on.""" | ||
|
|
||
| def test_normalization_type_is_rms(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| assert adapter.cfg.normalization_type == "RMS" | ||
|
|
||
| def test_positional_embedding_type_is_rotary( | ||
| self, adapter: Qwen2ArchitectureAdapter | ||
| ) -> None: | ||
| assert adapter.cfg.positional_embedding_type == "rotary" | ||
|
|
||
| def test_final_rms_is_true(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| assert adapter.cfg.final_rms is True | ||
|
|
||
| def test_gated_mlp_is_true(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| assert adapter.cfg.gated_mlp is True | ||
|
|
||
| def test_attn_only_is_false(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| assert adapter.cfg.attn_only is False | ||
|
|
||
| def test_default_prepend_bos_is_false( | ||
| self, adapter: Qwen2ArchitectureAdapter | ||
| ) -> None: | ||
| assert adapter.cfg.default_prepend_bos is False | ||
|
|
||
| def test_uses_rms_norm_is_true(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| assert adapter.cfg.uses_rms_norm is True | ||
|
|
||
| def test_n_key_value_heads_propagated(self) -> None: | ||
| adapter = Qwen2ArchitectureAdapter(_make_cfg(n_heads=8, n_key_value_heads=2)) | ||
| assert adapter.cfg.n_key_value_heads == 2 | ||
|
|
||
|
|
||
| class TestQwen2ComponentMapping: | ||
| """The adapter contract: TL canonical names mapped to Qwen2 HF module paths.""" | ||
|
|
||
| def test_top_level_keys(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| assert set(adapter.component_mapping.keys()) == { | ||
| "embed", | ||
| "rotary_emb", | ||
| "blocks", | ||
| "ln_final", | ||
| "unembed", | ||
| } | ||
|
|
||
| def test_embed_path(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| assert adapter.component_mapping["embed"].name == "model.embed_tokens" | ||
|
|
||
| def test_rotary_emb_path(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| assert adapter.component_mapping["rotary_emb"].name == "model.rotary_emb" | ||
|
|
||
| def test_blocks_path(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| assert adapter.component_mapping["blocks"].name == "model.layers" | ||
|
|
||
| def test_ln_final_path(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| assert adapter.component_mapping["ln_final"].name == "model.norm" | ||
|
|
||
| def test_unembed_path(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| assert adapter.component_mapping["unembed"].name == "lm_head" | ||
|
|
||
| def test_block_submodule_keys(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| blocks = adapter.component_mapping["blocks"] | ||
| assert set(blocks.submodules.keys()) == {"ln1", "ln2", "attn", "mlp"} | ||
|
|
||
| def test_attention_submodule_keys(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| attn = adapter.component_mapping["blocks"].submodules["attn"] | ||
| assert set(attn.submodules.keys()) == {"q", "k", "v", "o"} | ||
|
|
||
| def test_mlp_submodule_keys(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| mlp = adapter.component_mapping["blocks"].submodules["mlp"] | ||
| assert set(mlp.submodules.keys()) == {"gate", "in", "out"} | ||
|
|
||
| def test_bridge_types(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| mapping = adapter.component_mapping | ||
| blocks = mapping["blocks"] | ||
| assert isinstance(mapping["embed"], EmbeddingBridge) | ||
| assert isinstance(mapping["rotary_emb"], RotaryEmbeddingBridge) | ||
| assert isinstance(blocks, BlockBridge) | ||
| assert isinstance(mapping["ln_final"], RMSNormalizationBridge) | ||
| assert isinstance(mapping["unembed"], UnembeddingBridge) | ||
| assert isinstance(blocks.submodules["ln1"], RMSNormalizationBridge) | ||
| assert isinstance(blocks.submodules["ln2"], RMSNormalizationBridge) | ||
| assert isinstance(blocks.submodules["attn"], PositionEmbeddingsAttentionBridge) | ||
| assert isinstance(blocks.submodules["mlp"], MLPBridge) | ||
|
|
||
| def test_attention_hf_paths(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| attn = adapter.component_mapping["blocks"].submodules["attn"] | ||
| assert attn.name == "self_attn" | ||
| assert attn.submodules["q"].name == "q_proj" | ||
| assert attn.submodules["k"].name == "k_proj" | ||
| assert attn.submodules["v"].name == "v_proj" | ||
| assert attn.submodules["o"].name == "o_proj" | ||
|
|
||
| def test_mlp_hf_paths(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| mlp = adapter.component_mapping["blocks"].submodules["mlp"] | ||
| assert mlp.name == "mlp" | ||
| assert mlp.submodules["gate"].name == "gate_proj" | ||
| assert mlp.submodules["in"].name == "up_proj" | ||
| assert mlp.submodules["out"].name == "down_proj" | ||
|
|
||
| def test_linear_submodule_bridge_types( | ||
| self, adapter: Qwen2ArchitectureAdapter | ||
| ) -> None: | ||
| blocks = adapter.component_mapping["blocks"] | ||
| attn = blocks.submodules["attn"] | ||
| mlp = blocks.submodules["mlp"] | ||
| for submodule in [*attn.submodules.values(), *mlp.submodules.values()]: | ||
| assert isinstance(submodule, LinearBridge) | ||
|
|
||
|
|
||
| class TestQwen2WeightConversions: | ||
| """Qwen2 uses the standard QKVO conversions, with GQA-specific K/V heads.""" | ||
|
|
||
| def test_has_qkvo_keys(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| convs = adapter.weight_processing_conversions | ||
| assert convs is not None | ||
| assert set(convs.keys()) == { | ||
| "blocks.{i}.attn.q.weight", | ||
| "blocks.{i}.attn.k.weight", | ||
| "blocks.{i}.attn.v.weight", | ||
| "blocks.{i}.attn.o.weight", | ||
| } | ||
|
|
||
| def test_q_uses_n_heads(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] | ||
| assert isinstance(conv, ParamProcessingConversion) | ||
| assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) | ||
| assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads | ||
|
|
||
| def test_kv_use_n_key_value_heads(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| for key in ("blocks.{i}.attn.k.weight", "blocks.{i}.attn.v.weight"): | ||
| conv = adapter.weight_processing_conversions[key] | ||
| assert isinstance(conv, ParamProcessingConversion) | ||
| assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) | ||
| assert ( | ||
| conv.tensor_conversion.axes_lengths["n"] | ||
| == adapter.cfg.n_key_value_heads | ||
| ) | ||
|
|
||
| def test_o_uses_n_heads(self, adapter: Qwen2ArchitectureAdapter) -> None: | ||
| conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] | ||
| assert isinstance(conv, ParamProcessingConversion) | ||
| assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) | ||
| assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads | ||
|
|
||
|
|
||
| class TestQwen2FactoryRegistration: | ||
| """Factory lookup must resolve HF's architecture string to this adapter.""" | ||
|
|
||
| def test_factory_key_present(self) -> None: | ||
| assert "Qwen2ForCausalLM" in SUPPORTED_ARCHITECTURES | ||
|
|
||
| def test_factory_maps_to_correct_adapter_class(self) -> None: | ||
| assert SUPPORTED_ARCHITECTURES["Qwen2ForCausalLM"] is Qwen2ArchitectureAdapter | ||
|
|
||
| def test_factory_returns_correct_instance(self) -> None: | ||
| cfg = _make_cfg() | ||
| adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) | ||
| assert isinstance(adapter, Qwen2ArchitectureAdapter) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is excellent, and it would be great to write some tests that use it, but at present it does not appear wired into anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My apologies. Will wire it up and update the PR!