From 88e6509b7d927fb8a5089ba5adc8675e02141396 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Mon, 11 May 2026 14:01:33 -0500 Subject: [PATCH 01/10] Resolution for #644 and #341 --- .../_static/model_properties_table.jsonl | 6 +-- docs/source/content/model_structure.md | 8 ++++ tests/unit/factored_matrix/test_properties.py | 12 ++++- transformer_lens/FactoredMatrix.py | 44 +++++++++++++------ transformer_lens/SVDInterpreter.py | 2 +- 5 files changed, 53 insertions(+), 19 deletions(-) diff --git a/docs/source/_static/model_properties_table.jsonl b/docs/source/_static/model_properties_table.jsonl index 6a9d3a46c..d3e820a33 100644 --- a/docs/source/_static/model_properties_table.jsonl +++ b/docs/source/_static/model_properties_table.jsonl @@ -23,7 +23,7 @@ {"name.default_alias":"CodeLlamallama-2-7b","name.huggingface":null,"name.aliases":"","model_type":"llama","name.from_cfg":"CodeLlama-7b-hf","n_params.as_str":"6.5B","n_params.as_int":6476005376,"n_params.from_name":"7b","cfg.n_params":6476005376,"cfg.n_layers":32,"cfg.n_heads":32,"cfg.d_model":4096,"cfg.d_vocab":32016,"cfg.act_fn":"silu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"LlamaForCausalLM","cfg.normalization_type":"RMS","config.raw__":{"d_model":4096,"d_head":128,"n_layers":32,"n_ctx":4096,"n_heads":32,"d_mlp":11008,"d_vocab":32016,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":true,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"silu","normalization_type":"RMS","num_experts":null,"experts_per_token":null,"final_rms":true,"dtype":"torch.float32","model_name":"CodeLlama-7b-hf","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"LlamaForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"codellama\/CodeLlama-7b-hf","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0125,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":32016,"parallel_attn_mlp":false,"rotary_dim":128,"n_params":6476005376,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":1000000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 4096\nd_head: 128\nn_layers: 32\nn_ctx: 4096\nn_heads: 32\nd_mlp: 11008\nd_vocab: 32016\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: true\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: silu\nnormalization_type: RMS\nnum_experts: null\nexperts_per_token: null\nfinal_rms: true\ndtype: torch.float32\nmodel_name: CodeLlama-7b-hf\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: LlamaForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: codellama\/CodeLlama-7b-hf\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.0125\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 32016\nparallel_attn_mlp: false\nrotary_dim: 128\nn_params: 6476005376\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 1000000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"codellama\/CodeLlama-7b-hf","tokenizer.vocab_size":32016.0,"tokenizer.max_len":null,"tokenizer.class":"TokenizersBackend","tokenizer.vocab_hash":"Tq7bUWJcm1X5kj9R-2uR1o7lSq8=","tensor_shapes.state_dict":"embed:\n W_E: (32016, 4096)\nblocks:\n '[0-31]':\n ln1:\n w: (4096,)\n ln2:\n w: (4096,)\n attn:\n '[W_Q, W_K, W_V]': (32, 4096, 128)\n W_O: (32, 128, 4096)\n '[b_Q, b_K, b_V]': (32, 128)\n b_O: (4096,)\n mask: (4096, 4096)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (4096, 128)\n mlp:\n '[W_in, W_gate]': (4096, 11008)\n W_out: (11008, 4096)\n b_in: (11008,)\n b_out: (4096,)\nln_final:\n w: (4096,)\nunembed:\n W_U: (4096, 32016)\n b_U: (32016,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(32016, 4096)"},"blocks":{"[0-31]":{"ln1":{"w":"(4096,)"},"ln2":{"w":"(4096,)"},"attn":{"[W_Q, W_K, W_V]":"(32, 4096, 128)","W_O":"(32, 128, 4096)","[b_Q, b_K, b_V]":"(32, 128)","b_O":"(4096,)","mask":"(4096, 4096)","IGNORE":"()","[rotary_sin, rotary_cos]":"(4096, 128)"},"mlp":{"[W_in, W_gate]":"(4096, 11008)","W_out":"(11008, 4096)","b_in":"(11008,)","b_out":"(4096,)"}}},"ln_final":{"w":"(4096,)"},"unembed":{"W_U":"(4096, 32016)","b_U":"(32016,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-31]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 32, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 32, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n mlp:\n '[hook_pre, hook_pre_linear, hook_post]': (batch, seq_len, 11008)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 4096)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\nunembed:\n hook_in: (batch, seq_len, 4096)\n hook_out: (batch, seq_len, 32016)\nhook_embed: (batch, seq_len, 4096)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-31]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 32, 128)","[hook_attn_scores, hook_pattern]":"(batch, 32, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"mlp":{"[hook_pre, hook_pre_linear, hook_post]":"(batch, seq_len, 11008)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 4096)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"unembed":{"hook_in":"(batch, seq_len, 4096)","hook_out":"(batch, seq_len, 32016)"},"hook_embed":"(batch, seq_len, 4096)"}} {"name.default_alias":"CodeLlama-7b-instruct","name.huggingface":null,"name.aliases":"","model_type":"CodeLlama","name.from_cfg":"CodeLlama-7b-Instruct-hf","n_params.as_str":"6.5B","n_params.as_int":6476005376,"n_params.from_name":"7b","cfg.n_params":6476005376,"cfg.n_layers":32,"cfg.n_heads":32,"cfg.d_model":4096,"cfg.d_vocab":32016,"cfg.act_fn":"silu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"LlamaForCausalLM","cfg.normalization_type":"RMS","config.raw__":{"d_model":4096,"d_head":128,"n_layers":32,"n_ctx":4096,"n_heads":32,"d_mlp":11008,"d_vocab":32016,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":true,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"silu","normalization_type":"RMS","num_experts":null,"experts_per_token":null,"final_rms":true,"dtype":"torch.float32","model_name":"CodeLlama-7b-Instruct-hf","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"LlamaForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"codellama\/CodeLlama-7b-Instruct-hf","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0125,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":32016,"parallel_attn_mlp":false,"rotary_dim":128,"n_params":6476005376,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":1000000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 4096\nd_head: 128\nn_layers: 32\nn_ctx: 4096\nn_heads: 32\nd_mlp: 11008\nd_vocab: 32016\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: true\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: silu\nnormalization_type: RMS\nnum_experts: null\nexperts_per_token: null\nfinal_rms: true\ndtype: torch.float32\nmodel_name: CodeLlama-7b-Instruct-hf\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: LlamaForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: codellama\/CodeLlama-7b-Instruct-hf\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.0125\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 32016\nparallel_attn_mlp: false\nrotary_dim: 128\nn_params: 6476005376\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 1000000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"codellama\/CodeLlama-7b-Instruct-hf","tokenizer.vocab_size":32016.0,"tokenizer.max_len":null,"tokenizer.class":"TokenizersBackend","tokenizer.vocab_hash":"Tq7bUWJcm1X5kj9R-2uR1o7lSq8=","tensor_shapes.state_dict":"embed:\n W_E: (32016, 4096)\nblocks:\n '[0-31]':\n ln1:\n w: (4096,)\n ln2:\n w: (4096,)\n attn:\n '[W_Q, W_K, W_V]': (32, 4096, 128)\n W_O: (32, 128, 4096)\n '[b_Q, b_K, b_V]': (32, 128)\n b_O: (4096,)\n mask: (4096, 4096)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (4096, 128)\n mlp:\n '[W_in, W_gate]': (4096, 11008)\n W_out: (11008, 4096)\n b_in: (11008,)\n b_out: (4096,)\nln_final:\n w: (4096,)\nunembed:\n W_U: (4096, 32016)\n b_U: (32016,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(32016, 4096)"},"blocks":{"[0-31]":{"ln1":{"w":"(4096,)"},"ln2":{"w":"(4096,)"},"attn":{"[W_Q, W_K, W_V]":"(32, 4096, 128)","W_O":"(32, 128, 4096)","[b_Q, b_K, b_V]":"(32, 128)","b_O":"(4096,)","mask":"(4096, 4096)","IGNORE":"()","[rotary_sin, rotary_cos]":"(4096, 128)"},"mlp":{"[W_in, W_gate]":"(4096, 11008)","W_out":"(11008, 4096)","b_in":"(11008,)","b_out":"(4096,)"}}},"ln_final":{"w":"(4096,)"},"unembed":{"W_U":"(4096, 32016)","b_U":"(32016,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-31]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 32, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 32, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n mlp:\n '[hook_pre, hook_pre_linear, hook_post]': (batch, seq_len, 11008)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 4096)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\nunembed:\n hook_in: (batch, seq_len, 4096)\n hook_out: (batch, seq_len, 32016)\nhook_embed: (batch, seq_len, 4096)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-31]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 32, 128)","[hook_attn_scores, hook_pattern]":"(batch, 32, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"mlp":{"[hook_pre, hook_pre_linear, hook_post]":"(batch, seq_len, 11008)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 4096)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"unembed":{"hook_in":"(batch, seq_len, 4096)","hook_out":"(batch, seq_len, 32016)"},"hook_embed":"(batch, seq_len, 4096)"}} {"name.default_alias":"CodeLlama-7b-python","name.huggingface":null,"name.aliases":"","model_type":"CodeLlama","name.from_cfg":"CodeLlama-7b-Python-hf","n_params.as_str":"6.5B","n_params.as_int":6476005376,"n_params.from_name":"7b","cfg.n_params":6476005376,"cfg.n_layers":32,"cfg.n_heads":32,"cfg.d_model":4096,"cfg.d_vocab":32000,"cfg.act_fn":"silu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"LlamaForCausalLM","cfg.normalization_type":"RMS","config.raw__":{"d_model":4096,"d_head":128,"n_layers":32,"n_ctx":4096,"n_heads":32,"d_mlp":11008,"d_vocab":32000,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":true,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"silu","normalization_type":"RMS","num_experts":null,"experts_per_token":null,"final_rms":true,"dtype":"torch.float32","model_name":"CodeLlama-7b-Python-hf","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"LlamaForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"codellama\/CodeLlama-7b-Python-hf","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0125,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":32000,"parallel_attn_mlp":false,"rotary_dim":128,"n_params":6476005376,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":1000000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 4096\nd_head: 128\nn_layers: 32\nn_ctx: 4096\nn_heads: 32\nd_mlp: 11008\nd_vocab: 32000\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: true\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: silu\nnormalization_type: RMS\nnum_experts: null\nexperts_per_token: null\nfinal_rms: true\ndtype: torch.float32\nmodel_name: CodeLlama-7b-Python-hf\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: LlamaForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: codellama\/CodeLlama-7b-Python-hf\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.0125\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 32000\nparallel_attn_mlp: false\nrotary_dim: 128\nn_params: 6476005376\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 1000000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"codellama\/CodeLlama-7b-Python-hf","tokenizer.vocab_size":32000.0,"tokenizer.max_len":null,"tokenizer.class":"TokenizersBackend","tokenizer.vocab_hash":"e3A7wYziNQPAWcJ15GMAQY8qZqw=","tensor_shapes.state_dict":"embed:\n W_E: (32000, 4096)\nblocks:\n '[0-31]':\n ln1:\n w: (4096,)\n ln2:\n w: (4096,)\n attn:\n '[W_Q, W_K, W_V]': (32, 4096, 128)\n W_O: (32, 128, 4096)\n '[b_Q, b_K, b_V]': (32, 128)\n b_O: (4096,)\n mask: (4096, 4096)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (4096, 128)\n mlp:\n '[W_in, W_gate]': (4096, 11008)\n W_out: (11008, 4096)\n b_in: (11008,)\n b_out: (4096,)\nln_final:\n w: (4096,)\nunembed:\n W_U: (4096, 32000)\n b_U: (32000,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(32000, 4096)"},"blocks":{"[0-31]":{"ln1":{"w":"(4096,)"},"ln2":{"w":"(4096,)"},"attn":{"[W_Q, W_K, W_V]":"(32, 4096, 128)","W_O":"(32, 128, 4096)","[b_Q, b_K, b_V]":"(32, 128)","b_O":"(4096,)","mask":"(4096, 4096)","IGNORE":"()","[rotary_sin, rotary_cos]":"(4096, 128)"},"mlp":{"[W_in, W_gate]":"(4096, 11008)","W_out":"(11008, 4096)","b_in":"(11008,)","b_out":"(4096,)"}}},"ln_final":{"w":"(4096,)"},"unembed":{"W_U":"(4096, 32000)","b_U":"(32000,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-31]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 32, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 32, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n mlp:\n '[hook_pre, hook_pre_linear, hook_post]': (batch, seq_len, 11008)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 4096)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\nunembed:\n hook_in: (batch, seq_len, 4096)\n hook_out: (batch, seq_len, 32000)\nhook_embed: (batch, seq_len, 4096)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-31]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 32, 128)","[hook_attn_scores, hook_pattern]":"(batch, 32, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"mlp":{"[hook_pre, hook_pre_linear, hook_post]":"(batch, seq_len, 11008)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 4096)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"unembed":{"hook_in":"(batch, seq_len, 4096)","hook_out":"(batch, seq_len, 32000)"},"hook_embed":"(batch, seq_len, 4096)"}} -{"name.default_alias":"distillgpt2","name.huggingface":"distilgpt2","name.aliases":"distillgpt2, distill-gpt2, distil-gpt2, gpt2-xs","model_type":"gpt2","name.from_cfg":"distilgpt2","n_params.as_str":"42M","n_params.as_int":42467328,"n_params.from_name":null,"cfg.n_params":42467328,"cfg.n_layers":6,"cfg.n_heads":12,"cfg.d_model":768,"cfg.d_vocab":50257,"cfg.act_fn":"gelu_new","cfg.positional_embedding_type":"standard","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"GPT2LMHeadModel","cfg.normalization_type":"LN","config.raw__":{"d_model":768,"d_head":64,"n_layers":6,"n_ctx":1024,"n_heads":12,"d_mlp":3072,"d_vocab":50257,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"standard","n_key_value_heads":null,"attn_only":false,"gated_mlp":false,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"gelu_new","normalization_type":"LN","num_experts":null,"experts_per_token":null,"final_rms":false,"dtype":"torch.float32","model_name":"distilgpt2","use_attn_scale":true,"attn_scale":8.0,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"GPT2LMHeadModel","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"distilgpt2","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0288675135,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":50257,"parallel_attn_mlp":false,"rotary_dim":null,"n_params":42467328,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 768\nd_head: 64\nn_layers: 6\nn_ctx: 1024\nn_heads: 12\nd_mlp: 3072\nd_vocab: 50257\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: standard\nn_key_value_heads: null\nattn_only: false\ngated_mlp: false\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: gelu_new\nnormalization_type: LN\nnum_experts: null\nexperts_per_token: null\nfinal_rms: false\ndtype: torch.float32\nmodel_name: distilgpt2\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n AAAAAAAAIEA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: GPT2LMHeadModel\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: distilgpt2\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.02886751345948129\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 50257\nparallel_attn_mlp: false\nrotary_dim: null\nn_params: 42467328\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":null,"tokenizer.vocab_size":null,"tokenizer.max_len":null,"tokenizer.class":null,"tokenizer.vocab_hash":null,"tensor_shapes.state_dict":null,"tensor_shapes.state_dict.raw__":null,"tensor_shapes.activation_cache":null,"tensor_shapes.activation_cache.raw__":null} +{"name.default_alias":"distillgpt2","name.huggingface":"distilgpt2","name.aliases":"distillgpt2, distill-gpt2, distil-gpt2, gpt2-xs","model_type":"gpt2","name.from_cfg":"distilgpt2","n_params.as_str":"42M","n_params.as_int":42467328,"n_params.from_name":null,"cfg.n_params":42467328,"cfg.n_layers":6,"cfg.n_heads":12,"cfg.d_model":768,"cfg.d_vocab":50257,"cfg.act_fn":"gelu_new","cfg.positional_embedding_type":"standard","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"GPT2LMHeadModel","cfg.normalization_type":"LN","config.raw__":{"d_model":768,"d_head":64,"n_layers":6,"n_ctx":1024,"n_heads":12,"d_mlp":3072,"d_vocab":50257,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"standard","n_key_value_heads":null,"attn_only":false,"gated_mlp":false,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"gelu_new","normalization_type":"LN","num_experts":null,"experts_per_token":null,"final_rms":false,"dtype":"torch.float32","model_name":"distilgpt2","use_attn_scale":true,"attn_scale":8.0,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"GPT2LMHeadModel","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"distilgpt2","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0288675135,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":50257,"parallel_attn_mlp":false,"rotary_dim":null,"n_params":42467328,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 768\nd_head: 64\nn_layers: 6\nn_ctx: 1024\nn_heads: 12\nd_mlp: 3072\nd_vocab: 50257\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: standard\nn_key_value_heads: null\nattn_only: false\ngated_mlp: false\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: gelu_new\nnormalization_type: LN\nnum_experts: null\nexperts_per_token: null\nfinal_rms: false\ndtype: torch.float32\nmodel_name: distilgpt2\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n AAAAAAAAIEA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: GPT2LMHeadModel\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: distilgpt2\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.02886751345948129\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 50257\nparallel_attn_mlp: false\nrotary_dim: null\nn_params: 42467328\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"distilgpt2","tokenizer.vocab_size":50257.0,"tokenizer.max_len":1024.0,"tokenizer.class":"GPT2Tokenizer","tokenizer.vocab_hash":"v8xfIj5kwZX5RwgLU66lZNZUlE4=","tensor_shapes.state_dict":"embed:\n W_E: (50257, 768)\npos_embed:\n W_pos: (1024, 768)\nblocks:\n '[0-5]':\n ln1:\n '[w, b]': (768,)\n ln2:\n '[w, b]': (768,)\n attn:\n '[W_Q, W_K, W_V]': (12, 768, 64)\n W_O: (12, 64, 768)\n '[b_Q, b_K, b_V]': (12, 64)\n b_O: (768,)\n mask: (1024, 1024)\n IGNORE: ()\n mlp:\n W_in: (768, 3072)\n b_in: (3072,)\n W_out: (3072, 768)\n b_out: (768,)\nln_final:\n '[w, b]': (768,)\nunembed:\n W_U: (768, 50257)\n b_U: (50257,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(50257, 768)"},"pos_embed":{"W_pos":"(1024, 768)"},"blocks":{"[0-5]":{"ln1":{"[w, b]":"(768,)"},"ln2":{"[w, b]":"(768,)"},"attn":{"[W_Q, W_K, W_V]":"(12, 768, 64)","W_O":"(12, 64, 768)","[b_Q, b_K, b_V]":"(12, 64)","b_O":"(768,)","mask":"(1024, 1024)","IGNORE":"()"},"mlp":{"W_in":"(768, 3072)","b_in":"(3072,)","W_out":"(3072, 768)","b_out":"(768,)"}}},"ln_final":{"[w, b]":"(768,)"},"unembed":{"W_U":"(768, 50257)","b_U":"(50257,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-5]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 768)\n attn:\n '[hook_q, hook_k, hook_v, hook_z]': (batch, seq_len, 12, 64)\n '[hook_attn_scores, hook_pattern]': (batch, 12, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 768)\n mlp:\n '[hook_pre, hook_post]': (batch, seq_len, 3072)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 768)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 768)\nunembed:\n hook_in: (batch, seq_len, 768)\n hook_out: (batch, seq_len, 50257)\n'[hook_embed, hook_pos_embed]': (batch, seq_len, 768)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-5]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 768)"},"attn":{"[hook_q, hook_k, hook_v, hook_z]":"(batch, seq_len, 12, 64)","[hook_attn_scores, hook_pattern]":"(batch, 12, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 768)"},"mlp":{"[hook_pre, hook_post]":"(batch, seq_len, 3072)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 768)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 768)"},"unembed":{"hook_in":"(batch, seq_len, 768)","hook_out":"(batch, seq_len, 50257)"},"[hook_embed, hook_pos_embed]":"(batch, seq_len, 768)"}} {"name.default_alias":"gpt-j-6B","name.huggingface":null,"name.aliases":"","model_type":"gpt-j","name.from_cfg":"gpt-j-6B","n_params.as_str":"5.6B","n_params.as_int":5637144576,"n_params.from_name":"6B","cfg.n_params":5637144576,"cfg.n_layers":28,"cfg.n_heads":16,"cfg.d_model":4096,"cfg.d_vocab":50400,"cfg.act_fn":"gelu_new","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":true,"cfg.original_architecture":"GPTJForCausalLM","cfg.normalization_type":"LN","config.raw__":{"d_model":4096,"d_head":256,"n_layers":28,"n_ctx":2048,"n_heads":16,"d_mlp":16384,"d_vocab":50400,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":false,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"gelu_new","normalization_type":"LN","num_experts":null,"experts_per_token":null,"final_rms":false,"dtype":"torch.float32","model_name":"gpt-j-6B","use_attn_scale":true,"attn_scale":16.0,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"GPTJForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"EleutherAI\/gpt-j-6B","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0125,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":50400,"parallel_attn_mlp":true,"rotary_dim":64,"n_params":5637144576,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":true,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 4096\nd_head: 256\nn_layers: 28\nn_ctx: 2048\nn_heads: 16\nd_mlp: 16384\nd_vocab: 50400\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: false\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: gelu_new\nnormalization_type: LN\nnum_experts: null\nexperts_per_token: null\nfinal_rms: false\ndtype: torch.float32\nmodel_name: gpt-j-6B\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n AAAAAAAAMEA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: GPTJForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: EleutherAI\/gpt-j-6B\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.0125\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 50400\nparallel_attn_mlp: true\nrotary_dim: 64\nn_params: 5637144576\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: true\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"EleutherAI\/gpt-j-6B","tokenizer.vocab_size":50257.0,"tokenizer.max_len":2048.0,"tokenizer.class":"GPT2Tokenizer","tokenizer.vocab_hash":"aKfp-BCA9d3W27qknxFiS0DGC5s=","tensor_shapes.state_dict":"embed:\n W_E: (50400, 4096)\nblocks:\n '[0-27]':\n ln1:\n '[w, b]': (4096,)\n ln2:\n '[w, b]': (4096,)\n attn:\n '[W_Q, W_K, W_V]': (16, 4096, 256)\n W_O: (16, 256, 4096)\n '[b_Q, b_K, b_V]': (16, 256)\n b_O: (4096,)\n mask: (2048, 2048)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (2048, 64)\n mlp:\n W_in: (4096, 16384)\n b_in: (16384,)\n W_out: (16384, 4096)\n b_out: (4096,)\nln_final:\n '[w, b]': (4096,)\nunembed:\n W_U: (4096, 50400)\n b_U: (50400,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(50400, 4096)"},"blocks":{"[0-27]":{"ln1":{"[w, b]":"(4096,)"},"ln2":{"[w, b]":"(4096,)"},"attn":{"[W_Q, W_K, W_V]":"(16, 4096, 256)","W_O":"(16, 256, 4096)","[b_Q, b_K, b_V]":"(16, 256)","b_O":"(4096,)","mask":"(2048, 2048)","IGNORE":"()","[rotary_sin, rotary_cos]":"(2048, 64)"},"mlp":{"W_in":"(4096, 16384)","b_in":"(16384,)","W_out":"(16384, 4096)","b_out":"(4096,)"}}},"ln_final":{"[w, b]":"(4096,)"},"unembed":{"W_U":"(4096, 50400)","b_U":"(50400,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-27]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 16, 256)\n '[hook_attn_scores, hook_pattern]': (batch, 16, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n mlp:\n '[hook_pre, hook_post]': (batch, seq_len, 16384)\n '[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]': (batch, seq_len,\n 4096)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\nunembed:\n hook_in: (batch, seq_len, 4096)\n hook_out: (batch, seq_len, 50400)\nhook_embed: (batch, seq_len, 4096)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-27]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 16, 256)","[hook_attn_scores, hook_pattern]":"(batch, 16, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"mlp":{"[hook_pre, hook_post]":"(batch, seq_len, 16384)"},"[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 4096)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"unembed":{"hook_in":"(batch, seq_len, 4096)","hook_out":"(batch, seq_len, 50400)"},"hook_embed":"(batch, seq_len, 4096)"}} {"name.default_alias":"gpt-neo-1.3B","name.huggingface":null,"name.aliases":"","model_type":"gpt-neo","name.from_cfg":"gpt-neo-1.3B","n_params.as_str":"1.2B","n_params.as_int":1207959552,"n_params.from_name":"1.3B","cfg.n_params":1207959552,"cfg.n_layers":24,"cfg.n_heads":16,"cfg.d_model":2048,"cfg.d_vocab":50257,"cfg.act_fn":"gelu_new","cfg.positional_embedding_type":"standard","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"GPTNeoForCausalLM","cfg.normalization_type":"LN","config.raw__":{"d_model":2048,"d_head":128,"n_layers":24,"n_ctx":2048,"n_heads":16,"d_mlp":8192,"d_vocab":50257,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"standard","n_key_value_heads":null,"attn_only":false,"gated_mlp":false,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"gelu_new","normalization_type":"LN","num_experts":null,"experts_per_token":null,"final_rms":false,"dtype":"torch.float32","model_name":"gpt-neo-1.3B","use_attn_scale":false,"attn_scale":-1.0,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":true,"ungroup_grouped_query_attention":false,"original_architecture":"GPTNeoForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"EleutherAI\/gpt-neo-1.3B","window_size":256,"attn_types":["global","local","global","local","global","local","global","local","global","local","global","local","global","local","global","local","global","local","global","local","global","local","global","local"],"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0176776695,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":50257,"parallel_attn_mlp":false,"rotary_dim":null,"n_params":1207959552,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 2048\nd_head: 128\nn_layers: 24\nn_ctx: 2048\nn_heads: 16\nd_mlp: 8192\nd_vocab: 50257\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: standard\nn_key_value_heads: null\nattn_only: false\ngated_mlp: false\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: gelu_new\nnormalization_type: LN\nnum_experts: null\nexperts_per_token: null\nfinal_rms: false\ndtype: torch.float32\nmodel_name: gpt-neo-1.3B\nuse_attn_scale: false\nattn_scale: -1.0\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: true\nungroup_grouped_query_attention: false\noriginal_architecture: GPTNeoForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: EleutherAI\/gpt-neo-1.3B\nwindow_size: 256\nattn_types:\n- global\n- local\n- global\n- local\n- global\n- local\n- global\n- local\n- global\n- local\n- global\n- local\n- global\n- local\n- global\n- local\n- global\n- local\n- global\n- local\n- global\n- local\n- global\n- local\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.017677669529663688\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 50257\nparallel_attn_mlp: false\nrotary_dim: null\nn_params: 1207959552\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"EleutherAI\/gpt-neo-1.3B","tokenizer.vocab_size":50257.0,"tokenizer.max_len":2048.0,"tokenizer.class":"GPT2Tokenizer","tokenizer.vocab_hash":"v8xfIj5kwZX5RwgLU66lZNZUlE4=","tensor_shapes.state_dict":"embed:\n W_E: (50257, 2048)\npos_embed:\n W_pos: (2048, 2048)\nblocks:\n '[0-23]':\n ln1:\n '[w, b]': (2048,)\n ln2:\n '[w, b]': (2048,)\n attn:\n '[W_Q, W_K, W_V]': (16, 2048, 128)\n W_O: (16, 128, 2048)\n '[b_Q, b_K, b_V]': (16, 128)\n b_O: (2048,)\n mask: (2048, 2048)\n IGNORE: ()\n mlp:\n W_in: (2048, 8192)\n b_in: (8192,)\n W_out: (8192, 2048)\n b_out: (2048,)\nln_final:\n '[w, b]': (2048,)\nunembed:\n W_U: (2048, 50257)\n b_U: (50257,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(50257, 2048)"},"pos_embed":{"W_pos":"(2048, 2048)"},"blocks":{"[0-23]":{"ln1":{"[w, b]":"(2048,)"},"ln2":{"[w, b]":"(2048,)"},"attn":{"[W_Q, W_K, W_V]":"(16, 2048, 128)","W_O":"(16, 128, 2048)","[b_Q, b_K, b_V]":"(16, 128)","b_O":"(2048,)","mask":"(2048, 2048)","IGNORE":"()"},"mlp":{"W_in":"(2048, 8192)","b_in":"(8192,)","W_out":"(8192, 2048)","b_out":"(2048,)"}}},"ln_final":{"[w, b]":"(2048,)"},"unembed":{"W_U":"(2048, 50257)","b_U":"(50257,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-23]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 2048)\n attn:\n '[hook_q, hook_k, hook_v, hook_z]': (batch, seq_len, 16, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 16, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 2048)\n mlp:\n '[hook_pre, hook_post]': (batch, seq_len, 8192)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 2048)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 2048)\nunembed:\n hook_in: (batch, seq_len, 2048)\n hook_out: (batch, seq_len, 50257)\n'[hook_embed, hook_pos_embed]': (batch, seq_len, 2048)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-23]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 2048)"},"attn":{"[hook_q, hook_k, hook_v, hook_z]":"(batch, seq_len, 16, 128)","[hook_attn_scores, hook_pattern]":"(batch, 16, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 2048)"},"mlp":{"[hook_pre, hook_post]":"(batch, seq_len, 8192)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 2048)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 2048)"},"unembed":{"hook_in":"(batch, seq_len, 2048)","hook_out":"(batch, seq_len, 50257)"},"[hook_embed, hook_pos_embed]":"(batch, seq_len, 2048)"}} {"name.default_alias":"gpt-neo-125M","name.huggingface":null,"name.aliases":"","model_type":"gpt-neo","name.from_cfg":"gpt-neo-125M","n_params.as_str":"85M","n_params.as_int":84934656,"n_params.from_name":"125M","cfg.n_params":84934656,"cfg.n_layers":12,"cfg.n_heads":12,"cfg.d_model":768,"cfg.d_vocab":50257,"cfg.act_fn":"gelu_new","cfg.positional_embedding_type":"standard","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"GPTNeoForCausalLM","cfg.normalization_type":"LN","config.raw__":{"d_model":768,"d_head":64,"n_layers":12,"n_ctx":2048,"n_heads":12,"d_mlp":3072,"d_vocab":50257,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"standard","n_key_value_heads":null,"attn_only":false,"gated_mlp":false,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"gelu_new","normalization_type":"LN","num_experts":null,"experts_per_token":null,"final_rms":false,"dtype":"torch.float32","model_name":"gpt-neo-125M","use_attn_scale":false,"attn_scale":-1.0,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":true,"ungroup_grouped_query_attention":false,"original_architecture":"GPTNeoForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"EleutherAI\/gpt-neo-125M","window_size":256,"attn_types":["global","local","global","local","global","local","global","local","global","local","global","local"],"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0288675135,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":50257,"parallel_attn_mlp":false,"rotary_dim":null,"n_params":84934656,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 768\nd_head: 64\nn_layers: 12\nn_ctx: 2048\nn_heads: 12\nd_mlp: 3072\nd_vocab: 50257\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: standard\nn_key_value_heads: null\nattn_only: false\ngated_mlp: false\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: gelu_new\nnormalization_type: LN\nnum_experts: null\nexperts_per_token: null\nfinal_rms: false\ndtype: torch.float32\nmodel_name: gpt-neo-125M\nuse_attn_scale: false\nattn_scale: -1.0\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: true\nungroup_grouped_query_attention: false\noriginal_architecture: GPTNeoForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: EleutherAI\/gpt-neo-125M\nwindow_size: 256\nattn_types:\n- global\n- local\n- global\n- local\n- global\n- local\n- global\n- local\n- global\n- local\n- global\n- local\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.02886751345948129\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 50257\nparallel_attn_mlp: false\nrotary_dim: null\nn_params: 84934656\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"EleutherAI\/gpt-neo-125M","tokenizer.vocab_size":50257.0,"tokenizer.max_len":2048.0,"tokenizer.class":"GPT2Tokenizer","tokenizer.vocab_hash":"v8xfIj5kwZX5RwgLU66lZNZUlE4=","tensor_shapes.state_dict":"embed:\n W_E: (50257, 768)\npos_embed:\n W_pos: (2048, 768)\nblocks:\n '[0-11]':\n ln1:\n '[w, b]': (768,)\n ln2:\n '[w, b]': (768,)\n attn:\n '[W_Q, W_K, W_V]': (12, 768, 64)\n W_O: (12, 64, 768)\n '[b_Q, b_K, b_V]': (12, 64)\n b_O: (768,)\n mask: (2048, 2048)\n IGNORE: ()\n mlp:\n W_in: (768, 3072)\n b_in: (3072,)\n W_out: (3072, 768)\n b_out: (768,)\nln_final:\n '[w, b]': (768,)\nunembed:\n W_U: (768, 50257)\n b_U: (50257,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(50257, 768)"},"pos_embed":{"W_pos":"(2048, 768)"},"blocks":{"[0-11]":{"ln1":{"[w, b]":"(768,)"},"ln2":{"[w, b]":"(768,)"},"attn":{"[W_Q, W_K, W_V]":"(12, 768, 64)","W_O":"(12, 64, 768)","[b_Q, b_K, b_V]":"(12, 64)","b_O":"(768,)","mask":"(2048, 2048)","IGNORE":"()"},"mlp":{"W_in":"(768, 3072)","b_in":"(3072,)","W_out":"(3072, 768)","b_out":"(768,)"}}},"ln_final":{"[w, b]":"(768,)"},"unembed":{"W_U":"(768, 50257)","b_U":"(50257,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-11]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 768)\n attn:\n '[hook_q, hook_k, hook_v, hook_z]': (batch, seq_len, 12, 64)\n '[hook_attn_scores, hook_pattern]': (batch, 12, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 768)\n mlp:\n '[hook_pre, hook_post]': (batch, seq_len, 3072)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 768)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 768)\nunembed:\n hook_in: (batch, seq_len, 768)\n hook_out: (batch, seq_len, 50257)\n'[hook_embed, hook_pos_embed]': (batch, seq_len, 768)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-11]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 768)"},"attn":{"[hook_q, hook_k, hook_v, hook_z]":"(batch, seq_len, 12, 64)","[hook_attn_scores, hook_pattern]":"(batch, 12, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 768)"},"mlp":{"[hook_pre, hook_post]":"(batch, seq_len, 3072)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 768)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 768)"},"unembed":{"hook_in":"(batch, seq_len, 768)","hook_out":"(batch, seq_len, 50257)"},"[hook_embed, hook_pos_embed]":"(batch, seq_len, 768)"}} @@ -62,7 +62,7 @@ {"name.default_alias":"pythia-6.9b-deduped","name.huggingface":"EleutherAI\/pythia-6.9b-deduped","name.aliases":"pythia-6.9b-deduped, EleutherAI\/pythia-6.7b-deduped, pythia-6.7b-deduped","model_type":"pythia","name.from_cfg":"pythia-6.9b-deduped","n_params.as_str":"6.4B","n_params.as_int":6442450944,"n_params.from_name":"6.9b","cfg.n_params":6442450944,"cfg.n_layers":32,"cfg.n_heads":32,"cfg.d_model":4096,"cfg.d_vocab":50432,"cfg.act_fn":"gelu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":true,"cfg.original_architecture":"GPTNeoXForCausalLM","cfg.normalization_type":"LN","config.raw__":{"d_model":4096,"d_head":128,"n_layers":32,"n_ctx":2048,"n_heads":32,"d_mlp":16384,"d_vocab":50432,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":false,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":false,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"gelu","normalization_type":"LN","num_experts":null,"experts_per_token":null,"final_rms":false,"dtype":"torch.float32","model_name":"pythia-6.9b-deduped","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"GPTNeoXForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"EleutherAI\/pythia-6.9b-deduped","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0125,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":50432,"parallel_attn_mlp":true,"rotary_dim":32,"n_params":6442450944,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 4096\nd_head: 128\nn_layers: 32\nn_ctx: 2048\nn_heads: 32\nd_mlp: 16384\nd_vocab: 50432\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: false\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: false\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: gelu\nnormalization_type: LN\nnum_experts: null\nexperts_per_token: null\nfinal_rms: false\ndtype: torch.float32\nmodel_name: pythia-6.9b-deduped\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: GPTNeoXForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: EleutherAI\/pythia-6.9b-deduped\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.0125\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 50432\nparallel_attn_mlp: true\nrotary_dim: 32\nn_params: 6442450944\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"EleutherAI\/pythia-6.9b-deduped","tokenizer.vocab_size":50254.0,"tokenizer.max_len":null,"tokenizer.class":"GPTNeoXTokenizer","tokenizer.vocab_hash":"96EawM8Lij99W7OBTk0KW2ELUrQ=","tensor_shapes.state_dict":"embed:\n W_E: (50432, 4096)\nblocks:\n '[0-31]':\n ln1:\n '[w, b]': (4096,)\n ln2:\n '[w, b]': (4096,)\n attn:\n '[W_Q, W_K, W_V]': (32, 4096, 128)\n W_O: (32, 128, 4096)\n '[b_Q, b_K, b_V]': (32, 128)\n b_O: (4096,)\n mask: (2048, 2048)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (2048, 32)\n mlp:\n W_in: (4096, 16384)\n b_in: (16384,)\n W_out: (16384, 4096)\n b_out: (4096,)\nln_final:\n '[w, b]': (4096,)\nunembed:\n W_U: (4096, 50432)\n b_U: (50432,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(50432, 4096)"},"blocks":{"[0-31]":{"ln1":{"[w, b]":"(4096,)"},"ln2":{"[w, b]":"(4096,)"},"attn":{"[W_Q, W_K, W_V]":"(32, 4096, 128)","W_O":"(32, 128, 4096)","[b_Q, b_K, b_V]":"(32, 128)","b_O":"(4096,)","mask":"(2048, 2048)","IGNORE":"()","[rotary_sin, rotary_cos]":"(2048, 32)"},"mlp":{"W_in":"(4096, 16384)","b_in":"(16384,)","W_out":"(16384, 4096)","b_out":"(4096,)"}}},"ln_final":{"[w, b]":"(4096,)"},"unembed":{"W_U":"(4096, 50432)","b_U":"(50432,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-31]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 32, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 32, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n mlp:\n '[hook_pre, hook_post]': (batch, seq_len, 16384)\n '[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]': (batch, seq_len,\n 4096)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\nunembed:\n hook_in: (batch, seq_len, 4096)\n hook_out: (batch, seq_len, 50432)\nhook_embed: (batch, seq_len, 4096)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-31]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 32, 128)","[hook_attn_scores, hook_pattern]":"(batch, 32, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"mlp":{"[hook_pre, hook_post]":"(batch, seq_len, 16384)"},"[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 4096)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"unembed":{"hook_in":"(batch, seq_len, 4096)","hook_out":"(batch, seq_len, 50432)"},"hook_embed":"(batch, seq_len, 4096)"}} {"name.default_alias":"pythia-6.9b-deduped-v0","name.huggingface":"EleutherAI\/pythia-6.9b-deduped-v0","name.aliases":"pythia-6.9b-deduped-v0, EleutherAI\/pythia-6.7b-deduped-v0, pythia-6.7b-deduped-v0","model_type":"pythia","name.from_cfg":"pythia-6.9b-deduped-v0","n_params.as_str":"6.4B","n_params.as_int":6442450944,"n_params.from_name":"6.9b","cfg.n_params":6442450944,"cfg.n_layers":32,"cfg.n_heads":32,"cfg.d_model":4096,"cfg.d_vocab":50432,"cfg.act_fn":"gelu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":true,"cfg.original_architecture":"GPTNeoXForCausalLM","cfg.normalization_type":"LN","config.raw__":{"d_model":4096,"d_head":128,"n_layers":32,"n_ctx":2048,"n_heads":32,"d_mlp":16384,"d_vocab":50432,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":false,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":false,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"gelu","normalization_type":"LN","num_experts":null,"experts_per_token":null,"final_rms":false,"dtype":"torch.float32","model_name":"pythia-6.9b-deduped-v0","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"GPTNeoXForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"EleutherAI\/pythia-6.9b-deduped-v0","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0125,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":50432,"parallel_attn_mlp":true,"rotary_dim":32,"n_params":6442450944,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 4096\nd_head: 128\nn_layers: 32\nn_ctx: 2048\nn_heads: 32\nd_mlp: 16384\nd_vocab: 50432\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: false\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: false\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: gelu\nnormalization_type: LN\nnum_experts: null\nexperts_per_token: null\nfinal_rms: false\ndtype: torch.float32\nmodel_name: pythia-6.9b-deduped-v0\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: GPTNeoXForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: EleutherAI\/pythia-6.9b-deduped-v0\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.0125\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 50432\nparallel_attn_mlp: true\nrotary_dim: 32\nn_params: 6442450944\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"EleutherAI\/pythia-6.9b-deduped-v0","tokenizer.vocab_size":50254.0,"tokenizer.max_len":null,"tokenizer.class":"GPTNeoXTokenizer","tokenizer.vocab_hash":"96EawM8Lij99W7OBTk0KW2ELUrQ=","tensor_shapes.state_dict":"embed:\n W_E: (50432, 4096)\nblocks:\n '[0-31]':\n ln1:\n '[w, b]': (4096,)\n ln2:\n '[w, b]': (4096,)\n attn:\n '[W_Q, W_K, W_V]': (32, 4096, 128)\n W_O: (32, 128, 4096)\n '[b_Q, b_K, b_V]': (32, 128)\n b_O: (4096,)\n mask: (2048, 2048)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (2048, 32)\n mlp:\n W_in: (4096, 16384)\n b_in: (16384,)\n W_out: (16384, 4096)\n b_out: (4096,)\nln_final:\n '[w, b]': (4096,)\nunembed:\n W_U: (4096, 50432)\n b_U: (50432,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(50432, 4096)"},"blocks":{"[0-31]":{"ln1":{"[w, b]":"(4096,)"},"ln2":{"[w, b]":"(4096,)"},"attn":{"[W_Q, W_K, W_V]":"(32, 4096, 128)","W_O":"(32, 128, 4096)","[b_Q, b_K, b_V]":"(32, 128)","b_O":"(4096,)","mask":"(2048, 2048)","IGNORE":"()","[rotary_sin, rotary_cos]":"(2048, 32)"},"mlp":{"W_in":"(4096, 16384)","b_in":"(16384,)","W_out":"(16384, 4096)","b_out":"(4096,)"}}},"ln_final":{"[w, b]":"(4096,)"},"unembed":{"W_U":"(4096, 50432)","b_U":"(50432,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-31]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 32, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 32, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n mlp:\n '[hook_pre, hook_post]': (batch, seq_len, 16384)\n '[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]': (batch, seq_len,\n 4096)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\nunembed:\n hook_in: (batch, seq_len, 4096)\n hook_out: (batch, seq_len, 50432)\nhook_embed: (batch, seq_len, 4096)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-31]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 32, 128)","[hook_attn_scores, hook_pattern]":"(batch, 32, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"mlp":{"[hook_pre, hook_post]":"(batch, seq_len, 16384)"},"[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 4096)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"unembed":{"hook_in":"(batch, seq_len, 4096)","hook_out":"(batch, seq_len, 50432)"},"hook_embed":"(batch, seq_len, 4096)"}} {"name.default_alias":"pythia-6.9b-v0","name.huggingface":"EleutherAI\/pythia-6.9b-v0","name.aliases":"pythia-6.9b-v0, EleutherAI\/pythia-6.7b-v0, pythia-6.7b-v0","model_type":"pythia","name.from_cfg":"pythia-6.9b-v0","n_params.as_str":"6.4B","n_params.as_int":6442450944,"n_params.from_name":"6.9b","cfg.n_params":6442450944,"cfg.n_layers":32,"cfg.n_heads":32,"cfg.d_model":4096,"cfg.d_vocab":50432,"cfg.act_fn":"gelu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":true,"cfg.original_architecture":"GPTNeoXForCausalLM","cfg.normalization_type":"LN","config.raw__":{"d_model":4096,"d_head":128,"n_layers":32,"n_ctx":2048,"n_heads":32,"d_mlp":16384,"d_vocab":50432,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":false,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":false,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"gelu","normalization_type":"LN","num_experts":null,"experts_per_token":null,"final_rms":false,"dtype":"torch.float32","model_name":"pythia-6.9b-v0","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"GPTNeoXForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"EleutherAI\/pythia-6.9b-v0","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0125,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":50432,"parallel_attn_mlp":true,"rotary_dim":32,"n_params":6442450944,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 4096\nd_head: 128\nn_layers: 32\nn_ctx: 2048\nn_heads: 32\nd_mlp: 16384\nd_vocab: 50432\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: false\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: false\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: gelu\nnormalization_type: LN\nnum_experts: null\nexperts_per_token: null\nfinal_rms: false\ndtype: torch.float32\nmodel_name: pythia-6.9b-v0\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: GPTNeoXForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: EleutherAI\/pythia-6.9b-v0\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.0125\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 50432\nparallel_attn_mlp: true\nrotary_dim: 32\nn_params: 6442450944\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"EleutherAI\/pythia-6.9b-v0","tokenizer.vocab_size":50254.0,"tokenizer.max_len":null,"tokenizer.class":"GPTNeoXTokenizer","tokenizer.vocab_hash":"96EawM8Lij99W7OBTk0KW2ELUrQ=","tensor_shapes.state_dict":"embed:\n W_E: (50432, 4096)\nblocks:\n '[0-31]':\n ln1:\n '[w, b]': (4096,)\n ln2:\n '[w, b]': (4096,)\n attn:\n '[W_Q, W_K, W_V]': (32, 4096, 128)\n W_O: (32, 128, 4096)\n '[b_Q, b_K, b_V]': (32, 128)\n b_O: (4096,)\n mask: (2048, 2048)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (2048, 32)\n mlp:\n W_in: (4096, 16384)\n b_in: (16384,)\n W_out: (16384, 4096)\n b_out: (4096,)\nln_final:\n '[w, b]': (4096,)\nunembed:\n W_U: (4096, 50432)\n b_U: (50432,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(50432, 4096)"},"blocks":{"[0-31]":{"ln1":{"[w, b]":"(4096,)"},"ln2":{"[w, b]":"(4096,)"},"attn":{"[W_Q, W_K, W_V]":"(32, 4096, 128)","W_O":"(32, 128, 4096)","[b_Q, b_K, b_V]":"(32, 128)","b_O":"(4096,)","mask":"(2048, 2048)","IGNORE":"()","[rotary_sin, rotary_cos]":"(2048, 32)"},"mlp":{"W_in":"(4096, 16384)","b_in":"(16384,)","W_out":"(16384, 4096)","b_out":"(4096,)"}}},"ln_final":{"[w, b]":"(4096,)"},"unembed":{"W_U":"(4096, 50432)","b_U":"(50432,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-31]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 32, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 32, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n mlp:\n '[hook_pre, hook_post]': (batch, seq_len, 16384)\n '[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]': (batch, seq_len,\n 4096)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\nunembed:\n hook_in: (batch, seq_len, 4096)\n hook_out: (batch, seq_len, 50432)\nhook_embed: (batch, seq_len, 4096)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-31]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 32, 128)","[hook_attn_scores, hook_pattern]":"(batch, 32, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"mlp":{"[hook_pre, hook_post]":"(batch, seq_len, 16384)"},"[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 4096)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"unembed":{"hook_in":"(batch, seq_len, 4096)","hook_out":"(batch, seq_len, 50432)"},"hook_embed":"(batch, seq_len, 4096)"}} -{"name.default_alias":"pythia-70m","name.huggingface":"EleutherAI\/pythia-70m","name.aliases":"pythia-70m, pythia, EleutherAI\/pythia-19m, pythia-19m","model_type":"pythia","name.from_cfg":"pythia-70m","n_params.as_str":"19M","n_params.as_int":18874368,"n_params.from_name":"70m","cfg.n_params":18874368,"cfg.n_layers":6,"cfg.n_heads":8,"cfg.d_model":512,"cfg.d_vocab":50304,"cfg.act_fn":"gelu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":true,"cfg.original_architecture":"GPTNeoXForCausalLM","cfg.normalization_type":"LN","config.raw__":{"d_model":512,"d_head":64,"n_layers":6,"n_ctx":2048,"n_heads":8,"d_mlp":2048,"d_vocab":50304,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":false,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":false,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"gelu","normalization_type":"LN","num_experts":null,"experts_per_token":null,"final_rms":false,"dtype":"torch.float32","model_name":"pythia-70m","use_attn_scale":true,"attn_scale":8.0,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"GPTNeoXForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"EleutherAI\/pythia-70m","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0353553391,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":50304,"parallel_attn_mlp":true,"rotary_dim":16,"n_params":18874368,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 512\nd_head: 64\nn_layers: 6\nn_ctx: 2048\nn_heads: 8\nd_mlp: 2048\nd_vocab: 50304\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: false\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: false\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: gelu\nnormalization_type: LN\nnum_experts: null\nexperts_per_token: null\nfinal_rms: false\ndtype: torch.float32\nmodel_name: pythia-70m\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n AAAAAAAAIEA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: GPTNeoXForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: EleutherAI\/pythia-70m\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.035355339059327376\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 50304\nparallel_attn_mlp: true\nrotary_dim: 16\nn_params: 18874368\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":null,"tokenizer.vocab_size":null,"tokenizer.max_len":null,"tokenizer.class":null,"tokenizer.vocab_hash":null,"tensor_shapes.state_dict":null,"tensor_shapes.state_dict.raw__":null,"tensor_shapes.activation_cache":null,"tensor_shapes.activation_cache.raw__":null} +{"name.default_alias":"pythia-70m","name.huggingface":"EleutherAI\/pythia-70m","name.aliases":"pythia-70m, pythia, EleutherAI\/pythia-19m, pythia-19m","model_type":"pythia","name.from_cfg":"pythia-70m","n_params.as_str":"19M","n_params.as_int":18874368,"n_params.from_name":"70m","cfg.n_params":18874368,"cfg.n_layers":6,"cfg.n_heads":8,"cfg.d_model":512,"cfg.d_vocab":50304,"cfg.act_fn":"gelu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":true,"cfg.original_architecture":"GPTNeoXForCausalLM","cfg.normalization_type":"LN","config.raw__":{"d_model":512,"d_head":64,"n_layers":6,"n_ctx":2048,"n_heads":8,"d_mlp":2048,"d_vocab":50304,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":false,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":false,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"gelu","normalization_type":"LN","num_experts":null,"experts_per_token":null,"final_rms":false,"dtype":"torch.float32","model_name":"pythia-70m","use_attn_scale":true,"attn_scale":8.0,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"GPTNeoXForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"EleutherAI\/pythia-70m","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0353553391,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":50304,"parallel_attn_mlp":true,"rotary_dim":16,"n_params":18874368,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 512\nd_head: 64\nn_layers: 6\nn_ctx: 2048\nn_heads: 8\nd_mlp: 2048\nd_vocab: 50304\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: false\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: false\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: gelu\nnormalization_type: LN\nnum_experts: null\nexperts_per_token: null\nfinal_rms: false\ndtype: torch.float32\nmodel_name: pythia-70m\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n AAAAAAAAIEA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: GPTNeoXForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: EleutherAI\/pythia-70m\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.035355339059327376\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 50304\nparallel_attn_mlp: true\nrotary_dim: 16\nn_params: 18874368\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"EleutherAI\/pythia-70m","tokenizer.vocab_size":50254.0,"tokenizer.max_len":null,"tokenizer.class":"GPTNeoXTokenizer","tokenizer.vocab_hash":"96EawM8Lij99W7OBTk0KW2ELUrQ=","tensor_shapes.state_dict":"embed:\n W_E: (50304, 512)\nblocks:\n '[0-5]':\n ln1:\n '[w, b]': (512,)\n ln2:\n '[w, b]': (512,)\n attn:\n '[W_Q, W_K, W_V]': (8, 512, 64)\n W_O: (8, 64, 512)\n '[b_Q, b_K, b_V]': (8, 64)\n b_O: (512,)\n mask: (2048, 2048)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (2048, 16)\n mlp:\n W_in: (512, 2048)\n b_in: (2048,)\n W_out: (2048, 512)\n b_out: (512,)\nln_final:\n '[w, b]': (512,)\nunembed:\n W_U: (512, 50304)\n b_U: (50304,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(50304, 512)"},"blocks":{"[0-5]":{"ln1":{"[w, b]":"(512,)"},"ln2":{"[w, b]":"(512,)"},"attn":{"[W_Q, W_K, W_V]":"(8, 512, 64)","W_O":"(8, 64, 512)","[b_Q, b_K, b_V]":"(8, 64)","b_O":"(512,)","mask":"(2048, 2048)","IGNORE":"()","[rotary_sin, rotary_cos]":"(2048, 16)"},"mlp":{"W_in":"(512, 2048)","b_in":"(2048,)","W_out":"(2048, 512)","b_out":"(512,)"}}},"ln_final":{"[w, b]":"(512,)"},"unembed":{"W_U":"(512, 50304)","b_U":"(50304,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-5]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 512)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 8, 64)\n '[hook_attn_scores, hook_pattern]': (batch, 8, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 512)\n mlp:\n '[hook_pre, hook_post]': (batch, seq_len, 2048)\n '[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]': (batch, seq_len,\n 512)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 512)\nunembed:\n hook_in: (batch, seq_len, 512)\n hook_out: (batch, seq_len, 50304)\nhook_embed: (batch, seq_len, 512)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-5]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 512)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 8, 64)","[hook_attn_scores, hook_pattern]":"(batch, 8, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 512)"},"mlp":{"[hook_pre, hook_post]":"(batch, seq_len, 2048)"},"[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 512)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 512)"},"unembed":{"hook_in":"(batch, seq_len, 512)","hook_out":"(batch, seq_len, 50304)"},"hook_embed":"(batch, seq_len, 512)"}} {"name.default_alias":"pythia-70m-deduped","name.huggingface":"EleutherAI\/pythia-70m-deduped","name.aliases":"pythia-70m-deduped, EleutherAI\/pythia-19m-deduped, pythia-19m-deduped","model_type":"pythia","name.from_cfg":"pythia-70m-deduped","n_params.as_str":"19M","n_params.as_int":18874368,"n_params.from_name":"70m","cfg.n_params":18874368,"cfg.n_layers":6,"cfg.n_heads":8,"cfg.d_model":512,"cfg.d_vocab":50304,"cfg.act_fn":"gelu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":true,"cfg.original_architecture":"GPTNeoXForCausalLM","cfg.normalization_type":"LN","config.raw__":{"d_model":512,"d_head":64,"n_layers":6,"n_ctx":2048,"n_heads":8,"d_mlp":2048,"d_vocab":50304,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":false,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":false,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"gelu","normalization_type":"LN","num_experts":null,"experts_per_token":null,"final_rms":false,"dtype":"torch.float32","model_name":"pythia-70m-deduped","use_attn_scale":true,"attn_scale":8.0,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"GPTNeoXForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"EleutherAI\/pythia-70m-deduped","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0353553391,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":50304,"parallel_attn_mlp":true,"rotary_dim":16,"n_params":18874368,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 512\nd_head: 64\nn_layers: 6\nn_ctx: 2048\nn_heads: 8\nd_mlp: 2048\nd_vocab: 50304\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: false\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: false\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: gelu\nnormalization_type: LN\nnum_experts: null\nexperts_per_token: null\nfinal_rms: false\ndtype: torch.float32\nmodel_name: pythia-70m-deduped\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n AAAAAAAAIEA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: GPTNeoXForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: EleutherAI\/pythia-70m-deduped\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.035355339059327376\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 50304\nparallel_attn_mlp: true\nrotary_dim: 16\nn_params: 18874368\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"EleutherAI\/pythia-70m-deduped","tokenizer.vocab_size":50254.0,"tokenizer.max_len":null,"tokenizer.class":"GPTNeoXTokenizer","tokenizer.vocab_hash":"96EawM8Lij99W7OBTk0KW2ELUrQ=","tensor_shapes.state_dict":"embed:\n W_E: (50304, 512)\nblocks:\n '[0-5]':\n ln1:\n '[w, b]': (512,)\n ln2:\n '[w, b]': (512,)\n attn:\n '[W_Q, W_K, W_V]': (8, 512, 64)\n W_O: (8, 64, 512)\n '[b_Q, b_K, b_V]': (8, 64)\n b_O: (512,)\n mask: (2048, 2048)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (2048, 16)\n mlp:\n W_in: (512, 2048)\n b_in: (2048,)\n W_out: (2048, 512)\n b_out: (512,)\nln_final:\n '[w, b]': (512,)\nunembed:\n W_U: (512, 50304)\n b_U: (50304,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(50304, 512)"},"blocks":{"[0-5]":{"ln1":{"[w, b]":"(512,)"},"ln2":{"[w, b]":"(512,)"},"attn":{"[W_Q, W_K, W_V]":"(8, 512, 64)","W_O":"(8, 64, 512)","[b_Q, b_K, b_V]":"(8, 64)","b_O":"(512,)","mask":"(2048, 2048)","IGNORE":"()","[rotary_sin, rotary_cos]":"(2048, 16)"},"mlp":{"W_in":"(512, 2048)","b_in":"(2048,)","W_out":"(2048, 512)","b_out":"(512,)"}}},"ln_final":{"[w, b]":"(512,)"},"unembed":{"W_U":"(512, 50304)","b_U":"(50304,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-5]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 512)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 8, 64)\n '[hook_attn_scores, hook_pattern]': (batch, 8, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 512)\n mlp:\n '[hook_pre, hook_post]': (batch, seq_len, 2048)\n '[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]': (batch, seq_len,\n 512)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 512)\nunembed:\n hook_in: (batch, seq_len, 512)\n hook_out: (batch, seq_len, 50304)\nhook_embed: (batch, seq_len, 512)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-5]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 512)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 8, 64)","[hook_attn_scores, hook_pattern]":"(batch, 8, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 512)"},"mlp":{"[hook_pre, hook_post]":"(batch, seq_len, 2048)"},"[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 512)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 512)"},"unembed":{"hook_in":"(batch, seq_len, 512)","hook_out":"(batch, seq_len, 50304)"},"hook_embed":"(batch, seq_len, 512)"}} {"name.default_alias":"pythia-70m-deduped-v0","name.huggingface":"EleutherAI\/pythia-70m-deduped-v0","name.aliases":"pythia-70m-deduped-v0, EleutherAI\/pythia-19m-deduped-v0, pythia-19m-deduped-v0","model_type":"pythia","name.from_cfg":"pythia-70m-deduped-v0","n_params.as_str":"19M","n_params.as_int":18874368,"n_params.from_name":"70m","cfg.n_params":18874368,"cfg.n_layers":6,"cfg.n_heads":8,"cfg.d_model":512,"cfg.d_vocab":50304,"cfg.act_fn":"gelu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":true,"cfg.original_architecture":"GPTNeoXForCausalLM","cfg.normalization_type":"LN","config.raw__":{"d_model":512,"d_head":64,"n_layers":6,"n_ctx":2048,"n_heads":8,"d_mlp":2048,"d_vocab":50304,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":false,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":false,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"gelu","normalization_type":"LN","num_experts":null,"experts_per_token":null,"final_rms":false,"dtype":"torch.float32","model_name":"pythia-70m-deduped-v0","use_attn_scale":true,"attn_scale":8.0,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"GPTNeoXForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"EleutherAI\/pythia-70m-deduped-v0","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0353553391,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":50304,"parallel_attn_mlp":true,"rotary_dim":16,"n_params":18874368,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 512\nd_head: 64\nn_layers: 6\nn_ctx: 2048\nn_heads: 8\nd_mlp: 2048\nd_vocab: 50304\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: false\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: false\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: gelu\nnormalization_type: LN\nnum_experts: null\nexperts_per_token: null\nfinal_rms: false\ndtype: torch.float32\nmodel_name: pythia-70m-deduped-v0\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n AAAAAAAAIEA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: GPTNeoXForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: EleutherAI\/pythia-70m-deduped-v0\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.035355339059327376\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 50304\nparallel_attn_mlp: true\nrotary_dim: 16\nn_params: 18874368\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"EleutherAI\/pythia-70m-deduped-v0","tokenizer.vocab_size":50254.0,"tokenizer.max_len":null,"tokenizer.class":"GPTNeoXTokenizer","tokenizer.vocab_hash":"96EawM8Lij99W7OBTk0KW2ELUrQ=","tensor_shapes.state_dict":"embed:\n W_E: (50304, 512)\nblocks:\n '[0-5]':\n ln1:\n '[w, b]': (512,)\n ln2:\n '[w, b]': (512,)\n attn:\n '[W_Q, W_K, W_V]': (8, 512, 64)\n W_O: (8, 64, 512)\n '[b_Q, b_K, b_V]': (8, 64)\n b_O: (512,)\n mask: (2048, 2048)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (2048, 16)\n mlp:\n W_in: (512, 2048)\n b_in: (2048,)\n W_out: (2048, 512)\n b_out: (512,)\nln_final:\n '[w, b]': (512,)\nunembed:\n W_U: (512, 50304)\n b_U: (50304,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(50304, 512)"},"blocks":{"[0-5]":{"ln1":{"[w, b]":"(512,)"},"ln2":{"[w, b]":"(512,)"},"attn":{"[W_Q, W_K, W_V]":"(8, 512, 64)","W_O":"(8, 64, 512)","[b_Q, b_K, b_V]":"(8, 64)","b_O":"(512,)","mask":"(2048, 2048)","IGNORE":"()","[rotary_sin, rotary_cos]":"(2048, 16)"},"mlp":{"W_in":"(512, 2048)","b_in":"(2048,)","W_out":"(2048, 512)","b_out":"(512,)"}}},"ln_final":{"[w, b]":"(512,)"},"unembed":{"W_U":"(512, 50304)","b_U":"(50304,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-5]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 512)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 8, 64)\n '[hook_attn_scores, hook_pattern]': (batch, 8, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 512)\n mlp:\n '[hook_pre, hook_post]': (batch, seq_len, 2048)\n '[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]': (batch, seq_len,\n 512)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 512)\nunembed:\n hook_in: (batch, seq_len, 512)\n hook_out: (batch, seq_len, 50304)\nhook_embed: (batch, seq_len, 512)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-5]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 512)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 8, 64)","[hook_attn_scores, hook_pattern]":"(batch, 8, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 512)"},"mlp":{"[hook_pre, hook_post]":"(batch, seq_len, 2048)"},"[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 512)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 512)"},"unembed":{"hook_in":"(batch, seq_len, 512)","hook_out":"(batch, seq_len, 50304)"},"hook_embed":"(batch, seq_len, 512)"}} {"name.default_alias":"pythia-70m-v0","name.huggingface":"EleutherAI\/pythia-70m-v0","name.aliases":"pythia-70m-v0, pythia-v0, EleutherAI\/pythia-19m-v0, pythia-19m-v0","model_type":"pythia","name.from_cfg":"pythia-70m-v0","n_params.as_str":"19M","n_params.as_int":18874368,"n_params.from_name":"70m","cfg.n_params":18874368,"cfg.n_layers":6,"cfg.n_heads":8,"cfg.d_model":512,"cfg.d_vocab":50304,"cfg.act_fn":"gelu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":true,"cfg.original_architecture":"GPTNeoXForCausalLM","cfg.normalization_type":"LN","config.raw__":{"d_model":512,"d_head":64,"n_layers":6,"n_ctx":2048,"n_heads":8,"d_mlp":2048,"d_vocab":50304,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":false,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":false,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"gelu","normalization_type":"LN","num_experts":null,"experts_per_token":null,"final_rms":false,"dtype":"torch.float32","model_name":"pythia-70m-v0","use_attn_scale":true,"attn_scale":8.0,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"GPTNeoXForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"EleutherAI\/pythia-70m-v0","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0353553391,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":50304,"parallel_attn_mlp":true,"rotary_dim":16,"n_params":18874368,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 512\nd_head: 64\nn_layers: 6\nn_ctx: 2048\nn_heads: 8\nd_mlp: 2048\nd_vocab: 50304\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: false\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: false\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: gelu\nnormalization_type: LN\nnum_experts: null\nexperts_per_token: null\nfinal_rms: false\ndtype: torch.float32\nmodel_name: pythia-70m-v0\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n AAAAAAAAIEA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: GPTNeoXForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: EleutherAI\/pythia-70m-v0\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.035355339059327376\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 50304\nparallel_attn_mlp: true\nrotary_dim: 16\nn_params: 18874368\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"EleutherAI\/pythia-70m-v0","tokenizer.vocab_size":50254.0,"tokenizer.max_len":null,"tokenizer.class":"GPTNeoXTokenizer","tokenizer.vocab_hash":"96EawM8Lij99W7OBTk0KW2ELUrQ=","tensor_shapes.state_dict":"embed:\n W_E: (50304, 512)\nblocks:\n '[0-5]':\n ln1:\n '[w, b]': (512,)\n ln2:\n '[w, b]': (512,)\n attn:\n '[W_Q, W_K, W_V]': (8, 512, 64)\n W_O: (8, 64, 512)\n '[b_Q, b_K, b_V]': (8, 64)\n b_O: (512,)\n mask: (2048, 2048)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (2048, 16)\n mlp:\n W_in: (512, 2048)\n b_in: (2048,)\n W_out: (2048, 512)\n b_out: (512,)\nln_final:\n '[w, b]': (512,)\nunembed:\n W_U: (512, 50304)\n b_U: (50304,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(50304, 512)"},"blocks":{"[0-5]":{"ln1":{"[w, b]":"(512,)"},"ln2":{"[w, b]":"(512,)"},"attn":{"[W_Q, W_K, W_V]":"(8, 512, 64)","W_O":"(8, 64, 512)","[b_Q, b_K, b_V]":"(8, 64)","b_O":"(512,)","mask":"(2048, 2048)","IGNORE":"()","[rotary_sin, rotary_cos]":"(2048, 16)"},"mlp":{"W_in":"(512, 2048)","b_in":"(2048,)","W_out":"(2048, 512)","b_out":"(512,)"}}},"ln_final":{"[w, b]":"(512,)"},"unembed":{"W_U":"(512, 50304)","b_U":"(50304,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-5]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 512)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 8, 64)\n '[hook_attn_scores, hook_pattern]': (batch, 8, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 512)\n mlp:\n '[hook_pre, hook_post]': (batch, seq_len, 2048)\n '[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]': (batch, seq_len,\n 512)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 512)\nunembed:\n hook_in: (batch, seq_len, 512)\n hook_out: (batch, seq_len, 50304)\nhook_embed: (batch, seq_len, 512)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-5]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 512)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 8, 64)","[hook_attn_scores, hook_pattern]":"(batch, 8, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 512)"},"mlp":{"[hook_pre, hook_post]":"(batch, seq_len, 2048)"},"[hook_resid_pre, hook_attn_out, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 512)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 512)"},"unembed":{"hook_in":"(batch, seq_len, 512)","hook_out":"(batch, seq_len, 50304)"},"hook_embed":"(batch, seq_len, 512)"}} @@ -118,7 +118,7 @@ {"name.default_alias":"Llama-2-13b-chat","name.huggingface":null,"name.aliases":"","model_type":"Llama-2","name.from_cfg":"Llama-2-13b-chat-hf","n_params.as_str":"13B","n_params.as_int":12687769600,"n_params.from_name":"13b","cfg.n_params":12687769600,"cfg.n_layers":40,"cfg.n_heads":40,"cfg.d_model":5120,"cfg.d_vocab":32000,"cfg.act_fn":"silu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"LlamaForCausalLM","cfg.normalization_type":"RMS","config.raw__":{"d_model":5120,"d_head":128,"n_layers":40,"n_ctx":4096,"n_heads":40,"d_mlp":13824,"d_vocab":32000,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":true,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"silu","normalization_type":"RMS","num_experts":null,"experts_per_token":null,"final_rms":true,"dtype":"torch.float32","model_name":"Llama-2-13b-chat-hf","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"LlamaForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"meta-llama\/Llama-2-13b-chat-hf","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0111803399,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":32000,"parallel_attn_mlp":false,"rotary_dim":128,"n_params":12687769600,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 5120\nd_head: 128\nn_layers: 40\nn_ctx: 4096\nn_heads: 40\nd_mlp: 13824\nd_vocab: 32000\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: true\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: silu\nnormalization_type: RMS\nnum_experts: null\nexperts_per_token: null\nfinal_rms: true\ndtype: torch.float32\nmodel_name: Llama-2-13b-chat-hf\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: LlamaForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: meta-llama\/Llama-2-13b-chat-hf\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.011180339887498949\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 32000\nparallel_attn_mlp: false\nrotary_dim: 128\nn_params: 12687769600\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"meta-llama\/Llama-2-13b-chat-hf","tokenizer.vocab_size":32000.0,"tokenizer.max_len":null,"tokenizer.class":"TokenizersBackend","tokenizer.vocab_hash":"e3A7wYziNQPAWcJ15GMAQY8qZqw=","tensor_shapes.state_dict":"embed:\n W_E: (32000, 5120)\nblocks:\n '[0-39]':\n ln1:\n w: (5120,)\n ln2:\n w: (5120,)\n attn:\n '[W_Q, W_K, W_V]': (40, 5120, 128)\n W_O: (40, 128, 5120)\n '[b_Q, b_K, b_V]': (40, 128)\n b_O: (5120,)\n mask: (4096, 4096)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (4096, 128)\n mlp:\n '[W_in, W_gate]': (5120, 13824)\n W_out: (13824, 5120)\n b_in: (13824,)\n b_out: (5120,)\nln_final:\n w: (5120,)\nunembed:\n W_U: (5120, 32000)\n b_U: (32000,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(32000, 5120)"},"blocks":{"[0-39]":{"ln1":{"w":"(5120,)"},"ln2":{"w":"(5120,)"},"attn":{"[W_Q, W_K, W_V]":"(40, 5120, 128)","W_O":"(40, 128, 5120)","[b_Q, b_K, b_V]":"(40, 128)","b_O":"(5120,)","mask":"(4096, 4096)","IGNORE":"()","[rotary_sin, rotary_cos]":"(4096, 128)"},"mlp":{"[W_in, W_gate]":"(5120, 13824)","W_out":"(13824, 5120)","b_in":"(13824,)","b_out":"(5120,)"}}},"ln_final":{"w":"(5120,)"},"unembed":{"W_U":"(5120, 32000)","b_U":"(32000,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-39]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 5120)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 40, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 40, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 5120)\n mlp:\n '[hook_pre, hook_pre_linear, hook_post]': (batch, seq_len, 13824)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 5120)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 5120)\nunembed:\n hook_in: (batch, seq_len, 5120)\n hook_out: (batch, seq_len, 32000)\nhook_embed: (batch, seq_len, 5120)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-39]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 5120)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 40, 128)","[hook_attn_scores, hook_pattern]":"(batch, 40, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 5120)"},"mlp":{"[hook_pre, hook_pre_linear, hook_post]":"(batch, seq_len, 13824)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 5120)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 5120)"},"unembed":{"hook_in":"(batch, seq_len, 5120)","hook_out":"(batch, seq_len, 32000)"},"hook_embed":"(batch, seq_len, 5120)"}} {"name.default_alias":"Llama-2-13b","name.huggingface":null,"name.aliases":"","model_type":"Llama-2","name.from_cfg":"Llama-2-13b-hf","n_params.as_str":"13B","n_params.as_int":12687769600,"n_params.from_name":"13b","cfg.n_params":12687769600,"cfg.n_layers":40,"cfg.n_heads":40,"cfg.d_model":5120,"cfg.d_vocab":32000,"cfg.act_fn":"silu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"LlamaForCausalLM","cfg.normalization_type":"RMS","config.raw__":{"d_model":5120,"d_head":128,"n_layers":40,"n_ctx":4096,"n_heads":40,"d_mlp":13824,"d_vocab":32000,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":true,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"silu","normalization_type":"RMS","num_experts":null,"experts_per_token":null,"final_rms":true,"dtype":"torch.float32","model_name":"Llama-2-13b-hf","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"LlamaForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"meta-llama\/Llama-2-13b-hf","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0111803399,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":32000,"parallel_attn_mlp":false,"rotary_dim":128,"n_params":12687769600,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 5120\nd_head: 128\nn_layers: 40\nn_ctx: 4096\nn_heads: 40\nd_mlp: 13824\nd_vocab: 32000\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: true\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: silu\nnormalization_type: RMS\nnum_experts: null\nexperts_per_token: null\nfinal_rms: true\ndtype: torch.float32\nmodel_name: Llama-2-13b-hf\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: LlamaForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: meta-llama\/Llama-2-13b-hf\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.011180339887498949\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 32000\nparallel_attn_mlp: false\nrotary_dim: 128\nn_params: 12687769600\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"meta-llama\/Llama-2-13b-hf","tokenizer.vocab_size":32000.0,"tokenizer.max_len":null,"tokenizer.class":"TokenizersBackend","tokenizer.vocab_hash":"e3A7wYziNQPAWcJ15GMAQY8qZqw=","tensor_shapes.state_dict":"embed:\n W_E: (32000, 5120)\nblocks:\n '[0-39]':\n ln1:\n w: (5120,)\n ln2:\n w: (5120,)\n attn:\n '[W_Q, W_K, W_V]': (40, 5120, 128)\n W_O: (40, 128, 5120)\n '[b_Q, b_K, b_V]': (40, 128)\n b_O: (5120,)\n mask: (4096, 4096)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (4096, 128)\n mlp:\n '[W_in, W_gate]': (5120, 13824)\n W_out: (13824, 5120)\n b_in: (13824,)\n b_out: (5120,)\nln_final:\n w: (5120,)\nunembed:\n W_U: (5120, 32000)\n b_U: (32000,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(32000, 5120)"},"blocks":{"[0-39]":{"ln1":{"w":"(5120,)"},"ln2":{"w":"(5120,)"},"attn":{"[W_Q, W_K, W_V]":"(40, 5120, 128)","W_O":"(40, 128, 5120)","[b_Q, b_K, b_V]":"(40, 128)","b_O":"(5120,)","mask":"(4096, 4096)","IGNORE":"()","[rotary_sin, rotary_cos]":"(4096, 128)"},"mlp":{"[W_in, W_gate]":"(5120, 13824)","W_out":"(13824, 5120)","b_in":"(13824,)","b_out":"(5120,)"}}},"ln_final":{"w":"(5120,)"},"unembed":{"W_U":"(5120, 32000)","b_U":"(32000,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-39]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 5120)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 40, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 40, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 5120)\n mlp:\n '[hook_pre, hook_pre_linear, hook_post]': (batch, seq_len, 13824)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 5120)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 5120)\nunembed:\n hook_in: (batch, seq_len, 5120)\n hook_out: (batch, seq_len, 32000)\nhook_embed: (batch, seq_len, 5120)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-39]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 5120)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 40, 128)","[hook_attn_scores, hook_pattern]":"(batch, 40, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 5120)"},"mlp":{"[hook_pre, hook_pre_linear, hook_post]":"(batch, seq_len, 13824)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 5120)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 5120)"},"unembed":{"hook_in":"(batch, seq_len, 5120)","hook_out":"(batch, seq_len, 32000)"},"hook_embed":"(batch, seq_len, 5120)"}} {"name.default_alias":"Llama-2-70b-chat","name.huggingface":null,"name.aliases":"","model_type":"Llama-2","name.from_cfg":"Llama-2-70b-chat-hf","n_params.as_str":"78B","n_params.as_int":77846282240,"n_params.from_name":"70b","cfg.n_params":77846282240,"cfg.n_layers":80,"cfg.n_heads":64,"cfg.d_model":8192,"cfg.d_vocab":32000,"cfg.act_fn":"silu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"LlamaForCausalLM","cfg.normalization_type":"RMS","config.raw__":{"d_model":8192,"d_head":128,"n_layers":80,"n_ctx":4096,"n_heads":64,"d_mlp":28672,"d_vocab":32000,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"rotary","n_key_value_heads":8,"attn_only":false,"gated_mlp":true,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"silu","normalization_type":"RMS","num_experts":null,"experts_per_token":null,"final_rms":true,"dtype":"torch.float32","model_name":"Llama-2-70b-chat-hf","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"LlamaForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"meta-llama\/Llama-2-70b-chat-hf","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0088388348,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":32000,"parallel_attn_mlp":false,"rotary_dim":128,"n_params":77846282240,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 8192\nd_head: 128\nn_layers: 80\nn_ctx: 4096\nn_heads: 64\nd_mlp: 28672\nd_vocab: 32000\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: rotary\nn_key_value_heads: 8\nattn_only: false\ngated_mlp: true\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: silu\nnormalization_type: RMS\nnum_experts: null\nexperts_per_token: null\nfinal_rms: true\ndtype: torch.float32\nmodel_name: Llama-2-70b-chat-hf\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: LlamaForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: meta-llama\/Llama-2-70b-chat-hf\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.008838834764831844\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 32000\nparallel_attn_mlp: false\nrotary_dim: 128\nn_params: 77846282240\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"meta-llama\/Llama-2-70b-chat-hf","tokenizer.vocab_size":32000.0,"tokenizer.max_len":null,"tokenizer.class":"TokenizersBackend","tokenizer.vocab_hash":"e3A7wYziNQPAWcJ15GMAQY8qZqw=","tensor_shapes.state_dict":"embed:\n W_E: (32000, 8192)\nblocks:\n '[0-79]':\n ln1:\n w: (8192,)\n ln2:\n w: (8192,)\n attn:\n W_Q: (64, 8192, 128)\n W_O: (64, 128, 8192)\n b_Q: (64, 128)\n b_O: (8192,)\n '[_W_K, _W_V]': (8, 8192, 128)\n '[_b_K, _b_V]': (8, 128)\n mask: (4096, 4096)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (4096, 128)\n mlp:\n '[W_in, W_gate]': (8192, 28672)\n W_out: (28672, 8192)\n b_in: (28672,)\n b_out: (8192,)\nln_final:\n w: (8192,)\nunembed:\n W_U: (8192, 32000)\n b_U: (32000,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(32000, 8192)"},"blocks":{"[0-79]":{"ln1":{"w":"(8192,)"},"ln2":{"w":"(8192,)"},"attn":{"W_Q":"(64, 8192, 128)","W_O":"(64, 128, 8192)","b_Q":"(64, 128)","b_O":"(8192,)","[_W_K, _W_V]":"(8, 8192, 128)","[_b_K, _b_V]":"(8, 128)","mask":"(4096, 4096)","IGNORE":"()","[rotary_sin, rotary_cos]":"(4096, 128)"},"mlp":{"[W_in, W_gate]":"(8192, 28672)","W_out":"(28672, 8192)","b_in":"(28672,)","b_out":"(8192,)"}}},"ln_final":{"w":"(8192,)"},"unembed":{"W_U":"(8192, 32000)","b_U":"(32000,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-79]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 8192)\n attn:\n '[hook_q, hook_rot_q, hook_z]': (batch, seq_len, 64, 128)\n '[hook_k, hook_v, hook_rot_k]': (batch, seq_len, 8, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 64, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 8192)\n mlp:\n '[hook_pre, hook_pre_linear, hook_post]': (batch, seq_len, 28672)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 8192)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 8192)\nunembed:\n hook_in: (batch, seq_len, 8192)\n hook_out: (batch, seq_len, 32000)\nhook_embed: (batch, seq_len, 8192)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-79]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 8192)"},"attn":{"[hook_q, hook_rot_q, hook_z]":"(batch, seq_len, 64, 128)","[hook_k, hook_v, hook_rot_k]":"(batch, seq_len, 8, 128)","[hook_attn_scores, hook_pattern]":"(batch, 64, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 8192)"},"mlp":{"[hook_pre, hook_pre_linear, hook_post]":"(batch, seq_len, 28672)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 8192)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 8192)"},"unembed":{"hook_in":"(batch, seq_len, 8192)","hook_out":"(batch, seq_len, 32000)"},"hook_embed":"(batch, seq_len, 8192)"}} -{"name.default_alias":"Llama-2-7b-chat","name.huggingface":null,"name.aliases":"","model_type":"Llama-2","name.from_cfg":"Llama-2-7b-chat-hf","n_params.as_str":"6.5B","n_params.as_int":6476005376,"n_params.from_name":"7b","cfg.n_params":6476005376,"cfg.n_layers":32,"cfg.n_heads":32,"cfg.d_model":4096,"cfg.d_vocab":32000,"cfg.act_fn":"silu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"LlamaForCausalLM","cfg.normalization_type":"RMS","config.raw__":{"d_model":4096,"d_head":128,"n_layers":32,"n_ctx":4096,"n_heads":32,"d_mlp":11008,"d_vocab":32000,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":true,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"silu","normalization_type":"RMS","num_experts":null,"experts_per_token":null,"final_rms":true,"dtype":"torch.float32","model_name":"Llama-2-7b-chat-hf","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"LlamaForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"meta-llama\/Llama-2-7b-chat-hf","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0125,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":32000,"parallel_attn_mlp":false,"rotary_dim":128,"n_params":6476005376,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 4096\nd_head: 128\nn_layers: 32\nn_ctx: 4096\nn_heads: 32\nd_mlp: 11008\nd_vocab: 32000\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: true\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: silu\nnormalization_type: RMS\nnum_experts: null\nexperts_per_token: null\nfinal_rms: true\ndtype: torch.float32\nmodel_name: Llama-2-7b-chat-hf\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: LlamaForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: meta-llama\/Llama-2-7b-chat-hf\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.0125\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 32000\nparallel_attn_mlp: false\nrotary_dim: 128\nn_params: 6476005376\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"meta-llama\/Llama-2-7b-chat-hf","tokenizer.vocab_size":32000.0,"tokenizer.max_len":null,"tokenizer.class":"LlamaTokenizer","tokenizer.vocab_hash":"e3A7wYziNQPAWcJ15GMAQY8qZqw=","tensor_shapes.state_dict":"embed:\n W_E: (32000, 4096)\nblocks:\n '[0-31]':\n ln1:\n w: (4096,)\n ln2:\n w: (4096,)\n attn:\n '[W_Q, W_K, W_V]': (32, 4096, 128)\n W_O: (32, 128, 4096)\n '[b_Q, b_K, b_V]': (32, 128)\n b_O: (4096,)\n mask: (4096, 4096)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (4096, 128)\n mlp:\n '[W_in, W_gate]': (4096, 11008)\n W_out: (11008, 4096)\n b_in: (11008,)\n b_out: (4096,)\nln_final:\n w: (4096,)\nunembed:\n W_U: (4096, 32000)\n b_U: (32000,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(32000, 4096)"},"blocks":{"[0-31]":{"ln1":{"w":"(4096,)"},"ln2":{"w":"(4096,)"},"attn":{"[W_Q, W_K, W_V]":"(32, 4096, 128)","W_O":"(32, 128, 4096)","[b_Q, b_K, b_V]":"(32, 128)","b_O":"(4096,)","mask":"(4096, 4096)","IGNORE":"()","[rotary_sin, rotary_cos]":"(4096, 128)"},"mlp":{"[W_in, W_gate]":"(4096, 11008)","W_out":"(11008, 4096)","b_in":"(11008,)","b_out":"(4096,)"}}},"ln_final":{"w":"(4096,)"},"unembed":{"W_U":"(4096, 32000)","b_U":"(32000,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-31]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 32, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 32, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n mlp:\n '[hook_pre, hook_pre_linear, hook_post]': (batch, seq_len, 11008)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 4096)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\nunembed:\n hook_in: (batch, seq_len, 4096)\n hook_out: (batch, seq_len, 32000)\nhook_embed: (batch, seq_len, 4096)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-31]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 32, 128)","[hook_attn_scores, hook_pattern]":"(batch, 32, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"mlp":{"[hook_pre, hook_pre_linear, hook_post]":"(batch, seq_len, 11008)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 4096)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"unembed":{"hook_in":"(batch, seq_len, 4096)","hook_out":"(batch, seq_len, 32000)"},"hook_embed":"(batch, seq_len, 4096)"}} +{"name.default_alias":"Llama-2-7b-chat","name.huggingface":null,"name.aliases":"","model_type":"Llama-2","name.from_cfg":"Llama-2-7b-chat-hf","n_params.as_str":"6.5B","n_params.as_int":6476005376,"n_params.from_name":"7b","cfg.n_params":6476005376,"cfg.n_layers":32,"cfg.n_heads":32,"cfg.d_model":4096,"cfg.d_vocab":32000,"cfg.act_fn":"silu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"LlamaForCausalLM","cfg.normalization_type":"RMS","config.raw__":{"d_model":4096,"d_head":128,"n_layers":32,"n_ctx":4096,"n_heads":32,"d_mlp":11008,"d_vocab":32000,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":true,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"silu","normalization_type":"RMS","num_experts":null,"experts_per_token":null,"final_rms":true,"dtype":"torch.float32","model_name":"Llama-2-7b-chat-hf","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"LlamaForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"meta-llama\/Llama-2-7b-chat-hf","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0125,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":32000,"parallel_attn_mlp":false,"rotary_dim":128,"n_params":6476005376,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 4096\nd_head: 128\nn_layers: 32\nn_ctx: 4096\nn_heads: 32\nd_mlp: 11008\nd_vocab: 32000\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: true\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: silu\nnormalization_type: RMS\nnum_experts: null\nexperts_per_token: null\nfinal_rms: true\ndtype: torch.float32\nmodel_name: Llama-2-7b-chat-hf\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: LlamaForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: meta-llama\/Llama-2-7b-chat-hf\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.0125\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 32000\nparallel_attn_mlp: false\nrotary_dim: 128\nn_params: 6476005376\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"meta-llama\/Llama-2-7b-chat-hf","tokenizer.vocab_size":32000.0,"tokenizer.max_len":null,"tokenizer.class":"TokenizersBackend","tokenizer.vocab_hash":"e3A7wYziNQPAWcJ15GMAQY8qZqw=","tensor_shapes.state_dict":"embed:\n W_E: (32000, 4096)\nblocks:\n '[0-31]':\n ln1:\n w: (4096,)\n ln2:\n w: (4096,)\n attn:\n '[W_Q, W_K, W_V]': (32, 4096, 128)\n W_O: (32, 128, 4096)\n '[b_Q, b_K, b_V]': (32, 128)\n b_O: (4096,)\n mask: (4096, 4096)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (4096, 128)\n mlp:\n '[W_in, W_gate]': (4096, 11008)\n W_out: (11008, 4096)\n b_in: (11008,)\n b_out: (4096,)\nln_final:\n w: (4096,)\nunembed:\n W_U: (4096, 32000)\n b_U: (32000,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(32000, 4096)"},"blocks":{"[0-31]":{"ln1":{"w":"(4096,)"},"ln2":{"w":"(4096,)"},"attn":{"[W_Q, W_K, W_V]":"(32, 4096, 128)","W_O":"(32, 128, 4096)","[b_Q, b_K, b_V]":"(32, 128)","b_O":"(4096,)","mask":"(4096, 4096)","IGNORE":"()","[rotary_sin, rotary_cos]":"(4096, 128)"},"mlp":{"[W_in, W_gate]":"(4096, 11008)","W_out":"(11008, 4096)","b_in":"(11008,)","b_out":"(4096,)"}}},"ln_final":{"w":"(4096,)"},"unembed":{"W_U":"(4096, 32000)","b_U":"(32000,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-31]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 32, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 32, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n mlp:\n '[hook_pre, hook_pre_linear, hook_post]': (batch, seq_len, 11008)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 4096)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\nunembed:\n hook_in: (batch, seq_len, 4096)\n hook_out: (batch, seq_len, 32000)\nhook_embed: (batch, seq_len, 4096)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-31]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 32, 128)","[hook_attn_scores, hook_pattern]":"(batch, 32, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"mlp":{"[hook_pre, hook_pre_linear, hook_post]":"(batch, seq_len, 11008)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 4096)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"unembed":{"hook_in":"(batch, seq_len, 4096)","hook_out":"(batch, seq_len, 32000)"},"hook_embed":"(batch, seq_len, 4096)"}} {"name.default_alias":"Llama-2-7b","name.huggingface":null,"name.aliases":"","model_type":"Llama-2","name.from_cfg":"Llama-2-7b-hf","n_params.as_str":"6.5B","n_params.as_int":6476005376,"n_params.from_name":"7b","cfg.n_params":6476005376,"cfg.n_layers":32,"cfg.n_heads":32,"cfg.d_model":4096,"cfg.d_vocab":32000,"cfg.act_fn":"silu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"LlamaForCausalLM","cfg.normalization_type":"RMS","config.raw__":{"d_model":4096,"d_head":128,"n_layers":32,"n_ctx":4096,"n_heads":32,"d_mlp":11008,"d_vocab":32000,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"rotary","n_key_value_heads":null,"attn_only":false,"gated_mlp":true,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"silu","normalization_type":"RMS","num_experts":null,"experts_per_token":null,"final_rms":true,"dtype":"torch.float32","model_name":"Llama-2-7b-hf","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"LlamaForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"meta-llama\/Llama-2-7b-hf","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0125,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":32000,"parallel_attn_mlp":false,"rotary_dim":128,"n_params":6476005376,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":10000,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":false,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 4096\nd_head: 128\nn_layers: 32\nn_ctx: 4096\nn_heads: 32\nd_mlp: 11008\nd_vocab: 32000\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: rotary\nn_key_value_heads: null\nattn_only: false\ngated_mlp: true\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: silu\nnormalization_type: RMS\nnum_experts: null\nexperts_per_token: null\nfinal_rms: true\ndtype: torch.float32\nmodel_name: Llama-2-7b-hf\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: LlamaForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: meta-llama\/Llama-2-7b-hf\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.0125\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 32000\nparallel_attn_mlp: false\nrotary_dim: 128\nn_params: 6476005376\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 10000\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: false\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"meta-llama\/Llama-2-7b-hf","tokenizer.vocab_size":32000.0,"tokenizer.max_len":null,"tokenizer.class":"TokenizersBackend","tokenizer.vocab_hash":"e3A7wYziNQPAWcJ15GMAQY8qZqw=","tensor_shapes.state_dict":"embed:\n W_E: (32000, 4096)\nblocks:\n '[0-31]':\n ln1:\n w: (4096,)\n ln2:\n w: (4096,)\n attn:\n '[W_Q, W_K, W_V]': (32, 4096, 128)\n W_O: (32, 128, 4096)\n '[b_Q, b_K, b_V]': (32, 128)\n b_O: (4096,)\n mask: (4096, 4096)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (4096, 128)\n mlp:\n '[W_in, W_gate]': (4096, 11008)\n W_out: (11008, 4096)\n b_in: (11008,)\n b_out: (4096,)\nln_final:\n w: (4096,)\nunembed:\n W_U: (4096, 32000)\n b_U: (32000,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(32000, 4096)"},"blocks":{"[0-31]":{"ln1":{"w":"(4096,)"},"ln2":{"w":"(4096,)"},"attn":{"[W_Q, W_K, W_V]":"(32, 4096, 128)","W_O":"(32, 128, 4096)","[b_Q, b_K, b_V]":"(32, 128)","b_O":"(4096,)","mask":"(4096, 4096)","IGNORE":"()","[rotary_sin, rotary_cos]":"(4096, 128)"},"mlp":{"[W_in, W_gate]":"(4096, 11008)","W_out":"(11008, 4096)","b_in":"(11008,)","b_out":"(4096,)"}}},"ln_final":{"w":"(4096,)"},"unembed":{"W_U":"(4096, 32000)","b_U":"(32000,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-31]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n attn:\n '[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]': (batch, seq_len,\n 32, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 32, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\n mlp:\n '[hook_pre, hook_pre_linear, hook_post]': (batch, seq_len, 11008)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 4096)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 4096)\nunembed:\n hook_in: (batch, seq_len, 4096)\n hook_out: (batch, seq_len, 32000)\nhook_embed: (batch, seq_len, 4096)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-31]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"attn":{"[hook_q, hook_k, hook_v, hook_rot_q, hook_rot_k, hook_z]":"(batch, seq_len, 32, 128)","[hook_attn_scores, hook_pattern]":"(batch, 32, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"mlp":{"[hook_pre, hook_pre_linear, hook_post]":"(batch, seq_len, 11008)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 4096)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 4096)"},"unembed":{"hook_in":"(batch, seq_len, 4096)","hook_out":"(batch, seq_len, 32000)"},"hook_embed":"(batch, seq_len, 4096)"}} {"name.default_alias":"meta-llama\/Llama-3.1-70B","name.huggingface":null,"name.aliases":"","model_type":"llama","name.from_cfg":"Llama-3.1-70B","n_params.as_str":"78B","n_params.as_int":77846282240,"n_params.from_name":"70B","cfg.n_params":77846282240,"cfg.n_layers":80,"cfg.n_heads":64,"cfg.d_model":8192,"cfg.d_vocab":128256,"cfg.act_fn":"silu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"LlamaForCausalLM","cfg.normalization_type":"RMS","config.raw__":{"d_model":8192,"d_head":128,"n_layers":80,"n_ctx":2048,"n_heads":64,"d_mlp":28672,"d_vocab":128256,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"rotary","n_key_value_heads":8,"attn_only":false,"gated_mlp":true,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"silu","normalization_type":"RMS","num_experts":null,"experts_per_token":null,"final_rms":true,"dtype":"torch.float32","model_name":"Llama-3.1-70B","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"LlamaForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"meta-llama\/Llama-3.1-70B","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0088388348,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":128256,"parallel_attn_mlp":false,"rotary_dim":128,"n_params":77846282240,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":500000.0,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":true,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 8192\nd_head: 128\nn_layers: 80\nn_ctx: 2048\nn_heads: 64\nd_mlp: 28672\nd_vocab: 128256\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: rotary\nn_key_value_heads: 8\nattn_only: false\ngated_mlp: true\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: silu\nnormalization_type: RMS\nnum_experts: null\nexperts_per_token: null\nfinal_rms: true\ndtype: torch.float32\nmodel_name: Llama-3.1-70B\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: LlamaForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: meta-llama\/Llama-3.1-70B\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.008838834764831844\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 128256\nparallel_attn_mlp: false\nrotary_dim: 128\nn_params: 77846282240\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 500000.0\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: true\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"meta-llama\/Llama-3.1-70B","tokenizer.vocab_size":128000.0,"tokenizer.max_len":131072.0,"tokenizer.class":"TokenizersBackend","tokenizer.vocab_hash":"j9N50ddC7mjCgS4GseU9LmKZDKk=","tensor_shapes.state_dict":"embed:\n W_E: (128256, 8192)\nblocks:\n '[0-79]':\n ln1:\n w: (8192,)\n ln2:\n w: (8192,)\n attn:\n W_Q: (64, 8192, 128)\n W_O: (64, 128, 8192)\n b_Q: (64, 128)\n b_O: (8192,)\n '[_W_K, _W_V]': (8, 8192, 128)\n '[_b_K, _b_V]': (8, 128)\n mask: (2048, 2048)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (2048, 128)\n mlp:\n '[W_in, W_gate]': (8192, 28672)\n W_out: (28672, 8192)\n b_in: (28672,)\n b_out: (8192,)\nln_final:\n w: (8192,)\nunembed:\n W_U: (8192, 128256)\n b_U: (128256,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(128256, 8192)"},"blocks":{"[0-79]":{"ln1":{"w":"(8192,)"},"ln2":{"w":"(8192,)"},"attn":{"W_Q":"(64, 8192, 128)","W_O":"(64, 128, 8192)","b_Q":"(64, 128)","b_O":"(8192,)","[_W_K, _W_V]":"(8, 8192, 128)","[_b_K, _b_V]":"(8, 128)","mask":"(2048, 2048)","IGNORE":"()","[rotary_sin, rotary_cos]":"(2048, 128)"},"mlp":{"[W_in, W_gate]":"(8192, 28672)","W_out":"(28672, 8192)","b_in":"(28672,)","b_out":"(8192,)"}}},"ln_final":{"w":"(8192,)"},"unembed":{"W_U":"(8192, 128256)","b_U":"(128256,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-79]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 8192)\n attn:\n '[hook_q, hook_rot_q, hook_z]': (batch, seq_len, 64, 128)\n '[hook_k, hook_v, hook_rot_k]': (batch, seq_len, 8, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 64, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 8192)\n mlp:\n '[hook_pre, hook_pre_linear, hook_post]': (batch, seq_len, 28672)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 8192)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 8192)\nunembed:\n hook_in: (batch, seq_len, 8192)\n hook_out: (batch, seq_len, 128256)\nhook_embed: (batch, seq_len, 8192)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-79]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 8192)"},"attn":{"[hook_q, hook_rot_q, hook_z]":"(batch, seq_len, 64, 128)","[hook_k, hook_v, hook_rot_k]":"(batch, seq_len, 8, 128)","[hook_attn_scores, hook_pattern]":"(batch, 64, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 8192)"},"mlp":{"[hook_pre, hook_pre_linear, hook_post]":"(batch, seq_len, 28672)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 8192)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 8192)"},"unembed":{"hook_in":"(batch, seq_len, 8192)","hook_out":"(batch, seq_len, 128256)"},"hook_embed":"(batch, seq_len, 8192)"}} {"name.default_alias":"meta-llama\/Llama-3.1-70B-Instruct","name.huggingface":null,"name.aliases":"","model_type":"llama","name.from_cfg":"Llama-3.1-70B-Instruct","n_params.as_str":"78B","n_params.as_int":77846282240,"n_params.from_name":"70B","cfg.n_params":77846282240,"cfg.n_layers":80,"cfg.n_heads":64,"cfg.d_model":8192,"cfg.d_vocab":128256,"cfg.act_fn":"silu","cfg.positional_embedding_type":"rotary","cfg.parallel_attn_mlp":false,"cfg.original_architecture":"LlamaForCausalLM","cfg.normalization_type":"RMS","config.raw__":{"d_model":8192,"d_head":128,"n_layers":80,"n_ctx":2048,"n_heads":64,"d_mlp":28672,"d_vocab":128256,"device":"cpu","use_attn_result":false,"use_split_qkv_input":false,"default_prepend_bos":true,"positional_embedding_type":"rotary","n_key_value_heads":8,"attn_only":false,"gated_mlp":true,"uses_rms_norm":false,"eps":0.00001,"layer_norm_folding":false,"act_fn":"silu","normalization_type":"RMS","num_experts":null,"experts_per_token":null,"final_rms":true,"dtype":"torch.float32","model_name":"Llama-3.1-70B-Instruct","use_attn_scale":true,"attn_scale":11.313708499,"use_hook_mlp_in":false,"use_attn_in":false,"use_qk_norm":false,"use_local_attn":false,"ungroup_grouped_query_attention":false,"original_architecture":"LlamaForCausalLM","from_checkpoint":false,"checkpoint_index":null,"checkpoint_label_type":null,"checkpoint_value":null,"tokenizer_name":"meta-llama\/Llama-3.1-70B-Instruct","window_size":null,"attn_types":null,"init_mode":"gpt2","n_devices":1,"attention_dir":"causal","seed":null,"initializer_range":0.0088388348,"init_weights":false,"scale_attn_by_inverse_layer_idx":false,"d_vocab_out":128256,"parallel_attn_mlp":false,"rotary_dim":128,"n_params":77846282240,"use_hook_tokens":false,"tokenizer_prepends_bos":null,"post_embedding_ln":false,"rotary_base":500000.0,"rotary_base_local":null,"rotary_scaling_factor":1.0,"trust_remote_code":false,"rotary_adjacent_pairs":false,"load_in_4bit":false,"relative_attention_max_distance":null,"relative_attention_num_buckets":null,"decoder_start_token_id":null,"tie_word_embeddings":false,"use_normalization_before_and_after":false,"attn_scores_soft_cap":-1.0,"output_logits_soft_cap":-1.0,"use_NTK_by_parts_rope":true,"NTK_by_parts_low_freq_factor":1.0,"NTK_by_parts_high_freq_factor":4.0,"NTK_by_parts_factor":8.0,"NTK_original_ctx_len":8192,"use_yarn_rope":false,"yarn_factor":1.0,"yarn_attention_factor":1.0,"yarn_beta_fast":32.0,"yarn_beta_slow":1.0,"yarn_original_max_position_embeddings":4096,"norm_topk_prob":false},"config":"d_model: 8192\nd_head: 128\nn_layers: 80\nn_ctx: 2048\nn_heads: 64\nd_mlp: 28672\nd_vocab: 128256\ndevice: cpu\nuse_attn_result: false\nuse_split_qkv_input: false\ndefault_prepend_bos: true\npositional_embedding_type: rotary\nn_key_value_heads: 8\nattn_only: false\ngated_mlp: true\nuses_rms_norm: false\neps: 1.0e-05\nlayer_norm_folding: false\nact_fn: silu\nnormalization_type: RMS\nnum_experts: null\nexperts_per_token: null\nfinal_rms: true\ndtype: torch.float32\nmodel_name: Llama-3.1-70B-Instruct\nuse_attn_scale: true\nattn_scale: !!python\/object\/apply:numpy.core.multiarray.scalar\n- !!python\/object\/apply:numpy.dtype\n args:\n - f8\n - false\n - true\n state: !!python\/tuple\n - 3\n - <\n - null\n - null\n - null\n - -1\n - -1\n - 0\n- !!binary |\n zTt\/Zp6gJkA=\nuse_hook_mlp_in: false\nuse_attn_in: false\nuse_qk_norm: false\nuse_local_attn: false\nungroup_grouped_query_attention: false\noriginal_architecture: LlamaForCausalLM\nfrom_checkpoint: false\ncheckpoint_index: null\ncheckpoint_label_type: null\ncheckpoint_value: null\ntokenizer_name: meta-llama\/Llama-3.1-70B-Instruct\nwindow_size: null\nattn_types: null\ninit_mode: gpt2\nn_devices: 1\nattention_dir: causal\nseed: null\ninitializer_range: 0.008838834764831844\ninit_weights: false\nscale_attn_by_inverse_layer_idx: false\nd_vocab_out: 128256\nparallel_attn_mlp: false\nrotary_dim: 128\nn_params: 77846282240\nuse_hook_tokens: false\ntokenizer_prepends_bos: null\npost_embedding_ln: false\nrotary_base: 500000.0\nrotary_base_local: null\nrotary_scaling_factor: 1.0\ntrust_remote_code: false\nrotary_adjacent_pairs: false\nload_in_4bit: false\nrelative_attention_max_distance: null\nrelative_attention_num_buckets: null\ndecoder_start_token_id: null\ntie_word_embeddings: false\nuse_normalization_before_and_after: false\nattn_scores_soft_cap: -1.0\noutput_logits_soft_cap: -1.0\nuse_NTK_by_parts_rope: true\nNTK_by_parts_low_freq_factor: 1.0\nNTK_by_parts_high_freq_factor: 4.0\nNTK_by_parts_factor: 8.0\nNTK_original_ctx_len: 8192\nuse_yarn_rope: false\nyarn_factor: 1.0\nyarn_attention_factor: 1.0\nyarn_beta_fast: 32.0\nyarn_beta_slow: 1.0\nyarn_original_max_position_embeddings: 4096\nnorm_topk_prob: false\n","tokenizer.name":"meta-llama\/Llama-3.1-70B-Instruct","tokenizer.vocab_size":128000.0,"tokenizer.max_len":131072.0,"tokenizer.class":"TokenizersBackend","tokenizer.vocab_hash":"j9N50ddC7mjCgS4GseU9LmKZDKk=","tensor_shapes.state_dict":"embed:\n W_E: (128256, 8192)\nblocks:\n '[0-79]':\n ln1:\n w: (8192,)\n ln2:\n w: (8192,)\n attn:\n W_Q: (64, 8192, 128)\n W_O: (64, 128, 8192)\n b_Q: (64, 128)\n b_O: (8192,)\n '[_W_K, _W_V]': (8, 8192, 128)\n '[_b_K, _b_V]': (8, 128)\n mask: (2048, 2048)\n IGNORE: ()\n '[rotary_sin, rotary_cos]': (2048, 128)\n mlp:\n '[W_in, W_gate]': (8192, 28672)\n W_out: (28672, 8192)\n b_in: (28672,)\n b_out: (8192,)\nln_final:\n w: (8192,)\nunembed:\n W_U: (8192, 128256)\n b_U: (128256,)\n","tensor_shapes.state_dict.raw__":{"embed":{"W_E":"(128256, 8192)"},"blocks":{"[0-79]":{"ln1":{"w":"(8192,)"},"ln2":{"w":"(8192,)"},"attn":{"W_Q":"(64, 8192, 128)","W_O":"(64, 128, 8192)","b_Q":"(64, 128)","b_O":"(8192,)","[_W_K, _W_V]":"(8, 8192, 128)","[_b_K, _b_V]":"(8, 128)","mask":"(2048, 2048)","IGNORE":"()","[rotary_sin, rotary_cos]":"(2048, 128)"},"mlp":{"[W_in, W_gate]":"(8192, 28672)","W_out":"(28672, 8192)","b_in":"(28672,)","b_out":"(8192,)"}}},"ln_final":{"w":"(8192,)"},"unembed":{"W_U":"(8192, 128256)","b_U":"(128256,)"}},"tensor_shapes.activation_cache":"blocks:\n '[0-79]':\n ln1:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 8192)\n attn:\n '[hook_q, hook_rot_q, hook_z]': (batch, seq_len, 64, 128)\n '[hook_k, hook_v, hook_rot_k]': (batch, seq_len, 8, 128)\n '[hook_attn_scores, hook_pattern]': (batch, 64, seq_len, seq_len)\n ln2:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 8192)\n mlp:\n '[hook_pre, hook_pre_linear, hook_post]': (batch, seq_len, 28672)\n '[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]': (batch,\n seq_len, 8192)\nln_final:\n hook_scale: (batch, seq_len, 1)\n hook_normalized: (batch, seq_len, 8192)\nunembed:\n hook_in: (batch, seq_len, 8192)\n hook_out: (batch, seq_len, 128256)\nhook_embed: (batch, seq_len, 8192)\n","tensor_shapes.activation_cache.raw__":{"blocks":{"[0-79]":{"ln1":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 8192)"},"attn":{"[hook_q, hook_rot_q, hook_z]":"(batch, seq_len, 64, 128)","[hook_k, hook_v, hook_rot_k]":"(batch, seq_len, 8, 128)","[hook_attn_scores, hook_pattern]":"(batch, 64, seq_len, seq_len)"},"ln2":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 8192)"},"mlp":{"[hook_pre, hook_pre_linear, hook_post]":"(batch, seq_len, 28672)"},"[hook_resid_pre, hook_attn_out, hook_resid_mid, hook_mlp_out, hook_resid_post]":"(batch, seq_len, 8192)"}},"ln_final":{"hook_scale":"(batch, seq_len, 1)","hook_normalized":"(batch, seq_len, 8192)"},"unembed":{"hook_in":"(batch, seq_len, 8192)","hook_out":"(batch, seq_len, 128256)"},"hook_embed":"(batch, seq_len, 8192)"}} diff --git a/docs/source/content/model_structure.md b/docs/source/content/model_structure.md index 593860edd..ce5630d80 100644 --- a/docs/source/content/model_structure.md +++ b/docs/source/content/model_structure.md @@ -19,6 +19,14 @@ bridge = TransformerBridge.boot_transformers("gpt2", device="cpu") You can then call the familiar APIs: `to_tokens`, `to_string`, `generate`, `generate_stream`, `run_with_hooks`, `run_with_cache`. +## Architecture diagram + +The diagram below maps weight matrices and activation tensors to their TransformerLens names. Hook points sit on the activation arrows — the canonical hook names in the rest of this document correspond directly to the labeled tensors here. + +![TransformerLens architecture diagram with weight matrices and activation tensors labeled](../_static/TransformerLens_Diagram.svg) + +*Diagram by [Austin Kozlowski](https://github.com/akozlo). Click for full resolution.* + ## Top-Level Components Typical decoder-only models expose these top-level components (names vary by architecture): diff --git a/tests/unit/factored_matrix/test_properties.py b/tests/unit/factored_matrix/test_properties.py index e7e36fdca..091db15e8 100644 --- a/tests/unit/factored_matrix/test_properties.py +++ b/tests/unit/factored_matrix/test_properties.py @@ -138,9 +138,19 @@ def test_ndim(self, factored_matrices): def test_collapse_l(self, factored_matrices): for factored_matrix in factored_matrices: result = factored_matrix.collapse_l() - expected = factored_matrix.S[..., :, None] * utils.transpose(factored_matrix.Vh) + expected = factored_matrix.S[..., :, None] * utils.transpose(factored_matrix.V) assert torch.allclose(result, expected) + def test_V_and_Vh_alias_match(self, factored_matrices): + import warnings + + for factored_matrix in factored_matrices: + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + vh_value = factored_matrix.Vh + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + assert torch.equal(vh_value, factored_matrix.V) + def test_collapse_r(self, factored_matrices): for factored_matrix in factored_matrices: result = factored_matrix.collapse_r() diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py index 0c7ce3610..674dabf43 100644 --- a/transformer_lens/FactoredMatrix.py +++ b/transformer_lens/FactoredMatrix.py @@ -222,19 +222,20 @@ def svd( Float[torch.Tensor, "*leading_dims mdim"], Float[torch.Tensor, "*leading_dims rdim mdim"], ]: - """ - Efficient algorithm for finding Singular Value Decomposition, a tuple (U, S, Vh) for matrix M st S is a vector and U, Vh are orthogonal matrices, and U @ S.diag() @ Vh.T == M - - (Note that Vh is given as the transpose of the obvious thing) - """ - Ua, Sa, Vha = torch.svd(self.A) - Ub, Sb, Vhb = torch.svd(self.B) - middle = Sa[..., :, None] * tensor_utils.transpose(Vha) @ Ub * Sb[..., None, :] - Um, Sm, Vhm = torch.svd(middle) + """Singular Value Decomposition: returns ``(U, S, V)`` such that ``U @ S.diag() @ V.transpose(-2, -1) == M``.""" + # Transpose Vh back to V — the long-standing return convention; downstream + # callers transposed the old `.Vh` result, so preserving V keeps them working. + Ua, Sa, Vha = torch.linalg.svd(self.A, full_matrices=False) + Ub, Sb, Vhb = torch.linalg.svd(self.B, full_matrices=False) + Va = tensor_utils.transpose(Vha) + Vb = tensor_utils.transpose(Vhb) + middle = Sa[..., :, None] * tensor_utils.transpose(Va) @ Ub * Sb[..., None, :] + Um, Sm, Vhm = torch.linalg.svd(middle, full_matrices=False) + Vm = tensor_utils.transpose(Vhm) U = Ua @ Um - Vh = Vhb @ Vhm + V = Vb @ Vm S = Sm - return U, S, Vh + return U, S, V @property def U(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]: @@ -244,8 +245,23 @@ def U(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]: def S(self) -> Float[torch.Tensor, "*leading_dims mdim"]: return self.svd()[1] + @property + def V(self) -> Float[torch.Tensor, "*leading_dims rdim mdim"]: + """Right singular vectors. ``M == U @ S.diag() @ V.transpose(-2, -1)``.""" + return self.svd()[2] + @property def Vh(self) -> Float[torch.Tensor, "*leading_dims rdim mdim"]: + """Deprecated alias for :attr:`V` — historically misnamed; returns V, not its conjugate transpose.""" + import warnings + + warnings.warn( + "FactoredMatrix.Vh has always returned V (right singular vectors), not Vh. " + "Use .V for the canonical name; for the actual Hermitian transpose use " + ".V.transpose(-2, -1). The .Vh alias will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) return self.svd()[2] @property @@ -302,11 +318,11 @@ def __repr__(self): def make_even(self) -> FactoredMatrix: """ - Returns the factored form of (U @ S.sqrt().diag(), S.sqrt().diag() @ Vh) where U, S, Vh are the SVD of the matrix. This is an equivalent factorisation, but more even - each half has half the singular values, and orthogonal rows/cols + Returns the factored form of (U @ S.sqrt().diag(), S.sqrt().diag() @ V.T) where U, S, V are the SVD of the matrix. This is an equivalent factorisation, but more even - each half has half the singular values, and orthogonal rows/cols """ return FactoredMatrix( self.U * self.S.sqrt()[..., None, :], - self.S.sqrt()[..., :, None] * tensor_utils.transpose(self.Vh), + self.S.sqrt()[..., :, None] * tensor_utils.transpose(self.V), ) def get_corner(self, k=3): @@ -320,7 +336,7 @@ def collapse_l(self) -> Float[torch.Tensor, "*leading_dims mdim rdim"]: """ Collapses the left side of the factorization by removing the orthogonal factor (given by self.U). Returns a (..., mdim, rdim) tensor """ - return self.S[..., :, None] * tensor_utils.transpose(self.Vh) + return self.S[..., :, None] * tensor_utils.transpose(self.V) def collapse_r(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]: """ diff --git a/transformer_lens/SVDInterpreter.py b/transformer_lens/SVDInterpreter.py index bf69beec4..49580e865 100644 --- a/transformer_lens/SVDInterpreter.py +++ b/transformer_lens/SVDInterpreter.py @@ -88,7 +88,7 @@ def plot_matrix(matrix, tokens, k=10, filter="topk"): if vector_type == "OV": assert head_index is not None # keep mypy happy matrix = self._get_OV_matrix(layer_index, head_index) - V = matrix.Vh.T + V = matrix.V.T elif vector_type == "w_in": matrix = self._get_w_in_matrix(layer_index) From 481a86f3bd7c40d62ecc9c5c1584fa5616daba16 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Mon, 11 May 2026 17:24:27 -0500 Subject: [PATCH 02/10] Started activation cache improvement --- tests/acceptance/test_activation_cache.py | 189 +++++++++++++++++++++- transformer_lens/ActivationCache.py | 136 ++++++++++++---- 2 files changed, 295 insertions(+), 30 deletions(-) diff --git a/tests/acceptance/test_activation_cache.py b/tests/acceptance/test_activation_cache.py index 8e1c16891..33a53ff31 100644 --- a/tests/acceptance/test_activation_cache.py +++ b/tests/acceptance/test_activation_cache.py @@ -786,4 +786,191 @@ def test_get_neuron_results_without_slice(): pos_slice=None, ) - assert torch.isclose(ref_neuron_acts, neuron_acts).all() + assert torch.equal(ref_neuron_acts, neuron_acts) + + +@torch.no_grad +def test_get_neuron_results_project_output_onto_1d(): + """1D projection: contract W_out with [d_model] vector, drop the d_model dim.""" + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + layer = 1 + direction = torch.randn(model.cfg.d_model, device=model.cfg.device) + + full = cache.get_neuron_results(layer) # [batch, pos, d_mlp, d_model] + expected = full @ direction # [batch, pos, d_mlp] + projected = cache.get_neuron_results(layer, project_output_onto=direction) + assert projected.shape == expected.shape + assert torch.allclose(projected, expected, atol=1e-5) + + +@torch.no_grad +def test_get_neuron_results_project_output_onto_2d(): + """2D projection: contract W_out with [d_model, n_outs], keep n_outs as last dim.""" + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + layer = 1 + n_outs = 3 + directions = torch.randn(model.cfg.d_model, n_outs, device=model.cfg.device) + + full = cache.get_neuron_results(layer) + expected = full @ directions # [batch, pos, d_mlp, n_outs] + projected = cache.get_neuron_results(layer, project_output_onto=directions) + assert projected.shape == expected.shape + assert torch.allclose(projected, expected, atol=1e-5) + + +@torch.no_grad +def test_stack_neuron_results_project_output_onto_matches_unprojected(): + """stack_neuron_results with projection equals unprojected output then @ projection.""" + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + direction = torch.randn(model.cfg.d_model, device=model.cfg.device) + directions = torch.randn(model.cfg.d_model, 4, device=model.cfg.device) + + layer = model.cfg.n_layers + full = cache.stack_neuron_results(layer, pos_slice=-1) + # 1D + expected_1d = full @ direction + projected_1d = cache.stack_neuron_results(layer, pos_slice=-1, project_output_onto=direction) + assert projected_1d.shape == expected_1d.shape + assert torch.allclose(projected_1d, expected_1d, atol=1e-5) + # 2D + expected_2d = full @ directions + projected_2d = cache.stack_neuron_results(layer, pos_slice=-1, project_output_onto=directions) + assert projected_2d.shape == expected_2d.shape + assert torch.allclose(projected_2d, expected_2d, atol=1e-5) + + +@torch.no_grad +def test_stack_neuron_results_project_incl_remainder(): + """Projection commutes with the incl_remainder branch.""" + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + direction = torch.randn(model.cfg.d_model, device=model.cfg.device) + + layer = model.cfg.n_layers + full = cache.stack_neuron_results(layer, pos_slice=-1, incl_remainder=True) + expected = full @ direction + projected = cache.stack_neuron_results( + layer, pos_slice=-1, incl_remainder=True, project_output_onto=direction + ) + assert projected.shape == expected.shape + assert torch.allclose(projected, expected, atol=1e-5) + + +@torch.no_grad +def test_get_full_resid_decomposition_project_output_onto_1d(): + """Full decomposition projection equals unprojected stack @ direction.""" + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + direction = torch.randn(model.cfg.d_model, device=model.cfg.device) + + full = cache.get_full_resid_decomposition(layer=-1, pos_slice=-1, expand_neurons=True) + expected = full @ direction # [num_components, batch, ...] + projected = cache.get_full_resid_decomposition( + layer=-1, pos_slice=-1, expand_neurons=True, project_output_onto=direction + ) + assert projected.shape == expected.shape + assert torch.allclose(projected, expected, atol=1e-4) + + +@torch.no_grad +def test_get_full_resid_decomposition_project_output_onto_2d(): + """2D projection: last dim is num_outputs, not squeezed.""" + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + n_outs = 2 + directions = torch.randn(model.cfg.d_model, n_outs, device=model.cfg.device) + + full = cache.get_full_resid_decomposition(layer=-1, pos_slice=-1, expand_neurons=True) + expected = full @ directions + projected = cache.get_full_resid_decomposition( + layer=-1, pos_slice=-1, expand_neurons=True, project_output_onto=directions + ) + assert projected.shape == expected.shape + assert torch.allclose(projected, expected, atol=1e-4) + + +@torch.no_grad +def test_get_full_resid_decomposition_project_apply_ln_raises(): + """apply_ln=True combined with project_output_onto must raise.""" + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + direction = torch.randn(model.cfg.d_model, device=model.cfg.device) + with pytest.raises(NotImplementedError): + cache.get_full_resid_decomposition( + layer=-1, pos_slice=-1, apply_ln=True, project_output_onto=direction + ) + + +@torch.no_grad +def test_stack_neuron_results_project_apply_ln_raises(): + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + direction = torch.randn(model.cfg.d_model, device=model.cfg.device) + with pytest.raises(NotImplementedError): + cache.stack_neuron_results( + layer=-1, pos_slice=-1, apply_ln=True, project_output_onto=direction + ) + + +@torch.no_grad +def test_stack_neuron_results_projection_skips_dmodel_materialization(): + """Projected path's largest tensor must be much smaller than unprojected's, proving the + [..., d_mlp, d_model] intermediate is not materialized. Uses TorchDispatchMode to intercept + tensor ops; tracemalloc doesn't catch torch CPU allocations.""" + from torch.utils._python_dispatch import TorchDispatchMode + + class MaxTensorWatcher(TorchDispatchMode): + def __init__(self): + self.max_numel = 0 + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + result = func(*args, **(kwargs or {})) + tensors = result if isinstance(result, (list, tuple)) else (result,) + for t in tensors: + if isinstance(t, torch.Tensor): + self.max_numel = max(self.max_numel, t.numel()) + return result + + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + direction = torch.randn(model.cfg.d_model, device=model.cfg.device) + + # Warm-up calls (some kernel selection only happens on first dispatch). + _ = cache.stack_neuron_results(layer=-1, pos_slice=-1) + _ = cache.stack_neuron_results(layer=-1, pos_slice=-1, project_output_onto=direction) + + with MaxTensorWatcher() as watcher_unproj: + _ = cache.stack_neuron_results(layer=-1, pos_slice=-1) + with MaxTensorWatcher() as watcher_proj: + _ = cache.stack_neuron_results(layer=-1, pos_slice=-1, project_output_onto=direction) + + d_model = model.cfg.d_model + ratio = watcher_unproj.max_numel / max(watcher_proj.max_numel, 1) + # Expected ratio is ~d_model (~512); we require at least 10x to allow headroom. + assert ratio > 10, ( + f"Memory optimization not detected: max_numel unprojected={watcher_unproj.max_numel:,}, " + f"projected={watcher_proj.max_numel:,}, ratio={ratio:.1f}x. Expected projected path's " + f"largest tensor to be >>10x smaller than unprojected (the [..., d_mlp, d_model] " + f"intermediate has d_model={d_model} as a dim it shouldn't have)." + ) diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 09d92d6ca..e0b01f739 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -27,6 +27,17 @@ class first, including the examples, and then skimming the available methods. Yo from transformer_lens.utilities import Slice, SliceInput, warn_if_mps +def _normalize_projection_to_2d( + project: Optional[torch.Tensor], +) -> Tuple[Optional[torch.Tensor], bool]: + """Return ``(project_2d, squeeze_at_end)`` — 1D projections are reshaped to 2D for uniform internal handling and squeezed back at the user-facing return.""" + if project is None: + return None, False + if project.ndim == 1: + return project.unsqueeze(-1), True + return project, False + + class ActivationCache: """Activation Cache. @@ -854,7 +865,8 @@ def get_neuron_results( layer: int, neuron_slice: Union[Slice, SliceInput] = None, pos_slice: Union[Slice, SliceInput] = None, - ) -> Float[torch.Tensor, "*batch_and_pos_dims num_neurons d_model"]: + project_output_onto: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """Get Neuron Results. Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to @@ -868,9 +880,14 @@ def get_neuron_results( Slice of the neuron. pos_slice: Slice of the positions. + project_output_onto: + Optional ``[d_model]`` or ``[d_model, num_outputs]`` projection. Contracted with + ``W_out`` *before* the per-neuron expansion so the ``[..., d_mlp, d_model]`` + intermediate is never materialized. Returns: - Tensor of the results. + Last-dim is ``d_model`` (default), ``num_outputs`` (2D projection), or squeezed + (1D projection). """ if not isinstance(neuron_slice, Slice): neuron_slice = Slice(neuron_slice) @@ -886,7 +903,13 @@ def get_neuron_results( if neuron_slice is not None: neuron_acts = neuron_slice.apply(neuron_acts, dim=-1) W_out = neuron_slice.apply(W_out, dim=0) - return neuron_acts[..., None] * W_out + if project_output_onto is None: + return neuron_acts[..., None] * W_out + # W_out: [d_mlp, d_model]; project: [d_model] or [d_model, n_outs] + projected = W_out @ project_output_onto + if projected.ndim == 1: + return neuron_acts * projected + return neuron_acts[..., None] * projected def stack_neuron_results( self, @@ -896,10 +919,8 @@ def stack_neuron_results( return_labels: bool = False, incl_remainder: bool = False, apply_ln: bool = False, - ) -> Union[ - Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], - Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], - ]: + project_output_onto: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[str]]]: """Stack Neuron Results Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie @@ -908,7 +929,8 @@ def stack_neuron_results( into attribution by specific neurons. Note that doing this for all neurons is SUPER expensive on GPU memory and only works for - small models or short inputs. + small models or short inputs. Pass ``project_output_onto`` to fold the projection into the + per-neuron expansion and avoid the ``[..., d_mlp, d_model]`` intermediate. Args: layer: @@ -923,8 +945,19 @@ def stack_neuron_results( incl_remainder: Whether to return a final term which is "the rest of the residual stream". apply_ln: - Whether to apply LayerNorm to the stack. + Whether to apply LayerNorm to the stack. Not yet supported in combination with + ``project_output_onto`` (raises ``NotImplementedError``). + project_output_onto: + Optional ``[d_model]`` or ``[d_model, num_outputs]`` tensor. When set, each + component's last d_model dim is replaced by the projection (memory-efficient for + direction analyses; see ``get_neuron_results``). """ + if apply_ln and project_output_onto is not None: + raise NotImplementedError( + "stack_neuron_results does not yet support apply_ln=True together with " + "project_output_onto. Call without apply_ln, or apply the projection after a " + "non-projected call." + ) if layer is None or layer == -1: # Default to the residual stream immediately pre unembed @@ -938,6 +971,8 @@ def stack_neuron_results( if not isinstance(pos_slice, Slice): pos_slice = Slice(pos_slice) + project_2d, squeeze_projected = _normalize_projection_to_2d(project_output_onto) + neuron_labels: Union[torch.Tensor, np.ndarray] = neuron_slice.apply( torch.arange(self.model.cfg.d_mlp), dim=0 ) @@ -947,7 +982,12 @@ def stack_neuron_results( for l in range(layer): # Note that this has shape batch x pos x head_index x d_model components.append( - self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice) + self.get_neuron_results( + l, + pos_slice=pos_slice, + neuron_slice=neuron_slice, + project_output_onto=project_2d, + ) ) labels.extend([f"L{l}N{h}" for h in neuron_labels]) if components: @@ -958,27 +998,31 @@ def stack_neuron_results( ) if incl_remainder: - remainder = pos_slice.apply( - self[("resid_post", layer - 1)], dim=-2 - ) - components.sum(dim=0) + remainder_full = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2) + if project_2d is not None: + remainder_full = remainder_full @ project_2d + remainder = remainder_full - components.sum(dim=0) components = torch.cat([components, remainder[None]], dim=0) labels.append("remainder") elif incl_remainder: - components = torch.cat( - [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0 - ) + remainder_full = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2) + if project_2d is not None: + remainder_full = remainder_full @ project_2d + components = torch.cat([remainder_full[None]], dim=0) labels.append("remainder") else: # Returning empty, give it the right shape to stack properly - components = torch.zeros( - 0, - *pos_slice.apply(self["hook_embed"], dim=-2).shape, - device=self.model.cfg.device, - ) + empty_shape_src = pos_slice.apply(self["hook_embed"], dim=-2) + if project_2d is not None: + empty_shape_src = empty_shape_src @ project_2d + components = torch.zeros(0, *empty_shape_src.shape, device=self.model.cfg.device) if apply_ln: components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) + if squeeze_projected: + components = components.squeeze(-1) + if return_labels: return components, labels else: @@ -1096,10 +1140,8 @@ def get_full_resid_decomposition( apply_ln: bool = False, pos_slice: Union[Slice, SliceInput] = None, return_labels: bool = False, - ) -> Union[ - Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], - Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], - ]: + project_output_onto: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[str]]]: """Get the full Residual Decomposition. Decomposes the residual stream that is input into some layer into its @@ -1134,12 +1176,25 @@ def get_full_resid_decomposition( Whether to expand the MLP outputs to give every neuron's result or just return the MLP layer outputs. apply_ln: - Whether to apply LayerNorm to the stack. + Whether to apply LayerNorm to the stack. Not yet supported in combination with + ``project_output_onto`` (raises ``NotImplementedError``). pos_slice: Slice of the positions to take. return_labels: Whether to return the labels. + project_output_onto: + Optional ``[d_model]`` or ``[d_model, num_outputs]`` projection. Folded in + *before* the per-neuron expansion, so the ``[..., d_mlp, d_model]`` intermediate + is never materialized (memory saving applies only with ``expand_neurons=True``). + Output last-dim is squeezed for a 1D projection; ``num_outputs`` for 2D. """ + if apply_ln and project_output_onto is not None: + raise NotImplementedError( + "get_full_resid_decomposition does not yet support apply_ln=True together with " + "project_output_onto. Call without apply_ln, or apply the projection after a " + "non-projected call." + ) + if layer is None or layer == -1: # Default to the residual stream immediately pre unembed layer = self.model.cfg.n_layers @@ -1147,15 +1202,24 @@ def get_full_resid_decomposition( if not isinstance(pos_slice, Slice): pos_slice = Slice(pos_slice) + + project_2d, squeeze_projected = _normalize_projection_to_2d(project_output_onto) + head_stack, head_labels = self.stack_head_results( layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True ) + if project_2d is not None: + head_stack = head_stack @ project_2d labels = head_labels components = [head_stack] if not self.model.cfg.attn_only and layer > 0: if expand_neurons: + # Pass projection through so the d_mlp×d_model expansion is avoided downstream. neuron_stack, neuron_labels = self.stack_neuron_results( - layer, pos_slice=pos_slice, return_labels=True + layer, + pos_slice=pos_slice, + return_labels=True, + project_output_onto=project_2d, ) labels.extend(neuron_labels) components.append(neuron_stack) @@ -1171,17 +1235,28 @@ def get_full_resid_decomposition( mode="mlp", return_labels=True, ) + if project_2d is not None: + mlp_stack = mlp_stack @ project_2d labels.extend(mlp_labels) components.append(mlp_stack) if self.has_embed: + embed = pos_slice.apply(self["embed"], -2)[None] + if project_2d is not None: + embed = embed @ project_2d labels.append("embed") - components.append(pos_slice.apply(self["embed"], -2)[None]) + components.append(embed) if self.has_pos_embed: + pos_embed = pos_slice.apply(self["pos_embed"], -2)[None] + if project_2d is not None: + pos_embed = pos_embed @ project_2d labels.append("pos_embed") - components.append(pos_slice.apply(self["pos_embed"], -2)[None]) + components.append(pos_embed) # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs. bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons) + if project_2d is not None: + # Bias is [d_model], so project post-hoc for shape compatibility — no memory win here. + bias = bias @ project_2d bias = bias.expand((1,) + head_stack.shape[1:]) labels.append("bias") components.append(bias) @@ -1191,6 +1266,9 @@ def get_full_resid_decomposition( residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input ) + if squeeze_projected: + residual_stack = residual_stack.squeeze(-1) + if return_labels: return residual_stack, labels else: From 0cc6ff3e9619836b9e1d16e45ddf403051580756 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Mon, 11 May 2026 20:44:53 -0500 Subject: [PATCH 03/10] Full resolution for 210 + a demo notebook --- demos/Direct_Logit_Attribution_Demo.ipynb | 1078 +++++++++++++++++++++ tests/acceptance/test_activation_cache.py | 141 ++- transformer_lens/ActivationCache.py | 248 +++-- 3 files changed, 1380 insertions(+), 87 deletions(-) create mode 100644 demos/Direct_Logit_Attribution_Demo.ipynb diff --git a/demos/Direct_Logit_Attribution_Demo.ipynb b/demos/Direct_Logit_Attribution_Demo.ipynb new file mode 100644 index 000000000..a64d277b3 --- /dev/null +++ b/demos/Direct_Logit_Attribution_Demo.ipynb @@ -0,0 +1,1078 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "707daead", + "metadata": {}, + "source": [ + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "c4784038", + "metadata": {}, + "source": [ + "# TransformerLens Direct Logit Attribution Demo" + ] + }, + { + "cell_type": "markdown", + "id": "a8bbfb02", + "metadata": {}, + "source": [ + "Direct Logit Attribution (DLA) is one of the most common analyses in mechanistic interpretability: given a model and a target token (or the logit difference between two tokens), figure out how much each component of the residual stream \u2014 every attention head, every MLP neuron, the embeddings \u2014 contributes to that target logit. The math is straightforward (apply LayerNorm to the component, then project onto the unembed direction), but the standard way of computing it has a memory problem.\n", + "\n", + "`ActivationCache.get_full_resid_decomposition` returns a stack of shape `[num_components, batch, pos, d_model]`. For per-neuron decomposition on a typical model, `num_components` is in the thousands (every neuron in every MLP). Multiplied by `batch * pos * d_model`, the intermediate tensor can be many GB even for moderately-sized models \u2014 enough to OOM the user before they ever get to the projection.\n", + "\n", + "This notebook demonstrates the `project_output_onto` parameter, which folds the projection into the decomposition so the `[..., d_mlp, d_model]` intermediate is never materialized. The output last-dim shrinks from `d_model` to `num_outputs` (or to a scalar if you project onto a single direction), and the memory profile no longer scales with `d_model`." + ] + }, + { + "cell_type": "markdown", + "id": "b5bda06c", + "metadata": {}, + "source": [ + "## How to use this notebook\n", + "\n", + "Go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n", + "\n", + "Tips for reading this Colab:\n", + "\n", + "* You can run all this code for yourself!\n", + "* Use the table of contents pane in the sidebar to navigate\n", + "* Collapse irrelevant sections with the dropdown arrows\n", + "* Search the page using the search in the sidebar, not CTRL+F" + ] + }, + { + "cell_type": "markdown", + "id": "f4a21923", + "metadata": {}, + "source": [ + "## Setup (Ignore)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "0a64f026", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T00:49:08.827750Z", + "iopub.status.busy": "2026-05-12T00:49:08.827556Z", + "iopub.status.idle": "2026-05-12T00:49:08.854512Z", + "shell.execute_reply": "2026-05-12T00:49:08.854270Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", + "DEVELOPMENT_MODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "try:\n", + " import google.colab\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + "except:\n", + " IN_COLAB = False\n", + "\n", + "if not IN_GITHUB and not IN_COLAB:\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")\n", + "\n", + "if IN_COLAB:\n", + " %pip install transformer_lens\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3b63b88e", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T00:49:08.855741Z", + "iopub.status.busy": "2026-05-12T00:49:08.855666Z", + "iopub.status.idle": "2026-05-12T00:49:11.713160Z", + "shell.execute_reply": "2026-05-12T00:49:11.712926Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "device = 'cpu'\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "import torch\n", + "from torch.utils._python_dispatch import TorchDispatchMode\n", + "\n", + "from transformer_lens.model_bridge import TransformerBridge\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"{device = }\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "a334cc18", + "metadata": {}, + "source": [ + "## Load a model" + ] + }, + { + "cell_type": "markdown", + "id": "6f3450cb", + "metadata": {}, + "source": [ + "We use `gpt2` (small) because it's a familiar model and the technique is the same regardless of choice. The memory savings scale with `d_mlp \u00d7 d_model`, so the optimization matters more on larger models." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4a53e552", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T00:49:11.714396Z", + "iopub.status.busy": "2026-05-12T00:49:11.714247Z", + "iopub.status.idle": "2026-05-12T00:49:12.940983Z", + "shell.execute_reply": "2026-05-12T00:49:12.940727Z" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8c474c0cd53c4598bc2e1a9463763647", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading weights: 0%| | 0/148 [00:00 *When John and Mary went to the shops, John gave the bag to ___*\n", + "\n", + "The \"right\" answer (from the model's perspective) is *Mary*. We can quantify each residual stream component's contribution to that decision by projecting it onto the difference between the *Mary* and *John* unembed directions." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "bbb3afaf", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T00:49:12.942123Z", + "iopub.status.busy": "2026-05-12T00:49:12.942032Z", + "iopub.status.idle": "2026-05-12T00:49:12.992191Z", + "shell.execute_reply": "2026-05-12T00:49:12.991987Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Logit direction shape: torch.Size([768]) # [d_model]\n" + ] + } + ], + "source": [ + "prompt = \"When John and Mary went to the shops, John gave the bag to\"\n", + "tokens = model.to_tokens(prompt)\n", + "\n", + "with torch.no_grad():\n", + " _, cache = model.run_with_cache(tokens)\n", + "\n", + "mary_id = model.to_single_token(\" Mary\")\n", + "john_id = model.to_single_token(\" John\")\n", + "logit_direction = (\n", + " model.tokens_to_residual_directions(torch.tensor(mary_id))\n", + " - model.tokens_to_residual_directions(torch.tensor(john_id))\n", + ")\n", + "print(f\"Logit direction shape: {logit_direction.shape} # [d_model]\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "140871a6", + "metadata": {}, + "source": [ + "## Without `project_output_onto`: the memory problem\n", + "\n", + "The naive approach is to get the full decomposition, then project at the end:\n", + "\n", + "```python\n", + "contributions = cache.get_full_resid_decomposition(apply_ln=True) @ logit_direction\n", + "```\n", + "\n", + "This works, but `get_full_resid_decomposition` builds a tensor of shape `[num_components, batch, pos, d_model]` in memory before the projection. For `gpt2` with `d_mlp = 3072` and `n_layers = 12`, that's ~37k components \u00d7 `d_model = 768` per position. Manageable on this tiny model, but for a 7B model it would be tens of gigabytes." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2996da4f", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T00:49:12.993242Z", + "iopub.status.busy": "2026-05-12T00:49:12.993178Z", + "iopub.status.idle": "2026-05-12T00:49:13.071465Z", + "shell.execute_reply": "2026-05-12T00:49:13.071238Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Naive output shape: torch.Size([37011, 1]) # [num_components, batch]\n" + ] + } + ], + "source": [ + "# Naive: build full decomposition, then project\n", + "naive = cache.get_full_resid_decomposition(layer=-1, pos_slice=-1, apply_ln=True) @ logit_direction\n", + "print(f\"Naive output shape: {naive.shape} # [num_components, batch]\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "6322fb7d", + "metadata": {}, + "source": [ + "## With `project_output_onto`: same result, fraction of the memory\n", + "\n", + "Pass the logit direction directly to `get_full_resid_decomposition`. The projection is folded into the per-neuron expansion, so the `[..., d_mlp, d_model]` intermediate is never built." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f9231081", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T00:49:13.072479Z", + "iopub.status.busy": "2026-05-12T00:49:13.072419Z", + "iopub.status.idle": "2026-05-12T00:49:13.181990Z", + "shell.execute_reply": "2026-05-12T00:49:13.181777Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Tried to compute head results when they were already cached\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Projected output shape: torch.Size([37011, 1]) # same as naive\n", + "Number of components: 37011\n", + "Outputs match the naive computation within 1e-4 tolerance.\n" + ] + } + ], + "source": [ + "# Folded: projection happens during the decomposition\n", + "projected, labels = cache.get_full_resid_decomposition(\n", + " layer=-1,\n", + " pos_slice=-1,\n", + " apply_ln=True,\n", + " project_output_onto=logit_direction,\n", + " return_labels=True,\n", + ")\n", + "print(f\"Projected output shape: {projected.shape} # same as naive\")\n", + "print(f\"Number of components: {len(labels)}\")\n", + "\n", + "assert torch.allclose(naive, projected, atol=1e-4), \"Outputs should match!\"\n", + "print(\"Outputs match the naive computation within 1e-4 tolerance.\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "aa2f172d", + "metadata": {}, + "source": [ + "## Top contributors" + ] + }, + { + "cell_type": "markdown", + "id": "82fb3313", + "metadata": {}, + "source": [ + "Now we can ask which components most strongly push the model toward *Mary* vs *John*. Positive values push toward Mary; negative push toward John." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "3b13bf64", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T00:49:13.183021Z", + "iopub.status.busy": "2026-05-12T00:49:13.182946Z", + "iopub.status.idle": "2026-05-12T00:49:13.203035Z", + "shell.execute_reply": "2026-05-12T00:49:13.202800Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 8 components by absolute contribution:\n", + "\n", + " 1. L9H9: +1.9583\n", + " 2. L9H6: +1.6281\n", + " 3. L10H7: -1.4267\n", + " 4. L11H1: -0.9117\n", + " 5. L11H10: -0.9042\n", + " 6. L10H0: +0.6370\n", + " 7. L10H10: +0.4524\n", + " 8. L8H10: +0.4264\n" + ] + } + ], + "source": [ + "flat = projected[:, 0] # [num_components]\n", + "top_indices = flat.abs().topk(8).indices\n", + "print(\"Top 8 components by absolute contribution:\")\n", + "print()\n", + "for rank, i in enumerate(top_indices, start=1):\n", + " sign = \"+\" if flat[i].item() >= 0 else \" \"\n", + " print(f\" {rank}. {labels[i]:>10s}: {sign}{flat[i].item():.4f}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "e2226bbf", + "metadata": {}, + "source": [ + "`L9H9` and `L9H6` (attention heads in layer 9) are the dominant Mary-pushers \u2014 these are the well-known **name mover heads** from the IOI circuit ([Wang et al., 2022](https://arxiv.org/abs/2211.00593)), which copy the indirect object's representation into the prediction. The negative contributors `L10H7`, `L11H1`, and `L11H10` push in the opposite direction \u2014 these are the **S-inhibition / negative name mover** heads, part of the same circuit. The fact that attention heads dominate this analysis on `gpt2-small` is consistent with the IOI literature: MLP neurons contribute more diffusely and appear further down the ranking." + ] + }, + { + "cell_type": "markdown", + "id": "ae21801e", + "metadata": {}, + "source": [ + "## Projecting onto multiple directions at once\n", + "\n", + "`project_output_onto` also accepts a `[d_model, num_outputs]` tensor, letting you analyze contributions to many output directions in a single pass. The output shape becomes `[num_components, batch, num_outputs]`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8613b7e5", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T00:49:13.204044Z", + "iopub.status.busy": "2026-05-12T00:49:13.203981Z", + "iopub.status.idle": "2026-05-12T00:49:13.243470Z", + "shell.execute_reply": "2026-05-12T00:49:13.243246Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Tried to compute head results when they were already cached\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Target directions shape: torch.Size([768, 3])\n", + "Output shape: torch.Size([37011, 1, 3]) # [num_components, batch, num_targets]\n" + ] + } + ], + "source": [ + "# Three target tokens: Mary, John, and the\n", + "target_ids = torch.tensor([mary_id, john_id, model.to_single_token(\" the\")])\n", + "target_directions = model.tokens_to_residual_directions(target_ids).T # [d_model, 3]\n", + "print(f\"Target directions shape: {target_directions.shape}\")\n", + "\n", + "multi = cache.get_full_resid_decomposition(\n", + " layer=-1,\n", + " pos_slice=-1,\n", + " apply_ln=True,\n", + " project_output_onto=target_directions,\n", + ")\n", + "print(f\"Output shape: {multi.shape} # [num_components, batch, num_targets]\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "bb81fe06", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T00:49:13.244426Z", + "iopub.status.busy": "2026-05-12T00:49:13.244370Z", + "iopub.status.idle": "2026-05-12T00:49:13.266000Z", + "shell.execute_reply": "2026-05-12T00:49:13.265784Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top MLP neuron contribution per target token:\n", + "\n", + " ' Mary': L11N611 (+0.3188)\n", + " ' John': L11N621 (-0.4359)\n", + " ' the': L10N2214 (+0.4369)\n" + ] + } + ], + "source": [ + "# Top neuron contribution per target\n", + "# Find the neuron index range in the component list\n", + "neuron_idx_start = next(i for i, l in enumerate(labels) if l.startswith(\"L\") and \"N\" in l)\n", + "neuron_idx_end = next(\n", + " (i for i, l in enumerate(labels) if i > neuron_idx_start and not (l.startswith(\"L\") and \"N\" in l)),\n", + " len(labels),\n", + ")\n", + "neuron_slab = multi[neuron_idx_start:neuron_idx_end, 0] # [n_neurons, 3]\n", + "\n", + "print(\"Top MLP neuron contribution per target token:\")\n", + "print()\n", + "for col, tid in enumerate(target_ids):\n", + " top_neuron_local = neuron_slab[:, col].abs().argmax().item()\n", + " top_label = labels[neuron_idx_start + top_neuron_local]\n", + " val = neuron_slab[top_neuron_local, col].item()\n", + " print(f\" '{model.to_string(tid)}': {top_label} ({val:+.4f})\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "c74a101a", + "metadata": {}, + "source": [ + "## Verifying the memory savings" + ] + }, + { + "cell_type": "markdown", + "id": "d8024a48", + "metadata": {}, + "source": [ + "We can measure the largest tensor allocated during each call using `TorchDispatchMode`. The naive path materializes the `[..., d_mlp, d_model]` intermediate; the projected path stays bounded by the model's `W_out` matrix (a constant in the model, independent of batch/pos)." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "02a1f512", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T00:49:13.267103Z", + "iopub.status.busy": "2026-05-12T00:49:13.267027Z", + "iopub.status.idle": "2026-05-12T00:49:13.286322Z", + "shell.execute_reply": "2026-05-12T00:49:13.286133Z" + } + }, + "outputs": [], + "source": [ + "class MaxTensorWatcher(TorchDispatchMode):\n", + " def __init__(self):\n", + " self.max_numel = 0\n", + "\n", + " def __torch_dispatch__(self, func, types, args=(), kwargs=None):\n", + " result = func(*args, **(kwargs or {}))\n", + " tensors = result if isinstance(result, (list, tuple)) else (result,)\n", + " for t in tensors:\n", + " if isinstance(t, torch.Tensor):\n", + " self.max_numel = max(self.max_numel, t.numel())\n", + " return result\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "bd0bfb42", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T00:49:13.287203Z", + "iopub.status.busy": "2026-05-12T00:49:13.287151Z", + "iopub.status.idle": "2026-05-12T00:49:13.447752Z", + "shell.execute_reply": "2026-05-12T00:49:13.447535Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Tried to compute head results when they were already cached\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Tried to compute head results when they were already cached\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Single position (pos_slice=-1):\n", + " Naive: max tensor = 28,424,448 elements\n", + " Projected: max tensor = 2,359,296 elements\n", + " Ratio: 12.0x\n" + ] + } + ], + "source": [ + "# Single position (typical DLA-at-final-token setup)\n", + "with MaxTensorWatcher() as naive_watcher:\n", + " _ = cache.get_full_resid_decomposition(layer=-1, pos_slice=-1, apply_ln=True) @ logit_direction\n", + "with MaxTensorWatcher() as projected_watcher:\n", + " _ = cache.get_full_resid_decomposition(\n", + " layer=-1, pos_slice=-1, apply_ln=True, project_output_onto=logit_direction\n", + " )\n", + "print(\"Single position (pos_slice=-1):\")\n", + "print(f\" Naive: max tensor = {naive_watcher.max_numel:>12,} elements\")\n", + "print(f\" Projected: max tensor = {projected_watcher.max_numel:>12,} elements\")\n", + "print(f\" Ratio: {naive_watcher.max_numel / projected_watcher.max_numel:.1f}x\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "bb64925d", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T00:49:13.448829Z", + "iopub.status.busy": "2026-05-12T00:49:13.448766Z", + "iopub.status.idle": "2026-05-12T00:49:14.024336Z", + "shell.execute_reply": "2026-05-12T00:49:14.024129Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Tried to compute head results when they were already cached\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Tried to compute head results when they were already cached\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "All positions:\n", + " Naive: max tensor = 426,366,720 elements\n", + " Projected: max tensor = 2,359,296 elements\n", + " Ratio: 180.7x\n" + ] + } + ], + "source": [ + "# All positions \u2014 the optimization shows more dramatically as batch*pos grows\n", + "with MaxTensorWatcher() as naive_all:\n", + " _ = cache.get_full_resid_decomposition(layer=-1, apply_ln=True) @ logit_direction\n", + "with MaxTensorWatcher() as proj_all:\n", + " _ = cache.get_full_resid_decomposition(\n", + " layer=-1, apply_ln=True, project_output_onto=logit_direction\n", + " )\n", + "print(\"All positions:\")\n", + "print(f\" Naive: max tensor = {naive_all.max_numel:>12,} elements\")\n", + "print(f\" Projected: max tensor = {proj_all.max_numel:>12,} elements\")\n", + "print(f\" Ratio: {naive_all.max_numel / proj_all.max_numel:.1f}x\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "48b1304e", + "metadata": {}, + "source": [ + "On `gpt2` (small) with this prompt, the projected path keeps the max tensor bounded by `d_mlp \u00d7 d_model = 3072 \u00d7 768 \u2248 2.4M elements` (just the `W_out` matrix \u2014 a model constant). The naive path scales with `batch \u00d7 pos \u00d7 d_mlp \u00d7 d_model`. For real workloads on larger models (batch 32, pos 128, 7B params with `d_mlp \u2248 11K`, `d_model \u2248 4K`), the ratio approaches 500\u20134000x \u2014 the difference between a comfortable run and an OOM." + ] + }, + { + "cell_type": "markdown", + "id": "d7206154", + "metadata": {}, + "source": [ + "## When the memory benefit doesn't apply\n", + "\n", + "A few caveats worth flagging:\n", + "\n", + "- The memory saving applies to the **per-neuron expansion** in MLPs. If you set `expand_neurons=False`, the decomposition stops at MLP-layer granularity (no per-neuron breakdown) and there's no big intermediate to avoid. Projection still works on this path, it just doesn't save memory.\n", + "- For 1D projections (`[d_model]`), the output last-dim is squeezed away. For 2D projections (`[d_model, num_outputs]`), it's preserved as `num_outputs`. This matches PyTorch's usual matmul broadcasting rules.\n", + "- The same `project_output_onto` parameter exists on `cache.stack_neuron_results` and `cache.get_neuron_results` if you only need a subset of components." + ] + }, + { + "cell_type": "markdown", + "id": "7e0bacdf", + "metadata": {}, + "source": [ + "## Further reading\n", + "\n", + "- Activation Patching demos for component-level causal interventions\n", + "- [issue #210](https://github.com/TransformerLensOrg/TransformerLens/issues/210) for the original feature request and discussion" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "transformer-lens", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": { + "2f34578fd40e4aa79a405350ffebf546": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "35e2126222094b72a52eec44daf442ec": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_cff4f45be6ed4ba09fd2aad6918c13af", + "IPY_MODEL_3d373b5d43a3415fa9a97dc719925dce", + "IPY_MODEL_3c43cfdacfdf4ea184b53af40615455d" + ], + "layout": "IPY_MODEL_8f328344caef43fda146f610696513ec", + "tabbable": null, + "tooltip": null + } + }, + "3c43cfdacfdf4ea184b53af40615455d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_e904035b19e4441ea24375d4082fe6f1", + "placeholder": "\u200b", + "style": "IPY_MODEL_f530fbebb9014c17bc0401eb50f5c9c9", + "tabbable": null, + "tooltip": null, + "value": "\u2007148/148\u2007[00:00<00:00,\u20074069.74it/s,\u2007Materializing\u2007param=transformer.wte.weight]" + } + }, + "3d373b5d43a3415fa9a97dc719925dce": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_f939efa0f18446919b70834b812ebeda", + "max": 148, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f8a98e5a79da4db1ba18361d331cd59b", + "tabbable": null, + "tooltip": null, + "value": 148 + } + }, + "8f328344caef43fda146f610696513ec": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cff4f45be6ed4ba09fd2aad6918c13af": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_e7f25e4e59be4e39a0329b5679e6aaa2", + "placeholder": "\u200b", + "style": "IPY_MODEL_2f34578fd40e4aa79a405350ffebf546", + "tabbable": null, + "tooltip": null, + "value": "Loading\u2007weights:\u2007100%" + } + }, + "e7f25e4e59be4e39a0329b5679e6aaa2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e904035b19e4441ea24375d4082fe6f1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f530fbebb9014c17bc0401eb50f5c9c9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "f8a98e5a79da4db1ba18361d331cd59b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "f939efa0f18446919b70834b812ebeda": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + }, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/acceptance/test_activation_cache.py b/tests/acceptance/test_activation_cache.py index 33a53ff31..9709d013a 100644 --- a/tests/acceptance/test_activation_cache.py +++ b/tests/acceptance/test_activation_cache.py @@ -905,31 +905,158 @@ def test_get_full_resid_decomposition_project_output_onto_2d(): @torch.no_grad -def test_get_full_resid_decomposition_project_apply_ln_raises(): - """apply_ln=True combined with project_output_onto must raise.""" +def test_stack_neuron_results_apply_ln_and_project_1d(): + """LN-fused projection equals non-fused (apply_ln) then projected, for 1D direction.""" model = load_model("solu-2l") tokens, _ = get_ioi_tokens_and_answer_tokens(model) _, cache = model.run_with_cache(tokens) direction = torch.randn(model.cfg.d_model, device=model.cfg.device) - with pytest.raises(NotImplementedError): + + layer = model.cfg.n_layers + ref = cache.stack_neuron_results(layer, pos_slice=-1, apply_ln=True) @ direction + fused = cache.stack_neuron_results( + layer, pos_slice=-1, apply_ln=True, project_output_onto=direction + ) + assert fused.shape == ref.shape + assert torch.allclose(fused, ref, atol=1e-5) + + +@torch.no_grad +def test_stack_neuron_results_apply_ln_and_project_2d(): + """LN-fused projection equals non-fused (apply_ln) then projected, for 2D directions.""" + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + directions = torch.randn(model.cfg.d_model, 3, device=model.cfg.device) + + layer = model.cfg.n_layers + ref = cache.stack_neuron_results(layer, pos_slice=-1, apply_ln=True) @ directions + fused = cache.stack_neuron_results( + layer, pos_slice=-1, apply_ln=True, project_output_onto=directions + ) + assert fused.shape == ref.shape + assert torch.allclose(fused, ref, atol=1e-5) + + +@torch.no_grad +def test_stack_neuron_results_apply_ln_and_project_incl_remainder(): + """LN-fused projection respects incl_remainder via cached-scale LN linearity.""" + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + direction = torch.randn(model.cfg.d_model, device=model.cfg.device) + + layer = model.cfg.n_layers + ref = ( + cache.stack_neuron_results(layer, pos_slice=-1, apply_ln=True, incl_remainder=True) + @ direction + ) + fused = cache.stack_neuron_results( + layer, + pos_slice=-1, + apply_ln=True, + incl_remainder=True, + project_output_onto=direction, + ) + assert fused.shape == ref.shape + assert torch.allclose(fused, ref, atol=1e-5) + + +@torch.no_grad +def test_get_full_resid_decomposition_apply_ln_and_project_1d(): + """Full decomposition with apply_ln + 1D projection matches non-projected then projected.""" + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + direction = torch.randn(model.cfg.d_model, device=model.cfg.device) + ref = ( cache.get_full_resid_decomposition( - layer=-1, pos_slice=-1, apply_ln=True, project_output_onto=direction + layer=-1, pos_slice=-1, apply_ln=True, expand_neurons=True ) + @ direction + ) + fused = cache.get_full_resid_decomposition( + layer=-1, + pos_slice=-1, + apply_ln=True, + expand_neurons=True, + project_output_onto=direction, + ) + assert fused.shape == ref.shape + assert torch.allclose(fused, ref, atol=1e-4) @torch.no_grad -def test_stack_neuron_results_project_apply_ln_raises(): +def test_get_full_resid_decomposition_apply_ln_and_project_2d(): + """Full decomposition with apply_ln + 2D projection matches non-projected then projected.""" + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + directions = torch.randn(model.cfg.d_model, 2, device=model.cfg.device) + ref = ( + cache.get_full_resid_decomposition( + layer=-1, pos_slice=-1, apply_ln=True, expand_neurons=True + ) + @ directions + ) + fused = cache.get_full_resid_decomposition( + layer=-1, + pos_slice=-1, + apply_ln=True, + expand_neurons=True, + project_output_onto=directions, + ) + assert fused.shape == ref.shape + assert torch.allclose(fused, ref, atol=1e-4) + + +@torch.no_grad +def test_stack_neuron_results_apply_ln_projection_skips_dmodel_materialization(): + """LN-fused projection still avoids the d_mlp×d_model intermediate (memory watermark).""" + from torch.utils._python_dispatch import TorchDispatchMode + + class MaxTensorWatcher(TorchDispatchMode): + def __init__(self): + self.max_numel = 0 + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + result = func(*args, **(kwargs or {})) + tensors = result if isinstance(result, (list, tuple)) else (result,) + for t in tensors: + if isinstance(t, torch.Tensor): + self.max_numel = max(self.max_numel, t.numel()) + return result + model = load_model("solu-2l") tokens, _ = get_ioi_tokens_and_answer_tokens(model) _, cache = model.run_with_cache(tokens) direction = torch.randn(model.cfg.d_model, device=model.cfg.device) - with pytest.raises(NotImplementedError): - cache.stack_neuron_results( + + _ = cache.stack_neuron_results(layer=-1, pos_slice=-1, apply_ln=True) + _ = cache.stack_neuron_results( + layer=-1, pos_slice=-1, apply_ln=True, project_output_onto=direction + ) + + with MaxTensorWatcher() as watcher_unproj: + _ = cache.stack_neuron_results(layer=-1, pos_slice=-1, apply_ln=True) + with MaxTensorWatcher() as watcher_proj: + _ = cache.stack_neuron_results( layer=-1, pos_slice=-1, apply_ln=True, project_output_onto=direction ) + ratio = watcher_unproj.max_numel / max(watcher_proj.max_numel, 1) + assert ratio > 10, ( + f"LN-fused projection regressed: max_numel unprojected={watcher_unproj.max_numel:,}, " + f"projected={watcher_proj.max_numel:,}, ratio={ratio:.1f}x. " + f"Expected projected path's largest tensor to be >>10x smaller than unprojected." + ) + @torch.no_grad def test_stack_neuron_results_projection_skips_dmodel_materialization(): diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index e0b01f739..6a0d68e63 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -911,6 +911,79 @@ def get_neuron_results( return neuron_acts * projected return neuron_acts[..., None] * projected + def _get_cached_ln_scale( + self, + layer: Optional[int], + mlp_input: bool, + pos_slice: Slice, + batch_slice: Optional[Slice] = None, + ) -> torch.Tensor: + """Look up the cached LN scale and apply pos/batch slicing. Surfaces a clearer error + when the expected hook isn't in the cache (some non-decoder-only architectures expose + LN scale at a different path or not at all). + """ + if layer == self.model.cfg.n_layers or layer is None: + key = "ln_final.hook_scale" + else: + key = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale" + try: + scale = self[key] + except KeyError as e: + raise KeyError( + f"Cached LN scale not found at '{key}'. apply_ln operations require the model " + f"to have cached this hook (some non-decoder-only architectures expose LN scale " + f"under different module paths)." + ) from e + scale = pos_slice.apply(scale, dim=-2) + if batch_slice is not None and self.has_batch_dim: + scale = batch_slice.apply(scale) + return scale + + def _stack_neuron_results_apply_ln_projected( + self, + layer: int, + pos_slice: Slice, + neuron_slice: Slice, + project_2d: torch.Tensor, + ) -> torch.Tensor: + """LN-applied neuron stack with projection folded in — no d_mlp×d_model intermediate. + + Analytical formula (LN models, cached scale ``s``): + ``LN_s(a_n * W_out_n) @ p = (a_n / s) * (W_out_n @ p - mean(W_out_n) * sum_p)`` + RMS models drop the ``mean(W_out_n) * sum_p`` term (no centering). Always uses the + ln1 scale (mlp_input=False) since ``stack_neuron_results`` doesn't expose mlp_input. + """ + scale = self._get_cached_ln_scale(layer, mlp_input=False, pos_slice=pos_slice) + + apply_centering = self.model.cfg.normalization_type in ["LN", "LNPre"] + sum_p = project_2d.sum(dim=0) if apply_centering else None # [n_outs] + + components: list = [] + for l in range(layer): + W_out_l = self.model.blocks[l].mlp.W_out # [d_mlp, d_model] + W_out_l_sliced = neuron_slice.apply(W_out_l, dim=0) + W_proj_l = W_out_l_sliced @ project_2d # [d_mlp, n_outs] + if apply_centering: + assert sum_p is not None # set when apply_centering, narrow for mypy + W_means_l = W_out_l_sliced.mean(dim=-1) # [d_mlp] + lin_form_l = W_proj_l - W_means_l[:, None] * sum_p[None, :] + else: + lin_form_l = W_proj_l + a_l = self[("post", l, "mlp")] + a_l = pos_slice.apply(a_l, dim=-2) + a_l = neuron_slice.apply(a_l, dim=-1) + # (a_l / s)[..., None] is [..., d_mlp, 1]; broadcast with lin_form_l [d_mlp, n_outs] + components.append((a_l / scale)[..., None] * lin_form_l) + if not components: + empty_src = pos_slice.apply(self["hook_embed"], dim=-2) + return torch.zeros( + 0, *empty_src.shape[:-1], project_2d.shape[-1], device=self.model.cfg.device + ) + stacked = torch.cat(components, dim=-2) + return einops.rearrange( + stacked, "... concat_neuron_index n_outs -> concat_neuron_index ... n_outs" + ) + def stack_neuron_results( self, layer: int, @@ -945,27 +1018,18 @@ def stack_neuron_results( incl_remainder: Whether to return a final term which is "the rest of the residual stream". apply_ln: - Whether to apply LayerNorm to the stack. Not yet supported in combination with - ``project_output_onto`` (raises ``NotImplementedError``). + Whether to apply LayerNorm to the stack. project_output_onto: Optional ``[d_model]`` or ``[d_model, num_outputs]`` tensor. When set, each component's last d_model dim is replaced by the projection (memory-efficient for - direction analyses; see ``get_neuron_results``). + direction analyses; see ``get_neuron_results``). Combined with ``apply_ln=True``, + the projection is folded into the analytical cached-scale LN so the + ``[..., d_mlp, d_model]`` intermediate is still never materialized. """ - if apply_ln and project_output_onto is not None: - raise NotImplementedError( - "stack_neuron_results does not yet support apply_ln=True together with " - "project_output_onto. Call without apply_ln, or apply the projection after a " - "non-projected call." - ) - if layer is None or layer == -1: # Default to the residual stream immediately pre unembed layer = self.model.cfg.n_layers - components: Any = [] # TODO: fix typing properly - labels = [] - if not isinstance(neuron_slice, Slice): neuron_slice = Slice(neuron_slice) if not isinstance(pos_slice, Slice): @@ -979,46 +1043,64 @@ def stack_neuron_results( if isinstance(neuron_labels, int): neuron_labels = np.array([neuron_labels]) - for l in range(layer): - # Note that this has shape batch x pos x head_index x d_model - components.append( - self.get_neuron_results( - l, - pos_slice=pos_slice, - neuron_slice=neuron_slice, - project_output_onto=project_2d, - ) - ) - labels.extend([f"L{l}N{h}" for h in neuron_labels]) - if components: - components = torch.cat(components, dim=-2) - components = einops.rearrange( - components, - "... concat_neuron_index d_model -> concat_neuron_index ... d_model", + labels = [f"L{l}N{h}" for l in range(layer) for h in neuron_labels] + components: Any + ln_folded = apply_ln and project_2d is not None + if ln_folded: + assert project_2d is not None # narrow for mypy + # Analytical LN+projection — no d_mlp×d_model intermediate. + components = self._stack_neuron_results_apply_ln_projected( + layer, pos_slice, neuron_slice, project_2d ) - if incl_remainder: + # Linearity of cached-scale LN: remainder is LN_s(resid_post) @ p - sum(neurons). + resid_post = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2) + resid_post_ln = self.apply_ln_to_stack( + resid_post[None], layer, pos_slice=pos_slice + )[0] + remainder = resid_post_ln @ project_2d + if components.shape[0] > 0: + remainder = remainder - components.sum(dim=0) + components = torch.cat([components, remainder[None]], dim=0) + labels.append("remainder") + else: + per_layer: list = [] + for l in range(layer): + per_layer.append( + self.get_neuron_results( + l, + pos_slice=pos_slice, + neuron_slice=neuron_slice, + project_output_onto=project_2d, + ) + ) + if per_layer: + components = torch.cat(per_layer, dim=-2) + components = einops.rearrange( + components, + "... concat_neuron_index d_model -> concat_neuron_index ... d_model", + ) + if incl_remainder: + remainder_full = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2) + if project_2d is not None: + remainder_full = remainder_full @ project_2d + remainder = remainder_full - components.sum(dim=0) + components = torch.cat([components, remainder[None]], dim=0) + labels.append("remainder") + elif incl_remainder: remainder_full = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2) if project_2d is not None: remainder_full = remainder_full @ project_2d - remainder = remainder_full - components.sum(dim=0) - components = torch.cat([components, remainder[None]], dim=0) + components = torch.cat([remainder_full[None]], dim=0) labels.append("remainder") - elif incl_remainder: - remainder_full = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2) - if project_2d is not None: - remainder_full = remainder_full @ project_2d - components = torch.cat([remainder_full[None]], dim=0) - labels.append("remainder") - else: - # Returning empty, give it the right shape to stack properly - empty_shape_src = pos_slice.apply(self["hook_embed"], dim=-2) - if project_2d is not None: - empty_shape_src = empty_shape_src @ project_2d - components = torch.zeros(0, *empty_shape_src.shape, device=self.model.cfg.device) + else: + empty_shape_src = pos_slice.apply(self["hook_embed"], dim=-2) + if project_2d is not None: + empty_shape_src = empty_shape_src @ project_2d + components = torch.zeros(0, *empty_shape_src.shape, device=self.model.cfg.device) - if apply_ln: - components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) + if apply_ln: + components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) if squeeze_projected: components = components.squeeze(-1) @@ -1116,19 +1198,8 @@ def apply_ln_to_stack( if self.model.cfg.normalization_type in ["LN", "LNPre"]: residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True) - if layer == self.model.cfg.n_layers or layer is None: - scale = self["ln_final.hook_scale"] - else: - hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale" - scale = self[hook_name] - - # The shape of scale is [batch, position, 1] or [position, 1] - final dimension is a dummy - # thing to get broadcoasting to work nicely. - scale = pos_slice.apply(scale, dim=-2) - - if self.has_batch_dim: - # Apply batch slice to the scale - scale = batch_slice.apply(scale) + # Shape is [batch, position, 1] or [position, 1]; final dim is a dummy for broadcasting. + scale = self._get_cached_ln_scale(layer, mlp_input, pos_slice, batch_slice) return residual_stack / scale @@ -1176,8 +1247,7 @@ def get_full_resid_decomposition( Whether to expand the MLP outputs to give every neuron's result or just return the MLP layer outputs. apply_ln: - Whether to apply LayerNorm to the stack. Not yet supported in combination with - ``project_output_onto`` (raises ``NotImplementedError``). + Whether to apply LayerNorm to the stack. pos_slice: Slice of the positions to take. return_labels: @@ -1186,15 +1256,10 @@ def get_full_resid_decomposition( Optional ``[d_model]`` or ``[d_model, num_outputs]`` projection. Folded in *before* the per-neuron expansion, so the ``[..., d_mlp, d_model]`` intermediate is never materialized (memory saving applies only with ``expand_neurons=True``). - Output last-dim is squeezed for a 1D projection; ``num_outputs`` for 2D. + Combined with ``apply_ln=True``, the projection is fused into the analytical + cached-scale LN so the same memory benefit holds. Output last-dim is squeezed + for a 1D projection; ``num_outputs`` for 2D. """ - if apply_ln and project_output_onto is not None: - raise NotImplementedError( - "get_full_resid_decomposition does not yet support apply_ln=True together with " - "project_output_onto. Call without apply_ln, or apply the projection after a " - "non-projected call." - ) - if layer is None or layer == -1: # Default to the residual stream immediately pre unembed layer = self.model.cfg.n_layers @@ -1204,21 +1269,33 @@ def get_full_resid_decomposition( pos_slice = Slice(pos_slice) project_2d, squeeze_projected = _normalize_projection_to_2d(project_output_onto) + # When both apply_ln and projection are requested, LN is applied per-component (in + # d_model space for the small ones, analytically for neurons) before projection, so the + # final apply_ln_to_stack call is skipped — last-dim is already n_outs. + ln_folded = apply_ln and project_2d is not None + + def _ln_then_project(stack: torch.Tensor) -> torch.Tensor: + stack = self.apply_ln_to_stack(stack, layer, pos_slice=pos_slice, mlp_input=mlp_input) + return stack @ project_2d if project_2d is not None else stack head_stack, head_labels = self.stack_head_results( layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True ) - if project_2d is not None: + if ln_folded: + head_stack = _ln_then_project(head_stack) + elif project_2d is not None: head_stack = head_stack @ project_2d labels = head_labels components = [head_stack] if not self.model.cfg.attn_only and layer > 0: if expand_neurons: - # Pass projection through so the d_mlp×d_model expansion is avoided downstream. + # Only ask stack_neuron_results to apply LN when we want the fused analytical + # path (ln_folded). For the unfolded case the outer apply_ln_to_stack handles it. neuron_stack, neuron_labels = self.stack_neuron_results( layer, pos_slice=pos_slice, return_labels=True, + apply_ln=ln_folded, project_output_onto=project_2d, ) labels.extend(neuron_labels) @@ -1235,33 +1312,44 @@ def get_full_resid_decomposition( mode="mlp", return_labels=True, ) - if project_2d is not None: + if ln_folded: + mlp_stack = _ln_then_project(mlp_stack) + elif project_2d is not None: mlp_stack = mlp_stack @ project_2d labels.extend(mlp_labels) components.append(mlp_stack) if self.has_embed: embed = pos_slice.apply(self["embed"], -2)[None] - if project_2d is not None: + if ln_folded: + embed = _ln_then_project(embed) + elif project_2d is not None: embed = embed @ project_2d labels.append("embed") components.append(embed) if self.has_pos_embed: pos_embed = pos_slice.apply(self["pos_embed"], -2)[None] - if project_2d is not None: + if ln_folded: + pos_embed = _ln_then_project(pos_embed) + elif project_2d is not None: pos_embed = pos_embed @ project_2d labels.append("pos_embed") components.append(pos_embed) # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs. - bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons) - if project_2d is not None: - # Bias is [d_model], so project post-hoc for shape compatibility — no memory win here. - bias = bias @ project_2d - bias = bias.expand((1,) + head_stack.shape[1:]) + bias_full = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons) + if ln_folded: + # Expand bias to per-position d_model shape so LN can center, then project. + expand_shape: tuple = (1,) + tuple(head_stack.shape[1:-1]) + (self.model.cfg.d_model,) + bias = _ln_then_project(bias_full.expand(expand_shape)) + else: + if project_2d is not None: + # Bias is [d_model], so project post-hoc for shape compatibility — no memory win here. + bias_full = bias_full @ project_2d + bias = bias_full.expand((1,) + head_stack.shape[1:]) labels.append("bias") components.append(bias) residual_stack = torch.cat(components, dim=0) - if apply_ln: + if apply_ln and not ln_folded: residual_stack = self.apply_ln_to_stack( residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input ) From 22cfecbb477222ac33164f0105051f94a80d524e Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 12 May 2026 08:49:34 -0500 Subject: [PATCH 04/10] Resolution for #796, Factored Matrix memory leak --- tests/unit/factored_matrix/test_properties.py | 24 +++++++++++++++++ transformer_lens/FactoredMatrix.py | 27 ++++++++++++------- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/tests/unit/factored_matrix/test_properties.py b/tests/unit/factored_matrix/test_properties.py index 091db15e8..d77be1cad 100644 --- a/tests/unit/factored_matrix/test_properties.py +++ b/tests/unit/factored_matrix/test_properties.py @@ -151,6 +151,30 @@ def test_V_and_Vh_alias_match(self, factored_matrices): assert any(issubclass(w.category, DeprecationWarning) for w in caught) assert torch.equal(vh_value, factored_matrix.V) + def test_svd_caches_per_instance(self): + """svd() should cache its result on the instance — repeated calls return the same tensors.""" + m = FactoredMatrix(randn(4, 3), randn(3, 4)) + first_U, first_S, first_V = m.svd() + second_U, second_S, second_V = m.svd() + # Same object identity confirms the cache returns the stored value rather than recomputing. + assert first_U is second_U + assert first_S is second_S + assert first_V is second_V + + def test_svd_does_not_prevent_gc(self): + """svd's cache must not hold a strong reference that prevents the instance from being GC'd.""" + import gc + import weakref + + m = FactoredMatrix(randn(4, 3), randn(3, 4)) + _ = m.svd() # populate the cache + ref = weakref.ref(m) + del m + gc.collect() + assert ( + ref() is None + ), "FactoredMatrix instance survived deletion — svd cache is leaking references." + def test_collapse_r(self, factored_matrices): for factored_matrix in factored_matrices: result = factored_matrix.collapse_r() diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py index 674dabf43..067092dd4 100644 --- a/transformer_lens/FactoredMatrix.py +++ b/transformer_lens/FactoredMatrix.py @@ -6,7 +6,7 @@ from __future__ import annotations -from functools import lru_cache +from functools import cached_property from typing import Any, List, Protocol, Tuple, Union, cast, overload, runtime_checkable import torch @@ -214,17 +214,17 @@ def BA(self) -> Float[torch.Tensor, "*leading_dims rdim ldim"]: def T(self) -> FactoredMatrix: return FactoredMatrix(self.B.transpose(-2, -1), self.A.transpose(-2, -1)) - @lru_cache(maxsize=None) - def svd( + @cached_property + def _svd_cached( self, ) -> Tuple[ Float[torch.Tensor, "*leading_dims ldim mdim"], Float[torch.Tensor, "*leading_dims mdim"], Float[torch.Tensor, "*leading_dims rdim mdim"], ]: - """Singular Value Decomposition: returns ``(U, S, V)`` such that ``U @ S.diag() @ V.transpose(-2, -1) == M``.""" - # Transpose Vh back to V — the long-standing return convention; downstream - # callers transposed the old `.Vh` result, so preserving V keeps them working. + # cached_property stores the result on the instance, so it's freed with the instance. + # Avoids the lru_cache-on-method GC leak where every FactoredMatrix that ever + # had .svd() called on it was retained by the function-level cache. Ua, Sa, Vha = torch.linalg.svd(self.A, full_matrices=False) Ub, Sb, Vhb = torch.linalg.svd(self.B, full_matrices=False) Va = tensor_utils.transpose(Vha) @@ -232,10 +232,17 @@ def svd( middle = Sa[..., :, None] * tensor_utils.transpose(Va) @ Ub * Sb[..., None, :] Um, Sm, Vhm = torch.linalg.svd(middle, full_matrices=False) Vm = tensor_utils.transpose(Vhm) - U = Ua @ Um - V = Vb @ Vm - S = Sm - return U, S, V + return Ua @ Um, Sm, Vb @ Vm + + def svd( + self, + ) -> Tuple[ + Float[torch.Tensor, "*leading_dims ldim mdim"], + Float[torch.Tensor, "*leading_dims mdim"], + Float[torch.Tensor, "*leading_dims rdim mdim"], + ]: + """Singular Value Decomposition: returns ``(U, S, V)`` such that ``U @ S.diag() @ V.transpose(-2, -1) == M``.""" + return self._svd_cached @property def U(self) -> Float[torch.Tensor, "*leading_dims ldim mdim"]: From cc2968f68787ed66d397a0b625cb171729771718 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 12 May 2026 09:51:47 -0500 Subject: [PATCH 05/10] Resolved #453, underlying issue --- tests/unit/factored_matrix/test_properties.py | 83 ++++---- .../model_bridge/test_checkpoint_revision.py | 180 ++++++++++++++++++ transformer_lens/FactoredMatrix.py | 4 +- transformer_lens/HookedTransformer.py | 10 + transformer_lens/model_bridge/bridge.py | 14 ++ .../model_bridge/sources/transformers.py | 70 +++++++ 6 files changed, 316 insertions(+), 45 deletions(-) create mode 100644 tests/unit/model_bridge/test_checkpoint_revision.py diff --git a/tests/unit/factored_matrix/test_properties.py b/tests/unit/factored_matrix/test_properties.py index dd3fe3913..d8f13d450 100644 --- a/tests/unit/factored_matrix/test_properties.py +++ b/tests/unit/factored_matrix/test_properties.py @@ -78,19 +78,52 @@ def test_transpose_property(self, factored_matrices): def test_svd_property(self, factored_matrices): for factored_matrix in factored_matrices: - U, S, Vh = factored_matrix.svd() - assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.T, atol=1e-5) - # test that U and Vh are unitary + U, S, V = factored_matrix.svd() + assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ V.T, atol=1e-5) + # test that U and V are unitary assert torch.allclose(U.T @ U, torch.eye(U.shape[-1]), atol=1e-5) - assert torch.allclose(Vh.T @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5) + assert torch.allclose(V.T @ V, torch.eye(V.shape[-1]), atol=1e-5) def test_svd_property_leading_ones(self, factored_matrices_leading_ones): for factored_matrix in factored_matrices_leading_ones: - U, S, Vh = factored_matrix.svd() - assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.mT, atol=1e-5) - # test that U and Vh are unitary + U, S, V = factored_matrix.svd() + assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ V.mT, atol=1e-5) + # test that U and V are unitary assert torch.allclose(U.mT @ U, torch.eye(U.shape[-1]), atol=1e-5) - assert torch.allclose(Vh.mT @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5) + assert torch.allclose(V.mT @ V, torch.eye(V.shape[-1]), atol=1e-5) + + def test_V_and_Vh_alias_match(self, factored_matrices): + import warnings + + for factored_matrix in factored_matrices: + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + vh_value = factored_matrix.Vh + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + assert torch.equal(vh_value, factored_matrix.V) + + def test_svd_caches_per_instance(self): + """svd() should cache its result on the instance — repeated calls return the same tensors.""" + m = FactoredMatrix(randn(4, 3), randn(3, 4)) + first_U, first_S, first_V = m.svd() + second_U, second_S, second_V = m.svd() + assert first_U is second_U + assert first_S is second_S + assert first_V is second_V + + def test_svd_does_not_prevent_gc(self): + """svd's cache must not hold a strong reference that prevents the instance from being GC'd""" + import gc + import weakref + + m = FactoredMatrix(randn(4, 3), randn(3, 4)) + _ = m.svd() # populate the cache + ref = weakref.ref(m) + del m + gc.collect() + assert ( + ref() is None + ), "FactoredMatrix instance survived deletion — svd cache is leaking references." def test_eigenvalues_property(self, factored_matrices): for factored_matrix in factored_matrices: @@ -141,40 +174,6 @@ def test_collapse_l(self, factored_matrices): expected = factored_matrix.S[..., :, None] * utils.transpose(factored_matrix.V) assert torch.allclose(result, expected) - def test_V_and_Vh_alias_match(self, factored_matrices): - import warnings - - for factored_matrix in factored_matrices: - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - vh_value = factored_matrix.Vh - assert any(issubclass(w.category, DeprecationWarning) for w in caught) - assert torch.equal(vh_value, factored_matrix.V) - - def test_svd_caches_per_instance(self): - """svd() should cache its result on the instance — repeated calls return the same tensors.""" - m = FactoredMatrix(randn(4, 3), randn(3, 4)) - first_U, first_S, first_V = m.svd() - second_U, second_S, second_V = m.svd() - # Same object identity confirms the cache returns the stored value rather than recomputing. - assert first_U is second_U - assert first_S is second_S - assert first_V is second_V - - def test_svd_does_not_prevent_gc(self): - """svd's cache must not hold a strong reference that prevents the instance from being GC'd""" - import gc - import weakref - - m = FactoredMatrix(randn(4, 3), randn(3, 4)) - _ = m.svd() # populate the cache - ref = weakref.ref(m) - del m - gc.collect() - assert ( - ref() is None - ), "FactoredMatrix instance survived deletion — svd cache is leaking references." - def test_collapse_r(self, factored_matrices): for factored_matrix in factored_matrices: result = factored_matrix.collapse_r() diff --git a/tests/unit/model_bridge/test_checkpoint_revision.py b/tests/unit/model_bridge/test_checkpoint_revision.py new file mode 100644 index 000000000..de9c35235 --- /dev/null +++ b/tests/unit/model_bridge/test_checkpoint_revision.py @@ -0,0 +1,180 @@ +"""Unit tests for the bridge revision/checkpoint API (issue #453).""" + +from unittest.mock import MagicMock, patch + +import pytest + +from transformer_lens.model_bridge.sources.transformers import ( + _CHECKPOINT_REVISION_FORMATS, + _resolve_checkpoint_to_revision, +) + + +class TestResolveCheckpointToRevision: + def test_pythia_index_resolves_to_step_revision(self): + labels = [0, 1000, 3000, 10000] + with patch( + "transformer_lens.loading_from_pretrained.get_checkpoint_labels", + return_value=(labels, "step"), + ): + revision = _resolve_checkpoint_to_revision( + "EleutherAI/pythia-70m", checkpoint_index=2, checkpoint_value=None + ) + assert revision == "step3000" + + def test_pythia_value_resolves_to_step_revision(self): + labels = [0, 1000, 3000, 10000] + with patch( + "transformer_lens.loading_from_pretrained.get_checkpoint_labels", + return_value=(labels, "step"), + ): + revision = _resolve_checkpoint_to_revision( + "EleutherAI/pythia-70m", checkpoint_index=None, checkpoint_value=10000 + ) + assert revision == "step10000" + + def test_stanford_crfm_uses_checkpoint_prefix(self): + labels = [100, 200, 400] + with patch( + "transformer_lens.loading_from_pretrained.get_checkpoint_labels", + return_value=(labels, "step"), + ): + revision = _resolve_checkpoint_to_revision( + "stanford-crfm/alias-gpt2-small-x21", checkpoint_index=1, checkpoint_value=None + ) + assert revision == "checkpoint-200" + + def test_unknown_family_raises(self): + with pytest.raises(ValueError, match="known checkpoint revision convention"): + _resolve_checkpoint_to_revision("gpt2", checkpoint_index=0, checkpoint_value=None) + + def test_index_out_of_range_raises(self): + labels = [0, 1000] + with patch( + "transformer_lens.loading_from_pretrained.get_checkpoint_labels", + return_value=(labels, "step"), + ): + with pytest.raises(ValueError, match="out of range"): + _resolve_checkpoint_to_revision( + "EleutherAI/pythia-70m", checkpoint_index=5, checkpoint_value=None + ) + + def test_unknown_value_raises(self): + labels = [0, 1000] + with patch( + "transformer_lens.loading_from_pretrained.get_checkpoint_labels", + return_value=(labels, "step"), + ): + with pytest.raises(ValueError, match="not in available checkpoints"): + _resolve_checkpoint_to_revision( + "EleutherAI/pythia-70m", checkpoint_index=None, checkpoint_value=99999 + ) + + def test_neither_provided_raises(self): + with pytest.raises(ValueError, match="Must specify"): + _resolve_checkpoint_to_revision( + "EleutherAI/pythia-70m", checkpoint_index=None, checkpoint_value=None + ) + + def test_known_families_registered(self): + assert "EleutherAI/pythia" in _CHECKPOINT_REVISION_FORMATS + assert "stanford-crfm" in _CHECKPOINT_REVISION_FORMATS + + +class TestBootRevisionPlumbing: + """Verify that ``revision`` and ``checkpoint_*`` reach HF's from_pretrained calls.""" + + def _patched_boot(self, **boot_kwargs): + """Call boot() with all the side-effect HF calls patched out. + + Returns ``(autoconfig_kwargs, model_from_pretrained_kwargs)``. + """ + from transformer_lens.model_bridge.sources import transformers as bridge_src + + captured: dict = {} + + def fake_autoconfig_from_pretrained(*args, **kwargs): + captured["autoconfig_args"] = args + captured["autoconfig_kwargs"] = kwargs + cfg = MagicMock() + cfg.architectures = ["GPT2LMHeadModel"] + cfg.n_positions = 1024 + cfg.pad_token_id = 0 + cfg.eos_token_id = 0 + cfg.to_dict = lambda: {"model_type": "gpt2"} + cfg.__dict__["pad_token_id"] = 0 + return cfg + + def fake_model_from_pretrained(*args, **kwargs): + captured["model_args"] = args + captured["model_kwargs"] = kwargs + raise _AbortBoot() + + class _AbortBoot(Exception): + pass + + with patch.object( + bridge_src.AutoConfig, "from_pretrained", side_effect=fake_autoconfig_from_pretrained + ), patch( + "transformers.AutoModelForCausalLM.from_pretrained", + side_effect=fake_model_from_pretrained, + ): + with pytest.raises(Exception): + bridge_src.boot(model_name="EleutherAI/pythia-70m", **boot_kwargs) + + return captured + + def test_revision_forwarded_to_autoconfig(self): + captured = self._patched_boot(revision="step3000") + assert captured["autoconfig_kwargs"].get("revision") == "step3000" + + def test_revision_forwarded_to_model_load(self): + captured = self._patched_boot(revision="step3000") + assert captured.get("model_kwargs", {}).get("revision") == "step3000" + + def test_checkpoint_index_resolves_to_revision(self): + labels = [0, 1000, 3000, 10000] + with patch( + "transformer_lens.loading_from_pretrained.get_checkpoint_labels", + return_value=(labels, "step"), + ): + captured = self._patched_boot(checkpoint_index=2) + assert captured["autoconfig_kwargs"].get("revision") == "step3000" + assert captured.get("model_kwargs", {}).get("revision") == "step3000" + + def test_conflicting_revision_and_checkpoint_raises(self): + from transformer_lens.model_bridge.sources import transformers as bridge_src + + with pytest.raises(ValueError, match="not both"): + bridge_src.boot( + model_name="EleutherAI/pythia-70m", + revision="step1000", + checkpoint_index=2, + ) + + def test_default_revision_is_none(self): + """With no revision/checkpoint args, revision is not added to model_kwargs.""" + captured = self._patched_boot() + assert captured["autoconfig_kwargs"].get("revision") is None + assert "revision" not in captured.get("model_kwargs", {}) + + +class TestHookedTransformerCheckpointLabelAlias: + def test_checkpoint_label_routes_to_checkpoint_value(self): + from transformer_lens import HookedTransformer + + with patch("transformer_lens.loading.get_pretrained_model_config") as mock_get_cfg: + mock_get_cfg.side_effect = RuntimeError("stop after config call") + with pytest.raises(RuntimeError, match="stop after config call"): + HookedTransformer.from_pretrained("EleutherAI/pythia-70m", checkpoint_label=3000) + + _, kwargs = mock_get_cfg.call_args + assert kwargs["checkpoint_value"] == 3000 + + def test_checkpoint_label_and_value_together_raises(self): + from transformer_lens import HookedTransformer + + with pytest.raises(ValueError, match="aliases"): + HookedTransformer.from_pretrained( + "EleutherAI/pythia-70m", checkpoint_label=3000, checkpoint_value=1000 + ) diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py index 067092dd4..712394f55 100644 --- a/transformer_lens/FactoredMatrix.py +++ b/transformer_lens/FactoredMatrix.py @@ -222,9 +222,7 @@ def _svd_cached( Float[torch.Tensor, "*leading_dims mdim"], Float[torch.Tensor, "*leading_dims rdim mdim"], ]: - # cached_property stores the result on the instance, so it's freed with the instance. - # Avoids the lru_cache-on-method GC leak where every FactoredMatrix that ever - # had .svd() called on it was retained by the function-level cache. + # Cache on the instance (frees with it) rather than class-level — fixes the lru_cache leak. Ua, Sa, Vha = torch.linalg.svd(self.A, full_matrices=False) Ub, Sb, Vhb = torch.linalg.svd(self.B, full_matrices=False) Va = tensor_utils.transpose(Vha) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 9a7db35ae..103e532bf 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1157,6 +1157,7 @@ def from_pretrained( refactor_factored_attn_matrices: bool = False, checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, + checkpoint_label: Optional[int] = None, hf_model: Optional[PreTrainedModel] = None, device: Optional[Union[str, torch.device]] = None, n_devices: int = 1, @@ -1254,6 +1255,8 @@ def from_pretrained( labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be ignored. + checkpoint_label: Alias for ``checkpoint_value`` kept for backwards compatibility with + older docs and downstream code. Cannot be combined with ``checkpoint_value``. hf_model: If you have already loaded in the HuggingFace model, you can pass it in here rather than needing to recreate the object. Defaults to None. @@ -1311,6 +1314,13 @@ def from_pretrained( 3. Global default ("right") first_n_layers: If specified, only load the first n layers of the model. """ + if checkpoint_value is not None and checkpoint_label is not None: + raise ValueError( + "Specify checkpoint_value or checkpoint_label, not both — they are aliases." + ) + elif checkpoint_label is not None: + checkpoint_value = checkpoint_label + if model_name.lower().startswith("t5"): raise RuntimeError( "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer." diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index cb3895368..096879c31 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -196,6 +196,9 @@ def boot_transformers( n_devices: Optional[int] = None, max_memory: Optional[Dict[Union[str, int], str]] = None, n_ctx: Optional[int] = None, + revision: Optional[str] = None, + checkpoint_index: Optional[int] = None, + checkpoint_value: Optional[int] = None, ) -> "TransformerBridge": """Boot a model from HuggingFace (alias for sources.transformers.boot). @@ -231,6 +234,14 @@ def boot_transformers( n_ctx: Optional context length override. Writes to the appropriate HF config field for this model automatically (callers don't need to know the field name). Warns if larger than the model's default context length. + revision: Optional HF revision (branch, tag, or commit). Forwarded to the underlying + ``AutoConfig.from_pretrained`` and ``AutoModelForCausalLM.from_pretrained`` calls. + Mutually exclusive with ``checkpoint_index`` / ``checkpoint_value``. + checkpoint_index: Index into the available training checkpoints for the model family + (currently ``EleutherAI/pythia*`` and ``stanford-crfm/*``). Resolved to a revision + string via known per-family naming conventions. + checkpoint_value: Training step or token count of the desired checkpoint. Alternative + to ``checkpoint_index``; must match an entry in the family's checkpoint label list. Returns: The bridge to the loaded model. @@ -251,6 +262,9 @@ def boot_transformers( n_devices=n_devices, max_memory=max_memory, n_ctx=n_ctx, + revision=revision, + checkpoint_index=checkpoint_index, + checkpoint_value=checkpoint_value, ) @property diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index b7c4656f4..be2659e89 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -277,6 +277,55 @@ def get_hf_model_class_for_architecture(architecture: str): return AutoModelForCausalLM +# Known training-checkpoint revision conventions on HF. +_CHECKPOINT_REVISION_FORMATS: dict[str, str] = { + "EleutherAI/pythia": "step{value}", + "stanford-crfm": "checkpoint-{value}", +} + + +def _resolve_checkpoint_to_revision( + model_name: str, + checkpoint_index: int | None, + checkpoint_value: int | None, +) -> str: + """Convert a checkpoint index/value into an HF revision string, validated against ``get_checkpoint_labels``.""" + if checkpoint_index is None and checkpoint_value is None: + raise ValueError("Must specify either checkpoint_index or checkpoint_value.") + + format_str: str | None = None + for prefix, fmt in _CHECKPOINT_REVISION_FORMATS.items(): + if model_name.startswith(prefix): + format_str = fmt + break + if format_str is None: + raise ValueError( + f"Model {model_name!r} does not have a known checkpoint revision convention. " + f"Pass revision= directly if your model uses HF revisions. Known checkpoint " + f"families: {list(_CHECKPOINT_REVISION_FORMATS.keys())}." + ) + + from transformer_lens.loading_from_pretrained import get_checkpoint_labels + + labels, _ = get_checkpoint_labels(model_name) + if checkpoint_value is not None: + if checkpoint_value not in labels: + raise ValueError( + f"checkpoint_value={checkpoint_value} not in available checkpoints for " + f"{model_name!r}. {len(labels)} labels available, " + f"first/last: {labels[0]}..{labels[-1]}." + ) + else: + assert checkpoint_index is not None # narrowed by initial guard + if not 0 <= checkpoint_index < len(labels): + raise ValueError( + f"checkpoint_index={checkpoint_index} out of range [0, {len(labels)}) " + f"for {model_name!r}." + ) + checkpoint_value = labels[checkpoint_index] + return format_str.format(value=checkpoint_value) + + def boot( model_name: str, hf_config_overrides: dict | None = None, @@ -288,6 +337,9 @@ def boot( model_class: Any | None = None, hf_model: Any | None = None, n_ctx: int | None = None, + revision: str | None = None, + checkpoint_index: int | None = None, + checkpoint_value: int | None = None, # Experimental – Have not been fully tested on multi-gpu devices # Use at your own risk, report any issues here: https://github.com/TransformerLensOrg/TransformerLens/issues device_map: str | dict[str, str | int] | None = None, @@ -321,6 +373,15 @@ def boot( uses (n_positions / max_position_embeddings / etc.), so callers don't need to know the field name. If larger than the model's default, a warning is emitted — quality may degrade past the trained length for rotary models. + revision: Optional HF revision string (branch, tag, or commit). Forwarded to + ``AutoConfig.from_pretrained`` and ``AutoModelForCausalLM.from_pretrained``. + Mutually exclusive with ``checkpoint_index`` and ``checkpoint_value``. + checkpoint_index: Index into the available training checkpoints for the model family. + Convenience over ``revision`` for checkpointed models like EleutherAI/pythia* and + stanford-crfm/*. Resolved to a revision string via the known per-family naming + conventions (``step{value}`` for Pythia, ``checkpoint-{value}`` for stanford-crfm). + checkpoint_value: Training step or token count of the desired checkpoint. Alternative to + ``checkpoint_index``; must be one of the labels returned by ``get_checkpoint_labels``. Returns: The bridge to the loaded model. @@ -332,6 +393,12 @@ def boot( ) model_name = official_name break + if checkpoint_index is not None or checkpoint_value is not None: + if revision is not None: + raise ValueError( + "Specify either revision= or checkpoint_index/checkpoint_value, not both." + ) + revision = _resolve_checkpoint_to_revision(model_name, checkpoint_index, checkpoint_value) # Pass HF token for gated model access (e.g. meta-llama/*) from transformer_lens.utilities.hf_utils import get_hf_token @@ -346,6 +413,7 @@ def boot( output_attentions=True, trust_remote_code=trust_remote_code, token=_hf_token, + revision=revision, ) _n_ctx_field: str | None = None if n_ctx is not None: @@ -505,6 +573,8 @@ def boot( model_kwargs["token"] = _hf_token if trust_remote_code: model_kwargs["trust_remote_code"] = True + if revision is not None: + model_kwargs["revision"] = revision if resolved_device_map is not None: model_kwargs["device_map"] = resolved_device_map if resolved_max_memory is not None: From b60f37e308835979254d258ff2fe3c7ca85073c7 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 12 May 2026 10:29:16 -0500 Subject: [PATCH 06/10] Resolution for Issue #385, added notes about forced eager, added a test to check for future drift --- .../test_bridge_vs_hf_eager_parity.py | 125 ++++++++++++++++++ transformer_lens/model_bridge/bridge.py | 5 + 2 files changed, 130 insertions(+) create mode 100644 tests/integration/model_bridge/test_bridge_vs_hf_eager_parity.py diff --git a/tests/integration/model_bridge/test_bridge_vs_hf_eager_parity.py b/tests/integration/model_bridge/test_bridge_vs_hf_eager_parity.py new file mode 100644 index 000000000..180bd9fd0 --- /dev/null +++ b/tests/integration/model_bridge/test_bridge_vs_hf_eager_parity.py @@ -0,0 +1,125 @@ +"""Asserts ``TransformerBridge`` reproduces ``AutoModelForCausalLM`` eager-attention logits. + +Issue #385 reported drift between bridge and HF for rotary models like Pythia. The drift +was an attention-implementation mismatch — bridge always uses eager, default HF loads use +SDPA, which reorders ops in a fused kernel. Bridge vs HF *eager* matches to fp32-noise. +""" + +from typing import Callable + +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from transformer_lens.model_bridge import TransformerBridge + +MODEL_NAME = "EleutherAI/pythia-70m" + +# Op-reorder noise floor for fp32 transformer forward passes. We currently +# measure 0.0 on this model, but allow a small epsilon so harmless refactors +# (intermediate allocations, equivalent op reorderings) don't break the test. +FP32_NOISE_TOL = 1e-5 + + +@pytest.fixture(scope="module") +def tokenizer(): + return AutoTokenizer.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def bridge(): + return TransformerBridge.boot_transformers(MODEL_NAME, device="cpu", dtype=torch.float32) + + +@pytest.fixture(scope="module") +def hf_eager(): + """HF model loaded independently of the bridge's wrapped instance.""" + return AutoModelForCausalLM.from_pretrained( + MODEL_NAME, torch_dtype=torch.float32, attn_implementation="eager" + ).eval() + + +@pytest.fixture +def tokenize(tokenizer) -> Callable[[str], torch.Tensor]: + def _tok(prompt: str) -> torch.Tensor: + return tokenizer(prompt, return_tensors="pt").input_ids + + return _tok + + +@pytest.mark.parametrize("prompt", ["Hello, world!", "The quick brown fox jumps"]) +def test_bridge_logits_match_hf_eager(bridge, hf_eager, tokenize, prompt): + tokens = tokenize(prompt) + with torch.inference_mode(): + bridge_logits = bridge(tokens) + hf_logits = hf_eager(tokens).logits + max_diff = (bridge_logits - hf_logits).abs().max().item() + assert max_diff < FP32_NOISE_TOL, ( + f"{MODEL_NAME!r} bridge vs HF eager drift={max_diff:.2e} on {prompt!r} " + f"exceeds fp32-noise tolerance {FP32_NOISE_TOL:.0e} — bridge's " + f"_reconstruct_attention may have regressed (see issue #385)." + ) + + +def test_bridge_residual_stream_matches_hf_eager(bridge, hf_eager, tokenize): + """Per-layer parity catches compensating errors that wash out at the final logits.""" + tokens = tokenize("Hello, world!") + n_layers = len(hf_eager.gpt_neox.layers) + + hf_layer_out: dict[int, torch.Tensor] = {} + + def _make_hf_hook(idx): + def _h(_m, _i, o): + hf_layer_out[idx] = (o[0] if isinstance(o, tuple) else o).detach() + + return _h + + handles = [ + layer.register_forward_hook(_make_hf_hook(i)) + for i, layer in enumerate(hf_eager.gpt_neox.layers) + ] + try: + with torch.inference_mode(): + hf_eager(tokens) + finally: + for h in handles: + h.remove() + + bridge_layer_out: dict[int, torch.Tensor] = {} + fwd_hooks = [ + ( + f"blocks.{i}.hook_resid_post", + lambda v, hook, idx=i: bridge_layer_out.__setitem__(idx, v.detach()), + ) + for i in range(n_layers) + ] + with torch.inference_mode(): + bridge.run_with_hooks(tokens, fwd_hooks=fwd_hooks) + + for i in range(n_layers): + d = (hf_layer_out[i] - bridge_layer_out[i]).abs().max().item() + assert d < FP32_NOISE_TOL, ( + f"layer {i} residual drift={d:.2e} exceeds fp32-noise tolerance " + f"{FP32_NOISE_TOL:.0e} — bridge layer output diverges from HF eager." + ) + + +def test_bridge_attention_reconstruction_actually_runs(bridge, tokenize): + """Guard against tautology: prove bridge's custom attention path executes. + + If a future refactor made the bridge delegate to HF directly, the previous + parity tests would pass trivially. This one fails fast in that case by + asserting bridge-specific hooks fire during forward. + """ + tokens = tokenize("Hello, world!") + attn_scores_fired: list[bool] = [] + bridge.run_with_hooks( + tokens, + fwd_hooks=[ + ("blocks.0.attn.hook_attn_scores", lambda v, hook: attn_scores_fired.append(True)), + ], + ) + assert attn_scores_fired, ( + "blocks.0.attn.hook_attn_scores did not fire — bridge no longer runs its " + "own attention reconstruction, making the parity tests tautological." + ) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 096879c31..a92421473 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -207,6 +207,11 @@ def boot_transformers( Call ``enable_compatibility_mode()`` on the result for HookedTransformer- equivalent numerics. Generation, argmax, and CE loss are unaffected. + Attention implementation is forced to ``"eager"`` so hooks can capture scores + and patterns. For an apples-to-apples HF comparison, load the HF model with + ``attn_implementation="eager"`` too; comparing against the default ``"sdpa"`` + shows ~1e-3 fp32 drift from kernel-level op reordering, not a bridge bug. + Args: model_name: The name of the model to load. hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. From ddf50cb4f2b628850827ffa229def0a33706f4c1 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 12 May 2026 11:35:36 -0500 Subject: [PATCH 07/10] Added hook introspection mixin for #297 --- tests/unit/test_hook_introspection.py | 156 ++++++++++++++++++++++++ tests/unit/test_hook_points.py | 3 +- transformer_lens/hook_points.py | 63 +++++++++- transformer_lens/model_bridge/bridge.py | 4 +- 4 files changed, 221 insertions(+), 5 deletions(-) create mode 100644 tests/unit/test_hook_introspection.py diff --git a/tests/unit/test_hook_introspection.py b/tests/unit/test_hook_introspection.py new file mode 100644 index 000000000..620195094 --- /dev/null +++ b/tests/unit/test_hook_introspection.py @@ -0,0 +1,156 @@ +"""Tests for the hook-introspection API added for issue #297.""" + +from unittest import mock + +from transformer_lens.hook_points import ( + HookedRootModule, + HookIntrospectionMixin, + HookPoint, +) + + +class _ToyModel(HookedRootModule): + """Minimal HookedRootModule with two hook points for testing.""" + + def __init__(self): + super().__init__() + self.hook_a = HookPoint() + self.hook_b = HookPoint() + self.setup() + + +def _my_named_hook(activation, hook): + return activation + + +def _other_hook(activation, hook): + return activation + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_lens_handle_stores_user_hook(mock_handle): + mock_handle.return_value.id = 0 + hp = HookPoint() + hp.add_hook(_my_named_hook, dir="fwd") + assert hp.fwd_hooks[0].user_hook is _my_named_hook + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_hookpoint_repr_includes_hook_count(mock_handle): + mock_handle.return_value.id = 0 + hp = HookPoint() + hp.name = "blocks.0.hook_resid_post" + assert "blocks.0.hook_resid_post" in repr(hp) + hp.add_hook(_my_named_hook, dir="fwd") + hp.add_hook(_other_hook, dir="fwd") + rep = repr(hp) + assert "2 fwd" in rep + assert "bwd" not in rep + + +def test_hookpoint_repr_with_no_name_and_no_hooks(): + assert repr(HookPoint()) == "HookPoint()" + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_empty_model_returns_empty_dict(mock_handle): + model = _ToyModel() + assert model.list_hooks() == {} + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_returns_user_callable(mock_handle): + mock_handle.return_value.id = 0 + model = _ToyModel() + model.hook_a.add_hook(_my_named_hook, dir="fwd") + result = model.list_hooks() + assert set(result.keys()) == {"hook_a"} + handles = result["hook_a"] + assert len(handles) == 1 + assert handles[0].user_hook is _my_named_hook + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_name_filter_string(mock_handle): + mock_handle.return_value.id = 0 + model = _ToyModel() + model.hook_a.add_hook(_my_named_hook, dir="fwd") + model.hook_b.add_hook(_other_hook, dir="fwd") + assert set(model.list_hooks(name_filter="hook_a").keys()) == {"hook_a"} + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_name_filter_list(mock_handle): + mock_handle.return_value.id = 0 + model = _ToyModel() + model.hook_a.add_hook(_my_named_hook, dir="fwd") + model.hook_b.add_hook(_other_hook, dir="fwd") + assert set(model.list_hooks(name_filter=["hook_a", "hook_b"]).keys()) == {"hook_a", "hook_b"} + assert set(model.list_hooks(name_filter=["hook_b"]).keys()) == {"hook_b"} + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_name_filter_callable(mock_handle): + mock_handle.return_value.id = 0 + model = _ToyModel() + model.hook_a.add_hook(_my_named_hook, dir="fwd") + model.hook_b.add_hook(_other_hook, dir="fwd") + result = model.list_hooks(name_filter=lambda n: n.endswith("a")) + assert set(result.keys()) == {"hook_a"} + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_direction_filter(mock_handle): + mock_handle.return_value.id = 0 + model = _ToyModel() + model.hook_a.add_hook(_my_named_hook, dir="fwd") + model.hook_a.add_hook(_other_hook, dir="bwd") + assert len(model.list_hooks(dir="fwd")["hook_a"]) == 1 + assert len(model.list_hooks(dir="bwd")["hook_a"]) == 1 + assert len(model.list_hooks(dir="both")["hook_a"]) == 2 + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_list_hooks_excludes_permanent_when_requested(mock_handle): + mock_handle.return_value.id = 0 + model = _ToyModel() + model.hook_a.add_hook(_my_named_hook, dir="fwd", is_permanent=True) + model.hook_a.add_hook(_other_hook, dir="fwd", is_permanent=False) + assert len(model.list_hooks(including_permanent=True)["hook_a"]) == 2 + handles = model.list_hooks(including_permanent=False)["hook_a"] + assert len(handles) == 1 + assert handles[0].user_hook is _other_hook + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_mixin_works_on_class_with_hook_dict_attribute(mock_handle): + """Pin the duck-typed contract: mixin reads ``hook_dict`` off any class that exposes it.""" + mock_handle.return_value.id = 0 + + class Bag(HookIntrospectionMixin): + def __init__(self): + hp = HookPoint() + hp.add_hook(_my_named_hook, dir="fwd") + self.hook_dict = {"only_hook": hp} + + result = Bag().list_hooks() + assert set(result.keys()) == {"only_hook"} + assert result["only_hook"][0].user_hook is _my_named_hook + + +@mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) +def test_mixin_works_on_class_with_hook_dict_property(mock_handle): + """``getattr`` indirection must accept a ``@property`` provider too (bridge case).""" + mock_handle.return_value.id = 0 + + class PropertyBag(HookIntrospectionMixin): + def __init__(self): + self._hooks = {"only_hook": HookPoint()} + self._hooks["only_hook"].add_hook(_my_named_hook, dir="fwd") + + @property + def hook_dict(self): + return self._hooks + + result = PropertyBag().list_hooks() + assert result["only_hook"][0].user_hook is _my_named_hook diff --git a/tests/unit/test_hook_points.py b/tests/unit/test_hook_points.py index 4e7828450..df85849f0 100644 --- a/tests/unit/test_hook_points.py +++ b/tests/unit/test_hook_points.py @@ -60,10 +60,11 @@ def hook2(activation, hook): # Make LensHandle constructor return a simple container capturing the pt_handle ('hook') class _LensHandleBox: - def __init__(self, handle, is_permanent, context_level): + def __init__(self, handle, is_permanent, context_level, user_hook=None): self.hook = handle self.is_permanent = is_permanent self.context_level = context_level + self.user_hook = user_hook mock_lens_handle.side_effect = _LensHandleBox diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index a942ec39d..7ec1ed7c1 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -48,6 +48,9 @@ class LensHandle: context_level: Optional[int] = None """Context level associated with the hooks context manager for the given hook.""" + user_hook: Optional[Callable] = None + """The original hook callable, before ``add_hook`` wraps it.""" + # Define type aliases NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]] @@ -167,6 +170,14 @@ def __init__(self): # This scales the SUM of gradients, not element-wise (to avoid PyTorch bugs) self.backward_scale: float = 1.0 + def __repr__(self) -> str: + bits = [f"name={self.name!r}"] if self.name is not None else [] + if self.fwd_hooks: + bits.append(f"{len(self.fwd_hooks)} fwd") + if self.bwd_hooks: + bits.append(f"{len(self.bwd_hooks)} bwd") + return f"HookPoint({', '.join(bits)})" if bits else "HookPoint()" + def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None: self.add_hook(hook, dir=dir, is_permanent=True) @@ -273,7 +284,7 @@ def _bwd_hook_wrapper( else: raise ValueError(f"Invalid direction {dir}") - handle = LensHandle(pt_handle, is_permanent, level) + handle = LensHandle(pt_handle, is_permanent, level, user_hook=hook) if prepend: # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... @@ -376,7 +387,55 @@ def layer(self): # %% -class HookedRootModule(nn.Module): +class HookIntrospectionMixin: + """``list_hooks()`` mixin for any class exposing a ``hook_dict``. + + Accessed via ``getattr`` so subclasses can provide ``hook_dict`` as either + an instance attribute (``HookedRootModule``) or a ``@property`` (``TransformerBridge``). + """ + + def list_hooks( + self, + name_filter: NamesFilter = None, + dir: Literal["fwd", "bwd", "both"] = "both", + including_permanent: bool = True, + ) -> dict[str, list[LensHandle]]: + """Return attached hooks grouped by HookPoint name; empty HookPoints are omitted. + + Args: + name_filter: A hook name, list of names, or predicate. ``None`` matches all. + dir: Restrict to forward, backward, or both directions. + including_permanent: If False, drop permanent hooks from the result. + """ + if name_filter is None: + matches: Callable[[str], bool] = lambda _: True + elif callable(name_filter): + matches = name_filter + elif isinstance(name_filter, str): + target = name_filter + matches = lambda n: n == target + else: + allowed = set(name_filter) + matches = lambda n: n in allowed + + out: dict[str, list[LensHandle]] = {} + hook_dict: dict[str, HookPoint] = getattr(self, "hook_dict") + for name, hp in hook_dict.items(): + if not matches(name): + continue + handles: list[LensHandle] = [] + if dir in ("fwd", "both"): + handles.extend(hp.fwd_hooks) + if dir in ("bwd", "both"): + handles.extend(hp.bwd_hooks) + if not including_permanent: + handles = [h for h in handles if not h.is_permanent] + if handles: + out[name] = handles + return out + + +class HookedRootModule(HookIntrospectionMixin, nn.Module): """A class building on nn.Module to interface nicely with HookPoints. Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index a92421473..91e685204 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -33,7 +33,7 @@ from transformer_lens import utilities as utils from transformer_lens.ActivationCache import ActivationCache from transformer_lens.FactoredMatrix import FactoredMatrix -from transformer_lens.hook_points import HookPoint +from transformer_lens.hook_points import HookIntrospectionMixin, HookPoint from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.component_setup import set_original_components from transformer_lens.model_bridge.composition_scores import CompositionScores @@ -94,7 +94,7 @@ def build_alias_to_canonical_map(hook_dict, prefix=""): return aliases -class TransformerBridge(nn.Module): +class TransformerBridge(HookIntrospectionMixin, nn.Module): """Bridge between HuggingFace and TransformerLens models. This class provides a standardized interface to access components of a transformer From dea2f188425e3e7b803c8ae149ca968845c65f38 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 12 May 2026 12:30:08 -0500 Subject: [PATCH 08/10] Made improvements to booting training revisions --- .../model_bridge/test_checkpoint_revision.py | 49 +++++++++---------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/tests/unit/model_bridge/test_checkpoint_revision.py b/tests/unit/model_bridge/test_checkpoint_revision.py index de9c35235..a3a4919dc 100644 --- a/tests/unit/model_bridge/test_checkpoint_revision.py +++ b/tests/unit/model_bridge/test_checkpoint_revision.py @@ -1,6 +1,6 @@ """Unit tests for the bridge revision/checkpoint API (issue #453).""" -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest @@ -81,46 +81,41 @@ def test_known_families_registered(self): assert "stanford-crfm" in _CHECKPOINT_REVISION_FORMATS +class _AbortBoot(Exception): + """Raised by the model-load patch to short-circuit ``boot()`` before any real load.""" + + class TestBootRevisionPlumbing: - """Verify that ``revision`` and ``checkpoint_*`` reach HF's from_pretrained calls.""" + """Verify that ``revision`` and ``checkpoint_*`` reach HF's from_pretrained calls. - def _patched_boot(self, **boot_kwargs): - """Call boot() with all the side-effect HF calls patched out. + Uses pythia-70m's real cached config (avoids MagicMock fragility through the + adapter/config-mapping path) and aborts at the model-load step. + """ - Returns ``(autoconfig_kwargs, model_from_pretrained_kwargs)``. - """ + def _patched_boot(self, **boot_kwargs): from transformer_lens.model_bridge.sources import transformers as bridge_src captured: dict = {} + real_autoconfig = bridge_src.AutoConfig.from_pretrained + + def capture_autoconfig(name, **kwargs): + captured["autoconfig_kwargs"] = dict(kwargs) + # Strip the (possibly fake) revision so the real call hits the CI cache. + kwargs.pop("revision", None) + return real_autoconfig(name, **kwargs) - def fake_autoconfig_from_pretrained(*args, **kwargs): - captured["autoconfig_args"] = args - captured["autoconfig_kwargs"] = kwargs - cfg = MagicMock() - cfg.architectures = ["GPT2LMHeadModel"] - cfg.n_positions = 1024 - cfg.pad_token_id = 0 - cfg.eos_token_id = 0 - cfg.to_dict = lambda: {"model_type": "gpt2"} - cfg.__dict__["pad_token_id"] = 0 - return cfg - - def fake_model_from_pretrained(*args, **kwargs): - captured["model_args"] = args + def capture_model_load(*args, **kwargs): captured["model_kwargs"] = kwargs raise _AbortBoot() - class _AbortBoot(Exception): - pass - with patch.object( - bridge_src.AutoConfig, "from_pretrained", side_effect=fake_autoconfig_from_pretrained + bridge_src.AutoConfig, "from_pretrained", side_effect=capture_autoconfig ), patch( "transformers.AutoModelForCausalLM.from_pretrained", - side_effect=fake_model_from_pretrained, + side_effect=capture_model_load, ): - with pytest.raises(Exception): - bridge_src.boot(model_name="EleutherAI/pythia-70m", **boot_kwargs) + with pytest.raises(_AbortBoot): + bridge_src.boot(model_name="EleutherAI/pythia-70m", device="cpu", **boot_kwargs) return captured From f1c5a5b9da1f03a31d204b542fbb0ecad004e991 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 12 May 2026 15:20:54 -0500 Subject: [PATCH 09/10] Adapter test improvements --- .../test_baichuan_adapter.py | 194 ++++++++-- .../test_codegen_adapter.py | 190 +++++++--- .../test_cohere_adapter.py | 214 ++++++----- .../test_gemma3_adapter.py | 322 ++++++++++++++++ .../test_gemma3_config.py | 117 +----- .../test_gemma3_multimodal_adapter.py | 347 +++++++++++++++++ .../test_gpt_bigcode_adapter.py | 169 ++++++--- .../test_internlm2_adapter.py | 250 +++++++----- .../test_llava_adapter.py | 328 ++++++++++++++++ .../test_mpt_adapter.py | 336 +++++++++++++--- .../test_qwen3_5_adapter.py | 299 ++++++++------- .../test_qwen3_moe_adapter.py | 358 ++++++++++++++++++ .../test_qwen3_next_adapter.py | 338 +++++++++-------- .../test_xglm_adapter.py | 246 +++++++++--- .../model_bridge/test_qwen3_moe_adapter.py | 194 ---------- tests/unit/test_gemma3_multimodal_adapter.py | 106 ------ tests/unit/test_llava_config.py | 116 ------ 17 files changed, 2843 insertions(+), 1281 deletions(-) create mode 100644 tests/unit/model_bridge/supported_architectures/test_gemma3_adapter.py rename tests/unit/{ => model_bridge/supported_architectures}/test_gemma3_config.py (72%) create mode 100644 tests/unit/model_bridge/supported_architectures/test_gemma3_multimodal_adapter.py create mode 100644 tests/unit/model_bridge/supported_architectures/test_llava_adapter.py rename tests/unit/{ => model_bridge/supported_architectures}/test_qwen3_5_adapter.py (71%) create mode 100644 tests/unit/model_bridge/supported_architectures/test_qwen3_moe_adapter.py rename tests/unit/{ => model_bridge/supported_architectures}/test_qwen3_next_adapter.py (63%) delete mode 100644 tests/unit/model_bridge/test_qwen3_moe_adapter.py delete mode 100644 tests/unit/test_gemma3_multimodal_adapter.py delete mode 100644 tests/unit/test_llava_config.py diff --git a/tests/unit/model_bridge/supported_architectures/test_baichuan_adapter.py b/tests/unit/model_bridge/supported_architectures/test_baichuan_adapter.py index d1e2348df..f7f065085 100644 --- a/tests/unit/model_bridge/supported_architectures/test_baichuan_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_baichuan_adapter.py @@ -1,13 +1,4 @@ -"""Unit tests for BaichuanArchitectureAdapter. - -Tests cover: -- Config attributes -- Component mapping structure and HF module names -- Weight conversion keys/types -- split_qkv_matrix (W_pack) numerical correctness -- preprocess_weights (QKV split, fold_ln, NormHead normalization) -- Factory registration (both v1 and v2 class names) -""" +"""Unit tests for BaichuanArchitectureAdapter.""" from types import SimpleNamespace from typing import Any @@ -26,6 +17,7 @@ EmbeddingBridge, GatedMLPBridge, JointQKVPositionEmbeddingsAttentionBridge, + LinearBridge, RMSNormalizationBridge, UnembeddingBridge, ) @@ -59,12 +51,12 @@ def _make_cfg( ) -@pytest.fixture +@pytest.fixture(scope="class") def cfg() -> TransformerBridgeConfig: return _make_cfg(n_heads=8, d_model=64) -@pytest.fixture +@pytest.fixture(scope="class") def adapter(cfg: TransformerBridgeConfig) -> BaichuanArchitectureAdapter: return BaichuanArchitectureAdapter(cfg) @@ -230,7 +222,7 @@ def test_q_rearrange_n_equals_n_heads(self, adapter: BaichuanArchitectureAdapter assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads def test_k_rearrange_n_equals_n_heads(self, adapter: BaichuanArchitectureAdapter) -> None: - # Baichuan is MHA (no GQA), so K also uses n_heads + # Baichuan is MHA (no GQA): K uses n_heads. convs = adapter.weight_processing_conversions assert convs is not None conv = convs["blocks.{i}.attn.k.weight"] @@ -253,6 +245,150 @@ def test_no_source_key_on_q(self, adapter: BaichuanArchitectureAdapter) -> None: assert isinstance(conv, ParamProcessingConversion) assert conv.source_key is None + def test_k_conversion_type(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.k.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + + def test_v_conversion_type(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.v.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + + def test_o_conversion_type(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.o.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + + def test_k_rearrange_pattern(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.k.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_v_rearrange_pattern(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.v.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_v_rearrange_n_equals_n_heads(self, adapter: BaichuanArchitectureAdapter) -> None: + # Baichuan is MHA: V uses n_heads. + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.v.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_o_rearrange_n_equals_n_heads(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.o.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + +class TestBaichuanAdapterComponentTypesExtras: + """Type-level checks for the joint-QKV attention submodules and block submodules.""" + + @pytest.fixture(scope="class") + def adapter(self) -> BaichuanArchitectureAdapter: + return BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64)) + + def test_attn_qkv_is_linear_bridge(self, adapter: BaichuanArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn.submodules["qkv"], LinearBridge) + + def test_attn_o_is_linear_bridge(self, adapter: BaichuanArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn.submodules["o"], LinearBridge) + + def test_attn_has_joint_qkv_and_post_split_q_k_v( + self, adapter: BaichuanArchitectureAdapter + ) -> None: + # JointQKV bridges auto-create placeholder q/k/v alongside the joint qkv slot. + attn = adapter.component_mapping["blocks"].submodules["attn"] + for name in ("qkv", "o", "q", "k", "v"): + assert name in attn.submodules + assert attn.submodules["qkv"].name == "W_pack" + assert attn.submodules["o"].name == "o_proj" + + def test_mlp_gate_is_linear_bridge(self, adapter: BaichuanArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert isinstance(mlp.submodules["gate"], LinearBridge) + + def test_mlp_in_is_linear_bridge(self, adapter: BaichuanArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert isinstance(mlp.submodules["in"], LinearBridge) + + def test_mlp_out_is_linear_bridge(self, adapter: BaichuanArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert isinstance(mlp.submodules["out"], LinearBridge) + + +class TestBaichuanArchitectureGuards: + """What must NOT be there: Baichuan is LLaMA-pattern RoPE, no learned pos, no Gemma offsets.""" + + @pytest.fixture(scope="class") + def adapter(self) -> BaichuanArchitectureAdapter: + return BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64)) + + def test_no_pos_embed_component(self, adapter: BaichuanArchitectureAdapter) -> None: + # Rotary architecture: no learned positional embeddings. + assert "pos_embed" not in adapter.component_mapping + + def test_no_norm_offset_conversions(self, adapter: BaichuanArchitectureAdapter) -> None: + # RMSNorm has no Gemma-style offset. + convs = adapter.weight_processing_conversions + assert convs is not None + for key in convs: + assert "ln1.weight" not in key, f"Unexpected ln1 conversion: {key}" + assert "ln2.weight" not in key, f"Unexpected ln2 conversion: {key}" + assert "ln_final.weight" not in key, f"Unexpected ln_final conversion: {key}" + + def test_only_qkvo_conversion_keys(self, adapter: BaichuanArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + assert set(convs.keys()) == { + "blocks.{i}.attn.q.weight", + "blocks.{i}.attn.k.weight", + "blocks.{i}.attn.v.weight", + "blocks.{i}.attn.o.weight", + } + + +class TestBaichuanGQAFallback: + """Baichuan uses cfg.n_heads directly for all QKVO; n_key_value_heads on cfg is ignored.""" + + def test_kv_conversions_still_use_n_heads_when_n_kv_heads_set(self) -> None: + # Baichuan is MHA-only: pin K/V to n_heads regardless of n_key_value_heads, + # guarding against a silent switch to the GQA helper that would change K/V layout. + cfg = _make_cfg(n_heads=8, d_model=64) + cfg.n_key_value_heads = 2 # type: ignore[attr-defined] + adapter = BaichuanArchitectureAdapter(cfg) + convs = adapter.weight_processing_conversions + assert convs is not None + k = convs["blocks.{i}.attn.k.weight"] + v = convs["blocks.{i}.attn.v.weight"] + assert isinstance(k, ParamProcessingConversion) + assert isinstance(v, ParamProcessingConversion) + assert isinstance(k.tensor_conversion, RearrangeTensorConversion) + assert isinstance(v.tensor_conversion, RearrangeTensorConversion) + assert k.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + assert v.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + # --------------------------------------------------------------------------- # split_qkv_matrix (W_pack) tests @@ -293,7 +429,6 @@ def test_concatenated_split_correctness(self) -> None: d_model = 32 adapter = self._adapter(n_heads=4, d_model=d_model) attn = _make_w_pack_component(d_model) - # Fill W_pack: Q=1.0, K=2.0, V=3.0 w = torch.zeros(3 * d_model, d_model) w[:d_model, :] = 1.0 w[d_model : 2 * d_model, :] = 2.0 @@ -367,12 +502,12 @@ def test_fused_key_removed_and_split_keys_written(self) -> None: assert "blocks.0.attn.v.weight" in result def test_split_shapes(self) -> None: + # MHA: Q/K/V each [d_model, d_model]. d_model = 64 adapter = BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=d_model)) adapter._fold_ln_requested = True sd = self._make_state_dict(adapter, d_model=d_model) result = adapter.preprocess_weights(sd) - # Baichuan is MHA: Q, K, V each have shape [d_model, d_model] assert result["blocks.0.attn.q.weight"].shape == (d_model, d_model) assert result["blocks.0.attn.k.weight"].shape == (d_model, d_model) assert result["blocks.0.attn.v.weight"].shape == (d_model, d_model) @@ -455,7 +590,7 @@ def _adapter(self) -> BaichuanArchitectureAdapter: return BaichuanArchitectureAdapter(_make_cfg(n_heads=8, d_model=64)) def test_normhead_weights_normalized(self) -> None: - """NormHead (has first_flag) should have row-normalized weights after prepare_model.""" + """NormHead (has first_flag) row-normalizes weights at prepare_model.""" adapter = self._adapter() lm_head = SimpleNamespace( weight=nn.Parameter(torch.full((100, 64), 2.0)), @@ -476,17 +611,16 @@ def test_regular_linear_unchanged(self) -> None: assert torch.equal(lm_head.weight.data, original_w) def test_no_lm_head_is_noop(self) -> None: - """Model without lm_head should not raise.""" adapter = self._adapter() hf_model = SimpleNamespace() - adapter.prepare_model(hf_model) # should not raise + adapter.prepare_model(hf_model) def test_recomputes_rotary_from_scratch_when_inv_freq_is_meta(self) -> None: """Baichuan2's inv_freq/cos_cached are plain attrs that land on meta under HF v5 meta-init; prepare_model must recompute real values regardless.""" adapter = self._adapter() head_dim = adapter.cfg.d_model // adapter.cfg.n_heads - # Meta-device rotary matching v2's plain-attribute shape + # Meta-device rotary matching v2's plain-attribute shape. rotary = SimpleNamespace( inv_freq=torch.empty(head_dim // 2, device="meta"), cos_cached=torch.empty(1, 1, 16, head_dim, device="meta"), @@ -502,7 +636,7 @@ def test_recomputes_rotary_from_scratch_when_inv_freq_is_meta(self) -> None: assert rotary.cos_cached.device.type == "cpu" assert rotary.sin_cached.device.type == "cpu" assert rotary.cos_cached.shape == (1, 1, 16, head_dim) - # Sanity: cos(0) == 1 and position 0 of each head_dim element equals 1. + # cos(0) == 1. assert torch.allclose( rotary.cos_cached[0, 0, 0, :], torch.ones(head_dim), @@ -569,7 +703,7 @@ class _FakeRotary(nn.Module): def __init__(self, head_dim: int, max_seq_len: int) -> None: super().__init__() self.max_seq_len_cached = max_seq_len - # Fill with position-dependent values so tests can verify indexing. + # Position-dependent values so tests can verify indexing. cos = ( torch.arange(max_seq_len, dtype=torch.float32)[:, None] .expand(max_seq_len, head_dim) @@ -629,8 +763,7 @@ def _wire_bridge( rotary = _FakeRotary(head_dim=head_dim, max_seq_len=32) fake_attn = _FakeAttention(rotary, cfg.d_model) bridge.set_original_component(fake_attn) - # `o` LinearBridge is normally wired by setup_components via component_mapping; - # wire it directly for unit tests that construct the bridge standalone. + # Wire `o` directly since standalone construction skips setup_components. bridge.o.set_original_component(fake_attn.o_proj) return bridge, rotary, head_dim @@ -653,7 +786,7 @@ def test_uses_position_ids_when_position_embeddings_absent( q, k, v, position_ids=position_ids, use_cache=True ) - # rotary_emb called once, with kv_seq_len=seq (no past) + # No past: rotary called once with kv_seq_len=seq. assert rotary.calls == [seq] assert attn_output.shape == (batch, seq, cfg.d_model) assert present is not None @@ -681,7 +814,7 @@ def test_preserves_explicit_position_embeddings(self, cfg: TransformerBridgeConf position_ids=torch.tensor([[0, 1, 2, 3]]), use_cache=True, ) - # Caller-supplied embeddings must win; rotary_emb must not be called. + # Explicit embeddings win; rotary must not be called. assert rotary.calls == [] def test_use_cache_false_returns_none_present(self, cfg: TransformerBridgeConfig) -> None: @@ -704,7 +837,7 @@ def test_concats_past_key_value_along_seq_dim(self, cfg: TransformerBridgeConfig q = torch.zeros(batch, seq, cfg.d_model) k = torch.zeros_like(q) v = torch.zeros_like(q) - # HF's Model.forward generates position_ids offset by past_len. + # HF generates position_ids offset by past_len. position_ids = torch.tensor([[past_len, past_len + 1]]) _, _, present = bridge._reconstruct_attention( @@ -720,7 +853,7 @@ def test_concats_past_key_value_along_seq_dim(self, cfg: TransformerBridgeConfig present_k, present_v = present assert present_k.shape == (batch, cfg.n_heads, past_len + seq, head_dim) assert present_v.shape == (batch, cfg.n_heads, past_len + seq, head_dim) - # First past_len slots must be the provided past, unchanged. + # Past slots must be preserved unchanged. assert torch.equal(present_k[:, :, :past_len, :], past_k) assert torch.equal(present_v[:, :, :past_len, :], past_v) @@ -738,9 +871,7 @@ def test_preflight_raises_clean_import_error( ) -> None: import transformer_lens.model_bridge.supported_architectures.baichuan as baichuan_mod - # Force the preflight path: make find_spec report bitsandbytes missing, - # and make get_class_from_dynamic_module surface the transformers-style - # "requires the following packages... bitsandbytes" error. + # Force preflight: bnb missing + transformers raises its bnb-mentioning error. monkeypatch.setattr(baichuan_mod.importlib.util, "find_spec", lambda name: None) def _raise_bnb(*_a: Any, **_k: Any) -> None: @@ -766,6 +897,5 @@ def _raise_generic(*_a: Any, **_k: Any) -> None: raise ValueError("some unrelated loader failure") monkeypatch.setattr(dmu, "get_class_from_dynamic_module", _raise_generic) - # Must not raise — the generic failure path is swallowed (remote load - # may legitimately fail for offline tests, e.g. no network access). + # Generic failures are swallowed (offline/no-network is legit). adapter.prepare_loading("baichuan-inc/Baichuan2-7B-Chat", {}) diff --git a/tests/unit/model_bridge/supported_architectures/test_codegen_adapter.py b/tests/unit/model_bridge/supported_architectures/test_codegen_adapter.py index efee81fc9..caf8c63d7 100644 --- a/tests/unit/model_bridge/supported_architectures/test_codegen_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_codegen_adapter.py @@ -1,12 +1,4 @@ -"""Unit tests for CodeGenArchitectureAdapter. - -Tests cover: -- Config attribute validation (all required attributes are set correctly) -- Component mapping structure (correct bridge types, no ln2) -- Weight conversion keys and structure -- split_qkv_matrix correctness (numerical test with known weights) -- Factory registration (CodeGenForCausalLM maps to the right adapter) -""" +"""Unit tests for CodeGenArchitectureAdapter.""" from types import SimpleNamespace from typing import Any @@ -16,12 +8,18 @@ import torch.nn as nn from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) from transformer_lens.model_bridge.generalized_components import ( BlockBridge, CodeGenAttentionBridge, EmbeddingBridge, + LinearBridge, MLPBridge, NormalizationBridge, + ParallelBlockBridge, UnembeddingBridge, ) from transformer_lens.model_bridge.supported_architectures.codegen import ( @@ -55,12 +53,12 @@ def _make_cfg( ) -@pytest.fixture +@pytest.fixture(scope="class") def cfg() -> TransformerBridgeConfig: return _make_cfg() -@pytest.fixture +@pytest.fixture(scope="class") def adapter(cfg: TransformerBridgeConfig) -> CodeGenArchitectureAdapter: return CodeGenArchitectureAdapter(cfg) @@ -71,8 +69,6 @@ def adapter(cfg: TransformerBridgeConfig) -> CodeGenArchitectureAdapter: class TestCodeGenAdapterConfig: - """Tests that the adapter sets required config attributes correctly.""" - def test_normalization_type_is_ln(self, adapter: CodeGenArchitectureAdapter) -> None: assert adapter.cfg.normalization_type == "LN" @@ -98,8 +94,6 @@ def test_parallel_attn_mlp_is_true(self, adapter: CodeGenArchitectureAdapter) -> class TestCodeGenAdapterComponentMapping: - """Tests that component_mapping has the correct bridge types and structure.""" - def test_embed_is_embedding_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) @@ -133,9 +127,9 @@ def test_blocks_ln1_name(self, adapter: CodeGenArchitectureAdapter) -> None: assert blocks.submodules["ln1"].name == "ln_1" def test_no_ln2_in_blocks(self, adapter: CodeGenArchitectureAdapter) -> None: - """CodeGen uses parallel attn+MLP sharing ln_1 — there must be no ln2.""" + """Parallel attn+MLP shares ln_1; no ln2 exists.""" blocks = adapter.component_mapping["blocks"] - assert "ln2" not in blocks.submodules, "CodeGen parallel block must not have ln2" + assert "ln2" not in blocks.submodules def test_attn_is_codegen_attention_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: blocks = adapter.component_mapping["blocks"] @@ -168,8 +162,6 @@ def test_mlp_out_name(self, adapter: CodeGenArchitectureAdapter) -> None: class TestCodeGenAdapterWeightConversions: - """Tests that weight_processing_conversions has the expected keys.""" - def test_q_weight_key_present(self, adapter: CodeGenArchitectureAdapter) -> None: assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions @@ -186,26 +178,129 @@ def test_exactly_four_conversion_keys(self, adapter: CodeGenArchitectureAdapter) assert len(adapter.weight_processing_conversions) == 4 +class TestCodeGenAdapterWeightConversionSemantics: + """Each Q/K/V/O wraps a RearrangeTensorConversion with the right pattern and n axis.""" + + @pytest.fixture(scope="class") + def adapter(self) -> CodeGenArchitectureAdapter: + return CodeGenArchitectureAdapter(_make_cfg()) + + @pytest.mark.parametrize("slot", ["q", "k", "v"]) + def test_qkv_uses_split_heads_pattern( + self, adapter: CodeGenArchitectureAdapter, slot: str + ) -> None: + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_o_uses_merge_heads_pattern(self, adapter: CodeGenArchitectureAdapter) -> None: + conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_n_kv_heads_on_cfg_does_not_change_kv_conversions(self) -> None: + # CodeGen is MHA-only: K/V pinned to n_heads regardless of n_key_value_heads. + cfg = _make_cfg() + cfg.n_key_value_heads = 1 # type: ignore[attr-defined] + adapter = CodeGenArchitectureAdapter(cfg) + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + +class TestCodeGenAdapterComponentTypesExtras: + """Bridge-type assertions for joint-QKV, MLP submodules, and parallel block class.""" + + @pytest.fixture(scope="class") + def adapter(self) -> CodeGenArchitectureAdapter: + return CodeGenArchitectureAdapter(_make_cfg()) + + def test_blocks_is_parallel_block_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: + # Parallel attn+MLP: must use ParallelBlockBridge, not sequential BlockBridge. + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks, ParallelBlockBridge) + + def test_attn_qkv_is_linear_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + qkv = attn.submodules["qkv"] + assert isinstance(qkv, LinearBridge) + assert qkv.name == "qkv_proj" + + def test_attn_o_is_linear_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + o = attn.submodules["o"] + assert isinstance(o, LinearBridge) + assert o.name == "out_proj" + + def test_mlp_in_is_linear_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert isinstance(mlp.submodules["in"], LinearBridge) + + def test_mlp_out_is_linear_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert isinstance(mlp.submodules["out"], LinearBridge) + + def test_no_gate_in_mlp(self, adapter: CodeGenArchitectureAdapter) -> None: + """Non-gated MLP: no 'gate' submodule.""" + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert "gate" not in mlp.submodules + + +class TestCodeGenArchitectureGuards: + """RoPE-in-attention (no top-level rotary_emb), no learned pos, no Gemma offsets.""" + + @pytest.fixture(scope="class") + def adapter(self) -> CodeGenArchitectureAdapter: + return CodeGenArchitectureAdapter(_make_cfg()) + + def test_no_top_level_rotary_emb(self, adapter: CodeGenArchitectureAdapter) -> None: + # Rotary is applied inside attention forward; no standalone HF module to bind. + assert "rotary_emb" not in adapter.component_mapping + + def test_no_pos_embed_component(self, adapter: CodeGenArchitectureAdapter) -> None: + assert "pos_embed" not in adapter.component_mapping + + def test_no_norm_offset_conversions(self, adapter: CodeGenArchitectureAdapter) -> None: + # LN-only: no Gemma-style ln1/ln2 offsets. + for key in adapter.weight_processing_conversions: + assert "ln1.weight" not in key + assert "ln2.weight" not in key + assert "ln_final.weight" not in key + + def test_only_qkvo_conversion_keys(self, adapter: CodeGenArchitectureAdapter) -> None: + assert set(adapter.weight_processing_conversions.keys()) == { + "blocks.{i}.attn.q.weight", + "blocks.{i}.attn.k.weight", + "blocks.{i}.attn.v.weight", + "blocks.{i}.attn.o.weight", + } + + # --------------------------------------------------------------------------- # split_qkv_matrix numerical correctness tests # --------------------------------------------------------------------------- class TestCodeGenSplitQKVMatrix: - """Numerical tests verifying the mp_num=4 QKV split logic.""" + """Numerical tests for the mp_num=4 QKV split.""" def _make_adapter_with_dmodel(self, d_model: int, n_heads: int) -> CodeGenArchitectureAdapter: cfg = _make_cfg(d_model=d_model, n_heads=n_heads) return CodeGenArchitectureAdapter(cfg) def _make_attn_component(self, d_model: int) -> Any: - """Create a minimal attn component with a qkv_proj linear.""" + """Minimal attn with a qkv_proj linear.""" attn = SimpleNamespace() attn.qkv_proj = nn.Linear(d_model, d_model * 3, bias=False) return attn def test_returns_three_linear_modules(self) -> None: - """split_qkv_matrix must return exactly three nn.Linear modules.""" adapter = self._make_adapter_with_dmodel(64, 4) attn = self._make_attn_component(64) q, k, v = adapter.split_qkv_matrix(attn) @@ -214,7 +309,6 @@ def test_returns_three_linear_modules(self) -> None: assert isinstance(v, nn.Linear) def test_output_shapes_are_correct(self) -> None: - """Each of Q, K, V must have weight shape [n_embd, n_embd].""" d_model = 64 adapter = self._make_adapter_with_dmodel(d_model, 4) attn = self._make_attn_component(d_model) @@ -224,7 +318,6 @@ def test_output_shapes_are_correct(self) -> None: assert v.weight.shape == (d_model, d_model) def test_no_bias_on_outputs(self) -> None: - """The split linears must have no bias, matching qkv_proj.""" adapter = self._make_adapter_with_dmodel(64, 4) attn = self._make_attn_component(64) q, k, v = adapter.split_qkv_matrix(attn) @@ -233,23 +326,16 @@ def test_no_bias_on_outputs(self) -> None: assert v.bias is None def test_q_k_v_are_distinct(self) -> None: - """With a non-trivial weight, Q, K, V must differ from each other.""" adapter = self._make_adapter_with_dmodel(64, 4) attn = self._make_attn_component(64) - # Fill qkv_proj with distinct values per row nn.init.normal_(attn.qkv_proj.weight) q, k, v = adapter.split_qkv_matrix(attn) - # All three must differ - assert not torch.allclose(q.weight, k.weight), "Q and K weights must differ" - assert not torch.allclose(q.weight, v.weight), "Q and V weights must differ" - assert not torch.allclose(k.weight, v.weight), "K and V weights must differ" + assert not torch.allclose(q.weight, k.weight) + assert not torch.allclose(q.weight, v.weight) + assert not torch.allclose(k.weight, v.weight) def test_known_partition_ordering(self) -> None: - """Verify the mp_num=4 partition layout: within each partition [Q_part, V_part, K_part]. - - We construct a weight where partition index and slot index are embedded - in the values, then verify that Q, K, V extract the correct slices. - """ + """mp_num=4 layout within each partition is [Q_part, V_part, K_part].""" mp_num = 4 d_model = 64 n_heads = 4 @@ -258,18 +344,12 @@ def test_known_partition_ordering(self) -> None: adapter = self._make_adapter_with_dmodel(d_model, n_heads) attn = self._make_attn_component(d_model) - # Build a structured weight: rows are indexed 0..3*d_model-1. - # Reshape as [mp_num=4, 3, local_dim=16, d_model=64], set each slice - # to a unique constant so we can track which slot goes where. + # Tag each slot with a unique constant to track its destination. w = torch.zeros(mp_num, 3, local_dim, d_model) - # slot 0 = Q_part → fill with 1.0 - w[:, 0, :, :] = 1.0 - # slot 1 = V_part → fill with 2.0 - w[:, 1, :, :] = 2.0 - # slot 2 = K_part → fill with 3.0 - w[:, 2, :, :] = 3.0 - - # Flatten back to [3*d_model, d_model] as qkv_proj expects + w[:, 0, :, :] = 1.0 # Q_part + w[:, 1, :, :] = 2.0 # V_part + w[:, 2, :, :] = 3.0 # K_part + attn.qkv_proj.weight = nn.Parameter(w.reshape(3 * d_model, d_model)) q, k, v = adapter.split_qkv_matrix(attn) @@ -279,7 +359,6 @@ def test_known_partition_ordering(self) -> None: assert torch.all(v.weight == 2.0), "V should come from slot 1 (V_part)" def test_forward_output_shape_with_split(self) -> None: - """After split, Q/K/V linears should produce correct output shapes.""" d_model = 64 adapter = self._make_adapter_with_dmodel(d_model, 4) attn = self._make_attn_component(d_model) @@ -298,31 +377,18 @@ def test_forward_output_shape_with_split(self) -> None: class TestCodeGenFactoryRegistration: - """Tests that the factory maps CodeGenForCausalLM to the correct adapter. - - Note: Phase D (registration) is required for these tests to pass. They - are included here so that registration is verified as part of the Phase D - commit rather than needing a separate test file. - """ - def test_factory_returns_codegen_adapter(self) -> None: - """ArchitectureAdapterFactory must return a CodeGenArchitectureAdapter.""" from transformer_lens.factories.architecture_adapter_factory import ( ArchitectureAdapterFactory, ) cfg = _make_cfg() adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) - assert isinstance( - adapter, CodeGenArchitectureAdapter - ), f"Expected CodeGenArchitectureAdapter, got {type(adapter).__name__}" + assert isinstance(adapter, CodeGenArchitectureAdapter) def test_factory_key_is_codegen_for_causal_lm(self) -> None: - """SUPPORTED_ARCHITECTURES must have a 'CodeGenForCausalLM' key.""" from transformer_lens.factories.architecture_adapter_factory import ( SUPPORTED_ARCHITECTURES, ) - assert ( - "CodeGenForCausalLM" in SUPPORTED_ARCHITECTURES - ), "CodeGenForCausalLM must be registered in SUPPORTED_ARCHITECTURES" + assert "CodeGenForCausalLM" in SUPPORTED_ARCHITECTURES diff --git a/tests/unit/model_bridge/supported_architectures/test_cohere_adapter.py b/tests/unit/model_bridge/supported_architectures/test_cohere_adapter.py index 865110bf0..eb687e2b7 100644 --- a/tests/unit/model_bridge/supported_architectures/test_cohere_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_cohere_adapter.py @@ -1,25 +1,20 @@ -"""Unit tests for CohereArchitectureAdapter — Phases A, B, C. - -Covers: -- All cfg.* attributes set correctly in __init__ -- logit_scale None-check behaviour -- RoPE theta extraction from rope_parameters dict -- GQA n_key_value_heads forwarded to cfg -- Factory registration (CohereForCausalLM maps to CohereArchitectureAdapter) -- weight_processing_conversions: GQA-aware Q/K/V/O rearrangements -- preprocess_weights: logit_scale folded into unembed.weight -""" +"""Unit tests for CohereArchitectureAdapter: cfg, components, weight conversions, preprocess.""" import pytest import torch from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) from transformer_lens.model_bridge.generalized_components import ( BlockBridge, EmbeddingBridge, GatedMLPBridge, LinearBridge, NormalizationBridge, + ParallelBlockBridge, PositionEmbeddingsAttentionBridge, RotaryEmbeddingBridge, UnembeddingBridge, @@ -28,11 +23,6 @@ CohereArchitectureAdapter, ) -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - def _make_cfg( n_heads: int = 4, d_model: int = 64, @@ -44,7 +34,7 @@ def _make_cfg( logit_scale: float | None = 0.0625, rope_parameters: dict | None = None, ) -> TransformerBridgeConfig: - """Return a minimal TransformerBridgeConfig for Cohere adapter tests.""" + """Minimal TransformerBridgeConfig for Cohere adapter tests.""" cfg = TransformerBridgeConfig( d_model=d_model, d_head=d_model // n_heads, @@ -56,9 +46,7 @@ def _make_cfg( default_prepend_bos=True, architecture="CohereForCausalLM", ) - # Set Cohere-specific fields that sources/transformers.py would populate. - # logit_scale and rope_parameters are not declared on TransformerBridgeConfig; - # use setattr so mypy doesn't flag attr-defined errors in test helpers. + # logit_scale/rope_parameters are dynamic attrs on cfg (not declared on the dataclass). if n_key_value_heads is not None: cfg.n_key_value_heads = n_key_value_heads setattr(cfg, "logit_scale", logit_scale) @@ -67,37 +55,31 @@ def _make_cfg( return cfg -@pytest.fixture +@pytest.fixture(scope="class") def cfg() -> TransformerBridgeConfig: return _make_cfg() -@pytest.fixture +@pytest.fixture(scope="class") def adapter(cfg: TransformerBridgeConfig) -> CohereArchitectureAdapter: return CohereArchitectureAdapter(cfg) -# --------------------------------------------------------------------------- -# Config attribute tests -# --------------------------------------------------------------------------- - - class TestCohereAdapterConfig: - """Verify every cfg.* attribute set by CohereArchitectureAdapter.__init__.""" + """Adapter sets cfg.* attributes correctly.""" def test_normalization_type_is_ln(self, adapter: CohereArchitectureAdapter) -> None: assert adapter.cfg.normalization_type == "LN" def test_uses_rms_norm_is_false(self, adapter: CohereArchitectureAdapter) -> None: - # CohereLayerNorm subtracts the mean — NOT RMSNorm + # CohereLayerNorm subtracts the mean — NOT RMSNorm. assert adapter.cfg.uses_rms_norm is False def test_eps_attr_is_variance_epsilon(self, adapter: CohereArchitectureAdapter) -> None: - # CohereLayerNorm stores epsilon as self.variance_epsilon + # CohereLayerNorm stores epsilon as self.variance_epsilon. assert adapter.cfg.eps_attr == "variance_epsilon" def test_final_rms_is_false(self, adapter: CohereArchitectureAdapter) -> None: - # Final norm is also CohereLayerNorm, not RMSNorm assert adapter.cfg.final_rms is False def test_positional_embedding_type_is_rotary(self, adapter: CohereArchitectureAdapter) -> None: @@ -110,48 +92,35 @@ def test_attn_only_is_false(self, adapter: CohereArchitectureAdapter) -> None: assert adapter.cfg.attn_only is False def test_parallel_attn_mlp_is_true(self, adapter: CohereArchitectureAdapter) -> None: - # Single input_layernorm; attn and MLP run in parallel on same normed input + # Single input_layernorm; attn and MLP run in parallel on same normed input. assert adapter.cfg.parallel_attn_mlp is True def test_default_prepend_bos_is_true(self, adapter: CohereArchitectureAdapter) -> None: - # CohereTokenizerFast prepends BOS by default via add_special_tokens=True assert adapter.cfg.default_prepend_bos is True def test_n_key_value_heads_forwarded(self, adapter: CohereArchitectureAdapter) -> None: - # GQA: n_key_value_heads=2 from the test cfg should be on adapter.cfg assert adapter.cfg.n_key_value_heads == 2 def test_logit_scale_default(self, adapter: CohereArchitectureAdapter) -> None: - # Default logit_scale is 0.0625 (1/16) - # logit_scale is a Cohere-specific dynamic attribute on cfg assert getattr(adapter.cfg, "logit_scale") == pytest.approx(0.0625) def test_logit_scale_is_float(self, adapter: CohereArchitectureAdapter) -> None: assert isinstance(getattr(adapter.cfg, "logit_scale"), float) def test_rotary_base_extracted(self) -> None: - # rope_parameters dict with explicit rope_theta - # TransformerBridgeConfig stores rotary_base as int; 80000 == 80000.0 cfg = _make_cfg(rope_parameters={"rope_theta": 80000.0, "rope_type": "default"}) adapter = CohereArchitectureAdapter(cfg) assert adapter.cfg.rotary_base == 80000 def test_rotary_base_default_when_no_rope_parameters(self) -> None: - # When rope_parameters is absent, fall back via default_theta or 10000.0 - # TransformerBridgeConfig stores rotary_base as int - cfg = _make_cfg() # no rope_parameters key set + cfg = _make_cfg() adapter = CohereArchitectureAdapter(cfg) assert isinstance(adapter.cfg.rotary_base, int) assert adapter.cfg.rotary_base > 0 -# --------------------------------------------------------------------------- -# logit_scale None-check tests -# --------------------------------------------------------------------------- - - class TestCohereLogitScaleNoneCheck: - """Verify the explicit None-check for logit_scale (HF type is float | None).""" + """logit_scale=None falls back to default; explicit values are preserved (HF type is float | None).""" def test_none_logit_scale_falls_back_to_default(self) -> None: cfg = _make_cfg(logit_scale=None) @@ -169,13 +138,8 @@ def test_logit_scale_one_preserved(self) -> None: assert getattr(adapter.cfg, "logit_scale") == pytest.approx(1.0) -# --------------------------------------------------------------------------- -# Factory registration tests -# --------------------------------------------------------------------------- - - class TestCohereFactoryRegistration: - """Verify factory maps 'CohereForCausalLM' to CohereArchitectureAdapter.""" + """Factory maps 'CohereForCausalLM' to CohereArchitectureAdapter.""" def test_factory_returns_cohere_adapter(self) -> None: from transformer_lens.factories.architecture_adapter_factory import ( @@ -205,20 +169,8 @@ def test_factory_maps_to_correct_class(self) -> None: assert SUPPORTED_ARCHITECTURES["CohereForCausalLM"] is CohereArchitectureAdapter -# --------------------------------------------------------------------------- -# Component mapping tests (Phase B) -# --------------------------------------------------------------------------- - - class TestCohereAdapterComponentMapping: - """Verify component_mapping has the correct bridge types and HF module paths. - - Plan reference: Phase B — module paths table. - Block structure: Falcon parallel-attn pattern (ln1 only, no ln2). - Submodule shapes: Llama-style separate Q/K/V/O and SwiGLU gate/in/out. - """ - - # -- Top-level components -- + """component_mapping has the correct bridge types and HF module paths.""" def test_embed_is_embedding_bridge(self, adapter: CohereArchitectureAdapter) -> None: assert adapter.component_mapping is not None @@ -229,7 +181,7 @@ def test_embed_name(self, adapter: CohereArchitectureAdapter) -> None: assert adapter.component_mapping["embed"].name == "model.embed_tokens" def test_rotary_emb_is_rotary_bridge(self, adapter: CohereArchitectureAdapter) -> None: - # rotary_emb is top-level (not inside blocks), matching llama.py:75 / falcon.py:154 + # rotary_emb is top-level (not inside blocks), matching Llama/Falcon. assert adapter.component_mapping is not None assert isinstance(adapter.component_mapping["rotary_emb"], RotaryEmbeddingBridge) @@ -268,8 +220,6 @@ def test_unembed_name(self, adapter: CohereArchitectureAdapter) -> None: assert adapter.component_mapping is not None assert adapter.component_mapping["unembed"].name == "lm_head" - # -- Block submodules -- - def test_blocks_has_ln1(self, adapter: CohereArchitectureAdapter) -> None: assert adapter.component_mapping is not None blocks = adapter.component_mapping["blocks"] @@ -281,21 +231,18 @@ def test_blocks_ln1_is_normalization_bridge(self, adapter: CohereArchitectureAda assert isinstance(blocks.submodules["ln1"], NormalizationBridge) def test_blocks_ln1_name(self, adapter: CohereArchitectureAdapter) -> None: - # Cohere uses input_layernorm (same HF name as Llama, unlike Falcon's ln_attn) assert adapter.component_mapping is not None blocks = adapter.component_mapping["blocks"] assert blocks.submodules["ln1"].name == "input_layernorm" def test_no_ln2_in_blocks(self, adapter: CohereArchitectureAdapter) -> None: - # Parallel block: single norm feeds both attn and MLP — no post_attention_layernorm + # Parallel block: single norm feeds both attn and MLP — no post_attention_layernorm. assert adapter.component_mapping is not None blocks = adapter.component_mapping["blocks"] assert ( "ln2" not in blocks.submodules ), "Cohere parallel block must NOT have ln2 (no post_attention_layernorm)" - # -- Attention submodules -- - def test_attn_is_position_embeddings_attention_bridge( self, adapter: CohereArchitectureAdapter ) -> None: @@ -356,8 +303,6 @@ def test_attn_o_is_linear_bridge(self, adapter: CohereArchitectureAdapter) -> No assert isinstance(attn, PositionEmbeddingsAttentionBridge) assert isinstance(attn.submodules["o"], LinearBridge) - # -- MLP submodules -- - def test_mlp_is_gated_mlp_bridge(self, adapter: CohereArchitectureAdapter) -> None: assert adapter.component_mapping is not None blocks = adapter.component_mapping["blocks"] @@ -404,8 +349,6 @@ def test_mlp_out_is_linear_bridge(self, adapter: CohereArchitectureAdapter) -> N assert isinstance(mlp, GatedMLPBridge) assert isinstance(mlp.submodules["out"], LinearBridge) - # -- Full component_mapping key set -- - def test_all_expected_top_level_keys_present(self, adapter: CohereArchitectureAdapter) -> None: assert adapter.component_mapping is not None expected = {"embed", "rotary_emb", "blocks", "ln_final", "unembed"} @@ -415,13 +358,8 @@ def test_all_expected_top_level_keys_present(self, adapter: CohereArchitectureAd ), f"Unexpected top-level keys: {actual.symmetric_difference(expected)}" -# --------------------------------------------------------------------------- -# Weight processing conversions tests (Phase C) -# --------------------------------------------------------------------------- - - class TestCohereAdapterWeightConversions: - """Verify weight_processing_conversions has the expected GQA-aware Q/K/V/O keys.""" + """weight_processing_conversions has GQA-aware Q/K/V/O keys.""" def test_weight_processing_conversions_not_none( self, adapter: CohereArchitectureAdapter @@ -461,8 +399,7 @@ def test_exact_key_set(self, adapter: CohereArchitectureAdapter) -> None: assert set(adapter.weight_processing_conversions.keys()) == expected def test_gqa_adapter_kv_heads_in_conversions(self) -> None: - # GQA: K/V conversions must carry n=n_kv_heads (2), Q/O must carry n=n_heads (8). - # Verified by inspecting RearrangeTensorConversion.axes_lengths["n"]. + # K/V carry n=n_kv_heads; Q/O carry n=n_heads. from transformer_lens.conversion_utils.conversion_steps.rearrange_tensor_conversion import ( RearrangeTensorConversion, ) @@ -481,19 +418,14 @@ def _n(key: str) -> int: assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) return int(conv.tensor_conversion.axes_lengths["n"]) - assert _n("blocks.{i}.attn.q.weight") == 8 # full n_heads - assert _n("blocks.{i}.attn.k.weight") == 2 # n_kv_heads - assert _n("blocks.{i}.attn.v.weight") == 2 # n_kv_heads - assert _n("blocks.{i}.attn.o.weight") == 8 # full n_heads - - -# --------------------------------------------------------------------------- -# preprocess_weights tests (Phase C) -# --------------------------------------------------------------------------- + assert _n("blocks.{i}.attn.q.weight") == 8 + assert _n("blocks.{i}.attn.k.weight") == 2 + assert _n("blocks.{i}.attn.v.weight") == 2 + assert _n("blocks.{i}.attn.o.weight") == 8 class TestCoherePreprocessWeights: - """Verify preprocess_weights folds logit_scale into unembed.weight.""" + """preprocess_weights folds logit_scale into unembed.weight.""" def _make_state_dict(self, d_model: int = 64, d_vocab: int = 1000) -> dict[str, torch.Tensor]: """Minimal state dict with unembed.weight and unembed.bias.""" @@ -508,7 +440,6 @@ def test_unembed_weight_scaled_by_logit_scale(self) -> None: adapter = CohereArchitectureAdapter(cfg) sd = self._make_state_dict() sd = adapter.preprocess_weights(sd) - # All values in unembed.weight should be 1.0 * 0.5 = 0.5 assert torch.allclose(sd["unembed.weight"], torch.full_like(sd["unembed.weight"], 0.5)) def test_unembed_bias_scaled_by_logit_scale(self) -> None: @@ -520,20 +451,17 @@ def test_unembed_bias_scaled_by_logit_scale(self) -> None: assert torch.allclose(sd["unembed.bias"], torch.full_like(sd["unembed.bias"], 0.5)) def test_embed_weight_unchanged_when_tied(self) -> None: - # Simulate the tied-weight state: both keys reference the same storage. - # bridge.py lines 726-732 clone unembed.weight before calling preprocess_weights, - # so the fold must NOT corrupt embed.weight. This test fails if the fold is - # ever changed to an in-place op like mul_(). + # Bridge clones unembed.weight before calling preprocess_weights; the fold must NOT + # corrupt embed.weight (would fail if fold ever switched to an in-place op). cfg = _make_cfg(logit_scale=0.0625) adapter = CohereArchitectureAdapter(cfg) shared = torch.ones(1000, 64) sd: dict[str, torch.Tensor] = { "embed.weight": shared, - "unembed.weight": shared, # same tensor — genuinely tied + "unembed.weight": shared, } assert sd["embed.weight"].data_ptr() == sd["unembed.weight"].data_ptr() adapter.preprocess_weights(sd) - # embed.weight storage must be unscaled (all 1.0) assert torch.allclose(sd["embed.weight"], torch.ones_like(sd["embed.weight"])) def test_logit_scale_one_is_noop(self) -> None: @@ -545,7 +473,6 @@ def test_logit_scale_one_is_noop(self) -> None: assert torch.allclose(sd["unembed.weight"], original_unembed) def test_missing_unembed_bias_no_error(self) -> None: - # Guard: if unembed.bias is absent, no KeyError should be raised cfg = _make_cfg(logit_scale=0.0625) adapter = CohereArchitectureAdapter(cfg) sd = { @@ -558,8 +485,7 @@ def test_missing_unembed_bias_no_error(self) -> None: ) def test_default_logit_scale_applied(self) -> None: - # Default logit_scale is 0.0625; verify 1.0 input becomes 0.0625 - cfg = _make_cfg(logit_scale=None) # None triggers default 0.0625 + cfg = _make_cfg(logit_scale=None) adapter = CohereArchitectureAdapter(cfg) sd = self._make_state_dict() sd = adapter.preprocess_weights(sd) @@ -586,3 +512,81 @@ def test_returns_state_dict(self) -> None: sd = self._make_state_dict() result = adapter.preprocess_weights(sd) assert isinstance(result, dict) + + +class TestCohereWeightConversionSemantics: + """QKVO conversion entries use the expected types and patterns.""" + + def test_q_conversion_types(self, adapter: CohereArchitectureAdapter) -> None: + conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + + def test_qkv_split_heads_pattern(self, adapter: CohereArchitectureAdapter) -> None: + for slot in ("q", "k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_o_merge_heads_pattern(self, adapter: CohereArchitectureAdapter) -> None: + conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + + +class TestCohereGQASupport: + """n_key_value_heads propagates to K/V conversions only.""" + + def test_no_gqa_falls_back_to_n_heads(self) -> None: + cfg = _make_cfg(n_heads=4, n_key_value_heads=None) + adapter = CohereArchitectureAdapter(cfg) + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_gqa_propagates_to_kv_conversions(self) -> None: + cfg = _make_cfg(n_heads=8, n_key_value_heads=2) + adapter = CohereArchitectureAdapter(cfg) + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert conv.tensor_conversion.axes_lengths["n"] == 2 + + def test_gqa_does_not_change_q_or_o_conversions(self) -> None: + cfg = _make_cfg(n_heads=8, n_key_value_heads=2) + adapter = CohereArchitectureAdapter(cfg) + q_conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert q_conv.tensor_conversion.axes_lengths["n"] == 8 + assert o_conv.tensor_conversion.axes_lengths["n"] == 8 + + +class TestCohereArchitectureGuards: + """Guards against drift toward neighbouring adapter patterns.""" + + def test_no_norm_offset_conversions(self, adapter: CohereArchitectureAdapter) -> None: + # Cohere is not Gemma — no +1 norm offset entries. + for key in adapter.weight_processing_conversions: + assert "ln1" not in key + assert "ln2" not in key + assert "ln_final" not in key + + def test_no_mlp_weight_conversions(self, adapter: CohereArchitectureAdapter) -> None: + for key in adapter.weight_processing_conversions: + assert "mlp" not in key + + def test_block_is_parallel_block_bridge(self, adapter: CohereArchitectureAdapter) -> None: + # Parallel attn+MLP: must be ParallelBlockBridge, NOT sequential BlockBridge. + blocks = adapter.component_mapping["blocks"] + assert isinstance(blocks, ParallelBlockBridge) + + def test_uses_rms_norm_false_guard(self, adapter: CohereArchitectureAdapter) -> None: + # CohereLayerNorm is true mean-subtracting LayerNorm — guard against borrowing Llama's RMS. + assert adapter.cfg.uses_rms_norm is False + assert adapter.cfg.normalization_type == "LN" + assert adapter.cfg.final_rms is False + + def test_block_has_no_ln2(self, adapter: CohereArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert "ln2" not in blocks.submodules diff --git a/tests/unit/model_bridge/supported_architectures/test_gemma3_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gemma3_adapter.py new file mode 100644 index 000000000..a464fb1e7 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_gemma3_adapter.py @@ -0,0 +1,322 @@ +"""Unit tests for Gemma3ArchitectureAdapter (bridge structural). + +Legacy `get_pretrained_model_config` tests live in test_gemma3_config.py. +""" + +import pytest + +from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import ( + ArithmeticTensorConversion, + RearrangeTensorConversion, + TransposeTensorConversion, +) +from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import ( + OperationTypes, +) +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ArchitectureAdapterFactory, +) +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + LinearBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( + PositionEmbeddingsAttentionBridge, +) +from transformer_lens.model_bridge.supported_architectures.gemma3 import ( + Gemma3ArchitectureAdapter, +) + + +def _make_gemma3_cfg(**overrides): + """TransformerBridgeConfig for Gemma3 270M (text-only).""" + defaults = dict( + d_model=640, + d_head=256, + n_heads=4, + n_layers=18, + n_ctx=8192, + d_vocab=262144, + architecture="Gemma3ForCausalLM", + ) + defaults.update(overrides) + return TransformerBridgeConfig(**defaults) + + +class TestGemma3AdapterRegistration: + """Gemma3ArchitectureAdapter registration.""" + + def test_architecture_in_supported_architectures(self): + assert "Gemma3ForCausalLM" in SUPPORTED_ARCHITECTURES + + def test_architecture_maps_to_correct_adapter(self): + assert SUPPORTED_ARCHITECTURES["Gemma3ForCausalLM"] is Gemma3ArchitectureAdapter + + def test_factory_selects_correct_adapter(self): + cfg = _make_gemma3_cfg() + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, Gemma3ArchitectureAdapter) + + +class TestGemma3AdapterConfig: + """Gemma3ArchitectureAdapter cfg attributes.""" + + @pytest.fixture(scope="class") + def adapter(self): + return Gemma3ArchitectureAdapter(_make_gemma3_cfg()) + + def test_gated_mlp(self, adapter): + assert adapter.cfg.gated_mlp is True + + def test_uses_rms_norm(self, adapter): + assert adapter.cfg.uses_rms_norm is True + + def test_normalization_type(self, adapter): + assert adapter.cfg.normalization_type == "RMS" + + def test_rmsnorm_uses_offset(self, adapter): + # Gemma uses (1 + weight); offset must be advertised on cfg. + assert adapter.cfg.rmsnorm_uses_offset is True + + def test_positional_embedding_type(self, adapter): + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_attn_implementation_eager(self, adapter): + # eager required so output_attentions works (hook_attn_scores / hook_pattern). + assert adapter.cfg.attn_implementation == "eager" + + +class TestGemma3ComponentMappingPresence: + """Component slots must exist (deletion guard).""" + + @pytest.fixture(scope="class") + def adapter(self): + return Gemma3ArchitectureAdapter(_make_gemma3_cfg()) + + def test_has_top_level_components(self, adapter): + for name in ("embed", "rotary_emb", "blocks", "ln_final", "unembed"): + assert name in adapter.component_mapping + + def test_no_vision_components(self, adapter): + assert "vision_encoder" not in adapter.component_mapping + assert "vision_projector" not in adapter.component_mapping + + +class TestGemma3ComponentMappingPaths: + """HF module paths for each component slot (refactor-drift guard).""" + + @pytest.fixture(scope="class") + def adapter(self): + return Gemma3ArchitectureAdapter(_make_gemma3_cfg()) + + def test_embed_path(self, adapter): + assert adapter.component_mapping["embed"].name == "model.embed_tokens" + + def test_rotary_emb_path(self, adapter): + assert adapter.component_mapping["rotary_emb"].name == "model.rotary_emb" + + def test_blocks_path(self, adapter): + assert adapter.component_mapping["blocks"].name == "model.layers" + + def test_ln_final_path(self, adapter): + assert adapter.component_mapping["ln_final"].name == "model.norm" + + def test_unembed_path(self, adapter): + assert adapter.component_mapping["unembed"].name == "lm_head" + + +class TestGemma3ComponentTypes: + """Component bridge classes — guards against silent type substitution.""" + + @pytest.fixture(scope="class") + def adapter(self): + return Gemma3ArchitectureAdapter(_make_gemma3_cfg()) + + def test_embed_type(self, adapter): + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_rotary_emb_type(self, adapter): + assert isinstance(adapter.component_mapping["rotary_emb"], RotaryEmbeddingBridge) + + def test_blocks_type(self, adapter): + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + + def test_ln_final_type(self, adapter): + assert isinstance(adapter.component_mapping["ln_final"], RMSNormalizationBridge) + + def test_unembed_type(self, adapter): + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + +class TestGemma3BlockSubmodules: + """BlockBridge wires Gemma3 dual-norm submodules.""" + + @pytest.fixture(scope="class") + def blocks(self): + adapter = Gemma3ArchitectureAdapter(_make_gemma3_cfg()) + return adapter.component_mapping["blocks"] + + def test_block_has_required_submodules(self, blocks): + # Gemma3 has BOTH pre- and post-norms around attention and FFN. + for name in ("ln1", "ln1_post", "ln2", "ln2_post", "attn", "mlp"): + assert name in blocks.submodules, f"BlockBridge missing submodule '{name}'" + + def test_dual_normalization_pre_and_post(self, blocks): + for name in ("ln1", "ln1_post", "ln2", "ln2_post"): + sub = blocks.submodules[name] + assert isinstance(sub, RMSNormalizationBridge) + + def test_ln1_path(self, blocks): + assert blocks.submodules["ln1"].name == "input_layernorm" + + def test_ln1_post_path(self, blocks): + assert blocks.submodules["ln1_post"].name == "post_attention_layernorm" + + def test_ln2_path(self, blocks): + assert blocks.submodules["ln2"].name == "pre_feedforward_layernorm" + + def test_ln2_post_path(self, blocks): + assert blocks.submodules["ln2_post"].name == "post_feedforward_layernorm" + + def test_attn_is_position_embeddings_attention(self, blocks): + attn = blocks.submodules["attn"] + assert isinstance(attn, PositionEmbeddingsAttentionBridge) + assert attn.name == "self_attn" + + def test_attn_qkvo_submodule_paths(self, blocks): + attn = blocks.submodules["attn"] + for sub_name, expected_path in ( + ("q", "q_proj"), + ("k", "k_proj"), + ("v", "v_proj"), + ("o", "o_proj"), + ): + sub = attn.submodules[sub_name] + assert isinstance(sub, LinearBridge) + assert sub.name == expected_path + + def test_attn_has_qk_norm_submodules(self, blocks): + # Gemma3 specifically applies RMSNorm to Q and K inside attention. + attn = blocks.submodules["attn"] + for sub_name in ("q_norm", "k_norm"): + assert sub_name in attn.submodules + sub = attn.submodules[sub_name] + assert isinstance(sub, RMSNormalizationBridge) + assert sub.name == sub_name + + def test_mlp_is_gated(self, blocks): + mlp = blocks.submodules["mlp"] + assert isinstance(mlp, GatedMLPBridge) + assert mlp.name == "mlp" + + def test_mlp_submodule_paths(self, blocks): + mlp = blocks.submodules["mlp"] + for sub_name, expected_path in ( + ("gate", "gate_proj"), + ("in", "up_proj"), + ("out", "down_proj"), + ): + sub = mlp.submodules[sub_name] + assert isinstance(sub, LinearBridge) + assert sub.name == expected_path + + +class TestGemma3GQASupport: + """n_key_value_heads must propagate to K/V conversions only.""" + + def test_no_gqa_when_not_set(self): + # Unset n_key_value_heads leaves K/V axis-length n=None (no coercion to n_heads). + adapter = Gemma3ArchitectureAdapter(_make_gemma3_cfg()) + kv_conv = adapter.weight_processing_conversions["blocks.{i}.attn.k.weight"] + assert kv_conv.tensor_conversion.axes_lengths["n"] is None + q_conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + assert q_conv.tensor_conversion.axes_lengths["n"] == 4 + + def test_gqa_propagates_to_kv_conversions(self): + cfg = _make_gemma3_cfg(n_heads=8, n_key_value_heads=4) + adapter = Gemma3ArchitectureAdapter(cfg) + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert conv.tensor_conversion.axes_lengths["n"] == 4 + + def test_gqa_does_not_change_q_or_o_conversions(self): + cfg = _make_gemma3_cfg(n_heads=8, n_key_value_heads=4) + adapter = Gemma3ArchitectureAdapter(cfg) + # Q and O always use n_heads, regardless of GQA grouping. + q_conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert q_conv.tensor_conversion.axes_lengths["n"] == 8 + assert o_conv.tensor_conversion.axes_lengths["n"] == 8 + + +class TestGemma3WeightProcessingConversions: + """Conversion entries have the right semantics, not just presence.""" + + @pytest.fixture(scope="class") + def adapter(self): + return Gemma3ArchitectureAdapter(_make_gemma3_cfg()) + + def test_qkvo_conversion_classes_and_patterns(self, adapter): + for slot in ("q", "k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert isinstance(o_conv.tensor_conversion, RearrangeTensorConversion) + assert o_conv.tensor_conversion.pattern == "m (n h) -> n h m" + + def test_norm_offset_keys_present(self, adapter): + # Gemma3 needs +1 offset preprocessing for every RMSNorm weight. + for key in ( + "blocks.{i}.ln1.weight", + "blocks.{i}.ln1_post.weight", + "blocks.{i}.ln2.weight", + "blocks.{i}.ln2_post.weight", + "ln_final.weight", + "blocks.{i}.attn.q_norm.weight", + "blocks.{i}.attn.k_norm.weight", + ): + assert key in adapter.weight_processing_conversions, f"missing {key}" + + def test_norm_offset_conversion_semantics(self, adapter): + # Each norm-weight conversion must be ADDITION-by-1.0 (Gemma's +1 trick). + for key in ( + "blocks.{i}.ln1.weight", + "blocks.{i}.ln1_post.weight", + "blocks.{i}.ln2.weight", + "blocks.{i}.ln2_post.weight", + "ln_final.weight", + "blocks.{i}.attn.q_norm.weight", + "blocks.{i}.attn.k_norm.weight", + ): + conv = adapter.weight_processing_conversions[key] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, ArithmeticTensorConversion) + assert conv.tensor_conversion.operation == OperationTypes.ADDITION + assert conv.tensor_conversion.value == 1.0 + + def test_mlp_uses_transpose_conversion(self, adapter): + for slot in ("gate", "in", "out"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.mlp.{slot}.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, TransposeTensorConversion) + + def test_unembed_uses_transpose_conversion(self, adapter): + conv = adapter.weight_processing_conversions["unembed.weight"] + assert isinstance(conv.tensor_conversion, TransposeTensorConversion) + + def test_no_attention_bias_conversions(self, adapter): + # Gemma-3 has bias=None on q/k/v/o_proj — no bias conversion keys expected. + for key in adapter.weight_processing_conversions: + assert not key.endswith(".bias"), f"unexpected bias key {key}" diff --git a/tests/unit/test_gemma3_config.py b/tests/unit/model_bridge/supported_architectures/test_gemma3_config.py similarity index 72% rename from tests/unit/test_gemma3_config.py rename to tests/unit/model_bridge/supported_architectures/test_gemma3_config.py index 445b3c9ac..e00c579f4 100644 --- a/tests/unit/test_gemma3_config.py +++ b/tests/unit/model_bridge/supported_architectures/test_gemma3_config.py @@ -1,11 +1,7 @@ -""" -Unit tests for Gemma 3 and MedGemma model support. +"""Unit tests for Gemma 3 / MedGemma legacy `get_pretrained_model_config` lookup. -Tests cover: -1. Configuration generation for all Gemma 3 model variants -2. Weight conversion from HuggingFace format -3. Hybrid local/global attention configuration -4. Per-layer RoPE base support +Covers registration, config generation, hybrid attention, per-layer RoPE, +and the HookedTransformerConfig rotary_base_local field. """ from unittest import mock @@ -16,10 +12,6 @@ from transformer_lens.loading_from_pretrained import get_pretrained_model_config from transformer_lens.supported_models import OFFICIAL_MODEL_NAMES -# ============================================================================ -# Test Data -# ============================================================================ - GEMMA3_MODELS = [ "google/gemma-3-270m", "google/gemma-3-270m-it", @@ -89,13 +81,8 @@ } -# ============================================================================ -# Test: Model names in official list -# ============================================================================ - - class TestGemma3ModelRegistration: - """Test that all Gemma 3 and MedGemma models are registered in OFFICIAL_MODEL_NAMES.""" + """All Gemma 3 / MedGemma models are listed in OFFICIAL_MODEL_NAMES.""" @pytest.mark.parametrize("model_name", GEMMA3_MODELS) def test_gemma3_models_in_official_list(self, model_name: str): @@ -106,17 +93,11 @@ def test_medgemma_models_in_official_list(self, model_name: str): assert model_name in OFFICIAL_MODEL_NAMES, f"{model_name} should be in OFFICIAL_MODEL_NAMES" -# ============================================================================ -# Test: Configuration generation -# ============================================================================ - - class TestGemma3ConfigGeneration: - """Test that get_pretrained_model_config generates correct configs for Gemma 3.""" + """get_pretrained_model_config generates correct configs for Gemma 3.""" - @pytest.fixture + @pytest.fixture(scope="class") def mock_hf_config(self): - """Create a minimal mock HuggingFace config.""" config = mock.Mock() config.architectures = ["Gemma3ForCausalLM"] return config @@ -131,7 +112,6 @@ def mock_hf_config(self): ], ) def test_gemma3_small_model_config(self, model_name: str, size_key: str, mock_hf_config): - """Test configuration for small Gemma 3 models (270M, 1B).""" with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", return_value=mock_hf_config, @@ -156,7 +136,6 @@ def test_gemma3_small_model_config(self, model_name: str, size_key: str, mock_hf ], ) def test_gemma3_4b_model_config(self, model_name: str, size_key: str, mock_hf_config): - """Test configuration for 4B models (Gemma 3 and MedGemma).""" mock_hf_config.architectures = ["Gemma3ForConditionalGeneration"] with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", @@ -171,22 +150,16 @@ def test_gemma3_4b_model_config(self, model_name: str, size_key: str, mock_hf_co assert cfg.n_key_value_heads == expected["n_key_value_heads"] -# ============================================================================ -# Test: Hybrid attention configuration -# ============================================================================ - - class TestGemma3HybridAttention: - """Test hybrid local/global attention configuration (5:1 pattern).""" + """Hybrid local/global attention (5:1 pattern).""" - @pytest.fixture + @pytest.fixture(scope="class") def mock_hf_config(self): config = mock.Mock() config.architectures = ["Gemma3ForCausalLM"] return config def test_attn_types_pattern_270m(self, mock_hf_config): - """Test 5:1 local/global pattern for 270M (18 layers).""" with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", return_value=mock_hf_config, @@ -197,13 +170,11 @@ def test_attn_types_pattern_270m(self, mock_hf_config): assert cfg.attn_types is not None assert len(cfg.attn_types) == 18 - # Check 5:1 pattern: global at indices 5, 11, 17 for i, attn_type in enumerate(cfg.attn_types): expected = "global" if (i + 1) % 6 == 0 else "local" assert attn_type == expected, f"Layer {i}: expected {expected}, got {attn_type}" def test_attn_types_pattern_1b(self, mock_hf_config): - """Test 5:1 local/global pattern for 1B (26 layers).""" with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", return_value=mock_hf_config, @@ -213,14 +184,12 @@ def test_attn_types_pattern_1b(self, mock_hf_config): assert cfg.use_local_attn is True assert len(cfg.attn_types) == 26 - # Count global layers global_count = cfg.attn_types.count("global") local_count = cfg.attn_types.count("local") - assert global_count == 4 # 26 // 6 = 4 global layers + assert global_count == 4 assert local_count == 22 def test_window_size_small_models(self, mock_hf_config): - """Test that 270M/1B models use 512 token window.""" with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", return_value=mock_hf_config, @@ -229,7 +198,6 @@ def test_window_size_small_models(self, mock_hf_config): assert cfg.window_size == 512 def test_window_size_large_models(self, mock_hf_config): - """Test that 4B+ models use 1024 token window.""" mock_hf_config.architectures = ["Gemma3ForConditionalGeneration"] with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", @@ -239,107 +207,78 @@ def test_window_size_large_models(self, mock_hf_config): assert cfg.window_size == 1024 -# ============================================================================ -# Test: Per-layer RoPE base -# ============================================================================ - - class TestGemma3PerLayerRoPE: - """Test per-layer RoPE base configuration.""" + """Per-layer RoPE base configuration.""" - @pytest.fixture + @pytest.fixture(scope="class") def mock_hf_config(self): config = mock.Mock() config.architectures = ["Gemma3ForCausalLM"] return config def test_rotary_base_global(self, mock_hf_config): - """Test that global attention layers use 1M RoPE base.""" with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", return_value=mock_hf_config, ): cfg = get_pretrained_model_config("google/gemma-3-270m") - assert cfg.rotary_base == 1_000_000 def test_rotary_base_local(self, mock_hf_config): - """Test that local attention layers use 10K RoPE base.""" with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", return_value=mock_hf_config, ): cfg = get_pretrained_model_config("google/gemma-3-270m") - assert cfg.rotary_base_local == 10_000 -# ============================================================================ -# Test: Q/K Normalization -# ============================================================================ - - class TestGemma3QKNorm: - """Test Q/K normalization configuration.""" + """Q/K normalization configuration.""" - @pytest.fixture + @pytest.fixture(scope="class") def mock_hf_config(self): config = mock.Mock() config.architectures = ["Gemma3ForCausalLM"] return config def test_use_qk_norm_enabled(self, mock_hf_config): - """Test that Q/K normalization is enabled for all Gemma 3 models.""" with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", return_value=mock_hf_config, ): cfg = get_pretrained_model_config("google/gemma-3-270m") - assert cfg.use_qk_norm is True -# ============================================================================ -# Test: Normalization before and after -# ============================================================================ - - class TestGemma3Normalization: - """Test Gemma 2/3 style normalization (before and after blocks).""" + """Gemma 2/3 style normalization (before and after blocks).""" - @pytest.fixture + @pytest.fixture(scope="class") def mock_hf_config(self): config = mock.Mock() config.architectures = ["Gemma3ForCausalLM"] return config def test_normalization_before_and_after(self, mock_hf_config): - """Test that use_normalization_before_and_after is enabled.""" with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", return_value=mock_hf_config, ): cfg = get_pretrained_model_config("google/gemma-3-270m") - assert cfg.use_normalization_before_and_after is True -# ============================================================================ -# Test: Vocabulary size -# ============================================================================ - - class TestGemma3VocabSize: - """Test vocabulary size configuration.""" + """Vocabulary size configuration.""" - @pytest.fixture + @pytest.fixture(scope="class") def mock_hf_config(self): config = mock.Mock() config.architectures = ["Gemma3ForCausalLM"] return config def test_vocab_size_small_models(self, mock_hf_config): - """Test vocab size for 270M/1B models.""" with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", return_value=mock_hf_config, @@ -348,7 +287,6 @@ def test_vocab_size_small_models(self, mock_hf_config): assert cfg.d_vocab == 262144 def test_vocab_size_multimodal_models(self, mock_hf_config): - """Test vocab size for 4B+ multimodal models (262208).""" mock_hf_config.architectures = ["Gemma3ForConditionalGeneration"] with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", @@ -358,7 +296,6 @@ def test_vocab_size_multimodal_models(self, mock_hf_config): assert cfg.d_vocab == 262208 def test_vocab_size_medgemma_text_only(self, mock_hf_config): - """Test vocab size for MedGemma 27B text-only variant (262144).""" with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", return_value=mock_hf_config, @@ -367,22 +304,16 @@ def test_vocab_size_medgemma_text_only(self, mock_hf_config): assert cfg.d_vocab == 262144 -# ============================================================================ -# Test: Default context length -# ============================================================================ - - class TestGemma3ContextLength: - """Test default context length configuration.""" + """Default context length configuration.""" - @pytest.fixture + @pytest.fixture(scope="class") def mock_hf_config(self): config = mock.Mock() config.architectures = ["Gemma3ForCausalLM"] return config def test_default_context_length(self, mock_hf_config): - """Test that default n_ctx is 8192 (memory-safe default).""" with mock.patch( "transformer_lens.loading_from_pretrained.AutoConfig.from_pretrained", return_value=mock_hf_config, @@ -391,16 +322,10 @@ def test_default_context_length(self, mock_hf_config): assert cfg.n_ctx == 8192 -# ============================================================================ -# Test: HookedTransformerConfig with rotary_base_local -# ============================================================================ - - class TestHookedTransformerConfigRotaryBaseLocal: - """Test that HookedTransformerConfig supports rotary_base_local.""" + """HookedTransformerConfig supports rotary_base_local.""" def test_rotary_base_local_default_none(self): - """Test that rotary_base_local defaults to None.""" cfg = HookedTransformerConfig( d_model=128, d_head=32, @@ -412,7 +337,6 @@ def test_rotary_base_local_default_none(self): assert cfg.rotary_base_local is None def test_rotary_base_local_can_be_set(self): - """Test that rotary_base_local can be set to a custom value.""" cfg = HookedTransformerConfig( d_model=128, d_head=32, @@ -425,7 +349,6 @@ def test_rotary_base_local_can_be_set(self): assert cfg.rotary_base_local == 10000 def test_rotary_base_and_rotary_base_local_coexist(self): - """Test that both rotary_base and rotary_base_local can be set.""" cfg = HookedTransformerConfig( d_model=128, d_head=32, diff --git a/tests/unit/model_bridge/supported_architectures/test_gemma3_multimodal_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gemma3_multimodal_adapter.py new file mode 100644 index 000000000..6d46c17d8 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_gemma3_multimodal_adapter.py @@ -0,0 +1,347 @@ +"""Unit tests for Gemma3 multimodal architecture adapter registration.""" + +from types import SimpleNamespace + +import pytest + +from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import ( + ArithmeticTensorConversion, + RearrangeTensorConversion, + TransposeTensorConversion, +) +from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import ( + OperationTypes, +) +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ArchitectureAdapterFactory, +) +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + CLIPVisionEncoderBridge, + EmbeddingBridge, + GatedMLPBridge, + LinearBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + SiglipVisionEncoderBridge, + UnembeddingBridge, + VisionProjectionBridge, +) +from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( + PositionEmbeddingsAttentionBridge, +) +from transformer_lens.model_bridge.supported_architectures.gemma3_multimodal import ( + Gemma3MultimodalArchitectureAdapter, +) + + +def _make_gemma3_mm_cfg(with_vision_config: bool = True, **overrides): + """Create a TransformerBridgeConfig for Gemma3 4B multimodal.""" + defaults = dict( + d_model=2560, + d_head=256, + n_heads=8, + n_layers=34, + n_ctx=8192, + d_vocab=262208, + n_key_value_heads=4, + architecture="Gemma3ForConditionalGeneration", + ) + defaults.update(overrides) + cfg = TransformerBridgeConfig(**defaults) + if with_vision_config: + # Gemma3 multimodal pulls vision dims from cfg.vision_config (SigLIP). + cfg.vision_config = SimpleNamespace( + model_type="siglip_vision_model", + hidden_size=1152, + num_hidden_layers=27, + num_attention_heads=16, + ) + return cfg + + +class TestGemma3MultimodalRegistration: + """Test that Gemma3MultimodalArchitectureAdapter is properly registered.""" + + def test_architecture_in_supported_architectures(self): + assert "Gemma3ForConditionalGeneration" in SUPPORTED_ARCHITECTURES + + def test_architecture_maps_to_correct_adapter(self): + assert ( + SUPPORTED_ARCHITECTURES["Gemma3ForConditionalGeneration"] + is Gemma3MultimodalArchitectureAdapter + ) + + def test_factory_selects_correct_adapter(self): + cfg = _make_gemma3_mm_cfg() + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, Gemma3MultimodalArchitectureAdapter) + + +class TestGemma3MultimodalAdapterConfig: + """Test Gemma3MultimodalArchitectureAdapter configuration.""" + + @pytest.fixture(scope="class") + def adapter(self): + cfg = _make_gemma3_mm_cfg() + return Gemma3MultimodalArchitectureAdapter(cfg) + + def test_is_multimodal(self, adapter): + assert adapter.cfg.is_multimodal is True + + def test_gated_mlp(self, adapter): + assert adapter.cfg.gated_mlp is True + + def test_uses_rms_norm(self, adapter): + assert adapter.cfg.uses_rms_norm is True + + def test_normalization_type(self, adapter): + assert adapter.cfg.normalization_type == "RMS" + + def test_rmsnorm_uses_offset(self, adapter): + # Required to keep fold_ln from setting identity to 1.0 (Gemma's +1 trick). + assert adapter.cfg.rmsnorm_uses_offset is True + + def test_positional_embedding_type(self, adapter): + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_attn_implementation_eager(self, adapter): + assert adapter.cfg.attn_implementation == "eager" + + def test_vision_config_extracted(self, adapter): + assert adapter.cfg.vision_hidden_size == 1152 + assert adapter.cfg.vision_num_layers == 27 + assert adapter.cfg.vision_num_heads == 16 + + def test_mm_tokens_per_image_passthrough_when_set(self): + # When cfg.mm_tokens_per_image is set, the adapter passes it through. + cfg = _make_gemma3_mm_cfg() + cfg.mm_tokens_per_image = 128 + adapter = Gemma3MultimodalArchitectureAdapter(cfg) + assert adapter.cfg.mm_tokens_per_image == 128 + + +class TestGemma3MultimodalComponentMappingPresence: + """Top-level component slots must exist (deletion guard).""" + + @pytest.fixture(scope="class") + def adapter(self): + return Gemma3MultimodalArchitectureAdapter(_make_gemma3_mm_cfg()) + + def test_has_vision_components(self, adapter): + assert "vision_encoder" in adapter.component_mapping + assert "vision_projector" in adapter.component_mapping + + def test_has_language_model_components(self, adapter): + for name in ("embed", "rotary_emb", "blocks", "ln_final", "unembed"): + assert name in adapter.component_mapping + + +class TestGemma3MultimodalComponentMappingPaths: + """HF module paths for each component slot (refactor-drift guard).""" + + @pytest.fixture(scope="class") + def adapter(self): + return Gemma3MultimodalArchitectureAdapter(_make_gemma3_mm_cfg()) + + def test_vision_encoder_path(self, adapter): + assert adapter.component_mapping["vision_encoder"].name == "model.vision_tower" + + def test_vision_projector_path(self, adapter): + assert adapter.component_mapping["vision_projector"].name == "model.multi_modal_projector" + + def test_embed_path(self, adapter): + assert adapter.component_mapping["embed"].name == "model.language_model.embed_tokens" + + def test_rotary_emb_path(self, adapter): + assert adapter.component_mapping["rotary_emb"].name == "model.language_model.rotary_emb" + + def test_blocks_path(self, adapter): + assert adapter.component_mapping["blocks"].name == "model.language_model.layers" + + def test_ln_final_path(self, adapter): + assert adapter.component_mapping["ln_final"].name == "model.language_model.norm" + + def test_unembed_path(self, adapter): + assert adapter.component_mapping["unembed"].name == "lm_head" + + +class TestGemma3MultimodalComponentTypes: + """Component bridge classes — guards against silent type substitution.""" + + @pytest.fixture(scope="class") + def adapter(self): + return Gemma3MultimodalArchitectureAdapter(_make_gemma3_mm_cfg()) + + def test_vision_encoder_is_siglip_bridge(self, adapter): + # Gemma3 multimodal hard-wires SigLIP — must NOT be CLIP. + assert isinstance( + adapter.component_mapping["vision_encoder"], SiglipVisionEncoderBridge + ) + assert not isinstance( + adapter.component_mapping["vision_encoder"], CLIPVisionEncoderBridge + ) + + def test_vision_projector_type(self, adapter): + assert isinstance(adapter.component_mapping["vision_projector"], VisionProjectionBridge) + + def test_embed_type(self, adapter): + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_rotary_emb_type(self, adapter): + assert isinstance(adapter.component_mapping["rotary_emb"], RotaryEmbeddingBridge) + + def test_blocks_type(self, adapter): + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + + def test_ln_final_type(self, adapter): + assert isinstance(adapter.component_mapping["ln_final"], RMSNormalizationBridge) + + def test_unembed_type(self, adapter): + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + +class TestGemma3MultimodalBlockSubmodules: + """The BlockBridge must wire Gemma3 dual-norm submodules in the language model.""" + + @pytest.fixture(scope="class") + def blocks(self): + adapter = Gemma3MultimodalArchitectureAdapter(_make_gemma3_mm_cfg()) + return adapter.component_mapping["blocks"] + + def test_block_has_required_submodules(self, blocks): + for name in ("ln1", "ln1_post", "ln2", "ln2_post", "attn", "mlp"): + assert name in blocks.submodules, f"BlockBridge missing submodule '{name}'" + + def test_dual_normalization_pre_and_post(self, blocks): + for name in ("ln1", "ln1_post", "ln2", "ln2_post"): + sub = blocks.submodules[name] + assert isinstance(sub, RMSNormalizationBridge) + + def test_ln_submodule_paths(self, blocks): + assert blocks.submodules["ln1"].name == "input_layernorm" + assert blocks.submodules["ln1_post"].name == "post_attention_layernorm" + assert blocks.submodules["ln2"].name == "pre_feedforward_layernorm" + assert blocks.submodules["ln2_post"].name == "post_feedforward_layernorm" + + def test_attn_is_position_embeddings_attention(self, blocks): + attn = blocks.submodules["attn"] + assert isinstance(attn, PositionEmbeddingsAttentionBridge) + assert attn.name == "self_attn" + + def test_attn_qkvo_submodule_paths(self, blocks): + attn = blocks.submodules["attn"] + for sub_name, expected_path in ( + ("q", "q_proj"), + ("k", "k_proj"), + ("v", "v_proj"), + ("o", "o_proj"), + ): + sub = attn.submodules[sub_name] + assert isinstance(sub, LinearBridge) + assert sub.name == expected_path + + def test_attn_has_qk_norm_submodules(self, blocks): + attn = blocks.submodules["attn"] + for sub_name in ("q_norm", "k_norm"): + assert sub_name in attn.submodules + sub = attn.submodules[sub_name] + assert isinstance(sub, RMSNormalizationBridge) + assert sub.name == sub_name + + def test_mlp_is_gated(self, blocks): + mlp = blocks.submodules["mlp"] + assert isinstance(mlp, GatedMLPBridge) + assert mlp.name == "mlp" + + def test_mlp_submodule_paths(self, blocks): + mlp = blocks.submodules["mlp"] + for sub_name, expected_path in ( + ("gate", "gate_proj"), + ("in", "up_proj"), + ("out", "down_proj"), + ): + sub = mlp.submodules[sub_name] + assert isinstance(sub, LinearBridge) + assert sub.name == expected_path + + +class TestGemma3MultimodalGQASupport: + """GQA variants — n_key_value_heads must propagate to K/V conversions only.""" + + def test_default_4b_has_gqa(self): + # The 4B-style fixture already has n_key_value_heads=4 and n_heads=8. + adapter = Gemma3MultimodalArchitectureAdapter(_make_gemma3_mm_cfg()) + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert conv.tensor_conversion.axes_lengths["n"] == 4 + q_conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert q_conv.tensor_conversion.axes_lengths["n"] == 8 + assert o_conv.tensor_conversion.axes_lengths["n"] == 8 + + def test_no_gqa_falls_back_to_n_heads(self): + # Multimodal adapter uses 'or self.cfg.n_heads' fallback — None coerces. + cfg = _make_gemma3_mm_cfg(n_heads=8, n_key_value_heads=None) + adapter = Gemma3MultimodalArchitectureAdapter(cfg) + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert conv.tensor_conversion.axes_lengths["n"] == 8 + + +class TestGemma3MultimodalWeightProcessingConversions: + """Conversion entries are not just present — they have the right semantics.""" + + @pytest.fixture(scope="class") + def adapter(self): + return Gemma3MultimodalArchitectureAdapter(_make_gemma3_mm_cfg()) + + def test_qkvo_conversion_classes_and_patterns(self, adapter): + for slot in ("q", "k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert o_conv.tensor_conversion.pattern == "m (n h) -> n h m" + + def test_norm_offset_keys_present(self, adapter): + for key in ( + "blocks.{i}.ln1.weight", + "blocks.{i}.ln1_post.weight", + "blocks.{i}.ln2.weight", + "blocks.{i}.ln2_post.weight", + "ln_final.weight", + "blocks.{i}.attn.q_norm.weight", + "blocks.{i}.attn.k_norm.weight", + ): + assert key in adapter.weight_processing_conversions, f"missing {key}" + + def test_norm_offset_conversion_semantics(self, adapter): + for key in ( + "blocks.{i}.ln1.weight", + "blocks.{i}.ln1_post.weight", + "blocks.{i}.ln2.weight", + "blocks.{i}.ln2_post.weight", + "ln_final.weight", + "blocks.{i}.attn.q_norm.weight", + "blocks.{i}.attn.k_norm.weight", + ): + conv = adapter.weight_processing_conversions[key] + assert isinstance(conv.tensor_conversion, ArithmeticTensorConversion) + assert conv.tensor_conversion.operation == OperationTypes.ADDITION + assert conv.tensor_conversion.value == 1.0 + + def test_mlp_uses_transpose_conversion(self, adapter): + for slot in ("gate", "in", "out"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.mlp.{slot}.weight"] + assert isinstance(conv.tensor_conversion, TransposeTensorConversion) + + def test_unembed_uses_transpose_conversion(self, adapter): + conv = adapter.weight_processing_conversions["unembed.weight"] + assert isinstance(conv.tensor_conversion, TransposeTensorConversion) diff --git a/tests/unit/model_bridge/supported_architectures/test_gpt_bigcode_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gpt_bigcode_adapter.py index a5f91d020..4ff80eb70 100644 --- a/tests/unit/model_bridge/supported_architectures/test_gpt_bigcode_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_gpt_bigcode_adapter.py @@ -1,15 +1,4 @@ -"""Unit tests for GPTBigCodeArchitectureAdapter. - -Tests cover: -- Config attribute validation -- Component mapping structure (correct bridge types and HF module paths) -- Weight conversion keys -- MQAQKVConversionRule (Q and K/V branches, revert, passthrough) -- _split_qkv_matrix correctness (shapes, bias, no-bias, value correctness) -- multi_query assertion in _split_qkv_matrix -- End-to-end hook shapes with a fake MQA attention module (no downloads) -- Factory registration -""" +"""Unit tests for GPTBigCodeArchitectureAdapter.""" from typing import Any @@ -18,6 +7,10 @@ import torch.nn as nn from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) from transformer_lens.factories.architecture_adapter_factory import ( SUPPORTED_ARCHITECTURES, ArchitectureAdapterFactory, @@ -66,12 +59,12 @@ def _make_cfg( class FakeMQAAttention(nn.Module): - """Minimal GPTBigCodeAttention-like module for testing (no downloaded weights).""" + """Minimal GPTBigCodeAttention-like module (no downloaded weights).""" def __init__(self, d_model: int, d_head: int, multi_query: bool = True) -> None: super().__init__() self.multi_query = multi_query - # MQA: c_attn output = embed_dim + 2*head_dim + # MQA c_attn out = embed_dim + 2*head_dim; MHA = 3*embed_dim. out_features = d_model + 2 * d_head if multi_query else 3 * d_model self.c_attn = nn.Linear(d_model, out_features) self.c_proj = nn.Linear(d_model, d_model) @@ -80,12 +73,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # pragma: no cover return self.c_proj(x) -@pytest.fixture +@pytest.fixture(scope="class") def cfg() -> TransformerBridgeConfig: return _make_cfg() -@pytest.fixture +@pytest.fixture(scope="class") def adapter(cfg: TransformerBridgeConfig) -> GPTBigCodeArchitectureAdapter: return GPTBigCodeArchitectureAdapter(cfg) @@ -96,8 +89,6 @@ def adapter(cfg: TransformerBridgeConfig) -> GPTBigCodeArchitectureAdapter: class TestGPTBigCodeAdapterConfig: - """Verifies all required config attributes are set correctly.""" - def test_normalization_type_is_ln(self, adapter: GPTBigCodeArchitectureAdapter) -> None: assert adapter.cfg.normalization_type == "LN" @@ -125,8 +116,6 @@ def test_n_key_value_heads_is_one(self, adapter: GPTBigCodeArchitectureAdapter) class TestGPTBigCodeAdapterComponentMapping: - """Verifies component_mapping has the correct bridge types and HF paths.""" - def test_embed_is_embedding_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> None: assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) @@ -220,8 +209,6 @@ def test_unembed_name(self, adapter: GPTBigCodeArchitectureAdapter) -> None: class TestGPTBigCodeAdapterWeightConversions: - """Verifies weight_processing_conversions has expected keys.""" - def test_q_weight_key_present(self, adapter: GPTBigCodeArchitectureAdapter) -> None: assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions @@ -238,20 +225,117 @@ def test_exactly_four_conversion_keys(self, adapter: GPTBigCodeArchitectureAdapt assert len(adapter.weight_processing_conversions) == 4 +class TestGPTBigCodeWeightConversionSemantics: + """MQA pins n=1 on K/V; Q/O stay at n_heads.""" + + @pytest.fixture(scope="class") + def adapter(self) -> GPTBigCodeArchitectureAdapter: + return GPTBigCodeArchitectureAdapter(_make_cfg()) + + def test_q_conversion_type_and_pattern( + self, adapter: GPTBigCodeArchitectureAdapter + ) -> None: + conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_q_n_equals_n_heads(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + @pytest.mark.parametrize("slot", ["k", "v"]) + def test_kv_uses_mqa_n_equals_one( + self, adapter: GPTBigCodeArchitectureAdapter, slot: str + ) -> None: + # MQA: K/V have exactly 1 KV head. + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + assert conv.tensor_conversion.axes_lengths["n"] == 1 + + def test_o_conversion_type_and_pattern( + self, adapter: GPTBigCodeArchitectureAdapter + ) -> None: + conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + +class TestGPTBigCodeCombinedQKVFlags: + """Combined-QKV flags the loader depends on.""" + + @pytest.fixture(scope="class") + def adapter(self) -> GPTBigCodeArchitectureAdapter: + return GPTBigCodeArchitectureAdapter(_make_cfg()) + + def test_uses_combined_qkv_true(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert adapter.uses_combined_qkv is True + + def test_split_attention_weights_true(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert adapter.cfg.split_attention_weights is True + + def test_default_cfg_declares_uses_split_attention( + self, adapter: GPTBigCodeArchitectureAdapter + ) -> None: + # default_cfg records uses_split_attention as descriptor only (base __init__ already ran). + assert adapter.default_cfg.get("uses_split_attention") is True + + +class TestGPTBigCodeArchitectureGuards: + """Learned-pos LayerNorm arch: no rotary, no Gemma offsets.""" + + @pytest.fixture(scope="class") + def adapter(self) -> GPTBigCodeArchitectureAdapter: + return GPTBigCodeArchitectureAdapter(_make_cfg()) + + def test_no_top_level_rotary_emb(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert "rotary_emb" not in adapter.component_mapping + + def test_has_pos_embed_component(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert "pos_embed" in adapter.component_mapping + + def test_no_norm_offset_conversions(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + for key in adapter.weight_processing_conversions: + assert "ln1.weight" not in key + assert "ln2.weight" not in key + assert "ln_final.weight" not in key + + def test_only_qkvo_conversion_keys(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert set(adapter.weight_processing_conversions.keys()) == { + "blocks.{i}.attn.q.weight", + "blocks.{i}.attn.k.weight", + "blocks.{i}.attn.v.weight", + "blocks.{i}.attn.o.weight", + } + + def test_uses_rms_norm_false(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + assert adapter.cfg.uses_rms_norm is False + + def test_eps_attr(self, adapter: GPTBigCodeArchitectureAdapter) -> None: + # GPT-2 family eps (not RMS variance_epsilon). + assert adapter.cfg.eps_attr == "layer_norm_epsilon" + + # --------------------------------------------------------------------------- # MQAQKVConversionRule tests # --------------------------------------------------------------------------- class TestMQAQKVConversionRule: - """Verifies the branching QKV activation rearrangement for MQA.""" + """Branching QKV activation rearrangement for MQA.""" N_HEADS = 4 D_HEAD = 16 D_MODEL = N_HEADS * D_HEAD # 64 BATCH, SEQ = 2, 8 - @pytest.fixture + @pytest.fixture(scope="class") def rule(self) -> MQAQKVConversionRule: return MQAQKVConversionRule(n_heads=self.N_HEADS, d_head=self.D_HEAD) @@ -268,7 +352,7 @@ def test_kv_shaped_input_gives_one_head_dimension(self, rule: MQAQKVConversionRu assert out.shape == (self.BATCH, self.SEQ, 1, self.D_HEAD) def test_4d_input_passes_through_unchanged(self, rule: MQAQKVConversionRule) -> None: - """4D input is already in heads format — return as-is.""" + """4D input is already in heads format; returned as-is.""" x = torch.randn(self.BATCH, self.SEQ, self.N_HEADS, self.D_HEAD) out = rule.handle_conversion(x) assert out.shape == x.shape @@ -307,26 +391,25 @@ def test_invalid_ndim_raises(self, rule: MQAQKVConversionRule) -> None: class TestGPTBigCodeMQASplitQKVMatrix: - """Numerical correctness tests for the MQA asymmetric QKV split.""" + """Numerical tests for the MQA asymmetric QKV split.""" N_HEADS = 4 D_MODEL = 64 D_HEAD = D_MODEL // N_HEADS # 16 BATCH, SEQ = 2, 8 - @pytest.fixture + @pytest.fixture(scope="class") def adapter(self) -> GPTBigCodeArchitectureAdapter: cfg = _make_cfg(n_heads=self.N_HEADS, d_model=self.D_MODEL) return GPTBigCodeArchitectureAdapter(cfg) - @pytest.fixture + @pytest.fixture(scope="class") def fake_attn(self) -> FakeMQAAttention: return FakeMQAAttention(self.D_MODEL, self.D_HEAD, multi_query=True) - @pytest.fixture + @pytest.fixture(scope="class") def fake_attn_nobias(self) -> FakeMQAAttention: attn = FakeMQAAttention(self.D_MODEL, self.D_HEAD, multi_query=True) - # Remove bias from c_attn attn.c_attn = nn.Linear(self.D_MODEL, self.D_MODEL + 2 * self.D_HEAD, bias=False) return attn @@ -388,11 +471,10 @@ def test_no_bias_case_all_none( def test_q_k_v_weights_are_distinct( self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention ) -> None: - """With non-trivial c_attn weight, Q/K/V must differ.""" nn.init.normal_(fake_attn.c_attn.weight) q, k, v = adapter._split_qkv_matrix(fake_attn) - # K and V have the same shape [d_head, d_model] so compare directly - assert not torch.allclose(k.weight, v.weight), "K and V weights must differ" + # K/V share shape [d_head, d_model]; only Q differs in shape, so compare K vs V. + assert not torch.allclose(k.weight, v.weight) def test_q_forward_output_shape( self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention @@ -418,7 +500,7 @@ def test_v_forward_output_shape( def test_weight_values_match_c_attn_rows( self, adapter: GPTBigCodeArchitectureAdapter, fake_attn: FakeMQAAttention ) -> None: - """Q/K/V weight rows must exactly match the corresponding rows of c_attn.weight.""" + """Q/K/V rows match the corresponding c_attn.weight rows exactly.""" nn.init.normal_(fake_attn.c_attn.weight) original_weight = fake_attn.c_attn.weight.detach() q, k, v = adapter._split_qkv_matrix(fake_attn) @@ -429,7 +511,7 @@ def test_weight_values_match_c_attn_rows( def test_multi_query_false_raises_assertion( self, adapter: GPTBigCodeArchitectureAdapter ) -> None: - """Adapter must raise AssertionError for multi_query=False checkpoints.""" + """multi_query=False checkpoints must raise (adapter is MQA-only).""" mha_attn = FakeMQAAttention(self.D_MODEL, self.D_HEAD, multi_query=False) with pytest.raises(AssertionError, match="multi_query=True"): adapter._split_qkv_matrix(mha_attn) @@ -441,11 +523,10 @@ def test_multi_query_false_raises_assertion( class TestGPTBigCodeHookShapes: - """End-to-end forward pass verifying hook_q/hook_k/hook_v shapes. + """End-to-end hook_q/hook_k/hook_v shapes via a fake MQA attention. - Uses a fake MQA attention nn.Module (no model downloads). Registers explicit - hooks on hook_out so that hook_conversion (MQAQKVConversionRule) fires and - the captured tensors reflect the converted shapes. + Explicit hooks on hook_out trigger hook_conversion (MQAQKVConversionRule) + so captured tensors reflect the converted shapes. """ N_HEADS = 4 @@ -453,14 +534,14 @@ class TestGPTBigCodeHookShapes: D_HEAD = D_MODEL // N_HEADS # 16 BATCH, SEQ = 2, 8 - @pytest.fixture + @pytest.fixture(scope="class") def adapter(self) -> GPTBigCodeArchitectureAdapter: cfg = _make_cfg(n_heads=self.N_HEADS, d_model=self.D_MODEL) return GPTBigCodeArchitectureAdapter(cfg) - @pytest.fixture + @pytest.fixture(scope="class") def wired_attn_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> JointQKVAttentionBridge: - """Return attn bridge wired to a fake MQA attention module.""" + """Attn bridge wired to a fake MQA attention module.""" fake_attn = FakeMQAAttention(self.D_MODEL, self.D_HEAD, multi_query=True) blocks = adapter.component_mapping["blocks"] attn_bridge: JointQKVAttentionBridge = blocks.submodules["attn"] # type: ignore[assignment] @@ -470,7 +551,7 @@ def wired_attn_bridge(self, adapter: GPTBigCodeArchitectureAdapter) -> JointQKVA def _run_and_capture( self, attn_bridge: JointQKVAttentionBridge ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Register hooks on q/k/v hook_out, run forward, return captured tensors.""" + """Hook q/k/v hook_out, run forward, return captured tensors.""" captured: dict[str, torch.Tensor] = {} def _capture(name: str) -> Any: @@ -518,8 +599,6 @@ def test_attn_output_shape(self, wired_attn_bridge: JointQKVAttentionBridge) -> class TestGPTBigCodeFactoryRegistration: - """Verifies the factory maps GPTBigCodeForCausalLM to the correct adapter.""" - def test_factory_key_present(self) -> None: assert "GPTBigCodeForCausalLM" in SUPPORTED_ARCHITECTURES diff --git a/tests/unit/model_bridge/supported_architectures/test_internlm2_adapter.py b/tests/unit/model_bridge/supported_architectures/test_internlm2_adapter.py index b6d3d061d..b7ce4c82d 100644 --- a/tests/unit/model_bridge/supported_architectures/test_internlm2_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_internlm2_adapter.py @@ -1,10 +1,4 @@ -"""Unit tests for InternLM2ArchitectureAdapter. - -Tests cover (one class per phase): -- Phase A: Config attributes, weight conversion keys/types, split_wqkv numerics, - preprocess_weights behaviour -- Phase D: Factory registration -""" +"""Unit tests for InternLM2ArchitectureAdapter: cfg, weight conversions, split_wqkv, preprocess, factory.""" from types import SimpleNamespace from typing import Any @@ -23,6 +17,7 @@ EmbeddingBridge, GatedMLPBridge, JointQKVPositionEmbeddingsAttentionBridge, + LinearBridge, RMSNormalizationBridge, UnembeddingBridge, ) @@ -30,11 +25,6 @@ InternLM2ArchitectureAdapter, ) -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - def _make_cfg( n_heads: int = 8, n_key_value_heads: int = 2, @@ -57,12 +47,12 @@ def _make_cfg( ) -@pytest.fixture +@pytest.fixture(scope="class") def cfg() -> TransformerBridgeConfig: return _make_cfg() -@pytest.fixture +@pytest.fixture(scope="class") def adapter(cfg: TransformerBridgeConfig) -> InternLM2ArchitectureAdapter: return InternLM2ArchitectureAdapter(cfg) @@ -89,10 +79,7 @@ def _fill_interleaved( d_model: int, kv_group_vals: list[tuple[float, float, float]], ) -> None: - """Fill wqkv weight with per-kv-group constants for layout verification. - - kv_group_vals: list of (q_val, k_val, v_val) per kv-head group. - """ + """Fill wqkv weight with per-kv-group (q,k,v) constants for layout verification.""" n_kv_groups = n_heads // n_kv_heads gs = n_kv_groups + 2 w = torch.zeros(n_kv_heads, gs, head_dim, d_model) @@ -103,13 +90,8 @@ def _fill_interleaved( wqkv_linear.weight = nn.Parameter(w.reshape((n_heads + 2 * n_kv_heads) * head_dim, d_model)) -# --------------------------------------------------------------------------- -# Phase A — Config attribute tests -# --------------------------------------------------------------------------- - - class TestInternLM2AdapterConfig: - """Adapter must set all required config attributes.""" + """Adapter sets all required config attributes.""" def test_normalization_type(self, adapter: InternLM2ArchitectureAdapter) -> None: assert adapter.cfg.normalization_type == "RMS" @@ -136,29 +118,24 @@ def test_n_key_value_heads_propagated(self, adapter: InternLM2ArchitectureAdapte assert adapter.cfg.n_key_value_heads == 2 def test_supports_fold_ln_false(self, adapter: InternLM2ArchitectureAdapter) -> None: - # Must be False: fold_ln silently skips attn when wqkv is fused in bridge state dict. + # fold_ln silently skips attn when wqkv is fused in bridge state dict. assert adapter.supports_fold_ln is False -# --------------------------------------------------------------------------- -# Phase A — Component mapping structure tests -# --------------------------------------------------------------------------- - - class TestInternLM2AdapterComponentMapping: - """component_mapping must have correct bridge types and InternLM2-specific names.""" + """component_mapping has correct bridge types and InternLM2-specific names.""" def test_embed_is_embedding_bridge(self, adapter: InternLM2ArchitectureAdapter) -> None: assert adapter.component_mapping is not None assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) def test_embed_name(self, adapter: InternLM2ArchitectureAdapter) -> None: - # InternLM2 uses tok_embeddings, not embed_tokens + # InternLM2 uses tok_embeddings, not embed_tokens. assert adapter.component_mapping is not None assert adapter.component_mapping["embed"].name == "model.tok_embeddings" def test_no_top_level_rotary_emb(self, adapter: InternLM2ArchitectureAdapter) -> None: - # Per-layer rotary injected via setup_component_testing, not top-level mapping + # Per-layer rotary, not a top-level component. assert adapter.component_mapping is not None assert "rotary_emb" not in adapter.component_mapping @@ -185,7 +162,7 @@ def test_unembed_is_unembedding_bridge(self, adapter: InternLM2ArchitectureAdapt assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) def test_unembed_name(self, adapter: InternLM2ArchitectureAdapter) -> None: - # InternLM2 uses 'output', not 'lm_head' + # InternLM2 uses 'output', not 'lm_head'. assert adapter.component_mapping is not None assert adapter.component_mapping["unembed"].name == "output" @@ -195,7 +172,7 @@ def test_ln1_is_rms_normalization_bridge(self, adapter: InternLM2ArchitectureAda assert isinstance(blocks.submodules["ln1"], RMSNormalizationBridge) def test_ln1_name(self, adapter: InternLM2ArchitectureAdapter) -> None: - # InternLM2 uses attention_norm, not input_layernorm + # InternLM2 uses attention_norm, not input_layernorm. assert adapter.component_mapping is not None blocks = adapter.component_mapping["blocks"] assert blocks.submodules["ln1"].name == "attention_norm" @@ -206,7 +183,7 @@ def test_ln2_is_rms_normalization_bridge(self, adapter: InternLM2ArchitectureAda assert isinstance(blocks.submodules["ln2"], RMSNormalizationBridge) def test_ln2_name(self, adapter: InternLM2ArchitectureAdapter) -> None: - # InternLM2 uses ffn_norm, not post_attention_layernorm + # InternLM2 uses ffn_norm, not post_attention_layernorm. assert adapter.component_mapping is not None blocks = adapter.component_mapping["blocks"] assert blocks.submodules["ln2"].name == "ffn_norm" @@ -219,7 +196,7 @@ def test_attn_is_joint_qkv_position_embeddings_attention_bridge( assert isinstance(blocks.submodules["attn"], JointQKVPositionEmbeddingsAttentionBridge) def test_attn_name(self, adapter: InternLM2ArchitectureAdapter) -> None: - # InternLM2 uses 'attention', not 'self_attn' + # InternLM2 uses 'attention', not 'self_attn'. assert adapter.component_mapping is not None blocks = adapter.component_mapping["blocks"] assert blocks.submodules["attn"].name == "attention" @@ -240,37 +217,29 @@ def test_mlp_is_gated_mlp_bridge(self, adapter: InternLM2ArchitectureAdapter) -> assert isinstance(blocks.submodules["mlp"], GatedMLPBridge) def test_mlp_name(self, adapter: InternLM2ArchitectureAdapter) -> None: - # InternLM2 uses 'feed_forward', not 'mlp' + # InternLM2 uses 'feed_forward', not 'mlp'. assert adapter.component_mapping is not None blocks = adapter.component_mapping["blocks"] assert blocks.submodules["mlp"].name == "feed_forward" def test_mlp_gate_submodule_name(self, adapter: InternLM2ArchitectureAdapter) -> None: - # w1 = gate projection assert adapter.component_mapping is not None blocks = adapter.component_mapping["blocks"] assert blocks.submodules["mlp"].submodules["gate"].name == "w1" def test_mlp_in_submodule_name(self, adapter: InternLM2ArchitectureAdapter) -> None: - # w3 = up/in projection assert adapter.component_mapping is not None blocks = adapter.component_mapping["blocks"] assert blocks.submodules["mlp"].submodules["in"].name == "w3" def test_mlp_out_submodule_name(self, adapter: InternLM2ArchitectureAdapter) -> None: - # w2 = down/out projection assert adapter.component_mapping is not None blocks = adapter.component_mapping["blocks"] assert blocks.submodules["mlp"].submodules["out"].name == "w2" -# --------------------------------------------------------------------------- -# Phase A — Weight conversion key and type tests -# --------------------------------------------------------------------------- - - class TestInternLM2AdapterWeightConversions: - """weight_processing_conversions must have correct keys, types, and rearrange patterns.""" + """weight_processing_conversions has correct keys, types, and rearrange patterns.""" def test_q_weight_key_present(self, adapter: InternLM2ArchitectureAdapter) -> None: assert adapter.weight_processing_conversions is not None @@ -289,7 +258,6 @@ def test_o_weight_key_present(self, adapter: InternLM2ArchitectureAdapter) -> No assert "blocks.{i}.attn.o.weight" in adapter.weight_processing_conversions def test_exactly_four_conversion_keys(self, adapter: InternLM2ArchitectureAdapter) -> None: - # No bias entries for the bias=False shipped config assert adapter.weight_processing_conversions is not None assert len(adapter.weight_processing_conversions) == 4 @@ -349,18 +317,13 @@ def test_o_rearrange_n_equals_n_heads(self, adapter: InternLM2ArchitectureAdapte assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads def test_no_source_key_on_q(self, adapter: InternLM2ArchitectureAdapter) -> None: - # preprocess_weights writes split keys; no cross-key lookup needed at rearrange time + # preprocess_weights writes split keys; no cross-key lookup needed at rearrange time. assert adapter.weight_processing_conversions is not None conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] assert isinstance(conv, ParamProcessingConversion) assert conv.source_key is None -# --------------------------------------------------------------------------- -# Phase A — _split_internlm2_wqkv numerical tests -# --------------------------------------------------------------------------- - - class TestInternLM2SplitWqkv: """Numerical correctness of the interleaved GQA split function.""" @@ -384,7 +347,6 @@ def test_returns_three_linears(self) -> None: assert isinstance(v, nn.Linear) def test_gqa_shapes(self) -> None: - # n_heads=8, n_kv_heads=2, head_dim=4, d_model=32 adapter = self._adapter(n_heads=8, n_kv_heads=2, d_model=32) attn = _make_attn_component(8, 2, 4, 32) q, k, v = adapter._split_internlm2_wqkv(attn) @@ -393,7 +355,7 @@ def test_gqa_shapes(self) -> None: assert v.weight.shape == (2 * 4, 32) def test_mha_shapes(self) -> None: - # MHA: n_heads == n_kv_heads → gs=3 (standard [Q|K|V]) + # MHA: n_heads == n_kv_heads → gs=3 (standard [Q|K|V]). adapter = self._adapter(n_heads=4, n_kv_heads=4, d_model=32) attn = _make_attn_component(4, 4, 8, 32) q, k, v = adapter._split_internlm2_wqkv(attn) @@ -402,11 +364,9 @@ def test_mha_shapes(self) -> None: assert v.weight.shape == (4 * 8, 32) def test_interleaved_layout_correctness(self) -> None: - # n_heads=4, n_kv_heads=2, head_dim=4, d_model=16 → gs=4 (2 q-groups + k + v) n_heads, n_kv_heads, head_dim, d_model = 4, 2, 4, 16 adapter = self._adapter(n_heads=n_heads, n_kv_heads=n_kv_heads, d_model=d_model) attn = _make_attn_component(n_heads, n_kv_heads, head_dim, d_model) - # kv-group 0: Q=1.0, K=2.0, V=3.0; kv-group 1: Q=4.0, K=5.0, V=6.0 _fill_interleaved( attn.wqkv, n_heads, @@ -417,17 +377,13 @@ def test_interleaved_layout_correctness(self) -> None: ) q, k, v = adapter._split_internlm2_wqkv(attn) - n_kv_groups = n_heads // n_kv_heads # 2 - # Q: rows 0..n_kv_groups*head_dim-1 come from kv-group 0 Q slots (1.0), - # rows n_kv_groups*head_dim..n_heads*head_dim-1 from kv-group 1 Q slots (4.0) - assert torch.all(q.weight[: n_kv_groups * head_dim] == 1.0), "Q group-0 rows should be 1.0" - assert torch.all(q.weight[n_kv_groups * head_dim :] == 4.0), "Q group-1 rows should be 4.0" - # K: row 0..head_dim-1 = kv-group 0 K (2.0), head_dim..2*head_dim-1 = kv-group 1 K (5.0) - assert torch.all(k.weight[:head_dim] == 2.0), "K group-0 rows should be 2.0" - assert torch.all(k.weight[head_dim:] == 5.0), "K group-1 rows should be 5.0" - # V analogous - assert torch.all(v.weight[:head_dim] == 3.0), "V group-0 rows should be 3.0" - assert torch.all(v.weight[head_dim:] == 6.0), "V group-1 rows should be 6.0" + n_kv_groups = n_heads // n_kv_heads + assert torch.all(q.weight[: n_kv_groups * head_dim] == 1.0) + assert torch.all(q.weight[n_kv_groups * head_dim :] == 4.0) + assert torch.all(k.weight[:head_dim] == 2.0) + assert torch.all(k.weight[head_dim:] == 5.0) + assert torch.all(v.weight[:head_dim] == 3.0) + assert torch.all(v.weight[head_dim:] == 6.0) def test_no_bias(self) -> None: adapter = self._adapter() @@ -450,19 +406,17 @@ def test_with_bias_shapes(self) -> None: assert v.bias.shape == (n_kv_heads * head_dim,) def test_with_bias_interleaved_values(self) -> None: - # Verify bias values follow the same interleaved layout as weights n_heads, n_kv_heads, head_dim, d_model = 4, 2, 4, 16 adapter = self._adapter(n_heads=n_heads, n_kv_heads=n_kv_heads, d_model=d_model) attn = _make_attn_component(n_heads, n_kv_heads, head_dim, d_model, has_bias=True) n_kv_groups = n_heads // n_kv_heads gs = n_kv_groups + 2 - # Bias: interleaved [q0_vals, q1_vals, k_val, v_val] per kv-head group b = torch.zeros((n_heads + 2 * n_kv_heads) * head_dim) b_grouped = b.reshape(n_kv_heads, gs, head_dim) - b_grouped[0, :n_kv_groups, :] = 1.0 # kv-group 0 Q bias - b_grouped[0, n_kv_groups, :] = 2.0 # kv-group 0 K bias - b_grouped[0, n_kv_groups + 1, :] = 3.0 # kv-group 0 V bias - b_grouped[1, :n_kv_groups, :] = 4.0 # kv-group 1 Q bias + b_grouped[0, :n_kv_groups, :] = 1.0 + b_grouped[0, n_kv_groups, :] = 2.0 + b_grouped[0, n_kv_groups + 1, :] = 3.0 + b_grouped[1, :n_kv_groups, :] = 4.0 b_grouped[1, n_kv_groups, :] = 5.0 b_grouped[1, n_kv_groups + 1, :] = 6.0 attn.wqkv.bias = nn.Parameter(b_grouped.reshape(-1)) @@ -486,13 +440,8 @@ def test_forward_output_shapes(self) -> None: assert v(x).shape == (2, 5, n_kv_heads * head_dim) -# --------------------------------------------------------------------------- -# Phase A — preprocess_weights tests -# --------------------------------------------------------------------------- - - class TestInternLM2PreprocessWeights: - """preprocess_weights must split fused wqkv and fold layer norms.""" + """preprocess_weights splits fused wqkv and folds layer norms.""" def _make_state_dict_with_fused_qkv( self, @@ -504,7 +453,7 @@ def _make_state_dict_with_fused_qkv( ln1_scale: float = 1.0, qkv_val: float = 1.0, ) -> dict[str, torch.Tensor]: - """Build a bridge-format state dict with fused qkv.weight for each layer.""" + """Bridge-format state dict with fused qkv.weight for each layer.""" n_heads = adapter.cfg.n_heads n_kv_groups = n_heads // n_kv_heads gs = n_kv_groups + 2 @@ -528,7 +477,7 @@ def test_fused_key_removed_and_split_keys_written(self) -> None: result = adapter.preprocess_weights(sd) - assert "blocks.0.attn.qkv.weight" not in result, "fused qkv key must be deleted" + assert "blocks.0.attn.qkv.weight" not in result assert "blocks.0.attn.q.weight" in result assert "blocks.0.attn.k.weight" in result assert "blocks.0.attn.v.weight" in result @@ -546,7 +495,7 @@ def test_split_q_shape(self) -> None: assert result["blocks.0.attn.v.weight"].shape == (2 * 8, 64) def test_ln1_fold_applied_to_q(self) -> None: - """After folding ln1 scale=2.0 into qkv (all 1.0), q/k/v weights should be 2.0.""" + """ln1 scale=2.0 folded into qkv=1.0 → q/k/v weights become 2.0.""" adapter = InternLM2ArchitectureAdapter( _make_cfg(n_heads=8, n_key_value_heads=2, d_model=64) ) @@ -575,7 +524,6 @@ def test_ln2_fold_applied_to_mlp_gate(self) -> None: adapter._fold_ln_requested = True n_kv_heads, head_dim, d_model = 2, 8, 64 sd = self._make_state_dict_with_fused_qkv(adapter, n_kv_heads, head_dim, d_model, 2) - # Override ln2 with scale=3.0 sd["blocks.0.ln2.weight"] = torch.full((d_model,), 3.0) result = adapter.preprocess_weights(sd) assert torch.all(result["blocks.0.mlp.gate.weight"] == 3.0) @@ -609,7 +557,6 @@ def test_no_fold_when_not_requested(self) -> None: adapter, n_kv_heads, head_dim, d_model, 2, ln1_scale=5.0 ) result = adapter.preprocess_weights(sd) - # Fused key must still be present; no splitting or scaling assert "blocks.0.attn.qkv.weight" in result assert "blocks.0.attn.q.weight" not in result @@ -618,16 +565,14 @@ def test_dtype_preserved(self) -> None: adapter._fold_ln_requested = True n_kv_heads, head_dim, d_model = 2, 8, 64 sd = self._make_state_dict_with_fused_qkv(adapter, n_kv_heads, head_dim, d_model, 1) - # Cast to bfloat16 sd = {k: v.to(torch.bfloat16) for k, v in sd.items()} result = adapter.preprocess_weights(sd) assert result["blocks.0.attn.q.weight"].dtype == torch.bfloat16 def test_bias_split_when_present(self) -> None: - """config.bias=True: fused bias must be split into q/k/v bias keys.""" - # Use consistent d_model/n_heads so head_dim = d_model // n_heads = 64 // 4 = 16 + """Fused bias must be split into q/k/v bias keys when config.bias=True.""" n_heads, n_kv_heads, d_model = 4, 2, 64 - head_dim = d_model // n_heads # 16 + head_dim = d_model // n_heads adapter = InternLM2ArchitectureAdapter( _make_cfg(n_heads=n_heads, n_key_value_heads=n_kv_heads, d_model=d_model) ) @@ -653,7 +598,7 @@ def test_bias_split_when_present(self) -> None: assert result["blocks.0.attn.v.bias"].shape == (n_kv_heads * head_dim,) def test_all_layers_processed(self) -> None: - """Verify that all n_layers are processed, not just layer 0.""" + """All n_layers are processed, not just layer 0.""" adapter = InternLM2ArchitectureAdapter(_make_cfg(n_layers=3)) adapter._fold_ln_requested = True n_kv_heads, head_dim, d_model = 2, 8, 64 @@ -664,13 +609,8 @@ def test_all_layers_processed(self) -> None: assert f"blocks.{i}.attn.q.weight" in result -# --------------------------------------------------------------------------- -# Phase D — Factory registration (will pass after Phase D implemented) -# --------------------------------------------------------------------------- - - class TestInternLM2FactoryRegistration: - """Factory must map InternLM2ForCausalLM to InternLM2ArchitectureAdapter.""" + """Factory maps InternLM2ForCausalLM to InternLM2ArchitectureAdapter.""" def test_factory_returns_internlm2_adapter(self) -> None: from transformer_lens.factories.architecture_adapter_factory import ( @@ -689,3 +629,119 @@ def test_factory_key_in_supported_architectures(self) -> None: ) assert "InternLM2ForCausalLM" in SUPPORTED_ARCHITECTURES + + def test_factory_maps_to_correct_class(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert SUPPORTED_ARCHITECTURES["InternLM2ForCausalLM"] is InternLM2ArchitectureAdapter + + +class TestInternLM2ComponentMappingPresence: + """Component slots exist (deletion guard).""" + + def test_required_top_level_keys(self, adapter: InternLM2ArchitectureAdapter) -> None: + # No top-level rotary_emb (per-layer instead). + expected = {"embed", "blocks", "ln_final", "unembed"} + assert set(adapter.component_mapping.keys()) == expected + + def test_block_has_required_submodules(self, adapter: InternLM2ArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + for name in ("ln1", "ln2", "attn", "mlp"): + assert name in blocks.submodules, f"BlockBridge missing submodule '{name}'" + + +class TestInternLM2BlockLinearBridges: + """All attn/mlp submodule projections are LinearBridge instances.""" + + @pytest.fixture(scope="class") + def blocks(self, adapter: InternLM2ArchitectureAdapter) -> BlockBridge: + return adapter.component_mapping["blocks"] + + def test_attn_qkv_is_linear_bridge(self, blocks: BlockBridge) -> None: + attn = blocks.submodules["attn"] + assert isinstance(attn.submodules["qkv"], LinearBridge) + + def test_attn_o_is_linear_bridge(self, blocks: BlockBridge) -> None: + attn = blocks.submodules["attn"] + assert isinstance(attn.submodules["o"], LinearBridge) + + def test_attn_has_fused_qkv_path(self, blocks: BlockBridge) -> None: + # Fused HF projection at name="wqkv"; JointQKV* base adds virtual q/k/v splits on top. + attn = blocks.submodules["attn"] + assert "qkv" in attn.submodules + assert attn.submodules["qkv"].name == "wqkv" + + def test_mlp_gate_is_linear_bridge(self, blocks: BlockBridge) -> None: + mlp = blocks.submodules["mlp"] + assert isinstance(mlp.submodules["gate"], LinearBridge) + + def test_mlp_in_is_linear_bridge(self, blocks: BlockBridge) -> None: + mlp = blocks.submodules["mlp"] + assert isinstance(mlp.submodules["in"], LinearBridge) + + def test_mlp_out_is_linear_bridge(self, blocks: BlockBridge) -> None: + mlp = blocks.submodules["mlp"] + assert isinstance(mlp.submodules["out"], LinearBridge) + + +class TestInternLM2GQASupport: + """GQA propagation through weight_processing_conversions.""" + + def test_no_gqa_falls_back_to_n_heads(self) -> None: + cfg = _make_cfg() + cfg.n_key_value_heads = None + adapter = InternLM2ArchitectureAdapter(cfg) + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_gqa_propagates_to_kv_conversions(self) -> None: + cfg = _make_cfg(n_heads=8, n_key_value_heads=2) + adapter = InternLM2ArchitectureAdapter(cfg) + assert adapter.cfg.n_key_value_heads == 2 + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert conv.tensor_conversion.axes_lengths["n"] == 2 + + def test_gqa_does_not_change_q_or_o_conversions(self) -> None: + cfg = _make_cfg(n_heads=8, n_key_value_heads=2) + adapter = InternLM2ArchitectureAdapter(cfg) + q_conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert q_conv.tensor_conversion.axes_lengths["n"] == 8 + assert o_conv.tensor_conversion.axes_lengths["n"] == 8 + + +class TestInternLM2ArchitectureGuards: + """Guards against drift toward neighbouring adapter patterns.""" + + def test_no_norm_offset_conversions(self, adapter: InternLM2ArchitectureAdapter) -> None: + # InternLM2 is not Gemma — no +1 norm offset entries. + for key in adapter.weight_processing_conversions: + assert "ln1" not in key + assert "ln2" not in key + assert "ln_final" not in key + + def test_no_mlp_weight_conversions(self, adapter: InternLM2ArchitectureAdapter) -> None: + for key in adapter.weight_processing_conversions: + assert "mlp" not in key + + def test_no_top_level_rotary_emb(self, adapter: InternLM2ArchitectureAdapter) -> None: + assert "rotary_emb" not in adapter.component_mapping + + def test_block_uses_block_bridge_not_parallel( + self, adapter: InternLM2ArchitectureAdapter + ) -> None: + # Sequential, not parallel-attn-mlp — guard against borrowing Cohere's pattern. + from transformer_lens.model_bridge.generalized_components import ParallelBlockBridge + + blocks = adapter.component_mapping["blocks"] + assert not isinstance(blocks, ParallelBlockBridge) + assert isinstance(blocks, BlockBridge) + + def test_has_both_ln1_and_ln2(self, adapter: InternLM2ArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert "ln1" in blocks.submodules + assert "ln2" in blocks.submodules diff --git a/tests/unit/model_bridge/supported_architectures/test_llava_adapter.py b/tests/unit/model_bridge/supported_architectures/test_llava_adapter.py new file mode 100644 index 000000000..c797d416c --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_llava_adapter.py @@ -0,0 +1,328 @@ +"""Unit tests for LLava architecture adapter and configuration.""" + +from types import SimpleNamespace + +import pytest + +from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ArchitectureAdapterFactory, +) +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + CLIPVisionEncoderBridge, + EmbeddingBridge, + GatedMLPBridge, + LinearBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + SiglipVisionEncoderBridge, + UnembeddingBridge, + VisionProjectionBridge, +) +from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( + PositionEmbeddingsAttentionBridge, +) +from transformer_lens.model_bridge.supported_architectures.llava import ( + LlavaArchitectureAdapter, +) + + +def _make_llava_cfg(vision_model_type: str = "clip_vision_model", **overrides): + """TransformerBridgeConfig for LLava 1.5 7B.""" + defaults = dict( + d_model=4096, + d_head=128, + n_heads=32, + n_layers=32, + n_ctx=4096, + d_vocab=32064, + architecture="LlavaForConditionalGeneration", + ) + defaults.update(overrides) + cfg = TransformerBridgeConfig(**defaults) + # vision_config matters for vision-bridge selection (CLIP vs SigLIP). + cfg.vision_config = SimpleNamespace( + model_type=vision_model_type, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + ) + return cfg + + +class TestLlavaRegistration: + """LlavaArchitectureAdapter is registered.""" + + def test_architecture_in_supported_architectures(self): + assert "LlavaForConditionalGeneration" in SUPPORTED_ARCHITECTURES + + def test_architecture_maps_to_correct_adapter(self): + assert SUPPORTED_ARCHITECTURES["LlavaForConditionalGeneration"] is LlavaArchitectureAdapter + + def test_factory_selects_correct_adapter(self): + cfg = _make_llava_cfg() + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, LlavaArchitectureAdapter) + + +class TestLlavaAdapterConfig: + """LlavaArchitectureAdapter configuration.""" + + @pytest.fixture(scope="class") + def adapter(self): + cfg = _make_llava_cfg() + return LlavaArchitectureAdapter(cfg) + + def test_is_multimodal(self, adapter): + assert adapter.cfg.is_multimodal is True + + def test_gated_mlp(self, adapter): + assert adapter.cfg.gated_mlp is True + + def test_uses_rms_norm(self, adapter): + assert adapter.cfg.uses_rms_norm is True + + def test_normalization_type(self, adapter): + assert adapter.cfg.normalization_type == "RMS" + + def test_positional_embedding_type(self, adapter): + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_attn_implementation(self, adapter): + assert adapter.cfg.attn_implementation == "eager" + + def test_final_rms_is_true(self, adapter): + assert adapter.cfg.final_rms is True + + def test_attn_only_is_false(self, adapter): + assert adapter.cfg.attn_only is False + + def test_eps_attr(self, adapter): + assert adapter.cfg.eps_attr == "variance_epsilon" + + def test_vision_config_extracted(self, adapter): + assert adapter.cfg.vision_hidden_size == 1024 + assert adapter.cfg.vision_num_layers == 24 + assert adapter.cfg.vision_num_heads == 16 + + +class TestLlavaComponentMappingPresence: + """Required component slots exist.""" + + @pytest.fixture(scope="class") + def adapter(self): + return LlavaArchitectureAdapter(_make_llava_cfg()) + + def test_has_vision_encoder_component(self, adapter): + assert "vision_encoder" in adapter.component_mapping + + def test_has_vision_projector_component(self, adapter): + assert "vision_projector" in adapter.component_mapping + + def test_has_language_model_components(self, adapter): + for name in ("embed", "rotary_emb", "blocks", "ln_final", "unembed"): + assert name in adapter.component_mapping + + +class TestLlavaComponentMappingPaths: + """HF module path for each component slot.""" + + @pytest.fixture(scope="class") + def adapter(self): + return LlavaArchitectureAdapter(_make_llava_cfg()) + + def test_vision_encoder_path(self, adapter): + assert adapter.component_mapping["vision_encoder"].name == "model.vision_tower" + + def test_vision_projector_path(self, adapter): + assert adapter.component_mapping["vision_projector"].name == "model.multi_modal_projector" + + def test_embed_path(self, adapter): + assert adapter.component_mapping["embed"].name == "model.language_model.embed_tokens" + + def test_rotary_emb_path(self, adapter): + assert adapter.component_mapping["rotary_emb"].name == "model.language_model.rotary_emb" + + def test_blocks_path(self, adapter): + assert adapter.component_mapping["blocks"].name == "model.language_model.layers" + + def test_ln_final_path(self, adapter): + assert adapter.component_mapping["ln_final"].name == "model.language_model.norm" + + def test_unembed_path(self, adapter): + assert adapter.component_mapping["unembed"].name == "lm_head" + + +class TestLlavaComponentTypes: + """Component bridge classes — guards against silent type substitution.""" + + @pytest.fixture(scope="class") + def adapter(self): + return LlavaArchitectureAdapter(_make_llava_cfg()) + + def test_vision_encoder_is_clip_bridge(self, adapter): + # vision_model_type='clip_vision_model' must select CLIP, not SigLIP. + assert isinstance(adapter.component_mapping["vision_encoder"], CLIPVisionEncoderBridge) + + def test_vision_projector_type(self, adapter): + assert isinstance(adapter.component_mapping["vision_projector"], VisionProjectionBridge) + + def test_embed_type(self, adapter): + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_rotary_emb_type(self, adapter): + assert isinstance(adapter.component_mapping["rotary_emb"], RotaryEmbeddingBridge) + + def test_blocks_type(self, adapter): + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + + def test_ln_final_type(self, adapter): + assert isinstance(adapter.component_mapping["ln_final"], RMSNormalizationBridge) + + def test_unembed_type(self, adapter): + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + +class TestLlavaSiglipVisionVariant: + """SigLIP vision-tower variants select the SigLIP bridge.""" + + def test_siglip_selects_siglip_bridge(self): + adapter = LlavaArchitectureAdapter(_make_llava_cfg(vision_model_type="siglip_vision_model")) + assert isinstance( + adapter.component_mapping["vision_encoder"], SiglipVisionEncoderBridge + ) + assert not isinstance( + adapter.component_mapping["vision_encoder"], CLIPVisionEncoderBridge + ) + + def test_siglip_short_alias_selects_siglip_bridge(self): + adapter = LlavaArchitectureAdapter(_make_llava_cfg(vision_model_type="siglip")) + assert isinstance( + adapter.component_mapping["vision_encoder"], SiglipVisionEncoderBridge + ) + + +class TestLlavaBlockSubmodules: + """Language-model BlockBridge wires LLaMA-pattern submodules.""" + + @pytest.fixture(scope="class") + def blocks(self): + adapter = LlavaArchitectureAdapter(_make_llava_cfg()) + return adapter.component_mapping["blocks"] + + def test_block_has_required_submodules(self, blocks): + for name in ("ln1", "ln2", "attn", "mlp"): + assert name in blocks.submodules, f"BlockBridge missing submodule '{name}'" + + def test_ln1_is_rms_norm(self, blocks): + ln1 = blocks.submodules["ln1"] + assert isinstance(ln1, RMSNormalizationBridge) + assert ln1.name == "input_layernorm" + + def test_ln2_is_rms_norm(self, blocks): + ln2 = blocks.submodules["ln2"] + assert isinstance(ln2, RMSNormalizationBridge) + assert ln2.name == "post_attention_layernorm" + + def test_attn_is_position_embeddings_attention(self, blocks): + """LLaMA-style RoPE attention requires both mask and position embeddings.""" + attn = blocks.submodules["attn"] + assert isinstance(attn, PositionEmbeddingsAttentionBridge) + assert attn.name == "self_attn" + assert attn.requires_attention_mask is True + assert attn.requires_position_embeddings is True + + def test_attn_qkv_submodule_paths(self, blocks): + attn = blocks.submodules["attn"] + for sub_name, expected_path in ( + ("q", "q_proj"), + ("k", "k_proj"), + ("v", "v_proj"), + ("o", "o_proj"), + ): + sub = attn.submodules[sub_name] + assert isinstance(sub, LinearBridge) + assert sub.name == expected_path + + def test_mlp_is_gated(self, blocks): + mlp = blocks.submodules["mlp"] + assert isinstance(mlp, GatedMLPBridge) + assert mlp.name == "mlp" + + def test_mlp_submodule_paths(self, blocks): + mlp = blocks.submodules["mlp"] + for sub_name, expected_path in ( + ("gate", "gate_proj"), + ("in", "up_proj"), + ("out", "down_proj"), + ): + sub = mlp.submodules[sub_name] + assert isinstance(sub, LinearBridge) + assert sub.name == expected_path + + +class TestLlavaGQASupport: + """GQA: n_key_value_heads affects K/V conversions only.""" + + def test_no_gqa_when_not_set(self): + """Unset n_key_value_heads falls back to n_heads.""" + cfg = _make_llava_cfg() + adapter = LlavaArchitectureAdapter(cfg) + kv_conv = adapter.weight_processing_conversions["blocks.{i}.attn.k.weight"] + assert kv_conv.tensor_conversion.axes_lengths["n"] == 32 + + def test_gqa_propagates_to_kv_conversions(self): + cfg = _make_llava_cfg(n_key_value_heads=8) + adapter = LlavaArchitectureAdapter(cfg) + assert adapter.cfg.n_key_value_heads == 8 + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert conv.tensor_conversion.axes_lengths["n"] == 8 + + def test_gqa_does_not_change_q_or_o_conversions(self): + """Q and O always follow n_heads; GQA only affects K/V.""" + cfg = _make_llava_cfg(n_key_value_heads=8) + adapter = LlavaArchitectureAdapter(cfg) + q_conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert q_conv.tensor_conversion.axes_lengths["n"] == 32 + assert o_conv.tensor_conversion.axes_lengths["n"] == 32 + + +class TestLlavaWeightProcessingConversions: + """Rearrange semantics for QKVO conversion entries.""" + + @pytest.fixture(scope="class") + def adapter(self): + return LlavaArchitectureAdapter(_make_llava_cfg()) + + def test_all_qkvo_keys_exist(self, adapter): + for slot in ("q", "k", "v", "o"): + key = f"blocks.{{i}}.attn.{slot}.weight" + assert key in adapter.weight_processing_conversions + + def test_qkv_conversions_use_split_heads_pattern(self, adapter): + """'(n h) m -> n m h' splits (n_heads * d_head, d_model) into the bridge's (n, d_model, d_head).""" + for slot in ("q", "k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_o_conversion_uses_merge_heads_pattern(self, adapter): + """'m (n h) -> n h m' splits the trailing (n_heads*d_head) dim and moves n to the front.""" + conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + + def test_no_norm_offset_conversions(self, adapter): + """LLaMA-style RMSNorm — no +1 offset like Gemma.""" + for key in adapter.weight_processing_conversions: + assert "ln1" not in key + assert "ln2" not in key + assert "ln_final" not in key diff --git a/tests/unit/model_bridge/supported_architectures/test_mpt_adapter.py b/tests/unit/model_bridge/supported_architectures/test_mpt_adapter.py index 695a6e21b..acffe26a1 100644 --- a/tests/unit/model_bridge/supported_architectures/test_mpt_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_mpt_adapter.py @@ -1,15 +1,4 @@ -"""Unit tests for MPTArchitectureAdapter — Phase A (config + weight conversions), -Phase B-1 (component mapping + QKV split), and Phase D (factory registration). - -Tests cover: -- Config attribute validation (all required attributes set correctly) -- Weight conversion keys (four standard QKVO keys with .weight suffix) -- LayerNorm with bias=None wraps without error (MptBlock sets norm.bias = None) -- Component mapping keys (embed/blocks/ln_final/unembed; no pos_embed/rotary_emb) -- Block/attn/mlp submodule keys -- _split_mpt_qkv: output shapes and round-trip correctness -- Factory resolves MPTForCausalLM -> MPTArchitectureAdapter (no download) -""" +"""Unit tests for MPTArchitectureAdapter.""" import pytest import torch @@ -33,10 +22,7 @@ def _make_cfg( d_vocab: int = 256, n_ctx: int = 128, ) -> TransformerBridgeConfig: - """Return a minimal TransformerBridgeConfig for MPT adapter tests. - - Uses tiny dimensions — no HF Hub download required. - """ + """Minimal TransformerBridgeConfig for MPT adapter tests (no HF Hub download).""" return TransformerBridgeConfig( d_model=d_model, d_head=d_model // n_heads, @@ -50,12 +36,12 @@ def _make_cfg( ) -@pytest.fixture +@pytest.fixture(scope="class") def cfg() -> TransformerBridgeConfig: return _make_cfg() -@pytest.fixture +@pytest.fixture(scope="class") def adapter(cfg: TransformerBridgeConfig) -> MPTArchitectureAdapter: return MPTArchitectureAdapter(cfg) @@ -66,8 +52,6 @@ def adapter(cfg: TransformerBridgeConfig) -> MPTArchitectureAdapter: class TestMPTAdapterConfig: - """Verify all required config attributes are set correctly.""" - def test_normalization_type_is_ln(self, adapter: MPTArchitectureAdapter) -> None: assert adapter.cfg.normalization_type == "LN" @@ -93,8 +77,6 @@ def test_default_prepend_bos_is_false(self, adapter: MPTArchitectureAdapter) -> class TestMPTAdapterWeightConversions: - """Verify weight_processing_conversions has exactly the four QKVO keys.""" - def test_q_weight_key_present(self, adapter: MPTArchitectureAdapter) -> None: assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions @@ -108,12 +90,12 @@ def test_o_weight_key_present(self, adapter: MPTArchitectureAdapter) -> None: assert "blocks.{i}.attn.o.weight" in adapter.weight_processing_conversions def test_exactly_four_conversion_keys(self, adapter: MPTArchitectureAdapter) -> None: - # No MLP conversions — up_proj/down_proj use standard [out, in] layout. + # No MLP conversions: up_proj/down_proj use standard [out, in] layout. assert len(adapter.weight_processing_conversions) == 4 def test_no_mlp_conversion_keys(self, adapter: MPTArchitectureAdapter) -> None: keys = adapter.weight_processing_conversions - assert not any("mlp" in k for k in keys), "MLP weights need no special conversion" + assert not any("mlp" in k for k in keys) # --------------------------------------------------------------------------- @@ -122,21 +104,16 @@ def test_no_mlp_conversion_keys(self, adapter: MPTArchitectureAdapter) -> None: class TestMPTLayerNormBiasNone: - """Verify NormalizationBridge handles MPT's bias=None LayerNorm correctly.""" + """NormalizationBridge handles MPT's bias=None LayerNorm.""" def test_layernorm_bias_none_wraps_without_error(self, cfg: TransformerBridgeConfig) -> None: - """NormalizationBridge must accept and forward through a bias=None LayerNorm. - - MptBlock.__init__ explicitly sets norm_1.bias = None for backward compatibility - with Hub weights. This test front-loads any surprise from that pattern. - """ + """MptBlock.__init__ explicitly sets norm_1.bias = None for Hub-weight compat.""" from transformer_lens.model_bridge.generalized_components import ( NormalizationBridge, ) - # Replicate what MptBlock does: LayerNorm then strip bias ln = nn.LayerNorm(cfg.d_model, eps=1e-5) - ln.bias = None # exactly as MptBlock.__init__ does + ln.bias = None bridge = NormalizationBridge(name="norm_1", config=cfg) bridge.set_original_component(ln) @@ -145,9 +122,9 @@ def test_layernorm_bias_none_wraps_without_error(self, cfg: TransformerBridgeCon with torch.no_grad(): out = bridge(x) - assert out.shape == x.shape, "Output shape must match input shape" - assert not torch.isnan(out).any(), "Output must not contain NaN" - assert not torch.isinf(out).any(), "Output must not contain Inf" + assert out.shape == x.shape + assert not torch.isnan(out).any() + assert not torch.isinf(out).any() # --------------------------------------------------------------------------- @@ -156,14 +133,12 @@ def test_layernorm_bias_none_wraps_without_error(self, cfg: TransformerBridgeCon class TestMPTComponentMappingKeys: - """Verify top-level and nested component mapping keys are correct.""" - def test_top_level_keys_present(self, adapter: MPTArchitectureAdapter) -> None: keys = set(adapter.component_mapping.keys()) assert {"embed", "blocks", "ln_final", "unembed"} <= keys def test_no_pos_embed_key(self, adapter: MPTArchitectureAdapter) -> None: - # ALiBi has no learnable positional embedding module. + # ALiBi: no learnable positional embedding module. assert "pos_embed" not in adapter.component_mapping def test_no_rotary_emb_key(self, adapter: MPTArchitectureAdapter) -> None: @@ -177,7 +152,7 @@ def test_block_submodule_keys(self, adapter: MPTArchitectureAdapter) -> None: def test_attn_submodule_keys(self, adapter: MPTArchitectureAdapter) -> None: attn = adapter.component_mapping["blocks"].submodules["attn"] subkeys = set(attn.submodules.keys()) - # qkv and o are the projection submodules; q/k/v are created during split + # qkv/o are projection submodules; q/k/v are created during split. assert {"qkv", "o"} <= subkeys def test_mlp_submodule_keys(self, adapter: MPTArchitectureAdapter) -> None: @@ -192,15 +167,15 @@ def test_mlp_submodule_keys(self, adapter: MPTArchitectureAdapter) -> None: class TestMPTSplitQKV: - """Verify _split_mpt_qkv correctly decomposes Wqkv [3*d_model, d_model].""" + """_split_mpt_qkv decomposes Wqkv [3*d_model, d_model].""" def _make_fake_attn_component(self, d_model: int) -> object: - """Return a stub object with a Wqkv Linear attribute (no bias, row-concat layout).""" + """Stub with a Wqkv Linear (no bias, row-concat layout).""" class _FakeAttn(nn.Module): def __init__(self) -> None: super().__init__() - # Wqkv: [3*d_model, d_model] — MPT row-wise concat layout + # MPT layout: Wqkv [3*d_model, d_model] row-wise concat. self.Wqkv = nn.Linear(d_model, 3 * d_model, bias=False) return _FakeAttn() @@ -213,31 +188,22 @@ def test_split_returns_three_linears(self, adapter: MPTArchitectureAdapter) -> N assert all(isinstance(lin, nn.Linear) for lin in result) def test_split_output_shapes(self, adapter: MPTArchitectureAdapter) -> None: - """Each output linear must have weight shape [d_model, d_model].""" d_model = adapter.cfg.d_model fake_attn = self._make_fake_attn_component(d_model) q_lin, k_lin, v_lin = adapter._split_mpt_qkv(fake_attn) for lin in (q_lin, k_lin, v_lin): - assert lin.weight.shape == ( - d_model, - d_model, - ), f"Expected ({d_model}, {d_model}), got {lin.weight.shape}" + assert lin.weight.shape == (d_model, d_model) def test_split_roundtrip(self, adapter: MPTArchitectureAdapter) -> None: - """cat([q.weight, k.weight, v.weight], dim=0) must recover original Wqkv.weight. - - Uses batch_size=2 worth of distinct rows to surface any row/col transposition. - """ + """cat([q, k, v], dim=0) must recover original Wqkv (catches row/col transposition).""" d_model = adapter.cfg.d_model fake_attn = self._make_fake_attn_component(d_model) - original_w = fake_attn.Wqkv.weight.detach().clone() # [3*d_model, d_model] + original_w = fake_attn.Wqkv.weight.detach().clone() q_lin, k_lin, v_lin = adapter._split_mpt_qkv(fake_attn) recovered = torch.cat([q_lin.weight, k_lin.weight, v_lin.weight], dim=0) - assert torch.allclose( - recovered, original_w - ), "Round-trip failed: cat(Q,K,V) != original Wqkv" + assert torch.allclose(recovered, original_w) # --------------------------------------------------------------------------- @@ -246,13 +212,7 @@ def test_split_roundtrip(self, adapter: MPTArchitectureAdapter) -> None: class TestMPTFactoryRegistration: - """ArchitectureAdapterFactory must resolve MPTForCausalLM -> MPTArchitectureAdapter.""" - def test_factory_resolves_mpt_architecture(self) -> None: - """Factory returns an MPTArchitectureAdapter instance for MPTForCausalLM. - - Uses a fully programmatic config — no HF Hub download. - """ from transformer_lens.factories.architecture_adapter_factory import ( ArchitectureAdapterFactory, ) @@ -263,7 +223,6 @@ def test_factory_resolves_mpt_architecture(self) -> None: assert isinstance(adapter, MPTArchitectureAdapter) def test_factory_unknown_architecture_raises(self) -> None: - """Factory raises ValueError for an unregistered architecture key.""" from transformer_lens.factories.architecture_adapter_factory import ( ArchitectureAdapterFactory, ) @@ -274,10 +233,261 @@ def test_factory_unknown_architecture_raises(self) -> None: ArchitectureAdapterFactory.select_architecture_adapter(cfg) def test_mpt_in_supported_architectures_dict(self) -> None: - """MPTForCausalLM must appear in the SUPPORTED_ARCHITECTURES mapping.""" from transformer_lens.factories.architecture_adapter_factory import ( SUPPORTED_ARCHITECTURES, ) assert "MPTForCausalLM" in SUPPORTED_ARCHITECTURES assert SUPPORTED_ARCHITECTURES["MPTForCausalLM"] is MPTArchitectureAdapter + + +# --------------------------------------------------------------------------- +# Component-mapping HF paths +# --------------------------------------------------------------------------- + + +class TestMPTComponentMappingPaths: + """HF module paths per component slot (refactor-drift guard).""" + + def test_embed_path(self, adapter: MPTArchitectureAdapter) -> None: + assert adapter.component_mapping["embed"].name == "transformer.wte" + + def test_blocks_path(self, adapter: MPTArchitectureAdapter) -> None: + assert adapter.component_mapping["blocks"].name == "transformer.blocks" + + def test_ln_final_path(self, adapter: MPTArchitectureAdapter) -> None: + assert adapter.component_mapping["ln_final"].name == "transformer.norm_f" + + def test_unembed_path(self, adapter: MPTArchitectureAdapter) -> None: + assert adapter.component_mapping["unembed"].name == "lm_head" + + +# --------------------------------------------------------------------------- +# Component-mapping bridge types +# --------------------------------------------------------------------------- + + +class TestMPTComponentTypes: + """Component bridge classes (guards against silent type substitution).""" + + def test_embed_type(self, adapter: MPTArchitectureAdapter) -> None: + from transformer_lens.model_bridge.generalized_components import EmbeddingBridge + + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_blocks_type(self, adapter: MPTArchitectureAdapter) -> None: + from transformer_lens.model_bridge.generalized_components import BlockBridge + + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + + def test_ln_final_type(self, adapter: MPTArchitectureAdapter) -> None: + # MPT uses LayerNorm (bias=None), not RMSNorm. + from transformer_lens.model_bridge.generalized_components import ( + NormalizationBridge, + RMSNormalizationBridge, + ) + + ln_final = adapter.component_mapping["ln_final"] + assert isinstance(ln_final, NormalizationBridge) + assert not isinstance(ln_final, RMSNormalizationBridge) + + def test_unembed_type(self, adapter: MPTArchitectureAdapter) -> None: + from transformer_lens.model_bridge.generalized_components import UnembeddingBridge + + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + +# --------------------------------------------------------------------------- +# Block-submodule structure (types + HF paths) +# --------------------------------------------------------------------------- + + +class TestMPTBlockSubmoduleStructure: + """Each block submodule has the correct bridge type and HF path.""" + + def test_ln1_is_layernorm_at_norm_1(self, adapter: MPTArchitectureAdapter) -> None: + from transformer_lens.model_bridge.generalized_components import NormalizationBridge + + block = adapter.component_mapping["blocks"] + ln1 = block.submodules["ln1"] + assert isinstance(ln1, NormalizationBridge) + assert ln1.name == "norm_1" + + def test_ln2_is_layernorm_at_norm_2(self, adapter: MPTArchitectureAdapter) -> None: + from transformer_lens.model_bridge.generalized_components import NormalizationBridge + + block = adapter.component_mapping["blocks"] + ln2 = block.submodules["ln2"] + assert isinstance(ln2, NormalizationBridge) + assert ln2.name == "norm_2" + + def test_attn_is_mpt_alibi_attention_at_attn(self, adapter: MPTArchitectureAdapter) -> None: + from transformer_lens.model_bridge.generalized_components.mpt_alibi_attention import ( + MPTALiBiAttentionBridge, + ) + + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn, MPTALiBiAttentionBridge) + assert attn.name == "attn" + + def test_attn_does_not_require_position_embeddings( + self, adapter: MPTArchitectureAdapter + ) -> None: + # ALiBi bakes position into the score bias: no rotary, no learned pos. + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.requires_position_embeddings is False + + def test_attn_does_not_require_attention_mask( + self, adapter: MPTArchitectureAdapter + ) -> None: + # ALiBi bias slope IS the position-aware signal. + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.requires_attention_mask is False + + def test_attn_qkv_submodule_is_joint(self, adapter: MPTArchitectureAdapter) -> None: + # MPT joint-QKV ("Wqkv") wires the joint Linear at the explicit "qkv" slot. + from transformer_lens.model_bridge.generalized_components import LinearBridge + + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert "qkv" in attn.submodules + qkv_sub = attn.submodules["qkv"] + assert isinstance(qkv_sub, LinearBridge) + assert qkv_sub.name == "Wqkv" + + def test_attn_split_qkv_callback_wired(self, adapter: MPTArchitectureAdapter) -> None: + # Bound methods are unwrapped on each access; compare via MethodType attrs. + from types import MethodType + + attn = adapter.component_mapping["blocks"].submodules["attn"] + callback = attn.split_qkv_matrix + assert isinstance(callback, MethodType) + assert callback.__func__ is MPTArchitectureAdapter._split_mpt_qkv + assert callback.__self__ is adapter + + def test_attn_o_submodule(self, adapter: MPTArchitectureAdapter) -> None: + from transformer_lens.model_bridge.generalized_components import LinearBridge + + attn = adapter.component_mapping["blocks"].submodules["attn"] + o_sub = attn.submodules["o"] + assert isinstance(o_sub, LinearBridge) + assert o_sub.name == "out_proj" + + def test_mlp_is_plain_mlp_at_ffn(self, adapter: MPTArchitectureAdapter) -> None: + # MPT MLP is non-gated. + from transformer_lens.model_bridge.generalized_components import ( + GatedMLPBridge, + MLPBridge, + ) + + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert isinstance(mlp, MLPBridge) + assert not isinstance(mlp, GatedMLPBridge) + assert mlp.name == "ffn" + + def test_mlp_submodule_paths(self, adapter: MPTArchitectureAdapter) -> None: + from transformer_lens.model_bridge.generalized_components import LinearBridge + + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + for sub_name, expected_path in (("in", "up_proj"), ("out", "down_proj")): + sub = mlp.submodules[sub_name] + assert isinstance(sub, LinearBridge) + assert sub.name == expected_path + + +# --------------------------------------------------------------------------- +# Weight conversion semantics (patterns + classes) +# --------------------------------------------------------------------------- + + +class TestMPTWeightConversionSemantics: + """Each weight conversion entry uses the expected class and pattern.""" + + def test_qkv_conversion_classes_and_patterns( + self, adapter: MPTArchitectureAdapter + ) -> None: + from transformer_lens.conversion_utils.conversion_steps import ( + RearrangeTensorConversion, + ) + from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, + ) + + for slot in ("q", "k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_o_conversion_class_and_pattern(self, adapter: MPTArchitectureAdapter) -> None: + from transformer_lens.conversion_utils.conversion_steps import ( + RearrangeTensorConversion, + ) + + conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + + def test_no_norm_offset_conversions(self, adapter: MPTArchitectureAdapter) -> None: + # Plain LayerNorm: no +1 trick like Gemma. + for key in adapter.weight_processing_conversions: + assert not key.startswith("blocks.{i}.ln") + assert key != "ln_final.weight" + + +# --------------------------------------------------------------------------- +# MQA / GQA propagation +# --------------------------------------------------------------------------- + + +class TestMPTMQASupport: + """n_key_value_heads must reach K/V conversions (MPT supports MQA).""" + + def test_no_mqa_when_not_set(self) -> None: + # Without n_key_value_heads, K/V default to n_heads. + adapter = MPTArchitectureAdapter(_make_cfg(n_heads=2)) + kv_conv = adapter.weight_processing_conversions["blocks.{i}.attn.k.weight"] + assert kv_conv.tensor_conversion.axes_lengths["n"] == 2 + + def test_mqa_propagates_to_kv_conversions(self) -> None: + # MQA: single KV head shared across all Q heads. + cfg = _make_cfg(n_heads=4) + cfg.n_key_value_heads = 1 + adapter = MPTArchitectureAdapter(cfg) + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert conv.tensor_conversion.axes_lengths["n"] == 1 + + def test_mqa_does_not_change_q_or_o(self) -> None: + cfg = _make_cfg(n_heads=4) + cfg.n_key_value_heads = 1 + adapter = MPTArchitectureAdapter(cfg) + q_conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert q_conv.tensor_conversion.axes_lengths["n"] == 4 + assert o_conv.tensor_conversion.axes_lengths["n"] == 4 + + +# --------------------------------------------------------------------------- +# Architecture-specific guards +# --------------------------------------------------------------------------- + + +class TestMPTArchitectureGuards: + """No rotary, no pos_embed (MPT uses ALiBi).""" + + def test_no_rotary_emb_in_component_mapping( + self, adapter: MPTArchitectureAdapter + ) -> None: + assert "rotary_emb" not in adapter.component_mapping + + def test_no_pos_embed_in_component_mapping( + self, adapter: MPTArchitectureAdapter + ) -> None: + assert "pos_embed" not in adapter.component_mapping + + def test_no_rotary_emb_in_attn_submodules( + self, adapter: MPTArchitectureAdapter + ) -> None: + # ALiBi bias is computed inside the attention bridge: no rotary submodule. + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert "rotary_emb" not in attn.submodules diff --git a/tests/unit/test_qwen3_5_adapter.py b/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py similarity index 71% rename from tests/unit/test_qwen3_5_adapter.py rename to tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py index d1a4a7b6a..52e0ef3f8 100644 --- a/tests/unit/test_qwen3_5_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py @@ -1,14 +1,6 @@ -"""Unit tests for the Qwen3_5 architecture adapter (Phase A+B). +"""Unit tests for the Qwen3_5 architecture adapter. -Tests cover: -1. Registration: adapter importable, in SUPPORTED_ARCHITECTURES, in HF_SUPPORTED_ARCHITECTURES -2. Component mapping: correct bridge hierarchy with only universal submodules (no self_attn), - GatedMLPBridge with gate/in/out LinearBridge submodules -3. Config attributes: all cfg attributes set correctly -4. Weight conversions: preprocess_weights correctly slices q_proj.weight per-head -5. Integration: end-to-end tests with a tiny programmatically-constructed model - -Note: Qwen3_5 is supported only via TransformerBridge, not HookedTransformer. +Qwen3_5 is supported only via TransformerBridge, not HookedTransformer. """ import pytest @@ -26,20 +18,14 @@ except ImportError: _QWEN3_5_AVAILABLE = False -# ============================================================================ -# Test: Registration -# ============================================================================ - - @pytest.mark.skipif( not _QWEN3_5_AVAILABLE, reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers", ) class TestQwen3_5Registration: - """Verify the adapter is properly registered in all lookup tables.""" + """Adapter is registered in all lookup tables.""" def test_adapter_importable(self): - """Qwen3_5ArchitectureAdapter must be importable.""" from transformer_lens.model_bridge.supported_architectures import ( Qwen3_5ArchitectureAdapter, ) @@ -47,15 +33,12 @@ def test_adapter_importable(self): assert Qwen3_5ArchitectureAdapter is not None def test_in_supported_architectures(self): - """Qwen3_5ForCausalLM must be in SUPPORTED_ARCHITECTURES.""" assert "Qwen3_5ForCausalLM" in SUPPORTED_ARCHITECTURES def test_in_hf_supported_architectures(self): - """Qwen3_5ForCausalLM must be in HF_SUPPORTED_ARCHITECTURES.""" assert "Qwen3_5ForCausalLM" in HF_SUPPORTED_ARCHITECTURES def test_adapter_class_correct(self): - """The adapter class must be Qwen3_5ArchitectureAdapter.""" from transformer_lens.model_bridge.supported_architectures import ( Qwen3_5ArchitectureAdapter, ) @@ -63,13 +46,8 @@ def test_adapter_class_correct(self): assert SUPPORTED_ARCHITECTURES["Qwen3_5ForCausalLM"] is Qwen3_5ArchitectureAdapter -# ============================================================================ -# Helpers: TransformerBridgeConfig for adapter instantiation -# ============================================================================ - - def _make_bridge_cfg(**overrides): - """Create a minimal TransformerBridgeConfig for Qwen3_5 adapter tests.""" + """Minimal TransformerBridgeConfig for Qwen3_5 adapter tests.""" from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig defaults = dict( @@ -86,23 +64,12 @@ def _make_bridge_cfg(**overrides): return TransformerBridgeConfig(**defaults) -# ============================================================================ -# Test: Component Mapping (Phase A+B) -# ============================================================================ - - @pytest.mark.skipif( not _QWEN3_5_AVAILABLE, reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers", ) class TestQwen3_5ComponentMapping: - """Verify the component_mapping structure for Qwen3_5. - - The key invariant: self_attn is NOT mapped as a block submodule because - linear-attention layers lack self_attn. Only universally present submodules - (norms, dense MLP) are mapped. Unlike Qwen3Next, the MLP is a GatedMLPBridge - with enumerated gate/in/out LinearBridge submodules. - """ + """self_attn is not a block submodule (absent on linear-attn layers); dense GatedMLP only.""" @pytest.fixture def adapter(self): @@ -113,10 +80,7 @@ def adapter(self): cfg = _make_bridge_cfg() return Qwen3_5ArchitectureAdapter(cfg) - # ---- Top-level keys ---- - def test_component_mapping_keys(self, adapter): - """component_mapping must have exactly the expected top-level keys.""" assert set(adapter.component_mapping.keys()) == { "embed", "rotary_emb", @@ -125,47 +89,36 @@ def test_component_mapping_keys(self, adapter): "unembed", } - # ---- HF path names ---- - def test_embed_path(self, adapter): - """embed maps to model.embed_tokens.""" assert adapter.component_mapping["embed"].name == "model.embed_tokens" def test_rotary_emb_path(self, adapter): - """rotary_emb maps to model.rotary_emb.""" assert adapter.component_mapping["rotary_emb"].name == "model.rotary_emb" def test_blocks_path(self, adapter): - """blocks maps to model.layers.""" assert adapter.component_mapping["blocks"].name == "model.layers" def test_ln_final_path(self, adapter): - """ln_final maps to model.norm.""" assert adapter.component_mapping["ln_final"].name == "model.norm" def test_unembed_path(self, adapter): - """unembed maps to lm_head.""" assert adapter.component_mapping["unembed"].name == "lm_head" - # ---- Block submodules ---- - def test_block_submodules_keys(self, adapter): - """blocks submodules must contain ln1, ln2, mlp, and optional attn + linear_attn.""" submodules = adapter.component_mapping["blocks"].submodules assert set(submodules.keys()) == {"ln1", "ln2", "mlp", "attn", "linear_attn"} def test_attn_is_optional(self, adapter): - """attn must be marked optional (absent on linear-attention layers).""" + """attn is absent on linear-attention layers.""" submodules = adapter.component_mapping["blocks"].submodules assert submodules["attn"].optional is True def test_linear_attn_is_optional(self, adapter): - """linear_attn must be marked optional (absent on full-attention layers).""" + """linear_attn is absent on full-attention layers.""" submodules = adapter.component_mapping["blocks"].submodules assert submodules["linear_attn"].optional is True def test_linear_attn_bridge_type(self, adapter): - """linear_attn must be a GatedDeltaNetBridge.""" from transformer_lens.model_bridge.generalized_components.gated_delta_net import ( GatedDeltaNetBridge, ) @@ -174,51 +127,38 @@ def test_linear_attn_bridge_type(self, adapter): assert isinstance(submodules["linear_attn"], GatedDeltaNetBridge) def test_ln1_path(self, adapter): - """ln1 maps to input_layernorm.""" assert adapter.component_mapping["blocks"].submodules["ln1"].name == "input_layernorm" def test_ln2_path(self, adapter): - """ln2 maps to post_attention_layernorm.""" assert ( adapter.component_mapping["blocks"].submodules["ln2"].name == "post_attention_layernorm" ) def test_mlp_path(self, adapter): - """mlp maps to mlp.""" assert adapter.component_mapping["blocks"].submodules["mlp"].name == "mlp" - # ---- MLP submodules ---- - def test_mlp_submodule_keys(self, adapter): - """mlp submodules must be exactly {gate, in, out}.""" mlp = adapter.component_mapping["blocks"].submodules["mlp"] assert set(mlp.submodules.keys()) == {"gate", "in", "out"} def test_mlp_gate_path(self, adapter): - """mlp.gate maps to gate_proj.""" mlp = adapter.component_mapping["blocks"].submodules["mlp"] assert mlp.submodules["gate"].name == "gate_proj" def test_mlp_in_path(self, adapter): - """mlp.in maps to up_proj.""" mlp = adapter.component_mapping["blocks"].submodules["mlp"] assert mlp.submodules["in"].name == "up_proj" def test_mlp_out_path(self, adapter): - """mlp.out maps to down_proj.""" mlp = adapter.component_mapping["blocks"].submodules["mlp"] assert mlp.submodules["out"].name == "down_proj" - # ---- Bridge types ---- - def test_blocks_bridge_type(self, adapter): - """blocks uses BlockBridge.""" from transformer_lens.model_bridge.generalized_components import BlockBridge assert isinstance(adapter.component_mapping["blocks"], BlockBridge) def test_rotary_emb_bridge_type(self, adapter): - """rotary_emb uses RotaryEmbeddingBridge.""" from transformer_lens.model_bridge.generalized_components import ( RotaryEmbeddingBridge, ) @@ -226,7 +166,6 @@ def test_rotary_emb_bridge_type(self, adapter): assert isinstance(adapter.component_mapping["rotary_emb"], RotaryEmbeddingBridge) def test_ln1_bridge_type(self, adapter): - """ln1 uses RMSNormalizationBridge.""" from transformer_lens.model_bridge.generalized_components import ( RMSNormalizationBridge, ) @@ -235,7 +174,6 @@ def test_ln1_bridge_type(self, adapter): assert isinstance(ln1, RMSNormalizationBridge) def test_ln2_bridge_type(self, adapter): - """ln2 uses RMSNormalizationBridge.""" from transformer_lens.model_bridge.generalized_components import ( RMSNormalizationBridge, ) @@ -244,51 +182,40 @@ def test_ln2_bridge_type(self, adapter): assert isinstance(ln2, RMSNormalizationBridge) def test_mlp_bridge_type(self, adapter): - """mlp uses GatedMLPBridge (dense gated MLP, not MoE).""" from transformer_lens.model_bridge.generalized_components import GatedMLPBridge mlp = adapter.component_mapping["blocks"].submodules["mlp"] assert isinstance(mlp, GatedMLPBridge) def test_mlp_gate_bridge_type(self, adapter): - """mlp.gate uses LinearBridge.""" from transformer_lens.model_bridge.generalized_components import LinearBridge gate = adapter.component_mapping["blocks"].submodules["mlp"].submodules["gate"] assert isinstance(gate, LinearBridge) def test_mlp_in_bridge_type(self, adapter): - """mlp.in uses LinearBridge.""" from transformer_lens.model_bridge.generalized_components import LinearBridge up = adapter.component_mapping["blocks"].submodules["mlp"].submodules["in"] assert isinstance(up, LinearBridge) def test_mlp_out_bridge_type(self, adapter): - """mlp.out uses LinearBridge.""" from transformer_lens.model_bridge.generalized_components import LinearBridge down = adapter.component_mapping["blocks"].submodules["mlp"].submodules["out"] assert isinstance(down, LinearBridge) - # ---- weight_processing_conversions ---- - def test_weight_processing_conversions_empty(self, adapter): - """weight_processing_conversions is empty (no attention submodules mapped).""" + """No attention submodules mapped, so no conversions.""" assert adapter.weight_processing_conversions == {} -# ============================================================================ -# Test: Config Attributes (Phase A+B) -# ============================================================================ - - @pytest.mark.skipif( not _QWEN3_5_AVAILABLE, reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers", ) class TestQwen3_5ConfigAttributes: - """Verify all cfg attributes are set correctly by the adapter.""" + """cfg attributes set by the adapter.""" @pytest.fixture def adapter(self): @@ -321,15 +248,14 @@ def test_default_prepend_bos(self, adapter): assert adapter.cfg.default_prepend_bos is False def test_supports_fold_ln_false(self, adapter): - """supports_fold_ln must be False: hybrid layers break fold_ln.""" + """Hybrid layers break fold_ln.""" assert adapter.supports_fold_ln is False def test_attn_implementation_eager(self, adapter): - """attn_implementation must be 'eager' for output_attentions support.""" + """Eager attention required for output_attentions support.""" assert adapter.cfg.attn_implementation == "eager" def test_n_key_value_heads_set_when_gqa(self): - """n_key_value_heads is set on cfg when the input config has it.""" from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( Qwen3_5ArchitectureAdapter, ) @@ -339,7 +265,6 @@ def test_n_key_value_heads_set_when_gqa(self): assert adapter.cfg.n_key_value_heads == 2 def test_n_key_value_heads_not_set_when_absent(self): - """n_key_value_heads is not set when the config doesn't have it.""" from transformer_lens.config.TransformerBridgeConfig import ( TransformerBridgeConfig, ) @@ -347,7 +272,6 @@ def test_n_key_value_heads_not_set_when_absent(self): Qwen3_5ArchitectureAdapter, ) - # Config without n_key_value_heads cfg = TransformerBridgeConfig( d_model=1024, d_head=256, @@ -358,7 +282,7 @@ def test_n_key_value_heads_not_set_when_absent(self): architecture="Qwen3_5ForCausalLM", ) adapter = Qwen3_5ArchitectureAdapter(cfg) - # n_key_value_heads should equal n_heads (standard MHA default) + # When unset, n_key_value_heads must default to n_heads (standard MHA). assert not ( hasattr(adapter.cfg, "n_key_value_heads") and adapter.cfg.n_key_value_heads is not None @@ -366,26 +290,12 @@ def test_n_key_value_heads_not_set_when_absent(self): ) -# ============================================================================ -# Test: preprocess_weights (Phase A+B) -# ============================================================================ - - @pytest.mark.skipif( not _QWEN3_5_AVAILABLE, reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers", ) class TestQwen3_5PreprocessWeights: - """Verify preprocess_weights correctly slices q_proj.weight per-head. - - Background: In Qwen3_5, q_proj.weight has shape (n_heads * head_dim * 2, hidden_size) - where rows are organized as interleaved per-head pairs: - head_0_query (d_head rows), head_0_gate (d_head rows), - head_1_query (d_head rows), head_1_gate (d_head rows), ... - - A naive first-half slice would be wrong. The correct approach reshapes by - head and takes only the first d_head rows per head (the query half). - """ + """q_proj rows are interleaved per-head (query, gate, query, gate, ...) — naive first-half slice is wrong.""" N_HEADS = 4 D_HEAD = 8 @@ -401,12 +311,11 @@ def adapter(self): n_heads=self.N_HEADS, d_head=self.D_HEAD, d_model=self.HIDDEN_SIZE, - n_key_value_heads=self.N_HEADS, # MHA for simplicity + n_key_value_heads=self.N_HEADS, ) return Qwen3_5ArchitectureAdapter(cfg) def _make_q_proj_weight(self): - """Create a q_proj.weight tensor with distinct per-head-row values.""" import torch total_rows = self.N_HEADS * self.D_HEAD * 2 @@ -416,7 +325,6 @@ def _make_q_proj_weight(self): return w def test_q_proj_output_shape(self, adapter): - """preprocess_weights reduces q_proj rows from n_heads*d_head*2 to n_heads*d_head.""" import torch w = self._make_q_proj_weight() @@ -426,8 +334,6 @@ def test_q_proj_output_shape(self, adapter): assert out.shape == (self.N_HEADS * self.D_HEAD, self.HIDDEN_SIZE) def test_q_proj_selects_query_rows_not_naive_first_half(self, adapter): - """For each head i, output rows [i*d_head:(i+1)*d_head] == input rows - [i*d_head*2 : i*d_head*2 + d_head] (per-head interleaved layout).""" import torch w = self._make_q_proj_weight() @@ -446,7 +352,6 @@ def test_q_proj_selects_query_rows_not_naive_first_half(self, adapter): ) def test_naive_slice_would_be_wrong(self, adapter): - """Naive first-half slice gives different (wrong) results for n_heads > 1.""" import torch w = self._make_q_proj_weight() @@ -462,7 +367,6 @@ def test_naive_slice_would_be_wrong(self, adapter): ) def test_non_q_proj_weights_unchanged(self, adapter): - """k_proj, v_proj, and down_proj weights are NOT modified by preprocess_weights.""" import torch k_proj = torch.randn(self.N_HEADS * self.D_HEAD, self.HIDDEN_SIZE) @@ -476,7 +380,6 @@ def test_non_q_proj_weights_unchanged(self, adapter): assert torch.equal(result["model.layers.0.mlp.down_proj.weight"], down_proj) def test_multiple_layers_all_processed(self, adapter): - """q_proj.weight tensors across multiple layers are all sliced correctly.""" import torch w0 = self._make_q_proj_weight() @@ -491,11 +394,9 @@ def test_multiple_layers_all_processed(self, adapter): assert result["model.layers.3.self_attn.q_proj.weight"].shape == expected_shape def test_empty_state_dict_returns_empty(self, adapter): - """preprocess_weights with an empty state dict returns an empty dict.""" assert adapter.preprocess_weights({}) == {} def test_state_dict_without_q_proj_unchanged(self, adapter): - """A state dict with no q_proj keys is returned unmodified.""" import torch state_dict = {"model.embed_tokens.weight": torch.randn(100, self.HIDDEN_SIZE)} @@ -504,22 +405,152 @@ def test_state_dict_without_q_proj_unchanged(self, adapter): assert set(result.keys()) == original_keys def test_weight_processing_conversions_is_empty_dict(self, adapter): - """weight_processing_conversions is {} — q_proj slicing is in preprocess_weights.""" + """q_proj slicing happens in preprocess_weights, not as a conversion.""" assert adapter.weight_processing_conversions == {} -# ============================================================================ -# Test: Integration (Phase A+B) -# ============================================================================ +@pytest.mark.skipif( + not _QWEN3_5_AVAILABLE, + reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers", +) +class TestQwen3_5ComponentTypes: + """Top-level bridge classes — guards against silent type substitution.""" + @pytest.fixture + def adapter(self): + from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( + Qwen3_5ArchitectureAdapter, + ) -def _make_tiny_hf_model(): - """Create a tiny Qwen3_5ForCausalLM for integration testing. + return Qwen3_5ArchitectureAdapter(_make_bridge_cfg()) - 8 layers: layers 3 and 7 are full-attention (full_attention_interval=4), - layers 0-2 and 4-6 are linear-attention (GatedDeltaNet). - Dense gated MLP on all layers. - """ + def test_embed_is_embedding_bridge(self, adapter): + from transformer_lens.model_bridge.generalized_components import EmbeddingBridge + + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_ln_final_is_rms_norm_bridge(self, adapter): + from transformer_lens.model_bridge.generalized_components import ( + RMSNormalizationBridge, + ) + + assert isinstance(adapter.component_mapping["ln_final"], RMSNormalizationBridge) + + def test_unembed_is_unembedding_bridge(self, adapter): + from transformer_lens.model_bridge.generalized_components import ( + UnembeddingBridge, + ) + + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + +@pytest.mark.skipif( + not _QWEN3_5_AVAILABLE, + reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers", +) +class TestQwen3_5AttnSubmodules: + """Full-attention layers wire Qwen3-pattern submodules; gated q_proj half is pre-sliced.""" + + @pytest.fixture + def attn(self): + from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( + Qwen3_5ArchitectureAdapter, + ) + + adapter = Qwen3_5ArchitectureAdapter(_make_bridge_cfg()) + return adapter.component_mapping["blocks"].submodules["attn"] + + def test_attn_is_position_embeddings_attention(self, attn): + from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( + PositionEmbeddingsAttentionBridge, + ) + + assert isinstance(attn, PositionEmbeddingsAttentionBridge) + + def test_attn_path(self, attn): + assert attn.name == "self_attn" + + def test_attn_qkvo_submodule_paths(self, attn): + from transformer_lens.model_bridge.generalized_components import LinearBridge + + for sub_name, expected_path in ( + ("q", "q_proj"), + ("k", "k_proj"), + ("v", "v_proj"), + ("o", "o_proj"), + ): + sub = attn.submodules[sub_name] + assert isinstance(sub, LinearBridge) + assert sub.name == expected_path + + def test_attn_q_norm_k_norm_present(self, attn): + """Qwen3 family uses per-head Q/K RMSNorm.""" + from transformer_lens.model_bridge.generalized_components import ( + RMSNormalizationBridge, + ) + + assert isinstance(attn.submodules["q_norm"], RMSNormalizationBridge) + assert isinstance(attn.submodules["k_norm"], RMSNormalizationBridge) + assert attn.submodules["q_norm"].name == "q_norm" + assert attn.submodules["k_norm"].name == "k_norm" + + +@pytest.mark.skipif( + not _QWEN3_5_AVAILABLE, + reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers", +) +class TestQwen3_5HybridSpecifics: + """Qwen3.5-specific config invariants.""" + + @pytest.fixture + def adapter(self): + from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( + Qwen3_5ArchitectureAdapter, + ) + + return Qwen3_5ArchitectureAdapter(_make_bridge_cfg()) + + def test_gated_q_proj_flag_set(self, adapter): + """Flag drives preprocess_weights to slice the gated half of q_proj.""" + assert getattr(adapter.cfg, "gated_q_proj", False) is True + + +@pytest.mark.skipif( + not _QWEN3_5_AVAILABLE, + reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers", +) +class TestQwen3_5ArchitectureGuards: + """Guards against drift from Qwen3 conventions.""" + + @pytest.fixture + def adapter(self): + from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( + Qwen3_5ArchitectureAdapter, + ) + + return Qwen3_5ArchitectureAdapter(_make_bridge_cfg()) + + def test_no_norm_offset_conversions(self, adapter): + """LLaMA-style RMSNorm — no +1 offset like Gemma.""" + for key in adapter.weight_processing_conversions: + assert "ln1" not in key + assert "ln2" not in key + assert "ln_final" not in key + + def test_mlp_is_gated_not_moe(self, adapter): + """Dense GatedMLP, not MoE (Qwen3Next has MoE).""" + from transformer_lens.model_bridge.generalized_components import ( + GatedMLPBridge, + MoEBridge, + ) + + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert isinstance(mlp, GatedMLPBridge) + assert not isinstance(mlp, MoEBridge) + + +def _make_tiny_hf_model(): + """Tiny Qwen3_5ForCausalLM: 8 layers, full-attn at 3 and 7 (interval=4), GatedDeltaNet elsewhere.""" cfg = Qwen3_5TextConfig( hidden_size=128, num_hidden_layers=8, @@ -548,7 +579,7 @@ def _make_tiny_hf_model(): def _make_tiny_bridge(): - """Create a Qwen3_5 bridge from a tiny HF model.""" + """Build a Qwen3_5 bridge from a tiny HF model.""" from unittest.mock import MagicMock from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig @@ -578,15 +609,10 @@ def _make_tiny_bridge(): reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers", ) class TestQwen3_5Integration: - """End-to-end integration tests using a tiny programmatic Qwen3_5 model. - - The linear attention layers run via the torch fallback path when - flash-linear-attention / causal-conv1d are not installed. - """ + """End-to-end tests; linear-attn falls back to torch when flash-linear-attention is absent.""" @pytest.fixture(scope="class") def bridge_and_model(self): - """Create a tiny bridge + HF model pair, shared across the class.""" return _make_tiny_bridge() @pytest.fixture(scope="class") @@ -600,21 +626,12 @@ def hf_model(self, bridge_and_model): return hf def test_bridge_creation(self, bridge): - """TransformerBridge construction from a tiny Qwen3_5 model must succeed.""" from transformer_lens.model_bridge import TransformerBridge assert isinstance(bridge, TransformerBridge) def test_hook_names_present(self, bridge): - """Key hook names must be present; blocks.0.attn.* must NOT be present. - - Verified: - - blocks.0.hook_resid_pre: linear-attention layer (layer 0) - - blocks.3.hook_resid_pre: first full-attention layer (layer 3) - - blocks.0.ln1.*: norm present on all layers (universal submodule) - - blocks.0.mlp.*: MLP present on all layers (universal submodule) - - blocks.0.attn.*: NOT present (self_attn absent on linear-attn layers) - """ + """blocks.0.attn.* must NOT appear — self_attn is absent on linear-attn layers.""" hook_keys = set(bridge.hook_dict.keys()) assert "blocks.0.hook_resid_pre" in hook_keys, "linear-attn layer must have hook_resid_pre" @@ -630,7 +647,6 @@ def test_hook_names_present(self, bridge): ), "blocks.0.attn hooks must NOT be present (hybrid architecture)" def test_forward_pass_consistency(self, bridge, hf_model): - """Bridge output logits must match HF model output logits within atol=1e-4.""" import torch tokens = torch.randint(0, 512, (1, 4)) @@ -646,7 +662,6 @@ def test_forward_pass_consistency(self, bridge, hf_model): ), f"Logit mismatch: max diff = {(hf_logits - bridge_logits).abs().max().item():.6f}" def test_hook_activation_shapes(self, bridge): - """A hook on blocks.0.mlp.hook_out must capture a (batch, seq, d_model) tensor.""" import torch captured: list[torch.Tensor] = [] diff --git a/tests/unit/model_bridge/supported_architectures/test_qwen3_moe_adapter.py b/tests/unit/model_bridge/supported_architectures/test_qwen3_moe_adapter.py new file mode 100644 index 000000000..055909f39 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_qwen3_moe_adapter.py @@ -0,0 +1,358 @@ +"""Unit tests for the Qwen3MoeArchitectureAdapter — programmatic configs only, no downloads.""" + +import pytest + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps.rearrange_tensor_conversion import ( + RearrangeTensorConversion, +) +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ArchitectureAdapterFactory, +) +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + LinearBridge, + MoEBridge, + PositionEmbeddingsAttentionBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.qwen3_moe import ( + Qwen3MoeArchitectureAdapter, +) + + +@pytest.fixture(scope="class") +def cfg() -> TransformerBridgeConfig: + return TransformerBridgeConfig( + d_model=64, + d_head=16, + n_layers=2, + n_ctx=128, + n_heads=4, + n_key_value_heads=2, + d_vocab=256, + architecture="Qwen3MoeForCausalLM", + ) + + +@pytest.fixture(scope="class") +def adapter(cfg: TransformerBridgeConfig) -> Qwen3MoeArchitectureAdapter: + return Qwen3MoeArchitectureAdapter(cfg) + + +class TestQwen3MoeAdapterConfig: + def test_normalization_type_is_rms(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "RMS" + + def test_positional_embedding_type_is_rotary( + self, adapter: Qwen3MoeArchitectureAdapter + ) -> None: + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_final_rms_is_true(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + """OLMoE sets final_rms=False; Qwen3MoE must not drift to that.""" + assert adapter.cfg.final_rms is True + + def test_gated_mlp_is_true(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is True + + def test_uses_rms_norm_is_true(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + assert adapter.cfg.uses_rms_norm is True + + def test_attn_implementation_is_eager(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + assert adapter.cfg.attn_implementation == "eager" + + def test_default_prepend_bos_is_false(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + assert adapter.cfg.default_prepend_bos is False + + def test_n_kv_heads_propagated(self) -> None: + cfg = TransformerBridgeConfig( + d_model=64, + d_head=16, + n_layers=2, + n_ctx=128, + n_heads=4, + n_key_value_heads=2, + d_vocab=256, + architecture="Qwen3MoeForCausalLM", + ) + adapter = Qwen3MoeArchitectureAdapter(cfg) + assert adapter.cfg.n_key_value_heads == 2 + + +class TestQwen3MoeWeightConversions: + def test_has_qkvo_keys(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + assert "blocks.{i}.attn.q.weight" in convs + assert "blocks.{i}.attn.k.weight" in convs + assert "blocks.{i}.attn.v.weight" in convs + assert "blocks.{i}.attn.o.weight" in convs + + def test_q_rearrange_uses_n_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + q_conv = convs["blocks.{i}.attn.q.weight"] + assert isinstance(q_conv, ParamProcessingConversion) + assert isinstance(q_conv.tensor_conversion, RearrangeTensorConversion) + axes = q_conv.tensor_conversion.axes_lengths + assert axes.get("n") == 4 + + def test_kv_rearrange_uses_n_kv_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + """GQA: K/V follow n_key_value_heads (2), not n_heads.""" + convs = adapter.weight_processing_conversions + assert convs is not None + k_conv = convs["blocks.{i}.attn.k.weight"] + v_conv = convs["blocks.{i}.attn.v.weight"] + assert isinstance(k_conv, ParamProcessingConversion) + assert isinstance(v_conv, ParamProcessingConversion) + assert isinstance(k_conv.tensor_conversion, RearrangeTensorConversion) + assert isinstance(v_conv.tensor_conversion, RearrangeTensorConversion) + assert k_conv.tensor_conversion.axes_lengths.get("n") == 2 + assert v_conv.tensor_conversion.axes_lengths.get("n") == 2 + + def test_o_rearrange_uses_n_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + o_conv = convs["blocks.{i}.attn.o.weight"] + assert isinstance(o_conv, ParamProcessingConversion) + assert isinstance(o_conv.tensor_conversion, RearrangeTensorConversion) + assert o_conv.tensor_conversion.axes_lengths.get("n") == 4 + + +class TestQwen3MoeComponentMapping: + def test_has_required_top_level_keys(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert mapping is not None + for key in ("embed", "rotary_emb", "blocks", "ln_final", "unembed"): + assert key in mapping, f"Missing top-level key: {key!r}" + + def test_blocks_has_required_submodules(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert mapping is not None + blocks = mapping["blocks"] + for key in ("ln1", "ln2", "attn", "mlp"): + assert key in blocks.submodules, f"Missing blocks submodule: {key!r}" + + def test_attn_has_all_submodules(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert mapping is not None + attn = mapping["blocks"].submodules["attn"] + for key in ("q", "k", "v", "o", "q_norm", "k_norm"): + assert key in attn.submodules, f"Missing attn submodule: {key!r}" + + def test_ln1_ln2_are_rms_norm_bridges(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert mapping is not None + subs = mapping["blocks"].submodules + assert isinstance(subs["ln1"], RMSNormalizationBridge) + assert isinstance(subs["ln2"], RMSNormalizationBridge) + + def test_mlp_is_moe_bridge(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert mapping is not None + mlp = mapping["blocks"].submodules["mlp"] + assert isinstance(mlp, MoEBridge) + + def test_mlp_has_gate_submodule(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert mapping is not None + mlp = mapping["blocks"].submodules["mlp"] + assert "gate" in mlp.submodules + + def test_q_norm_k_norm_are_rms_norm_bridges(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert mapping is not None + attn_subs = mapping["blocks"].submodules["attn"].submodules + assert isinstance(attn_subs["q_norm"], RMSNormalizationBridge) + assert isinstance(attn_subs["k_norm"], RMSNormalizationBridge) + + def test_hf_module_paths(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + mapping = adapter.component_mapping + assert mapping is not None + assert mapping["embed"].name == "model.embed_tokens" + assert mapping["ln_final"].name == "model.norm" + assert mapping["unembed"].name == "lm_head" + assert mapping["blocks"].name == "model.layers" + subs = mapping["blocks"].submodules + assert subs["ln1"].name == "input_layernorm" + assert subs["ln2"].name == "post_attention_layernorm" + assert subs["attn"].name == "self_attn" + assert subs["mlp"].name == "mlp" + + +class TestQwen3MoeFactoryRegistration: + def test_factory_lookup_returns_adapter_class(self) -> None: + assert SUPPORTED_ARCHITECTURES["Qwen3MoeForCausalLM"] is Qwen3MoeArchitectureAdapter + + def test_factory_selects_correct_adapter(self) -> None: + cfg = TransformerBridgeConfig( + d_model=64, + d_head=16, + n_layers=2, + n_ctx=128, + n_heads=4, + n_key_value_heads=2, + d_vocab=256, + architecture="Qwen3MoeForCausalLM", + ) + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, Qwen3MoeArchitectureAdapter) + + +class TestQwen3MoeComponentTypes: + """Top-level bridge classes — guards against silent type substitution.""" + + def test_embed_is_embedding_bridge(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_rotary_emb_is_rotary_bridge(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["rotary_emb"], RotaryEmbeddingBridge) + + def test_blocks_is_block_bridge(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + + def test_ln_final_is_rms_norm_bridge(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["ln_final"], RMSNormalizationBridge) + + def test_unembed_is_unembedding_bridge(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + +class TestQwen3MoeBlockSubmodules: + """BlockBridge submodule types and HF paths.""" + + def test_attn_is_position_embeddings_attention( + self, adapter: Qwen3MoeArchitectureAdapter + ) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert isinstance(attn, PositionEmbeddingsAttentionBridge) + + def test_attn_requires_attention_mask_and_position_embeddings( + self, adapter: Qwen3MoeArchitectureAdapter + ) -> None: + """RoPE attention requires both an attention mask and position embeddings.""" + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.requires_attention_mask is True + assert attn.requires_position_embeddings is True + + def test_attn_qkvo_submodule_types_and_paths( + self, adapter: Qwen3MoeArchitectureAdapter + ) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + for sub_name, expected_path in ( + ("q", "q_proj"), + ("k", "k_proj"), + ("v", "v_proj"), + ("o", "o_proj"), + ): + sub = attn.submodules[sub_name] + assert isinstance(sub, LinearBridge) + assert sub.name == expected_path + + def test_attn_q_norm_k_norm_paths(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + """Per-head Q/K-norm RMSNorm.""" + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert attn.submodules["q_norm"].name == "q_norm" + assert attn.submodules["k_norm"].name == "k_norm" + + def test_mlp_gate_submodule_type(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + """Router is a LinearBridge so the routing logits can be hooked.""" + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert isinstance(mlp.submodules["gate"], LinearBridge) + + +class TestQwen3MoeWeightConversionPatterns: + """Rearrange patterns on weight conversions.""" + + def test_qkv_pattern_is_split_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + for slot in ("q", "k", "v"): + conv = convs[f"blocks.{{i}}.attn.{slot}.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_o_pattern_is_merge_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + + +class TestQwen3MoeGQA: + """GQA: K/V follow n_key_value_heads; Q/O always follow n_heads.""" + + def test_no_gqa_fallback_to_n_heads(self) -> None: + """Without n_key_value_heads, K/V fall back to n_heads.""" + cfg = TransformerBridgeConfig( + d_model=64, + d_head=16, + n_layers=2, + n_ctx=128, + n_heads=4, + d_vocab=256, + architecture="Qwen3MoeForCausalLM", + ) + adapter = Qwen3MoeArchitectureAdapter(cfg) + for slot in ("k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert conv.tensor_conversion.axes_lengths["n"] == 4 + + def test_gqa_does_not_affect_q_or_o(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + q_conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + o_conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert q_conv.tensor_conversion.axes_lengths["n"] == 4 + assert o_conv.tensor_conversion.axes_lengths["n"] == 4 + + +class TestQwen3MoeMoEStructure: + """MoE structural invariants distinguishing Qwen3MoE from dense Qwen3.""" + + def test_mlp_is_moe_not_gated_mlp(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert isinstance(mlp, MoEBridge) + assert not isinstance(mlp, GatedMLPBridge) + + def test_mlp_has_only_gate_submodule(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + """Experts are batched 3D tensors inside the MoE block — only the router is mapped.""" + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert set(mlp.submodules.keys()) == {"gate"} + + +class TestQwen3MoeArchitectureGuards: + """Guards against drift from Qwen3 conventions.""" + + def test_no_norm_offset_conversions(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + """LLaMA-style RMSNorm — no +1 offset like Gemma.""" + for key in adapter.weight_processing_conversions: + assert "ln1" not in key + assert "ln2" not in key + assert "ln_final" not in key + + def test_weight_conversions_are_only_qkvo( + self, adapter: Qwen3MoeArchitectureAdapter + ) -> None: + """Expert/gate weights pass through untouched.""" + assert set(adapter.weight_processing_conversions.keys()) == { + "blocks.{i}.attn.q.weight", + "blocks.{i}.attn.k.weight", + "blocks.{i}.attn.v.weight", + "blocks.{i}.attn.o.weight", + } + + def test_attn_is_not_optional(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + """Non-hybrid: every layer has self_attn.""" + attn = adapter.component_mapping["blocks"].submodules["attn"] + assert getattr(attn, "optional", False) is False + + def test_no_linear_attn_submodule(self, adapter: Qwen3MoeArchitectureAdapter) -> None: + """Non-hybrid: no GatedDeltaNet linear-attention submodule.""" + submodules = adapter.component_mapping["blocks"].submodules + assert "linear_attn" not in submodules diff --git a/tests/unit/test_qwen3_next_adapter.py b/tests/unit/model_bridge/supported_architectures/test_qwen3_next_adapter.py similarity index 63% rename from tests/unit/test_qwen3_next_adapter.py rename to tests/unit/model_bridge/supported_architectures/test_qwen3_next_adapter.py index 516d7a8b5..9829ac6a6 100644 --- a/tests/unit/test_qwen3_next_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_qwen3_next_adapter.py @@ -1,15 +1,8 @@ -"""Unit tests for the Qwen3Next architecture adapter (Phases A through D). - -Tests cover: -1. Registration: adapter importable, in SUPPORTED_ARCHITECTURES, in HF_SUPPORTED_ARCHITECTURES -2. Component mapping: correct bridge hierarchy with only universal submodules (no self_attn) -3. Weight conversions: preprocess_weights correctly slices q_proj.weight per-head -4. Integration: end-to-end tests with a tiny programmatically-constructed model - -Note: Qwen3Next is supported only via TransformerBridge, not HookedTransformer. -No tests exercise convert_hf_model_config here — the TransformerBridge path -reads the HF config directly via the adapter and does not go through -transformer_lens.loading_from_pretrained. +"""Unit tests for the Qwen3Next architecture adapter. + +Qwen3Next is supported only via TransformerBridge, not HookedTransformer. +The bridge reads HF config directly via the adapter and bypasses +transformer_lens.loading_from_pretrained, so no convert_hf_model_config tests here. """ import pytest @@ -19,16 +12,10 @@ ) from transformer_lens.tools.model_registry import HF_SUPPORTED_ARCHITECTURES -# ============================================================================ -# Test: Registration -# ============================================================================ - - class TestQwen3NextRegistration: - """Verify the adapter is properly registered in all lookup tables.""" + """Adapter is registered in all lookup tables.""" def test_adapter_importable(self): - """Qwen3NextArchitectureAdapter must be importable.""" from transformer_lens.model_bridge.supported_architectures import ( Qwen3NextArchitectureAdapter, ) @@ -36,15 +23,12 @@ def test_adapter_importable(self): assert Qwen3NextArchitectureAdapter is not None def test_in_supported_architectures(self): - """Qwen3NextForCausalLM must be in SUPPORTED_ARCHITECTURES.""" assert "Qwen3NextForCausalLM" in SUPPORTED_ARCHITECTURES def test_in_hf_supported_architectures(self): - """Qwen3NextForCausalLM must be in HF_SUPPORTED_ARCHITECTURES.""" assert "Qwen3NextForCausalLM" in HF_SUPPORTED_ARCHITECTURES def test_adapter_class_correct(self): - """The adapter class must be Qwen3NextArchitectureAdapter.""" from transformer_lens.model_bridge.supported_architectures import ( Qwen3NextArchitectureAdapter, ) @@ -52,13 +36,8 @@ def test_adapter_class_correct(self): assert SUPPORTED_ARCHITECTURES["Qwen3NextForCausalLM"] is Qwen3NextArchitectureAdapter -# ============================================================================ -# Helpers: TransformerBridgeConfig for adapter instantiation -# ============================================================================ - - def _make_bridge_cfg(**overrides): - """Create a minimal TransformerBridgeConfig for Qwen3Next adapter tests.""" + """Minimal TransformerBridgeConfig for Qwen3Next adapter tests.""" from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig defaults = dict( @@ -75,19 +54,8 @@ def _make_bridge_cfg(**overrides): return TransformerBridgeConfig(**defaults) -# ============================================================================ -# Test: Component Mapping (Phase B) -# ============================================================================ - - class TestQwen3NextComponentMapping: - """Verify the component_mapping structure for Qwen3Next. - - The key invariant: self_attn is NOT mapped as a block submodule because - linear-attention layers lack self_attn, and get_remote_component raises - AttributeError for missing attributes (verified in architecture_adapter.py). - Only universally present submodules (norms, MLP) are mapped. - """ + """self_attn is not a block submodule (absent on linear-attn layers); only universal subs mapped.""" @pytest.fixture def adapter(self): @@ -98,10 +66,7 @@ def adapter(self): cfg = _make_bridge_cfg() return Qwen3NextArchitectureAdapter(cfg) - # ---- Top-level keys ---- - def test_component_mapping_keys(self, adapter): - """component_mapping must have exactly the expected top-level keys.""" assert set(adapter.component_mapping.keys()) == { "embed", "rotary_emb", @@ -110,47 +75,36 @@ def test_component_mapping_keys(self, adapter): "unembed", } - # ---- HF path names ---- - def test_embed_path(self, adapter): - """embed maps to model.embed_tokens.""" assert adapter.component_mapping["embed"].name == "model.embed_tokens" def test_rotary_emb_path(self, adapter): - """rotary_emb maps to model.rotary_emb.""" assert adapter.component_mapping["rotary_emb"].name == "model.rotary_emb" def test_blocks_path(self, adapter): - """blocks maps to model.layers.""" assert adapter.component_mapping["blocks"].name == "model.layers" def test_ln_final_path(self, adapter): - """ln_final maps to model.norm.""" assert adapter.component_mapping["ln_final"].name == "model.norm" def test_unembed_path(self, adapter): - """unembed maps to lm_head.""" assert adapter.component_mapping["unembed"].name == "lm_head" - # ---- Block submodules ---- - def test_block_submodules_keys(self, adapter): - """blocks submodules must contain ln1, ln2, mlp, and optional attn + linear_attn.""" submodules = adapter.component_mapping["blocks"].submodules assert set(submodules.keys()) == {"ln1", "ln2", "mlp", "attn", "linear_attn"} def test_attn_is_optional(self, adapter): - """attn must be marked optional (absent on linear-attention layers).""" + """attn is absent on linear-attention layers.""" submodules = adapter.component_mapping["blocks"].submodules assert submodules["attn"].optional is True def test_linear_attn_is_optional(self, adapter): - """linear_attn must be marked optional (absent on full-attention layers).""" + """linear_attn is absent on full-attention layers.""" submodules = adapter.component_mapping["blocks"].submodules assert submodules["linear_attn"].optional is True def test_linear_attn_bridge_type(self, adapter): - """linear_attn must be a GatedDeltaNetBridge.""" from transformer_lens.model_bridge.generalized_components.gated_delta_net import ( GatedDeltaNetBridge, ) @@ -159,45 +113,30 @@ def test_linear_attn_bridge_type(self, adapter): assert isinstance(submodules["linear_attn"], GatedDeltaNetBridge) def test_ln1_path(self, adapter): - """ln1 maps to input_layernorm.""" submodules = adapter.component_mapping["blocks"].submodules assert submodules["ln1"].name == "input_layernorm" def test_ln2_path(self, adapter): - """ln2 maps to post_attention_layernorm.""" submodules = adapter.component_mapping["blocks"].submodules assert submodules["ln2"].name == "post_attention_layernorm" def test_mlp_path(self, adapter): - """mlp maps to mlp.""" submodules = adapter.component_mapping["blocks"].submodules assert submodules["mlp"].name == "mlp" - # ---- MLP submodules ---- - def test_mlp_has_no_submodules(self, adapter): - """mlp is a MoEBridge with no enumerated submodules. - - Real Qwen3Next checkpoints use Qwen3NextSparseMoeBlock whose router - (`gate`) is a Qwen3NextTopKRouter rather than nn.Linear, and whose - experts are batched as 3D tensors inside Qwen3NextExperts. MoEBridge - wraps the whole block and delegates to HF's native forward, so no - internal submodules are mapped here. - """ + """Qwen3NextSparseMoeBlock has a non-Linear router and 3D batched experts; MoEBridge delegates to HF forward, so no internal subs are mapped.""" mlp = adapter.component_mapping["blocks"].submodules["mlp"] assert mlp.submodules == {} - # ---- Bridge types ---- - def test_mlp_bridge_type(self, adapter): - """mlp uses MoEBridge (sparse MoE on every real checkpoint).""" + """Every real checkpoint is sparse MoE.""" from transformer_lens.model_bridge.generalized_components import MoEBridge mlp = adapter.component_mapping["blocks"].submodules["mlp"] assert isinstance(mlp, MoEBridge) def test_ln1_bridge_type(self, adapter): - """ln1 uses RMSNormalizationBridge.""" from transformer_lens.model_bridge.generalized_components import ( RMSNormalizationBridge, ) @@ -206,7 +145,6 @@ def test_ln1_bridge_type(self, adapter): assert isinstance(ln1, RMSNormalizationBridge) def test_ln2_bridge_type(self, adapter): - """ln2 uses RMSNormalizationBridge.""" from transformer_lens.model_bridge.generalized_components import ( RMSNormalizationBridge, ) @@ -215,42 +153,24 @@ def test_ln2_bridge_type(self, adapter): assert isinstance(ln2, RMSNormalizationBridge) def test_blocks_bridge_type(self, adapter): - """blocks uses BlockBridge.""" from transformer_lens.model_bridge.generalized_components import BlockBridge assert isinstance(adapter.component_mapping["blocks"], BlockBridge) def test_rotary_emb_bridge_type(self, adapter): - """rotary_emb uses RotaryEmbeddingBridge.""" from transformer_lens.model_bridge.generalized_components import ( RotaryEmbeddingBridge, ) assert isinstance(adapter.component_mapping["rotary_emb"], RotaryEmbeddingBridge) - # ---- weight_processing_conversions ---- - def test_weight_processing_conversions_empty(self, adapter): - """weight_processing_conversions is empty (no attention submodules mapped).""" + """No attention submodules mapped, so no conversions.""" assert adapter.weight_processing_conversions == {} -# ============================================================================ -# Test: Weight Conversions (Phase C) -# ============================================================================ - - class TestQwen3NextWeightConversions: - """Verify preprocess_weights correctly slices q_proj.weight per-head. - - Background: In Qwen3Next, q_proj.weight has shape (n_heads * head_dim * 2, hidden_size) - where rows are organized as interleaved per-head pairs: - head_0_query (d_head rows), head_0_gate (d_head rows), - head_1_query (d_head rows), head_1_gate (d_head rows), ... - - A naive first-half slice would be wrong. The correct approach reshapes by - head and takes only the first d_head rows per head (the query half). - """ + """q_proj rows are interleaved per-head (query, gate, query, gate, ...) — naive first-half slice is wrong.""" N_HEADS = 4 D_HEAD = 8 @@ -266,17 +186,11 @@ def adapter(self): n_heads=self.N_HEADS, d_head=self.D_HEAD, d_model=self.HIDDEN_SIZE, - n_key_value_heads=self.N_HEADS, # MHA for simplicity + n_key_value_heads=self.N_HEADS, ) return Qwen3NextArchitectureAdapter(cfg) def _make_q_proj_weight(self): - """Create a q_proj.weight tensor with distinct per-head-row values. - - Shape: (n_heads * d_head * 2, hidden_size) - Each row is filled with a unique integer so we can verify which rows - were selected after slicing. - """ import torch total_rows = self.N_HEADS * self.D_HEAD * 2 @@ -286,7 +200,6 @@ def _make_q_proj_weight(self): return w def test_q_proj_output_shape(self, adapter): - """preprocess_weights reduces q_proj rows from n_heads*d_head*2 to n_heads*d_head.""" import torch w = self._make_q_proj_weight() @@ -298,12 +211,6 @@ def test_q_proj_output_shape(self, adapter): assert out.shape == (self.N_HEADS * self.D_HEAD, self.HIDDEN_SIZE) def test_q_proj_selects_query_rows_not_naive_first_half(self, adapter): - """For each head i, output rows [i*d_head : (i+1)*d_head] == input rows - [i*d_head*2 : i*d_head*2 + d_head]. - - This verifies the per-head reshape: a naive slice of the first half would - incorrectly include gate rows from later heads. - """ import torch w = self._make_q_proj_weight() @@ -314,7 +221,6 @@ def test_q_proj_selects_query_rows_not_naive_first_half(self, adapter): for head_idx in range(self.N_HEADS): out_rows = out[head_idx * self.D_HEAD : (head_idx + 1) * self.D_HEAD] - # Per-head interleaved layout: query rows for head i start at i*(d_head*2) expected_start = head_idx * self.D_HEAD * 2 expected_rows = w[expected_start : expected_start + self.D_HEAD] assert torch.equal(out_rows, expected_rows), ( @@ -324,11 +230,6 @@ def test_q_proj_selects_query_rows_not_naive_first_half(self, adapter): ) def test_naive_slice_would_be_wrong(self, adapter): - """Demonstrate that a naive first-half slice gives different (wrong) results. - - This documents the correctness invariant: the interleaved layout means - naive slicing includes gate rows from intermediate heads. - """ import torch w = self._make_q_proj_weight() @@ -337,10 +238,8 @@ def test_naive_slice_would_be_wrong(self, adapter): result = adapter.preprocess_weights(state_dict) correct_out = result["model.layers.0.self_attn.q_proj.weight"] - # Naive first half: just take the top n_heads*d_head rows naive_out = w[: self.N_HEADS * self.D_HEAD] - # They should differ (unless n_heads==1, where both produce the same result) if self.N_HEADS > 1: assert not torch.equal(correct_out, naive_out), ( "Naive first-half slice gave the same result as per-head slice — " @@ -348,7 +247,6 @@ def test_naive_slice_would_be_wrong(self, adapter): ) def test_non_q_proj_weights_unchanged(self, adapter): - """k_proj, v_proj, and down_proj weights are NOT modified by preprocess_weights.""" import torch k_proj = torch.randn(self.N_HEADS * self.D_HEAD, self.HIDDEN_SIZE) @@ -364,11 +262,10 @@ def test_non_q_proj_weights_unchanged(self, adapter): assert torch.equal(result["model.layers.0.mlp.down_proj.weight"], down_proj) def test_multiple_layers_all_processed(self, adapter): - """q_proj.weight tensors across multiple layers are all sliced correctly.""" import torch w0 = self._make_q_proj_weight() - w3 = self._make_q_proj_weight() * 2 # distinct values to catch cross-layer bugs + w3 = self._make_q_proj_weight() * 2 state_dict = { "model.layers.0.self_attn.q_proj.weight": w0, @@ -382,12 +279,10 @@ def test_multiple_layers_all_processed(self, adapter): assert result["model.layers.3.self_attn.q_proj.weight"].shape == expected_shape def test_empty_state_dict_returns_empty(self, adapter): - """preprocess_weights with an empty state dict returns an empty dict.""" result = adapter.preprocess_weights({}) assert result == {} def test_state_dict_without_q_proj_unchanged(self, adapter): - """A state dict with no q_proj keys is returned unmodified.""" import torch state_dict = { @@ -400,13 +295,164 @@ def test_state_dict_without_q_proj_unchanged(self, adapter): assert set(result.keys()) == original_keys def test_weight_processing_conversions_is_empty_dict(self, adapter): - """weight_processing_conversions is {} — q_proj slicing is done in preprocess_weights.""" + """q_proj slicing happens in preprocess_weights, not as a conversion.""" assert adapter.weight_processing_conversions == {} -# ============================================================================ -# Test: Integration (Phase D) -# ============================================================================ +class TestQwen3NextConfigAttributes: + """cfg attributes set by the adapter.""" + + @pytest.fixture + def adapter(self): + from transformer_lens.model_bridge.supported_architectures.qwen3_next import ( + Qwen3NextArchitectureAdapter, + ) + + return Qwen3NextArchitectureAdapter(_make_bridge_cfg()) + + def test_normalization_type(self, adapter): + assert adapter.cfg.normalization_type == "RMS" + + def test_positional_embedding_type(self, adapter): + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_final_rms(self, adapter): + assert adapter.cfg.final_rms is True + + def test_gated_mlp(self, adapter): + assert adapter.cfg.gated_mlp is True + + def test_attn_only(self, adapter): + assert adapter.cfg.attn_only is False + + def test_uses_rms_norm(self, adapter): + assert adapter.cfg.uses_rms_norm is True + + def test_default_prepend_bos(self, adapter): + assert adapter.cfg.default_prepend_bos is False + + def test_attn_implementation_eager(self, adapter): + assert adapter.cfg.attn_implementation == "eager" + + def test_supports_fold_ln_false(self, adapter): + """Hybrid layers break fold_ln.""" + assert adapter.supports_fold_ln is False + + def test_gated_q_proj_flag_set(self, adapter): + """Flag drives preprocess_weights to slice the gated half of q_proj.""" + assert getattr(adapter.cfg, "gated_q_proj", False) is True + + def test_n_key_value_heads_propagates(self, adapter): + assert adapter.cfg.n_key_value_heads == 2 + + +class TestQwen3NextComponentTypes: + """Top-level bridge classes — guards against silent type substitution.""" + + @pytest.fixture + def adapter(self): + from transformer_lens.model_bridge.supported_architectures.qwen3_next import ( + Qwen3NextArchitectureAdapter, + ) + + return Qwen3NextArchitectureAdapter(_make_bridge_cfg()) + + def test_embed_is_embedding_bridge(self, adapter): + from transformer_lens.model_bridge.generalized_components import EmbeddingBridge + + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_ln_final_is_rms_norm_bridge(self, adapter): + from transformer_lens.model_bridge.generalized_components import ( + RMSNormalizationBridge, + ) + + assert isinstance(adapter.component_mapping["ln_final"], RMSNormalizationBridge) + + def test_unembed_is_unembedding_bridge(self, adapter): + from transformer_lens.model_bridge.generalized_components import ( + UnembeddingBridge, + ) + + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + +class TestQwen3NextAttnSubmodules: + """Full-attention layers wire Qwen3-pattern submodules; gated q_proj half is pre-sliced.""" + + @pytest.fixture + def attn(self): + from transformer_lens.model_bridge.supported_architectures.qwen3_next import ( + Qwen3NextArchitectureAdapter, + ) + + adapter = Qwen3NextArchitectureAdapter(_make_bridge_cfg()) + return adapter.component_mapping["blocks"].submodules["attn"] + + def test_attn_is_position_embeddings_attention(self, attn): + from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( + PositionEmbeddingsAttentionBridge, + ) + + assert isinstance(attn, PositionEmbeddingsAttentionBridge) + + def test_attn_path(self, attn): + assert attn.name == "self_attn" + + def test_attn_qkvo_submodule_paths(self, attn): + from transformer_lens.model_bridge.generalized_components import LinearBridge + + for sub_name, expected_path in ( + ("q", "q_proj"), + ("k", "k_proj"), + ("v", "v_proj"), + ("o", "o_proj"), + ): + sub = attn.submodules[sub_name] + assert isinstance(sub, LinearBridge) + assert sub.name == expected_path + + def test_attn_q_norm_k_norm_present(self, attn): + """Qwen3 family uses per-head Q/K RMSNorm.""" + from transformer_lens.model_bridge.generalized_components import ( + RMSNormalizationBridge, + ) + + assert isinstance(attn.submodules["q_norm"], RMSNormalizationBridge) + assert isinstance(attn.submodules["k_norm"], RMSNormalizationBridge) + assert attn.submodules["q_norm"].name == "q_norm" + assert attn.submodules["k_norm"].name == "k_norm" + + +class TestQwen3NextArchitectureGuards: + """Guards against drift from Qwen3 conventions.""" + + @pytest.fixture + def adapter(self): + from transformer_lens.model_bridge.supported_architectures.qwen3_next import ( + Qwen3NextArchitectureAdapter, + ) + + return Qwen3NextArchitectureAdapter(_make_bridge_cfg()) + + def test_no_norm_offset_conversions(self, adapter): + """LLaMA-style RMSNorm — no +1 offset like Gemma.""" + for key in adapter.weight_processing_conversions: + assert "ln1" not in key + assert "ln2" not in key + assert "ln_final" not in key + + def test_mlp_is_moe_not_gated(self, adapter): + """MoE, not the dense GatedMLP of Qwen3/Qwen3.5.""" + from transformer_lens.model_bridge.generalized_components import ( + GatedMLPBridge, + MoEBridge, + ) + + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + assert isinstance(mlp, MoEBridge) + assert not isinstance(mlp, GatedMLPBridge) + try: from transformers import Qwen3NextConfig, Qwen3NextForCausalLM @@ -417,18 +463,7 @@ def test_weight_processing_conversions_is_empty_dict(self, adapter): def _make_tiny_hf_model(): - """Create a tiny Qwen3Next model for integration testing. - - Uses num_experts=4 (sparse MoE) to exercise the real production code path. - Every real Qwen3Next checkpoint has mlp_only_layers=[] and - decoder_sparse_step=1, so every decoder layer uses Qwen3NextSparseMoeBlock. - Test fixtures must mirror this or the adapter's MoE wiring goes untested. - - Config details: - - 8 layers: layers 3 and 7 are full-attention (full_attention_interval=4) - - All other layers are linear_attention - - sparse MoE MLP on all layers (num_experts=4, num_experts_per_tok=2) - """ + """Tiny Qwen3Next model: 8 layers (full-attn at 3, 7), sparse MoE on every layer to exercise the MoE path.""" cfg = Qwen3NextConfig( hidden_size=128, num_hidden_layers=8, @@ -463,7 +498,7 @@ def _make_tiny_hf_model(): def _make_tiny_bridge(): - """Create a Qwen3Next bridge from a tiny HF model.""" + """Build a Qwen3Next bridge from a tiny HF model.""" from unittest.mock import MagicMock from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig @@ -493,16 +528,10 @@ def _make_tiny_bridge(): reason="Qwen3NextForCausalLM not available in installed transformers", ) class TestQwen3NextIntegration: - """End-to-end integration tests using a tiny programmatic Qwen3Next model. - - Tests use num_experts=4 (sparse MoE) to exercise the real production code - path. The linear attention layers run via the torch fallback path when - flash-linear-attention / causal-conv1d are not installed. - """ + """End-to-end tests; linear-attn falls back to torch when flash-linear-attention is absent.""" @pytest.fixture(scope="class") def bridge_and_model(self): - """Create a tiny bridge + HF model pair, shared across the class.""" return _make_tiny_bridge() @pytest.fixture(scope="class") @@ -516,46 +545,30 @@ def hf_model(self, bridge_and_model): return hf def test_bridge_creation(self, bridge): - """TransformerBridge construction from a tiny Qwen3Next model must succeed.""" from transformer_lens.model_bridge import TransformerBridge assert isinstance(bridge, TransformerBridge) def test_hook_names_present(self, bridge): - """Key hook names must be present in the bridge hook_dict. - - Verified hook names: - - blocks.0.hook_resid_pre: present on linear-attention layer (layer 0) - - blocks.3.hook_resid_pre: present on first full-attention layer (layer 3) - - blocks.0.ln1.*: norm is present on all layers (universal submodule) - - blocks.0.mlp.*: MLP is present on all layers (universal submodule) - - Also verifies that blocks.0.attn.* is NOT present — self_attn is only on - full-attention layers, so it is NOT mapped as a block submodule. - """ + """blocks.0.attn.* must NOT appear — self_attn is absent on linear-attn layers.""" hook_keys = set(bridge.hook_dict.keys()) - # Block-level residual hooks exist on all layers assert "blocks.0.hook_resid_pre" in hook_keys, "linear-attn layer must have hook_resid_pre" assert "blocks.3.hook_resid_pre" in hook_keys, "full-attn layer must have hook_resid_pre" - # Norm hooks present on all layers assert any( "blocks.0.ln1" in k for k in hook_keys ), "blocks.0.ln1 submodule hooks must be present" - # MLP hooks present on all layers assert any( "blocks.0.mlp" in k for k in hook_keys ), "blocks.0.mlp submodule hooks must be present" - # No attn bridge — self_attn is absent on linear-attention layers assert not any( "blocks.0.attn" in k for k in hook_keys ), "blocks.0.attn hooks must NOT be present (hybrid architecture)" def test_forward_pass_consistency(self, bridge, hf_model): - """Bridge output logits must match HF model output logits to within atol=1e-4.""" import torch tokens = torch.randint(0, 512, (1, 4)) @@ -571,7 +584,6 @@ def test_forward_pass_consistency(self, bridge, hf_model): ), f"Logit mismatch: max diff = {(hf_logits - bridge_logits).abs().max().item():.6f}" def test_hook_activation_shapes(self, bridge): - """A hook added on blocks.0.mlp.hook_out must capture a (batch, seq, d_model) tensor.""" import torch captured: list[torch.Tensor] = [] diff --git a/tests/unit/model_bridge/supported_architectures/test_xglm_adapter.py b/tests/unit/model_bridge/supported_architectures/test_xglm_adapter.py index 73b68dbb4..3476e4962 100644 --- a/tests/unit/model_bridge/supported_architectures/test_xglm_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_xglm_adapter.py @@ -1,12 +1,4 @@ -"""Unit tests for XGLMArchitectureAdapter. - -Tests cover: -- Config attribute validation (all required attributes set correctly) [Phase A] -- Weight conversion keys and structure [Phase A] -- Component mapping structure (correct bridge types and HF module paths) [Phase B] -- Embedding scale hook compatibility [Phase C] -- Factory registration (XGLMForCausalLM maps to the right adapter) [Phase D] -""" +"""Unit tests for XGLMArchitectureAdapter: cfg, components, weight conversions, hook compat, factory.""" import math from types import SimpleNamespace @@ -15,10 +7,15 @@ import torch from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) from transformer_lens.model_bridge.generalized_components import ( AttentionBridge, BlockBridge, EmbeddingBridge, + LinearBridge, NormalizationBridge, SymbolicBridge, UnembeddingBridge, @@ -27,11 +24,6 @@ XGLMArchitectureAdapter, ) -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - def _make_cfg( n_heads: int = 4, d_model: int = 64, @@ -40,7 +32,7 @@ def _make_cfg( d_vocab: int = 1000, n_ctx: int = 512, ) -> TransformerBridgeConfig: - """Return a minimal TransformerBridgeConfig for XGLM adapter tests.""" + """Minimal TransformerBridgeConfig for XGLM adapter tests.""" return TransformerBridgeConfig( d_model=d_model, d_head=d_model // n_heads, @@ -54,23 +46,18 @@ def _make_cfg( ) -@pytest.fixture +@pytest.fixture(scope="class") def cfg() -> TransformerBridgeConfig: return _make_cfg() -@pytest.fixture +@pytest.fixture(scope="class") def adapter(cfg: TransformerBridgeConfig) -> XGLMArchitectureAdapter: return XGLMArchitectureAdapter(cfg) -# --------------------------------------------------------------------------- -# Phase A: Config attribute tests -# --------------------------------------------------------------------------- - - class TestXGLMAdapterConfig: - """Adapter must set all required config attributes to the correct values.""" + """Adapter sets all required config attributes.""" def test_normalization_type_is_ln(self, adapter: XGLMArchitectureAdapter) -> None: assert adapter.cfg.normalization_type == "LN" @@ -91,13 +78,8 @@ def test_uses_rms_norm_is_false(self, adapter: XGLMArchitectureAdapter) -> None: assert adapter.cfg.uses_rms_norm is False -# --------------------------------------------------------------------------- -# Phase A: Weight processing conversion tests -# --------------------------------------------------------------------------- - - class TestXGLMAdapterWeightConversions: - """Adapter must define exactly the four standard QKVO weight conversions.""" + """Adapter defines exactly the four standard QKVO weight conversions.""" def test_q_weight_key_present(self, adapter: XGLMArchitectureAdapter) -> None: assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions @@ -115,13 +97,8 @@ def test_exactly_four_conversion_keys(self, adapter: XGLMArchitectureAdapter) -> assert len(adapter.weight_processing_conversions) == 4 -# --------------------------------------------------------------------------- -# Phase B: Component mapping structure tests -# --------------------------------------------------------------------------- - - class TestXGLMAdapterComponentMapping: - """Component mapping must have the correct bridge types and HF module paths.""" + """component_mapping has correct bridge types and HF module paths.""" def test_embed_is_embedding_bridge(self, adapter: XGLMArchitectureAdapter) -> None: assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) @@ -130,7 +107,7 @@ def test_embed_name(self, adapter: XGLMArchitectureAdapter) -> None: assert adapter.component_mapping["embed"].name == "model.embed_tokens" def test_no_pos_embed_in_mapping(self, adapter: XGLMArchitectureAdapter) -> None: - # Sinusoidal embeddings have no weights — no bridge entry expected + # Sinusoidal embeddings have no weights — no bridge entry. assert "pos_embed" not in adapter.component_mapping def test_blocks_is_block_bridge(self, adapter: XGLMArchitectureAdapter) -> None: @@ -188,7 +165,7 @@ def test_attn_v_name(self, adapter: XGLMArchitectureAdapter) -> None: assert attn.submodules["v"].name == "v_proj" def test_attn_o_name_is_out_proj(self, adapter: XGLMArchitectureAdapter) -> None: - # Critical: XGLM uses out_proj, not o_proj (scaffold error pattern) + # XGLM uses out_proj, not o_proj (common scaffold mistake). attn = adapter.component_mapping["blocks"].submodules["attn"] assert attn.submodules["o"].name == "out_proj" @@ -213,20 +190,15 @@ def test_mlp_out_name(self, adapter: XGLMArchitectureAdapter) -> None: assert mlp.submodules["out"].name == "fc2" -# --------------------------------------------------------------------------- -# Phase C: Embedding scale hook compatibility tests -# --------------------------------------------------------------------------- - - def _make_mock_bridge() -> SimpleNamespace: - """Return a minimal mock bridge with embed.hook_out for hook-compat tests.""" + """Minimal mock bridge with embed.hook_out for hook-compat tests.""" hook_out = SimpleNamespace(hook_conversion=None) embed = SimpleNamespace(hook_out=hook_out) return SimpleNamespace(embed=embed) class TestXGLMAdapterHookCompatibility: - """setup_hook_compatibility must attach a scale conversion to hook_embed.""" + """setup_hook_compatibility attaches a scale conversion to hook_embed.""" def test_sets_hook_conversion_on_embed_hook_out(self, adapter: XGLMArchitectureAdapter) -> None: bridge = _make_mock_bridge() @@ -234,17 +206,15 @@ def test_sets_hook_conversion_on_embed_hook_out(self, adapter: XGLMArchitectureA assert bridge.embed.hook_out.hook_conversion is not None def test_scales_by_sqrt_d_model(self, adapter: XGLMArchitectureAdapter) -> None: - # d_model=64, sqrt(64)=8 exactly bridge = _make_mock_bridge() adapter.setup_hook_compatibility(bridge) conv = bridge.embed.hook_out.hook_conversion x = torch.ones(2, 4, 64) result = conv.handle_conversion(x) - expected_scale = math.sqrt(64) # 8.0 + expected_scale = math.sqrt(64) assert torch.allclose(result, x * expected_scale, atol=1e-6) def test_revert_inverts_scale(self, adapter: XGLMArchitectureAdapter) -> None: - # round-trip: revert(handle_conversion(x)) == x; exact for sqrt(64)=8 bridge = _make_mock_bridge() adapter.setup_hook_compatibility(bridge) conv = bridge.embed.hook_out.hook_conversion @@ -252,23 +222,16 @@ def test_revert_inverts_scale(self, adapter: XGLMArchitectureAdapter) -> None: assert torch.allclose(conv.revert(conv.handle_conversion(x)), x, atol=1e-6) def test_no_error_when_embed_missing(self, adapter: XGLMArchitectureAdapter) -> None: - # Guard: if bridge lacks embed, setup_hook_compatibility should not raise - bridge = SimpleNamespace() # no embed attribute - adapter.setup_hook_compatibility(bridge) # must not raise + bridge = SimpleNamespace() + adapter.setup_hook_compatibility(bridge) def test_no_error_when_hook_out_missing(self, adapter: XGLMArchitectureAdapter) -> None: - # Guard: if embed lacks hook_out, no error expected - bridge = SimpleNamespace(embed=SimpleNamespace()) # embed but no hook_out - adapter.setup_hook_compatibility(bridge) # must not raise - - -# --------------------------------------------------------------------------- -# Phase D: Factory registration tests -# --------------------------------------------------------------------------- + bridge = SimpleNamespace(embed=SimpleNamespace()) + adapter.setup_hook_compatibility(bridge) class TestXGLMFactoryRegistration: - """XGLMForCausalLM must be registered in SUPPORTED_ARCHITECTURES and resolve correctly.""" + """XGLMForCausalLM is registered in SUPPORTED_ARCHITECTURES and resolves correctly.""" def test_factory_returns_xglm_adapter(self) -> None: from transformer_lens.factories.architecture_adapter_factory import ( @@ -285,3 +248,168 @@ def test_factory_key_is_xglm_for_causal_lm(self) -> None: ) assert "XGLMForCausalLM" in SUPPORTED_ARCHITECTURES + + def test_factory_maps_to_correct_class(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert SUPPORTED_ARCHITECTURES["XGLMForCausalLM"] is XGLMArchitectureAdapter + + +class TestXGLMComponentMappingPresence: + """Component slots exist (deletion guard).""" + + def test_has_embed(self, adapter: XGLMArchitectureAdapter) -> None: + assert "embed" in adapter.component_mapping + + def test_has_blocks(self, adapter: XGLMArchitectureAdapter) -> None: + assert "blocks" in adapter.component_mapping + + def test_has_ln_final(self, adapter: XGLMArchitectureAdapter) -> None: + assert "ln_final" in adapter.component_mapping + + def test_has_unembed(self, adapter: XGLMArchitectureAdapter) -> None: + assert "unembed" in adapter.component_mapping + + def test_all_expected_top_level_keys_present( + self, adapter: XGLMArchitectureAdapter + ) -> None: + # No top-level rotary_emb (sinusoidal) and no pos_embed (non-persistent). + expected = {"embed", "blocks", "ln_final", "unembed"} + assert set(adapter.component_mapping.keys()) == expected + + +class TestXGLMBlockSubmodules: + """Decoder BlockBridge wires XGLM-pattern submodules.""" + + @pytest.fixture(scope="class") + def blocks(self, adapter: XGLMArchitectureAdapter) -> BlockBridge: + return adapter.component_mapping["blocks"] + + def test_block_has_required_submodules(self, blocks: BlockBridge) -> None: + for name in ("ln1", "ln2", "attn", "mlp"): + assert name in blocks.submodules, f"BlockBridge missing submodule '{name}'" + + def test_ln1_is_normalization_bridge(self, blocks: BlockBridge) -> None: + ln1 = blocks.submodules["ln1"] + assert isinstance(ln1, NormalizationBridge) + assert ln1.name == "self_attn_layer_norm" + + def test_ln2_is_normalization_bridge(self, blocks: BlockBridge) -> None: + ln2 = blocks.submodules["ln2"] + assert isinstance(ln2, NormalizationBridge) + assert ln2.name == "final_layer_norm" + + def test_attn_is_attention_bridge(self, blocks: BlockBridge) -> None: + attn = blocks.submodules["attn"] + assert isinstance(attn, AttentionBridge) + assert attn.name == "self_attn" + # 4-D mask, no position embeddings (sinusoidal added pre-block). + assert attn.requires_attention_mask is True + assert attn.attention_mask_4d is True + + def test_attn_qkvo_submodules_are_linear_bridges(self, blocks: BlockBridge) -> None: + attn = blocks.submodules["attn"] + for sub_name, expected_path in ( + ("q", "q_proj"), + ("k", "k_proj"), + ("v", "v_proj"), + ("o", "out_proj"), + ): + sub = attn.submodules[sub_name] + assert isinstance(sub, LinearBridge), f"attn.{sub_name} must be LinearBridge" + assert sub.name == expected_path + + def test_mlp_is_symbolic_bridge(self, blocks: BlockBridge) -> None: + # fc1/fc2 live directly on the decoder layer — SymbolicBridge holds the TL shape. + mlp = blocks.submodules["mlp"] + assert isinstance(mlp, SymbolicBridge) + + def test_mlp_submodules_are_linear_bridges(self, blocks: BlockBridge) -> None: + mlp = blocks.submodules["mlp"] + for sub_name, expected_path in (("in", "fc1"), ("out", "fc2")): + sub = mlp.submodules[sub_name] + assert isinstance(sub, LinearBridge), f"mlp.{sub_name} must be LinearBridge" + assert sub.name == expected_path + + def test_mlp_has_no_gate(self, blocks: BlockBridge) -> None: + # Standard 2-layer MLP (fc1 -> gelu -> fc2), NOT gated. + mlp = blocks.submodules["mlp"] + assert "gate" not in mlp.submodules + + +class TestXGLMComponentTypes: + """Component bridge classes — guard against silent type substitution.""" + + def test_embed_type(self, adapter: XGLMArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) + + def test_blocks_type(self, adapter: XGLMArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) + + def test_ln_final_type(self, adapter: XGLMArchitectureAdapter) -> None: + # XGLM uses LayerNorm (not RMS). + assert isinstance( + adapter.component_mapping["ln_final"], NormalizationBridge + ) + + def test_unembed_type(self, adapter: XGLMArchitectureAdapter) -> None: + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) + + +class TestXGLMWeightConversionSemantics: + """QKVO conversion entries use the expected types and patterns.""" + + def test_q_conversion_type(self, adapter: XGLMArchitectureAdapter) -> None: + conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + + def test_qkv_split_heads_pattern(self, adapter: XGLMArchitectureAdapter) -> None: + for slot in ("q", "k", "v"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_o_merge_heads_pattern(self, adapter: XGLMArchitectureAdapter) -> None: + conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + + def test_qkvo_n_axis_equals_n_heads(self, adapter: XGLMArchitectureAdapter) -> None: + # MHA: K/V share n_heads with Q/O (no GQA on XGLM). + for slot in ("q", "k", "v", "o"): + conv = adapter.weight_processing_conversions[f"blocks.{{i}}.attn.{slot}.weight"] + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + +class TestXGLMArchitectureGuards: + """Guards against drift toward neighbouring adapter patterns.""" + + def test_no_gqa_setting(self, adapter: XGLMArchitectureAdapter) -> None: + # All published XGLM sizes are MHA. + assert getattr(adapter.cfg, "n_key_value_heads", None) is None + + def test_no_norm_offset_conversions(self, adapter: XGLMArchitectureAdapter) -> None: + # XGLM is not Gemma — no +1 norm offset entries. + for key in adapter.weight_processing_conversions: + assert "ln1" not in key + assert "ln2" not in key + assert "ln_final" not in key + + def test_no_mlp_weight_conversions(self, adapter: XGLMArchitectureAdapter) -> None: + for key in adapter.weight_processing_conversions: + assert "mlp" not in key + + def test_center_writing_weights_disabled(self, adapter: XGLMArchitectureAdapter) -> None: + # Sinusoidal pos_embed has no params → cannot center pos_embed. + assert adapter.supports_center_writing_weights is False + + def test_no_rotary_in_blocks(self, adapter: XGLMArchitectureAdapter) -> None: + blocks = adapter.component_mapping["blocks"] + assert "rotary_emb" not in blocks.submodules + + def test_no_top_level_rotary_emb(self, adapter: XGLMArchitectureAdapter) -> None: + assert "rotary_emb" not in adapter.component_mapping diff --git a/tests/unit/model_bridge/test_qwen3_moe_adapter.py b/tests/unit/model_bridge/test_qwen3_moe_adapter.py deleted file mode 100644 index af6a0155c..000000000 --- a/tests/unit/model_bridge/test_qwen3_moe_adapter.py +++ /dev/null @@ -1,194 +0,0 @@ -"""Unit tests for the Qwen3MoeArchitectureAdapter. - -All tests use programmatic TransformerBridgeConfig instances — no network access -or model downloads. -""" - -import pytest - -from transformer_lens.config import TransformerBridgeConfig -from transformer_lens.conversion_utils.conversion_steps.rearrange_tensor_conversion import ( - RearrangeTensorConversion, -) -from transformer_lens.conversion_utils.param_processing_conversion import ( - ParamProcessingConversion, -) -from transformer_lens.factories.architecture_adapter_factory import ( - SUPPORTED_ARCHITECTURES, -) -from transformer_lens.model_bridge.generalized_components import ( - MoEBridge, - RMSNormalizationBridge, -) -from transformer_lens.model_bridge.supported_architectures.qwen3_moe import ( - Qwen3MoeArchitectureAdapter, -) - - -@pytest.fixture -def cfg() -> TransformerBridgeConfig: - return TransformerBridgeConfig( - d_model=64, - d_head=16, - n_layers=2, - n_ctx=128, - n_heads=4, - n_key_value_heads=2, - d_vocab=256, - architecture="Qwen3MoeForCausalLM", - ) - - -@pytest.fixture -def adapter(cfg: TransformerBridgeConfig) -> Qwen3MoeArchitectureAdapter: - return Qwen3MoeArchitectureAdapter(cfg) - - -class TestQwen3MoeAdapterConfig: - def test_normalization_type_is_rms(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - assert adapter.cfg.normalization_type == "RMS" - - def test_positional_embedding_type_is_rotary( - self, adapter: Qwen3MoeArchitectureAdapter - ) -> None: - assert adapter.cfg.positional_embedding_type == "rotary" - - def test_final_rms_is_true(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - """Qwen3MoE uses final_rms=True; OLMoE uses False.""" - assert adapter.cfg.final_rms is True - - def test_gated_mlp_is_true(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - assert adapter.cfg.gated_mlp is True - - def test_uses_rms_norm_is_true(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - assert adapter.cfg.uses_rms_norm is True - - def test_attn_implementation_is_eager(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - assert adapter.cfg.attn_implementation == "eager" - - def test_default_prepend_bos_is_false(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - assert adapter.cfg.default_prepend_bos is False - - def test_n_kv_heads_propagated(self) -> None: - """n_key_value_heads from the loaded config is preserved.""" - cfg = TransformerBridgeConfig( - d_model=64, - d_head=16, - n_layers=2, - n_ctx=128, - n_heads=4, - n_key_value_heads=2, - d_vocab=256, - architecture="Qwen3MoeForCausalLM", - ) - adapter = Qwen3MoeArchitectureAdapter(cfg) - assert adapter.cfg.n_key_value_heads == 2 - - -class TestQwen3MoeWeightConversions: - def test_has_qkvo_keys(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - convs = adapter.weight_processing_conversions - assert convs is not None - assert "blocks.{i}.attn.q.weight" in convs - assert "blocks.{i}.attn.k.weight" in convs - assert "blocks.{i}.attn.v.weight" in convs - assert "blocks.{i}.attn.o.weight" in convs - - def test_q_rearrange_uses_n_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - """Q rearrange uses n_heads (4).""" - convs = adapter.weight_processing_conversions - assert convs is not None - q_conv = convs["blocks.{i}.attn.q.weight"] - assert isinstance(q_conv, ParamProcessingConversion) - assert isinstance(q_conv.tensor_conversion, RearrangeTensorConversion) - axes = q_conv.tensor_conversion.axes_lengths - assert axes.get("n") == 4 - - def test_kv_rearrange_uses_n_kv_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - """K/V rearrange uses n_key_value_heads (2) for GQA.""" - convs = adapter.weight_processing_conversions - assert convs is not None - k_conv = convs["blocks.{i}.attn.k.weight"] - v_conv = convs["blocks.{i}.attn.v.weight"] - assert isinstance(k_conv, ParamProcessingConversion) - assert isinstance(v_conv, ParamProcessingConversion) - assert isinstance(k_conv.tensor_conversion, RearrangeTensorConversion) - assert isinstance(v_conv.tensor_conversion, RearrangeTensorConversion) - assert k_conv.tensor_conversion.axes_lengths.get("n") == 2 - assert v_conv.tensor_conversion.axes_lengths.get("n") == 2 - - def test_o_rearrange_uses_n_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - """O rearrange uses n_heads (4).""" - convs = adapter.weight_processing_conversions - assert convs is not None - o_conv = convs["blocks.{i}.attn.o.weight"] - assert isinstance(o_conv, ParamProcessingConversion) - assert isinstance(o_conv.tensor_conversion, RearrangeTensorConversion) - assert o_conv.tensor_conversion.axes_lengths.get("n") == 4 - - -class TestQwen3MoeComponentMapping: - def test_has_required_top_level_keys(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - mapping = adapter.component_mapping - assert mapping is not None - for key in ("embed", "rotary_emb", "blocks", "ln_final", "unembed"): - assert key in mapping, f"Missing top-level key: {key!r}" - - def test_blocks_has_required_submodules(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - mapping = adapter.component_mapping - assert mapping is not None - blocks = mapping["blocks"] - for key in ("ln1", "ln2", "attn", "mlp"): - assert key in blocks.submodules, f"Missing blocks submodule: {key!r}" - - def test_attn_has_all_submodules(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - mapping = adapter.component_mapping - assert mapping is not None - attn = mapping["blocks"].submodules["attn"] - for key in ("q", "k", "v", "o", "q_norm", "k_norm"): - assert key in attn.submodules, f"Missing attn submodule: {key!r}" - - def test_ln1_ln2_are_rms_norm_bridges(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - mapping = adapter.component_mapping - assert mapping is not None - subs = mapping["blocks"].submodules - assert isinstance(subs["ln1"], RMSNormalizationBridge) - assert isinstance(subs["ln2"], RMSNormalizationBridge) - - def test_mlp_is_moe_bridge(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - mapping = adapter.component_mapping - assert mapping is not None - mlp = mapping["blocks"].submodules["mlp"] - assert isinstance(mlp, MoEBridge) - - def test_mlp_has_gate_submodule(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - mapping = adapter.component_mapping - assert mapping is not None - mlp = mapping["blocks"].submodules["mlp"] - assert "gate" in mlp.submodules - - def test_q_norm_k_norm_are_rms_norm_bridges(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - mapping = adapter.component_mapping - assert mapping is not None - attn_subs = mapping["blocks"].submodules["attn"].submodules - assert isinstance(attn_subs["q_norm"], RMSNormalizationBridge) - assert isinstance(attn_subs["k_norm"], RMSNormalizationBridge) - - def test_hf_module_paths(self, adapter: Qwen3MoeArchitectureAdapter) -> None: - """HF module path names are mapped correctly.""" - mapping = adapter.component_mapping - assert mapping is not None - assert mapping["embed"].name == "model.embed_tokens" - assert mapping["ln_final"].name == "model.norm" - assert mapping["unembed"].name == "lm_head" - assert mapping["blocks"].name == "model.layers" - subs = mapping["blocks"].submodules - assert subs["ln1"].name == "input_layernorm" - assert subs["ln2"].name == "post_attention_layernorm" - assert subs["attn"].name == "self_attn" - assert subs["mlp"].name == "mlp" - - -class TestQwen3MoeFactoryRegistration: - def test_factory_lookup_returns_adapter_class(self) -> None: - assert SUPPORTED_ARCHITECTURES["Qwen3MoeForCausalLM"] is Qwen3MoeArchitectureAdapter diff --git a/tests/unit/test_gemma3_multimodal_adapter.py b/tests/unit/test_gemma3_multimodal_adapter.py deleted file mode 100644 index 5f8afbf59..000000000 --- a/tests/unit/test_gemma3_multimodal_adapter.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Unit tests for Gemma3 multimodal architecture adapter registration.""" - -import pytest - -from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig -from transformer_lens.factories.architecture_adapter_factory import ( - SUPPORTED_ARCHITECTURES, - ArchitectureAdapterFactory, -) -from transformer_lens.model_bridge.supported_architectures.gemma3_multimodal import ( - Gemma3MultimodalArchitectureAdapter, -) - - -def _make_gemma3_mm_cfg(**overrides): - """Create a TransformerBridgeConfig for Gemma3 4B multimodal.""" - defaults = dict( - d_model=2560, - d_head=256, - n_heads=8, - n_layers=34, - n_ctx=8192, - d_vocab=262208, - n_key_value_heads=4, - architecture="Gemma3ForConditionalGeneration", - ) - defaults.update(overrides) - return TransformerBridgeConfig(**defaults) - - -class TestGemma3MultimodalRegistration: - """Test that Gemma3MultimodalArchitectureAdapter is properly registered.""" - - def test_architecture_in_supported_architectures(self): - assert "Gemma3ForConditionalGeneration" in SUPPORTED_ARCHITECTURES - - def test_architecture_maps_to_correct_adapter(self): - assert ( - SUPPORTED_ARCHITECTURES["Gemma3ForConditionalGeneration"] - is Gemma3MultimodalArchitectureAdapter - ) - - def test_factory_selects_correct_adapter(self): - cfg = _make_gemma3_mm_cfg() - adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) - assert isinstance(adapter, Gemma3MultimodalArchitectureAdapter) - - -class TestGemma3MultimodalAdapterConfig: - """Test Gemma3MultimodalArchitectureAdapter configuration.""" - - @pytest.fixture - def adapter(self): - cfg = _make_gemma3_mm_cfg() - return Gemma3MultimodalArchitectureAdapter(cfg) - - def test_is_multimodal(self, adapter): - assert adapter.cfg.is_multimodal is True - - def test_gated_mlp(self, adapter): - assert adapter.cfg.gated_mlp is True - - def test_uses_rms_norm(self, adapter): - assert adapter.cfg.uses_rms_norm is True - - def test_normalization_type(self, adapter): - assert adapter.cfg.normalization_type == "RMS" - - def test_positional_embedding_type(self, adapter): - assert adapter.cfg.positional_embedding_type == "rotary" - - def test_has_vision_encoder_component(self, adapter): - assert "vision_encoder" in adapter.component_mapping - - def test_has_vision_projector_component(self, adapter): - assert "vision_projector" in adapter.component_mapping - - def test_has_language_model_components(self, adapter): - assert "embed" in adapter.component_mapping - assert "blocks" in adapter.component_mapping - assert "ln_final" in adapter.component_mapping - assert "unembed" in adapter.component_mapping - - def test_vision_encoder_path(self, adapter): - assert adapter.component_mapping["vision_encoder"].name == "model.vision_tower" - - def test_vision_projector_path(self, adapter): - assert adapter.component_mapping["vision_projector"].name == "model.multi_modal_projector" - - def test_embed_path(self, adapter): - assert adapter.component_mapping["embed"].name == "model.language_model.embed_tokens" - - def test_blocks_path(self, adapter): - assert adapter.component_mapping["blocks"].name == "model.language_model.layers" - - def test_ln_final_path(self, adapter): - assert adapter.component_mapping["ln_final"].name == "model.language_model.norm" - - def test_unembed_path(self, adapter): - assert adapter.component_mapping["unembed"].name == "lm_head" - - def test_weight_processing_conversions_exist(self, adapter): - assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions - assert "blocks.{i}.attn.k.weight" in adapter.weight_processing_conversions - assert "blocks.{i}.attn.v.weight" in adapter.weight_processing_conversions - assert "blocks.{i}.attn.o.weight" in adapter.weight_processing_conversions diff --git a/tests/unit/test_llava_config.py b/tests/unit/test_llava_config.py deleted file mode 100644 index 843ae5472..000000000 --- a/tests/unit/test_llava_config.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Unit tests for LLava architecture adapter and configuration.""" - -import pytest - -from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig -from transformer_lens.factories.architecture_adapter_factory import ( - SUPPORTED_ARCHITECTURES, - ArchitectureAdapterFactory, -) -from transformer_lens.model_bridge.supported_architectures.llava import ( - LlavaArchitectureAdapter, -) - - -def _make_llava_cfg(**overrides): - """Create a TransformerBridgeConfig for LLava 1.5 7B.""" - defaults = dict( - d_model=4096, - d_head=128, - n_heads=32, - n_layers=32, - n_ctx=4096, - d_vocab=32064, - architecture="LlavaForConditionalGeneration", - ) - defaults.update(overrides) - return TransformerBridgeConfig(**defaults) - - -class TestLlavaRegistration: - """Test that LlavaArchitectureAdapter is properly registered.""" - - def test_architecture_in_supported_architectures(self): - assert "LlavaForConditionalGeneration" in SUPPORTED_ARCHITECTURES - - def test_architecture_maps_to_correct_adapter(self): - assert SUPPORTED_ARCHITECTURES["LlavaForConditionalGeneration"] is LlavaArchitectureAdapter - - def test_factory_selects_correct_adapter(self): - cfg = _make_llava_cfg() - adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) - assert isinstance(adapter, LlavaArchitectureAdapter) - - -class TestLlavaAdapterConfig: - """Test LlavaArchitectureAdapter configuration.""" - - @pytest.fixture - def adapter(self): - cfg = _make_llava_cfg() - return LlavaArchitectureAdapter(cfg) - - def test_is_multimodal(self, adapter): - assert adapter.cfg.is_multimodal is True - - def test_gated_mlp(self, adapter): - assert adapter.cfg.gated_mlp is True - - def test_uses_rms_norm(self, adapter): - assert adapter.cfg.uses_rms_norm is True - - def test_normalization_type(self, adapter): - assert adapter.cfg.normalization_type == "RMS" - - def test_positional_embedding_type(self, adapter): - assert adapter.cfg.positional_embedding_type == "rotary" - - def test_attn_implementation(self, adapter): - assert adapter.cfg.attn_implementation == "eager" - - def test_has_vision_encoder_component(self, adapter): - assert "vision_encoder" in adapter.component_mapping - - def test_has_vision_projector_component(self, adapter): - assert "vision_projector" in adapter.component_mapping - - def test_has_language_model_components(self, adapter): - assert "embed" in adapter.component_mapping - assert "rotary_emb" in adapter.component_mapping - assert "blocks" in adapter.component_mapping - assert "ln_final" in adapter.component_mapping - assert "unembed" in adapter.component_mapping - - def test_vision_encoder_path(self, adapter): - assert adapter.component_mapping["vision_encoder"].name == "model.vision_tower" - - def test_vision_projector_path(self, adapter): - assert adapter.component_mapping["vision_projector"].name == "model.multi_modal_projector" - - def test_embed_path(self, adapter): - assert adapter.component_mapping["embed"].name == "model.language_model.embed_tokens" - - def test_rotary_emb_path(self, adapter): - assert adapter.component_mapping["rotary_emb"].name == "model.language_model.rotary_emb" - - def test_blocks_path(self, adapter): - assert adapter.component_mapping["blocks"].name == "model.language_model.layers" - - def test_ln_final_path(self, adapter): - assert adapter.component_mapping["ln_final"].name == "model.language_model.norm" - - def test_unembed_path(self, adapter): - assert adapter.component_mapping["unembed"].name == "lm_head" - - def test_weight_processing_conversions_exist(self, adapter): - assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions - assert "blocks.{i}.attn.k.weight" in adapter.weight_processing_conversions - assert "blocks.{i}.attn.v.weight" in adapter.weight_processing_conversions - assert "blocks.{i}.attn.o.weight" in adapter.weight_processing_conversions - - def test_no_norm_offset_conversions(self, adapter): - """LLava (LLaMA-based) should NOT have +1 norm offset like Gemma.""" - for key in adapter.weight_processing_conversions: - assert "ln1" not in key - assert "ln2" not in key - assert "ln_final" not in key From fcb0e36599a8071b4fed1954af41917c388e5e9e Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Tue, 12 May 2026 15:34:50 -0500 Subject: [PATCH 10/10] format cleanup --- .../test_cohere_adapter.py | 1 + .../test_gemma3_multimodal_adapter.py | 8 ++--- .../test_gpt_bigcode_adapter.py | 8 ++--- .../test_internlm2_adapter.py | 5 ++- .../test_llava_adapter.py | 12 ++----- .../test_mpt_adapter.py | 32 ++++++++----------- .../test_qwen3_5_adapter.py | 1 + .../test_qwen3_moe_adapter.py | 4 +-- .../test_qwen3_next_adapter.py | 1 + .../test_xglm_adapter.py | 9 ++---- 10 files changed, 32 insertions(+), 49 deletions(-) diff --git a/tests/unit/model_bridge/supported_architectures/test_cohere_adapter.py b/tests/unit/model_bridge/supported_architectures/test_cohere_adapter.py index eb687e2b7..5cf0f8ed5 100644 --- a/tests/unit/model_bridge/supported_architectures/test_cohere_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_cohere_adapter.py @@ -23,6 +23,7 @@ CohereArchitectureAdapter, ) + def _make_cfg( n_heads: int = 4, d_model: int = 64, diff --git a/tests/unit/model_bridge/supported_architectures/test_gemma3_multimodal_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gemma3_multimodal_adapter.py index 6d46c17d8..2edb2ca9d 100644 --- a/tests/unit/model_bridge/supported_architectures/test_gemma3_multimodal_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_gemma3_multimodal_adapter.py @@ -180,12 +180,8 @@ def adapter(self): def test_vision_encoder_is_siglip_bridge(self, adapter): # Gemma3 multimodal hard-wires SigLIP — must NOT be CLIP. - assert isinstance( - adapter.component_mapping["vision_encoder"], SiglipVisionEncoderBridge - ) - assert not isinstance( - adapter.component_mapping["vision_encoder"], CLIPVisionEncoderBridge - ) + assert isinstance(adapter.component_mapping["vision_encoder"], SiglipVisionEncoderBridge) + assert not isinstance(adapter.component_mapping["vision_encoder"], CLIPVisionEncoderBridge) def test_vision_projector_type(self, adapter): assert isinstance(adapter.component_mapping["vision_projector"], VisionProjectionBridge) diff --git a/tests/unit/model_bridge/supported_architectures/test_gpt_bigcode_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gpt_bigcode_adapter.py index 4ff80eb70..9e9089cc4 100644 --- a/tests/unit/model_bridge/supported_architectures/test_gpt_bigcode_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_gpt_bigcode_adapter.py @@ -232,9 +232,7 @@ class TestGPTBigCodeWeightConversionSemantics: def adapter(self) -> GPTBigCodeArchitectureAdapter: return GPTBigCodeArchitectureAdapter(_make_cfg()) - def test_q_conversion_type_and_pattern( - self, adapter: GPTBigCodeArchitectureAdapter - ) -> None: + def test_q_conversion_type_and_pattern(self, adapter: GPTBigCodeArchitectureAdapter) -> None: conv = adapter.weight_processing_conversions["blocks.{i}.attn.q.weight"] assert isinstance(conv, ParamProcessingConversion) assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) @@ -257,9 +255,7 @@ def test_kv_uses_mqa_n_equals_one( assert conv.tensor_conversion.pattern == "(n h) m -> n m h" assert conv.tensor_conversion.axes_lengths["n"] == 1 - def test_o_conversion_type_and_pattern( - self, adapter: GPTBigCodeArchitectureAdapter - ) -> None: + def test_o_conversion_type_and_pattern(self, adapter: GPTBigCodeArchitectureAdapter) -> None: conv = adapter.weight_processing_conversions["blocks.{i}.attn.o.weight"] assert isinstance(conv, ParamProcessingConversion) assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) diff --git a/tests/unit/model_bridge/supported_architectures/test_internlm2_adapter.py b/tests/unit/model_bridge/supported_architectures/test_internlm2_adapter.py index b7ce4c82d..a3db51074 100644 --- a/tests/unit/model_bridge/supported_architectures/test_internlm2_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_internlm2_adapter.py @@ -25,6 +25,7 @@ InternLM2ArchitectureAdapter, ) + def _make_cfg( n_heads: int = 8, n_key_value_heads: int = 2, @@ -735,7 +736,9 @@ def test_block_uses_block_bridge_not_parallel( self, adapter: InternLM2ArchitectureAdapter ) -> None: # Sequential, not parallel-attn-mlp — guard against borrowing Cohere's pattern. - from transformer_lens.model_bridge.generalized_components import ParallelBlockBridge + from transformer_lens.model_bridge.generalized_components import ( + ParallelBlockBridge, + ) blocks = adapter.component_mapping["blocks"] assert not isinstance(blocks, ParallelBlockBridge) diff --git a/tests/unit/model_bridge/supported_architectures/test_llava_adapter.py b/tests/unit/model_bridge/supported_architectures/test_llava_adapter.py index c797d416c..f0f1397fd 100644 --- a/tests/unit/model_bridge/supported_architectures/test_llava_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_llava_adapter.py @@ -194,18 +194,12 @@ class TestLlavaSiglipVisionVariant: def test_siglip_selects_siglip_bridge(self): adapter = LlavaArchitectureAdapter(_make_llava_cfg(vision_model_type="siglip_vision_model")) - assert isinstance( - adapter.component_mapping["vision_encoder"], SiglipVisionEncoderBridge - ) - assert not isinstance( - adapter.component_mapping["vision_encoder"], CLIPVisionEncoderBridge - ) + assert isinstance(adapter.component_mapping["vision_encoder"], SiglipVisionEncoderBridge) + assert not isinstance(adapter.component_mapping["vision_encoder"], CLIPVisionEncoderBridge) def test_siglip_short_alias_selects_siglip_bridge(self): adapter = LlavaArchitectureAdapter(_make_llava_cfg(vision_model_type="siglip")) - assert isinstance( - adapter.component_mapping["vision_encoder"], SiglipVisionEncoderBridge - ) + assert isinstance(adapter.component_mapping["vision_encoder"], SiglipVisionEncoderBridge) class TestLlavaBlockSubmodules: diff --git a/tests/unit/model_bridge/supported_architectures/test_mpt_adapter.py b/tests/unit/model_bridge/supported_architectures/test_mpt_adapter.py index acffe26a1..0c848f3d0 100644 --- a/tests/unit/model_bridge/supported_architectures/test_mpt_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_mpt_adapter.py @@ -292,7 +292,9 @@ def test_ln_final_type(self, adapter: MPTArchitectureAdapter) -> None: assert not isinstance(ln_final, RMSNormalizationBridge) def test_unembed_type(self, adapter: MPTArchitectureAdapter) -> None: - from transformer_lens.model_bridge.generalized_components import UnembeddingBridge + from transformer_lens.model_bridge.generalized_components import ( + UnembeddingBridge, + ) assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) @@ -306,7 +308,9 @@ class TestMPTBlockSubmoduleStructure: """Each block submodule has the correct bridge type and HF path.""" def test_ln1_is_layernorm_at_norm_1(self, adapter: MPTArchitectureAdapter) -> None: - from transformer_lens.model_bridge.generalized_components import NormalizationBridge + from transformer_lens.model_bridge.generalized_components import ( + NormalizationBridge, + ) block = adapter.component_mapping["blocks"] ln1 = block.submodules["ln1"] @@ -314,7 +318,9 @@ def test_ln1_is_layernorm_at_norm_1(self, adapter: MPTArchitectureAdapter) -> No assert ln1.name == "norm_1" def test_ln2_is_layernorm_at_norm_2(self, adapter: MPTArchitectureAdapter) -> None: - from transformer_lens.model_bridge.generalized_components import NormalizationBridge + from transformer_lens.model_bridge.generalized_components import ( + NormalizationBridge, + ) block = adapter.component_mapping["blocks"] ln2 = block.submodules["ln2"] @@ -337,9 +343,7 @@ def test_attn_does_not_require_position_embeddings( attn = adapter.component_mapping["blocks"].submodules["attn"] assert attn.requires_position_embeddings is False - def test_attn_does_not_require_attention_mask( - self, adapter: MPTArchitectureAdapter - ) -> None: + def test_attn_does_not_require_attention_mask(self, adapter: MPTArchitectureAdapter) -> None: # ALiBi bias slope IS the position-aware signal. attn = adapter.component_mapping["blocks"].submodules["attn"] assert attn.requires_attention_mask is False @@ -402,9 +406,7 @@ def test_mlp_submodule_paths(self, adapter: MPTArchitectureAdapter) -> None: class TestMPTWeightConversionSemantics: """Each weight conversion entry uses the expected class and pattern.""" - def test_qkv_conversion_classes_and_patterns( - self, adapter: MPTArchitectureAdapter - ) -> None: + def test_qkv_conversion_classes_and_patterns(self, adapter: MPTArchitectureAdapter) -> None: from transformer_lens.conversion_utils.conversion_steps import ( RearrangeTensorConversion, ) @@ -475,19 +477,13 @@ def test_mqa_does_not_change_q_or_o(self) -> None: class TestMPTArchitectureGuards: """No rotary, no pos_embed (MPT uses ALiBi).""" - def test_no_rotary_emb_in_component_mapping( - self, adapter: MPTArchitectureAdapter - ) -> None: + def test_no_rotary_emb_in_component_mapping(self, adapter: MPTArchitectureAdapter) -> None: assert "rotary_emb" not in adapter.component_mapping - def test_no_pos_embed_in_component_mapping( - self, adapter: MPTArchitectureAdapter - ) -> None: + def test_no_pos_embed_in_component_mapping(self, adapter: MPTArchitectureAdapter) -> None: assert "pos_embed" not in adapter.component_mapping - def test_no_rotary_emb_in_attn_submodules( - self, adapter: MPTArchitectureAdapter - ) -> None: + def test_no_rotary_emb_in_attn_submodules(self, adapter: MPTArchitectureAdapter) -> None: # ALiBi bias is computed inside the attention bridge: no rotary submodule. attn = adapter.component_mapping["blocks"].submodules["attn"] assert "rotary_emb" not in attn.submodules diff --git a/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py b/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py index 52e0ef3f8..c5781807b 100644 --- a/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py @@ -18,6 +18,7 @@ except ImportError: _QWEN3_5_AVAILABLE = False + @pytest.mark.skipif( not _QWEN3_5_AVAILABLE, reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers", diff --git a/tests/unit/model_bridge/supported_architectures/test_qwen3_moe_adapter.py b/tests/unit/model_bridge/supported_architectures/test_qwen3_moe_adapter.py index 055909f39..771b8bdde 100644 --- a/tests/unit/model_bridge/supported_architectures/test_qwen3_moe_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_qwen3_moe_adapter.py @@ -336,9 +336,7 @@ def test_no_norm_offset_conversions(self, adapter: Qwen3MoeArchitectureAdapter) assert "ln2" not in key assert "ln_final" not in key - def test_weight_conversions_are_only_qkvo( - self, adapter: Qwen3MoeArchitectureAdapter - ) -> None: + def test_weight_conversions_are_only_qkvo(self, adapter: Qwen3MoeArchitectureAdapter) -> None: """Expert/gate weights pass through untouched.""" assert set(adapter.weight_processing_conversions.keys()) == { "blocks.{i}.attn.q.weight", diff --git a/tests/unit/model_bridge/supported_architectures/test_qwen3_next_adapter.py b/tests/unit/model_bridge/supported_architectures/test_qwen3_next_adapter.py index 9829ac6a6..f3b0e945b 100644 --- a/tests/unit/model_bridge/supported_architectures/test_qwen3_next_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_qwen3_next_adapter.py @@ -12,6 +12,7 @@ ) from transformer_lens.tools.model_registry import HF_SUPPORTED_ARCHITECTURES + class TestQwen3NextRegistration: """Adapter is registered in all lookup tables.""" diff --git a/tests/unit/model_bridge/supported_architectures/test_xglm_adapter.py b/tests/unit/model_bridge/supported_architectures/test_xglm_adapter.py index 3476e4962..62b1d3733 100644 --- a/tests/unit/model_bridge/supported_architectures/test_xglm_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_xglm_adapter.py @@ -24,6 +24,7 @@ XGLMArchitectureAdapter, ) + def _make_cfg( n_heads: int = 4, d_model: int = 64, @@ -272,9 +273,7 @@ def test_has_ln_final(self, adapter: XGLMArchitectureAdapter) -> None: def test_has_unembed(self, adapter: XGLMArchitectureAdapter) -> None: assert "unembed" in adapter.component_mapping - def test_all_expected_top_level_keys_present( - self, adapter: XGLMArchitectureAdapter - ) -> None: + def test_all_expected_top_level_keys_present(self, adapter: XGLMArchitectureAdapter) -> None: # No top-level rotary_emb (sinusoidal) and no pos_embed (non-persistent). expected = {"embed", "blocks", "ln_final", "unembed"} assert set(adapter.component_mapping.keys()) == expected @@ -350,9 +349,7 @@ def test_blocks_type(self, adapter: XGLMArchitectureAdapter) -> None: def test_ln_final_type(self, adapter: XGLMArchitectureAdapter) -> None: # XGLM uses LayerNorm (not RMS). - assert isinstance( - adapter.component_mapping["ln_final"], NormalizationBridge - ) + assert isinstance(adapter.component_mapping["ln_final"], NormalizationBridge) def test_unembed_type(self, adapter: XGLMArchitectureAdapter) -> None: assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge)