diff --git a/demos/Grokking_Demo.ipynb b/demos/Grokking_Demo.ipynb index 09af51e9a..2fab2d43d 100644 --- a/demos/Grokking_Demo.ipynb +++ b/demos/Grokking_Demo.ipynb @@ -139,10 +139,8 @@ "source": [ "import transformer_lens\n", "import transformer_lens.utilities as utils\n", - "from transformer_lens.hook_points import (\n", - " HookedRootModule,\n", - " HookPoint,\n", - ") # Hooking utilities\n", + "from transformer_lens.hook_points import HookPoint\n", + "from transformer_lens.HookedRootModule import HookedRootModule # Hooking utilities\n", "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n", "\n", "\n", diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb index 3c4e72dc5..70ffabbea 100644 --- a/demos/Main_Demo.ipynb +++ b/demos/Main_Demo.ipynb @@ -1738,7 +1738,7 @@ "metadata": {}, "outputs": [], "source": [ - "from transformer_lens.hook_points import HookedRootModule, HookPoint\n", + "from transformer_lens.hook_points import HookPoint\nfrom transformer_lens.HookedRootModule import HookedRootModule\n", "\n", "\n", "class SquareThenAdd(nn.Module):\n", diff --git a/demos/Othello_GPT.ipynb b/demos/Othello_GPT.ipynb index 40469f357..ea087d7df 100644 --- a/demos/Othello_GPT.ipynb +++ b/demos/Othello_GPT.ipynb @@ -181,10 +181,8 @@ "source": [ "import transformer_lens\n", "import transformer_lens.utilities as utils\n", - "from transformer_lens.hook_points import (\n", - " HookedRootModule,\n", - " HookPoint,\n", - ") # Hooking utilities\n", + "from transformer_lens.hook_points import HookPoint\n", + "from transformer_lens.HookedRootModule import HookedRootModule # Hooking utilities\n", "from transformer_lens import (\n", " HookedTransformer,\n", " HookedTransformerConfig,\n", @@ -595,7 +593,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "transformer-lens", "language": "python", "name": "python3" }, @@ -609,7 +607,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.12.12" }, "orig_nbformat": 4 }, diff --git a/tests/unit/components/mlps/test_gpt_oss_moe_component.py b/tests/unit/components/mlps/test_gpt_oss_moe_component.py index 2c8e56a6a..7a3eb03c8 100644 --- a/tests/unit/components/mlps/test_gpt_oss_moe_component.py +++ b/tests/unit/components/mlps/test_gpt_oss_moe_component.py @@ -7,7 +7,7 @@ GptOssExpert, GptOssMoE, ) -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def make_moe_cfg(): diff --git a/tests/unit/model_bridge/supported_architectures/test_gemma3_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gemma3_adapter.py index a464fb1e7..b49f4114a 100644 --- a/tests/unit/model_bridge/supported_architectures/test_gemma3_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_gemma3_adapter.py @@ -5,7 +5,7 @@ import pytest -from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig +from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig from transformer_lens.conversion_utils.conversion_steps import ( ArithmeticTensorConversion, RearrangeTensorConversion, diff --git a/tests/unit/model_bridge/supported_architectures/test_gemma3_config.py b/tests/unit/model_bridge/supported_architectures/test_gemma3_config.py index e00c579f4..26751f5c3 100644 --- a/tests/unit/model_bridge/supported_architectures/test_gemma3_config.py +++ b/tests/unit/model_bridge/supported_architectures/test_gemma3_config.py @@ -8,7 +8,7 @@ import pytest -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.loading_from_pretrained import get_pretrained_model_config from transformer_lens.supported_models import OFFICIAL_MODEL_NAMES diff --git a/tests/unit/model_bridge/supported_architectures/test_gemma3_multimodal_adapter.py b/tests/unit/model_bridge/supported_architectures/test_gemma3_multimodal_adapter.py index 2edb2ca9d..6a1172dd7 100644 --- a/tests/unit/model_bridge/supported_architectures/test_gemma3_multimodal_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_gemma3_multimodal_adapter.py @@ -4,7 +4,7 @@ import pytest -from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig +from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig from transformer_lens.conversion_utils.conversion_steps import ( ArithmeticTensorConversion, RearrangeTensorConversion, diff --git a/tests/unit/model_bridge/supported_architectures/test_llava_adapter.py b/tests/unit/model_bridge/supported_architectures/test_llava_adapter.py index f0f1397fd..8e99a0385 100644 --- a/tests/unit/model_bridge/supported_architectures/test_llava_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_llava_adapter.py @@ -4,7 +4,7 @@ import pytest -from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig +from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion from transformer_lens.conversion_utils.param_processing_conversion import ( ParamProcessingConversion, diff --git a/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py b/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py index c5781807b..f34fc0011 100644 --- a/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_qwen3_5_adapter.py @@ -49,7 +49,9 @@ def test_adapter_class_correct(self): def _make_bridge_cfg(**overrides): """Minimal TransformerBridgeConfig for Qwen3_5 adapter tests.""" - from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig + from transformer_lens.config.transformer_bridge_config import ( + TransformerBridgeConfig, + ) defaults = dict( d_model=1024, @@ -266,7 +268,7 @@ def test_n_key_value_heads_set_when_gqa(self): assert adapter.cfg.n_key_value_heads == 2 def test_n_key_value_heads_not_set_when_absent(self): - from transformer_lens.config.TransformerBridgeConfig import ( + from transformer_lens.config.transformer_bridge_config import ( TransformerBridgeConfig, ) from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( @@ -583,7 +585,9 @@ def _make_tiny_bridge(): """Build a Qwen3_5 bridge from a tiny HF model.""" from unittest.mock import MagicMock - from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig + from transformer_lens.config.transformer_bridge_config import ( + TransformerBridgeConfig, + ) from transformer_lens.model_bridge import TransformerBridge from transformer_lens.model_bridge.supported_architectures.qwen3_5 import ( Qwen3_5ArchitectureAdapter, diff --git a/tests/unit/model_bridge/supported_architectures/test_qwen3_next_adapter.py b/tests/unit/model_bridge/supported_architectures/test_qwen3_next_adapter.py index f3b0e945b..3e0cc10e6 100644 --- a/tests/unit/model_bridge/supported_architectures/test_qwen3_next_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_qwen3_next_adapter.py @@ -39,7 +39,9 @@ def test_adapter_class_correct(self): def _make_bridge_cfg(**overrides): """Minimal TransformerBridgeConfig for Qwen3Next adapter tests.""" - from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig + from transformer_lens.config.transformer_bridge_config import ( + TransformerBridgeConfig, + ) defaults = dict( d_model=2048, @@ -502,7 +504,9 @@ def _make_tiny_bridge(): """Build a Qwen3Next bridge from a tiny HF model.""" from unittest.mock import MagicMock - from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig + from transformer_lens.config.transformer_bridge_config import ( + TransformerBridgeConfig, + ) from transformer_lens.model_bridge import TransformerBridge from transformer_lens.model_bridge.supported_architectures.qwen3_next import ( Qwen3NextArchitectureAdapter, diff --git a/tests/unit/model_bridge/test_weight_processing_adapter_paths.py b/tests/unit/model_bridge/test_weight_processing_adapter_paths.py index 998a1604b..34cd5bce4 100644 --- a/tests/unit/model_bridge/test_weight_processing_adapter_paths.py +++ b/tests/unit/model_bridge/test_weight_processing_adapter_paths.py @@ -10,7 +10,7 @@ from transformer_lens import HookedTransformer from transformer_lens import utilities as utils -from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig +from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.weight_processing import ProcessWeights diff --git a/tests/unit/pretrained_weight_conversions/test_apertus.py b/tests/unit/pretrained_weight_conversions/test_apertus.py index 68ba5089a..1f765e092 100644 --- a/tests/unit/pretrained_weight_conversions/test_apertus.py +++ b/tests/unit/pretrained_weight_conversions/test_apertus.py @@ -4,7 +4,7 @@ import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.pretrained.weight_conversions.apertus import ( convert_apertus_weights, ) diff --git a/tests/unit/pretrained_weight_conversions/test_gemma.py b/tests/unit/pretrained_weight_conversions/test_gemma.py index 5e56f615f..223bae52f 100644 --- a/tests/unit/pretrained_weight_conversions/test_gemma.py +++ b/tests/unit/pretrained_weight_conversions/test_gemma.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.pretrained.weight_conversions.gemma import convert_gemma_weights diff --git a/tests/unit/pretrained_weight_conversions/test_hubert_weights.py b/tests/unit/pretrained_weight_conversions/test_hubert_weights.py index 24b07444c..28343567c 100644 --- a/tests/unit/pretrained_weight_conversions/test_hubert_weights.py +++ b/tests/unit/pretrained_weight_conversions/test_hubert_weights.py @@ -5,7 +5,7 @@ import pytest import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.pretrained.weight_conversions.hubert import convert_hubert_weights diff --git a/tests/unit/pretrained_weight_conversions/test_olmo3.py b/tests/unit/pretrained_weight_conversions/test_olmo3.py index 6548f3d99..370b494ba 100644 --- a/tests/unit/pretrained_weight_conversions/test_olmo3.py +++ b/tests/unit/pretrained_weight_conversions/test_olmo3.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.pretrained.weight_conversions.olmo3 import convert_olmo3_weights diff --git a/tests/unit/pretrained_weight_conversions/test_openai.py b/tests/unit/pretrained_weight_conversions/test_openai.py index 3cd515ead..bdaf9a2e5 100644 --- a/tests/unit/pretrained_weight_conversions/test_openai.py +++ b/tests/unit/pretrained_weight_conversions/test_openai.py @@ -1,6 +1,6 @@ import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.pretrained.weight_conversions.openai import ( convert_gpt_oss_weights, ) diff --git a/tests/unit/test_hook_introspection.py b/tests/unit/test_hook_introspection.py index 620195094..592813094 100644 --- a/tests/unit/test_hook_introspection.py +++ b/tests/unit/test_hook_introspection.py @@ -2,11 +2,8 @@ from unittest import mock -from transformer_lens.hook_points import ( - HookedRootModule, - HookIntrospectionMixin, - HookPoint, -) +from transformer_lens.hook_points import HookIntrospectionMixin, HookPoint +from transformer_lens.HookedRootModule import HookedRootModule class _ToyModel(HookedRootModule): diff --git a/tests/unit/test_hooked_root_module.py b/tests/unit/test_hooked_root_module.py index da1ca2d97..6ec3d9f9f 100644 --- a/tests/unit/test_hooked_root_module.py +++ b/tests/unit/test_hooked_root_module.py @@ -1,6 +1,7 @@ from unittest.mock import Mock -from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedRootModule import HookedRootModule MODEL_NAME = "solu-2l" diff --git a/tests/unit/test_key_value_cache_entry.py b/tests/unit/test_key_value_cache_entry.py index b7806454a..e4b300017 100644 --- a/tests/unit/test_key_value_cache_entry.py +++ b/tests/unit/test_key_value_cache_entry.py @@ -20,7 +20,7 @@ from transformer_lens.cache.key_value_cache_entry import ( TransformerLensKeyValueCacheEntry, ) -from transformer_lens.config.TransformerLensConfig import TransformerLensConfig +from transformer_lens.config.transformer_lens_config import TransformerLensConfig def _make_cfg(dtype: torch.dtype, n_heads: int = 4, d_head: int = 8, n_key_value_heads=None): diff --git a/tests/unit/test_optional_submodule.py b/tests/unit/test_optional_submodule.py index 168ad0ce1..f71886987 100644 --- a/tests/unit/test_optional_submodule.py +++ b/tests/unit/test_optional_submodule.py @@ -226,7 +226,9 @@ def test_succeeds_on_universal_submodule(self): class TestRefactorFactoredAttnHybrid: def test_skips_missing_attn_layers(self): - from transformer_lens.config.TransformerLensConfig import TransformerLensConfig + from transformer_lens.config.transformer_lens_config import ( + TransformerLensConfig, + ) from transformer_lens.weight_processing import ProcessWeights cfg = TransformerLensConfig( @@ -255,7 +257,9 @@ def test_skips_missing_attn_layers(self): assert "blocks.3.attn.W_Q" not in result def test_raises_on_partial_attn_keys(self): - from transformer_lens.config.TransformerLensConfig import TransformerLensConfig + from transformer_lens.config.transformer_lens_config import ( + TransformerLensConfig, + ) from transformer_lens.weight_processing import ProcessWeights cfg = TransformerLensConfig( diff --git a/tests/unit/test_weight_processing.py b/tests/unit/test_weight_processing.py index d8f12a2ae..415e7657b 100644 --- a/tests/unit/test_weight_processing.py +++ b/tests/unit/test_weight_processing.py @@ -11,7 +11,7 @@ import pytest import torch -from transformer_lens.config.TransformerLensConfig import TransformerLensConfig +from transformer_lens.config.transformer_lens_config import TransformerLensConfig from transformer_lens.weight_processing import ProcessWeights # from typing import Dict # Unused import diff --git a/tests/unit/utilities/test_logits_utils.py b/tests/unit/utilities/test_logits_utils.py new file mode 100644 index 000000000..b5a8c2602 --- /dev/null +++ b/tests/unit/utilities/test_logits_utils.py @@ -0,0 +1,87 @@ +"""Unit tests for transformer_lens.utilities.logits_utils.""" + +import pandas as pd +import pytest +import torch + +from transformer_lens.utilities.logits_utils import logits_to_df + + +class _StubTokenizer: + """Minimal tokenizer surface used by logits_to_df (decode of single ids).""" + + def __init__(self, vocab: list[str]): + self._vocab = vocab + + def decode(self, ids: list[int]) -> str: + return "".join(self._vocab[i] for i in ids) + + +@pytest.fixture(scope="module") +def logits() -> torch.Tensor: + return torch.tensor([1.0, 3.0, 2.0, 0.5]) + + +class TestLogitsToDf: + def test_returns_dataframe(self, logits: torch.Tensor): + df = logits_to_df(logits) + assert isinstance(df, pd.DataFrame) + + def test_columns_no_tokenizer(self, logits: torch.Tensor): + df = logits_to_df(logits) + assert list(df.columns) == ["token_index", "logit", "log_prob", "probability"] + + def test_columns_with_tokenizer(self, logits: torch.Tensor): + tok = _StubTokenizer(["a", "b", "c", "d"]) + df = logits_to_df(logits, tokenizer=tok) + assert list(df.columns) == [ + "token_index", + "token_string", + "logit", + "log_prob", + "probability", + ] + + def test_sorted_by_descending_probability(self, logits: torch.Tensor): + df = logits_to_df(logits) + probs = df["probability"].tolist() + assert probs == sorted(probs, reverse=True) + + def test_token_indices_match_logit_argsort(self, logits: torch.Tensor): + df = logits_to_df(logits) + # logits = [1.0, 3.0, 2.0, 0.5] -> argsort desc = [1, 2, 0, 3] + assert df["token_index"].tolist() == [1, 2, 0, 3] + + def test_top_k_truncates(self, logits: torch.Tensor): + df = logits_to_df(logits, top_k=2) + assert len(df) == 2 + assert df["token_index"].tolist() == [1, 2] + + def test_token_string_decoded(self, logits: torch.Tensor): + tok = _StubTokenizer(["a", "b", "c", "d"]) + df = logits_to_df(logits, tokenizer=tok, top_k=2) + assert df["token_string"].tolist() == ["b", "c"] + + def test_log_prob_and_probability_consistent(self, logits: torch.Tensor): + df = logits_to_df(logits) + assert torch.allclose( + torch.tensor(df["log_prob"].tolist()).exp(), + torch.tensor(df["probability"].tolist()), + atol=1e-6, + ) + + def test_probabilities_sum_to_one(self, logits: torch.Tensor): + df = logits_to_df(logits) + assert df["probability"].sum() == pytest.approx(1.0, abs=1e-6) + + def test_logit_column_preserves_input_values(self, logits: torch.Tensor): + df = logits_to_df(logits) + # Order is argsort-desc; just check membership and float-equality. + assert sorted(df["logit"].tolist()) == sorted(logits.tolist()) + + def test_rejects_non_1d_input(self): + # Shape constraint enforced by jaxtyping/beartype on Float[Tensor, "d_vocab"]. + from beartype.roar import BeartypeCallHintParamViolation + + with pytest.raises(BeartypeCallHintParamViolation): + logits_to_df(torch.zeros(3, 4)) diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 6a0d68e63..946c567c8 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -15,7 +15,17 @@ class first, including the examples, and then skimming the available methods. Yo import logging import warnings -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterator, + List, + Optional, + Tuple, + Union, + cast, +) import einops import numpy as np @@ -26,6 +36,10 @@ class first, including the examples, and then skimming the available methods. Yo import transformer_lens.utilities as utils from transformer_lens.utilities import Slice, SliceInput, warn_if_mps +if TYPE_CHECKING: + from transformer_lens.components import TransformerBlock + from transformer_lens.HookedTransformer import HookedTransformer + def _normalize_projection_to_2d( project: Optional[torch.Tensor], @@ -124,9 +138,15 @@ class ActivationCache: Whether the activations have a batch dimension. """ - def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True): + def __init__( + self, + cache_dict: Dict[str, torch.Tensor], + model: Any, + has_batch_dim: bool = True, + ): self.cache_dict = cache_dict - self.model = model + # Helper methods require HT-internal structure; bridge users only use cache_dict. + self.model = cast("HookedTransformer", model) self.has_batch_dim = has_batch_dim self.has_embed = "hook_embed" in self.cache_dict self.has_pos_embed = "hook_pos_embed" in self.cache_dict @@ -715,6 +735,9 @@ def compute_head_results( residual stream from that head. attn_out for a layer is the sum of head results plus b_O. Intended use is to enable use_attn_results when running and caching the model, but this can be useful if you forget. + + Works for both HookedTransformer and TransformerBridge — bridge exposes + ``blocks[i].attn.W_O`` via its component-mapping compatibility shim. """ # Return if valid 4D results exist; replace stale 3D Bridge entries if needed first_key = "blocks.0.attn.hook_result" @@ -739,7 +762,9 @@ def compute_head_results( ) # Element-wise multiplication of z and W_O (with shape [head_index, d_head, d_model]) - result = z * self.model.blocks[layer].attn.W_O + # nn.ModuleList[T][i] is typed Tensor|Module upstream; cast restores T. + block = cast("TransformerBlock", self.model.blocks[layer]) + result = z * block.attn.W_O # Sum over d_head to get the contribution of each head to the residual stream self.cache_dict[f"blocks.{layer}.attn.hook_result"] = result.sum(dim=-2) @@ -895,7 +920,9 @@ def get_neuron_results( pos_slice = Slice(pos_slice) neuron_acts = self[("post", layer, "mlp")] - W_out = self.model.blocks[layer].mlp.W_out + # ModuleList[T] indexing is typed `Tensor | Module` upstream; cast restores T. + block = cast("TransformerBlock", self.model.blocks[layer]) + W_out = block.mlp.W_out if pos_slice is not None: # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures # that position dimension is -2 when we apply position slice @@ -952,6 +979,7 @@ def _stack_neuron_results_apply_ln_projected( ``LN_s(a_n * W_out_n) @ p = (a_n / s) * (W_out_n @ p - mean(W_out_n) * sum_p)`` RMS models drop the ``mean(W_out_n) * sum_p`` term (no centering). Always uses the ln1 scale (mlp_input=False) since ``stack_neuron_results`` doesn't expose mlp_input. + """ scale = self._get_cached_ln_scale(layer, mlp_input=False, pos_slice=pos_slice) @@ -960,7 +988,9 @@ def _stack_neuron_results_apply_ln_projected( components: list = [] for l in range(layer): - W_out_l = self.model.blocks[l].mlp.W_out # [d_mlp, d_model] + # nn.ModuleList[T][i] is typed Tensor|Module upstream; cast restores T. + block = cast("TransformerBlock", self.model.blocks[l]) + W_out_l = block.mlp.W_out # [d_mlp, d_model] W_out_l_sliced = neuron_slice.apply(W_out_l, dim=0) W_proj_l = W_out_l_sliced @ project_2d # [d_mlp, n_outs] if apply_centering: @@ -1037,8 +1067,10 @@ def stack_neuron_results( project_2d, squeeze_projected = _normalize_projection_to_2d(project_output_onto) + d_mlp = self.model.cfg.d_mlp + assert d_mlp is not None, "model.cfg.d_mlp must be set" neuron_labels: Union[torch.Tensor, np.ndarray] = neuron_slice.apply( - torch.arange(self.model.cfg.d_mlp), dim=0 + torch.arange(d_mlp), dim=0 ) if isinstance(neuron_labels, int): neuron_labels = np.array([neuron_labels]) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index c76f9c7b7..472425ec5 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -20,9 +20,9 @@ from transformer_lens import loading_from_pretrained as loading from transformer_lens.ActivationCache import ActivationCache from transformer_lens.components import MLP, Attention, BertBlock -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.FactoredMatrix import FactoredMatrix -from transformer_lens.hook_points import HookedRootModule +from transformer_lens.HookedRootModule import HookedRootModule from transformer_lens.utilities import devices T = TypeVar("T", bound="HookedAudioEncoder") diff --git a/transformer_lens/HookedEncoder.py b/transformer_lens/HookedEncoder.py index 4c239f3d8..f71a9c75a 100644 --- a/transformer_lens/HookedEncoder.py +++ b/transformer_lens/HookedEncoder.py @@ -29,9 +29,10 @@ Unembed, ) from transformer_lens.components.mlps.gated_mlp import GatedMLP -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.FactoredMatrix import FactoredMatrix -from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedRootModule import HookedRootModule from transformer_lens.utilities import devices T = TypeVar("T", bound="HookedEncoder") diff --git a/transformer_lens/HookedEncoderDecoder.py b/transformer_lens/HookedEncoderDecoder.py index e683d2f91..7bc56df9e 100644 --- a/transformer_lens/HookedEncoderDecoder.py +++ b/transformer_lens/HookedEncoderDecoder.py @@ -34,9 +34,10 @@ import transformer_lens.loading_from_pretrained as loading from transformer_lens.ActivationCache import ActivationCache from transformer_lens.components import MLP, Embed, GatedMLP, RMSNorm, T5Block, Unembed -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.FactoredMatrix import FactoredMatrix -from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedRootModule import HookedRootModule from transformer_lens.utilities import sample_logits, warn_if_mps from transformer_lens.utilities.multi_gpu import get_device_for_block_index diff --git a/transformer_lens/HookedRootModule.py b/transformer_lens/HookedRootModule.py new file mode 100644 index 000000000..262886057 --- /dev/null +++ b/transformer_lens/HookedRootModule.py @@ -0,0 +1,560 @@ +"""HookedRootModule. + +Base class extending :class:`torch.nn.Module` with hook-based introspection +utilities used by :class:`HookedTransformer` and friends. Lives in its own +module so that downstream code (e.g. :class:`ActivationCache`) can type-hint +against it without the broader ``hook_points`` import surface. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable, Iterable +from contextlib import contextmanager +from functools import partial +from typing import Any, Literal, Optional, Union, cast + +import torch +import torch.nn as nn +from torch import Tensor + +from transformer_lens.hook_points import ( + DeviceType, + HookFunction, + HookIntrospectionMixin, + HookPoint, + NamesFilter, +) +from transformer_lens.utilities import Slice, SliceInput, warn_if_mps + + +class HookedRootModule(HookIntrospectionMixin, nn.Module): + """A class building on nn.Module to interface nicely with HookPoints. + + Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, + and run_with_cache to run the model on some input and return a cache of all activations. + + Notes: + + The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the + module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add + the fixed version, the broken one is still there. To solve this, run_with_hooks will remove + hooks at the end by default, and I recommend using the API of this and run_with_cache. If you + want to add hooks into global state, I recommend being intentional about this, and I recommend + using reset_hooks liberally in your code to remove any accidentally remaining global state. + + The main time this goes wrong is when you want to use backward hooks (to cache or intervene on + gradients). In this case, you need to keep the hooks around as global state until you've run + loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks) + """ + + name: Optional[str] + mod_dict: dict[str, nn.Module] + hook_dict: dict[str, HookPoint] + + def __init__(self, *args: Any): + super().__init__() + self.is_caching = False + self.context_level = 0 + + def setup(self): + """ + Sets up model. + + This function must be called in the model's `__init__` method AFTER defining all layers. It + adds a parameter to each module containing its name, and builds a dictionary mapping module + names to the module instances. It also initializes a hook dictionary for modules of type + "HookPoint". + """ + self.mod_dict = {} + self.hook_dict = {} + for name, module in self.named_modules(): + if name == "": + continue + module.name = name + self.mod_dict[name] = module + # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):" + if isinstance(module, HookPoint): + self.hook_dict[name] = module + + def hook_points(self): + return self.hook_dict.values() + + def remove_all_hook_fns( + self, + direction: Literal["fwd", "bwd", "both"] = "both", + including_permanent: bool = False, + level: Optional[int] = None, + ): + for hp in self.hook_points(): + hp.remove_hooks(direction, including_permanent=including_permanent, level=level) + + def clear_contexts(self): + for hp in self.hook_points(): + hp.clear_context() + + def reset_hooks( + self, + clear_contexts: bool = True, + direction: Literal["fwd", "bwd", "both"] = "both", + including_permanent: bool = False, + level: Optional[int] = None, + ): + if clear_contexts: + self.clear_contexts() + self.remove_all_hook_fns(direction, including_permanent, level=level) + self.is_caching = False + + def check_and_add_hook( + self, + hook_point: HookPoint, + hook_point_name: str, + hook: HookFunction, + dir: Literal["fwd", "bwd"] = "fwd", + is_permanent: bool = False, + level: Optional[int] = None, + prepend: bool = False, + ) -> None: + """Runs checks on the hook, and then adds it to the hook point""" + + self.check_hooks_to_add( + hook_point, + hook_point_name, + hook, + dir=dir, + is_permanent=is_permanent, + prepend=prepend, + ) + hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend) + + def check_hooks_to_add( + self, + hook_point: HookPoint, + hook_point_name: str, + hook: HookFunction, + dir: Literal["fwd", "bwd"] = "fwd", + is_permanent: bool = False, + prepend: bool = False, + ) -> None: + """Override this function to add checks on which hooks should be added""" + pass + + def add_hook( + self, + name: Union[str, Callable[[str], bool]], + hook: HookFunction, + dir: Literal["fwd", "bwd"] = "fwd", + is_permanent: bool = False, + level: Optional[int] = None, + prepend: bool = False, + ) -> None: + if isinstance(name, str): + hook_point = self.mod_dict[name] + assert isinstance( + hook_point, HookPoint + ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes. + self.check_and_add_hook( + hook_point, + name, + hook, + dir=dir, + is_permanent=is_permanent, + level=level, + prepend=prepend, + ) + else: + # Otherwise, name is a Boolean function on names + for hook_point_name, hp in self.hook_dict.items(): + if name(hook_point_name): + self.check_and_add_hook( + hp, + hook_point_name, + hook, + dir=dir, + is_permanent=is_permanent, + level=level, + prepend=prepend, + ) + + def add_perma_hook( + self, + name: Union[str, Callable[[str], bool]], + hook: HookFunction, + dir: Literal["fwd", "bwd"] = "fwd", + ) -> None: + self.add_hook(name, hook, dir=dir, is_permanent=True) + + def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]): + """This function takes a key for the mod_dict and enables the related hook for that module + + Args: + name (str): The module name + hook (Callable): The hook to add + dir (Literal["fwd", "bwd"]): The direction for the hook + """ + hook_point_module = self.mod_dict[name] + if not hasattr(hook_point_module, "add_hook"): + raise TypeError(f"Expected a module with add_hook, got {type(hook_point_module)}") + if isinstance(hook_point_module, torch.Tensor): + raise TypeError( + "Module set as Tensor for some reason!" + ) # mypy seems to think these could be tensors after a torch update no idea why, or if this is possible + module_with_hook = cast(HookPoint, hook_point_module) + module_with_hook.add_hook(hook, dir=dir, level=self.context_level) + + def _enable_hooks_for_points( + self, + hook_points: Iterable[tuple[str, HookPoint]], + enabled: Callable, + hook: Callable, + dir: Literal["fwd", "bwd"], + ): + """Enables hooks for a list of points + + Args: + hook_points (Dict[str, HookPoint]): The hook points + enabled (Callable): _description_ + hook (Callable): _description_ + dir (Literal["fwd", "bwd"]): _description_ + """ + for hook_name, hook_point in hook_points: + if enabled(hook_name): + hook_point.add_hook(hook, dir=dir, level=self.context_level) + + def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]): + """Enables an individual hook on a hook point + + Args: + name (str): The name of the hook + hook (Callable): The actual hook + dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd". + """ + if isinstance(name, str): + self._enable_hook_with_name(name=name, hook=hook, dir=dir) + else: + self._enable_hooks_for_points( + hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir + ) + + @contextmanager + def hooks( + self, + fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], + bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], + reset_hooks_end: bool = True, + clear_contexts: bool = False, + ): + """ + A context manager for adding temporary hooks to the model. + + Args: + fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a + Boolean function on hook names and hook is the function to add to that hook point. + bwd_hooks: Same as fwd_hooks, but for the backward pass. + reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits. + clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. + + Example: + + .. code-block:: python + + with model.hooks(fwd_hooks=my_hooks): + hooked_loss = model(text, return_type="loss") + """ + try: + self.context_level += 1 + + for name, hook in fwd_hooks: + self._enable_hook(name=name, hook=hook, dir="fwd") + for name, hook in bwd_hooks: + self._enable_hook(name=name, hook=hook, dir="bwd") + yield self + finally: + if reset_hooks_end: + self.reset_hooks( + clear_contexts, including_permanent=False, level=self.context_level + ) + self.context_level -= 1 + + def run_with_hooks( + self, + *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific? + fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], + bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], + reset_hooks_end: bool = True, + clear_contexts: bool = False, + **model_kwargs: Any, + ): + """ + Runs the model with specified forward and backward hooks. + + Args: + fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is + either the name of a hook point or a boolean function on hook names, and hook is the + function to add to that hook point. Hooks with names that evaluate to True are added + respectively. + bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the + backward pass. + reset_hooks_end (bool): If True, all hooks are removed at the end, including those added + during this run. Default is True. + clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is + False. + *model_args: Positional arguments for the model. + **model_kwargs: Keyword arguments for the model's forward function. See your related + models forward pass for details as to what sort of arguments you can pass through. + + Note: + If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks + remain active. This function only runs a forward pass. + """ + if len(bwd_hooks) > 0 and reset_hooks_end: + logging.warning( + "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur." + ) + + with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model: + return hooked_model.forward(*model_args, **model_kwargs) + + def add_caching_hooks( + self, + names_filter: NamesFilter = None, + incl_bwd: bool = False, + device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not? + remove_batch_dim: bool = False, + cache: Optional[dict] = None, + ) -> dict: + """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately. + + Args: + names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. + incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. + device (_type_, optional): The device to store on. Defaults to same device as model. + remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. + cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. + + Returns: + cache (dict): The cache where activations will be stored. + """ + if device is not None: + warn_if_mps(device) + if cache is None: + cache = {} + + if names_filter is None: + names_filter = lambda name: True + elif isinstance(names_filter, str): + filter_str = names_filter + names_filter = lambda name: name == filter_str + elif isinstance(names_filter, list): + filter_list = names_filter + names_filter = lambda name: name in filter_list + + assert callable(names_filter), "names_filter must be a callable" + + self.is_caching = True + + def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool): + assert hook.name is not None + hook_name = hook.name + if is_backward: + hook_name += "_grad" + if remove_batch_dim: + cache[hook_name] = tensor.detach().to(device)[0] + else: + cache[hook_name] = tensor.detach().to(device) + + for name, hp in self.hook_dict.items(): + if names_filter(name): + hp.add_hook(partial(save_hook, is_backward=False), "fwd") + if incl_bwd: + hp.add_hook(partial(save_hook, is_backward=True), "bwd") + return cache + + def run_with_cache( + self, + *model_args: Any, + names_filter: NamesFilter = None, + device: DeviceType = None, + remove_batch_dim: bool = False, + incl_bwd: bool = False, + reset_hooks_end: bool = True, + clear_contexts: bool = False, + pos_slice: Optional[Union[Slice, SliceInput]] = None, + **model_kwargs: Any, + ): + """ + Runs the model and returns the model output and a Cache object. + + Args: + *model_args: Positional arguments for the model. + names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str, + list of str, or a function that takes a string and returns a bool. Defaults to None, which + means cache everything. + device (str or torch.Device, optional): The device to cache activations on. Defaults to the + model device. WARNING: Setting a different device than the one used by the model leads to + significant performance degradation. + remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only + makes sense with batch_size=1 inputs. Defaults to False. + incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients + as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss + functions are not supported. Defaults to False. + reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the + end of the run. Defaults to True. + clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset. + Defaults to False. + pos_slice: + The slice to apply to the cache output. Defaults to None, do nothing. + **model_kwargs: Keyword arguments for the model's forward function. See your related + models forward pass for details as to what sort of arguments you can pass through. + + Returns: + tuple: A tuple containing the model output and a Cache object. + + """ + + pos_slice = Slice.unwrap(pos_slice) + + cache_dict, fwd, bwd = self.get_caching_hooks( + names_filter, + incl_bwd, + device, + remove_batch_dim=remove_batch_dim, + pos_slice=pos_slice, + ) + + with self.hooks( + fwd_hooks=fwd, + bwd_hooks=bwd, + reset_hooks_end=reset_hooks_end, + clear_contexts=clear_contexts, + ): + model_out = self(*model_args, **model_kwargs) + if incl_bwd: + model_out.backward() + + return model_out, cache_dict + + def get_caching_hooks( + self, + names_filter: NamesFilter = None, + incl_bwd: bool = False, + device: DeviceType = None, + remove_batch_dim: bool = False, + cache: Optional[dict] = None, + pos_slice: Optional[Union[Slice, SliceInput]] = None, + ) -> tuple[dict, list, list]: + """Creates hooks to cache activations. Note: It does not add the hooks to the model. + + Args: + names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. + incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. + device (_type_, optional): The device to store on. Keeps on the same device as the layer if None. + remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. + cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. + + Returns: + cache (dict): The cache where activations will be stored. + fwd_hooks (list): The forward hooks. + bwd_hooks (list): The backward hooks. Empty if incl_bwd is False. + """ + if device is not None: + warn_if_mps(device) + if cache is None: + cache = {} + + pos_slice = Slice.unwrap(pos_slice) + + if names_filter is None: + names_filter = lambda name: True + elif isinstance(names_filter, str): + filter_str = names_filter + names_filter = lambda name: name == filter_str + elif isinstance(names_filter, list): + filter_list = names_filter + names_filter = lambda name: name in filter_list + elif callable(names_filter): + names_filter = names_filter + else: + raise ValueError("names_filter must be a string, list of strings, or function") + assert callable(names_filter) # Callable[[str], bool] + + self.is_caching = True + + def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool = False): + # for attention heads the pos dimension is the third from last + if hook.name is None: + raise RuntimeError("Hook should have been provided a name") + + hook_name = hook.name + if is_backward: + hook_name += "_grad" + resid_stream = tensor.detach().to(device) + if remove_batch_dim: + resid_stream = resid_stream[0] + + if ( + hook.name.endswith("hook_q") + or hook.name.endswith("hook_k") + or hook.name.endswith("hook_v") + or hook.name.endswith("hook_z") + or hook.name.endswith("hook_result") + ): + pos_dim = -3 + else: + # for all other components the pos dimension is the second from last + # including the attn scores where the dest token is the second from last + pos_dim = -2 + + if ( + tensor.dim() >= -pos_dim + ): # check if the residual stream has a pos dimension before trying to slice + resid_stream = pos_slice.apply(resid_stream, dim=pos_dim) + cache[hook_name] = resid_stream + + fwd_hooks = [] + bwd_hooks = [] + for name, _ in self.hook_dict.items(): + if names_filter(name): + fwd_hooks.append((name, partial(save_hook, is_backward=False))) + if incl_bwd: + bwd_hooks.append((name, partial(save_hook, is_backward=True))) + + return cache, fwd_hooks, bwd_hooks + + def cache_all( + self, + cache: Optional[dict], + incl_bwd: bool = False, + device: DeviceType = None, + remove_batch_dim: bool = False, + ): + logging.warning( + "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" + ) + self.add_caching_hooks( + names_filter=lambda name: True, + cache=cache, + incl_bwd=incl_bwd, + device=device, + remove_batch_dim=remove_batch_dim, + ) + + def cache_some( + self, + cache: Optional[dict], + names: Callable[[str], bool], + incl_bwd: bool = False, + device: DeviceType = None, + remove_batch_dim: bool = False, + ): + """Cache a list of hook provided by names, Boolean function on names""" + logging.warning( + "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" + ) + self.add_caching_hooks( + names_filter=names, + cache=cache, + incl_bwd=incl_bwd, + device=device, + remove_batch_dim=remove_batch_dim, + ) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 103e532bf..9baf479e5 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -59,9 +59,10 @@ ) from transformer_lens.components.mlps.gated_mlp import GatedMLP from transformer_lens.components.mlps.mlp import MLP -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.FactoredMatrix import FactoredMatrix -from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedRootModule import HookedRootModule from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES from transformer_lens.utilities import ( USE_DEFAULT_VALUE, diff --git a/transformer_lens/cache/key_value_cache.py b/transformer_lens/cache/key_value_cache.py index a600afacd..3ccf81af9 100644 --- a/transformer_lens/cache/key_value_cache.py +++ b/transformer_lens/cache/key_value_cache.py @@ -10,13 +10,15 @@ import torch from jaxtyping import Int -from transformer_lens.config.TransformerLensConfig import TransformerLensConfig +from transformer_lens.config.transformer_lens_config import TransformerLensConfig from transformer_lens.utilities.multi_gpu import get_device_for_block_index from .key_value_cache_entry import TransformerLensKeyValueCacheEntry if TYPE_CHECKING: - from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig + from transformer_lens.config.hooked_transformer_config import ( + HookedTransformerConfig, + ) @dataclass diff --git a/transformer_lens/cache/key_value_cache_entry.py b/transformer_lens/cache/key_value_cache_entry.py index b8c8c57a8..f2f04e53c 100644 --- a/transformer_lens/cache/key_value_cache_entry.py +++ b/transformer_lens/cache/key_value_cache_entry.py @@ -10,7 +10,7 @@ import torch from jaxtyping import Float -from transformer_lens.config.TransformerLensConfig import TransformerLensConfig +from transformer_lens.config.transformer_lens_config import TransformerLensConfig @dataclass diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 8cd51ea91..f2dd0338b 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -15,7 +15,7 @@ TransformerLensKeyValueCacheEntry, ) from transformer_lens.components.rms_norm import RMSNorm -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookPoint from transformer_lens.utilities import get_offset_position_ids diff --git a/transformer_lens/components/attention.py b/transformer_lens/components/attention.py index 604f3261d..cfce395d4 100644 --- a/transformer_lens/components/attention.py +++ b/transformer_lens/components/attention.py @@ -10,7 +10,7 @@ from transformers.utils import is_bitsandbytes_available from transformer_lens.components import AbstractAttention -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig if is_bitsandbytes_available(): from bitsandbytes.nn.modules import Params4bit diff --git a/transformer_lens/components/bert_block.py b/transformer_lens/components/bert_block.py index 8eb0e45ac..fec0d052d 100644 --- a/transformer_lens/components/bert_block.py +++ b/transformer_lens/components/bert_block.py @@ -10,7 +10,7 @@ from jaxtyping import Float from transformer_lens.components import Attention, LayerNorm -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.factories.mlp_factory import MLPFactory from transformer_lens.hook_points import HookPoint from transformer_lens.utilities import repeat_along_head_dimension diff --git a/transformer_lens/components/bert_embed.py b/transformer_lens/components/bert_embed.py index a74aee351..e0648a7f0 100644 --- a/transformer_lens/components/bert_embed.py +++ b/transformer_lens/components/bert_embed.py @@ -11,7 +11,7 @@ from jaxtyping import Float, Int from transformer_lens.components import Embed, LayerNorm, PosEmbed, TokenTypeEmbed -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.hook_points import HookPoint diff --git a/transformer_lens/components/bert_mlm_head.py b/transformer_lens/components/bert_mlm_head.py index 413709cda..b229a20b6 100644 --- a/transformer_lens/components/bert_mlm_head.py +++ b/transformer_lens/components/bert_mlm_head.py @@ -10,7 +10,7 @@ from jaxtyping import Float from transformer_lens.components import LayerNorm -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig class BertMLMHead(nn.Module): diff --git a/transformer_lens/components/bert_nsp_head.py b/transformer_lens/components/bert_nsp_head.py index 8757c862e..937adc36d 100644 --- a/transformer_lens/components/bert_nsp_head.py +++ b/transformer_lens/components/bert_nsp_head.py @@ -9,7 +9,7 @@ import torch.nn as nn from jaxtyping import Float -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.hook_points import HookPoint diff --git a/transformer_lens/components/bert_pooler.py b/transformer_lens/components/bert_pooler.py index 873f15428..784d9598d 100644 --- a/transformer_lens/components/bert_pooler.py +++ b/transformer_lens/components/bert_pooler.py @@ -9,7 +9,7 @@ import torch.nn as nn from jaxtyping import Float -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.hook_points import HookPoint diff --git a/transformer_lens/components/embed.py b/transformer_lens/components/embed.py index cd8d046d8..addfff8be 100644 --- a/transformer_lens/components/embed.py +++ b/transformer_lens/components/embed.py @@ -10,7 +10,7 @@ from jaxtyping import Float, Int from transformer_lens.components import LayerNorm -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig # Embed & Unembed diff --git a/transformer_lens/components/grouped_query_attention.py b/transformer_lens/components/grouped_query_attention.py index e59cda644..d9fbbf8a8 100644 --- a/transformer_lens/components/grouped_query_attention.py +++ b/transformer_lens/components/grouped_query_attention.py @@ -6,7 +6,7 @@ from transformer_lens.components import AbstractAttention from transformer_lens.components.rms_norm import RMSNorm -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear diff --git a/transformer_lens/components/layer_norm.py b/transformer_lens/components/layer_norm.py index a43465fe3..061432bfc 100644 --- a/transformer_lens/components/layer_norm.py +++ b/transformer_lens/components/layer_norm.py @@ -9,7 +9,7 @@ import torch.nn as nn from jaxtyping import Float -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.hook_points import HookPoint diff --git a/transformer_lens/components/layer_norm_pre.py b/transformer_lens/components/layer_norm_pre.py index fc9f40a19..e54890da8 100644 --- a/transformer_lens/components/layer_norm_pre.py +++ b/transformer_lens/components/layer_norm_pre.py @@ -9,7 +9,7 @@ import torch.nn as nn from jaxtyping import Float -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.hook_points import HookPoint diff --git a/transformer_lens/components/mlps/can_be_used_as_mlp.py b/transformer_lens/components/mlps/can_be_used_as_mlp.py index e0d75c104..9c031e6d8 100644 --- a/transformer_lens/components/mlps/can_be_used_as_mlp.py +++ b/transformer_lens/components/mlps/can_be_used_as_mlp.py @@ -13,7 +13,7 @@ from transformer_lens.components.layer_norm import LayerNorm from transformer_lens.components.layer_norm_pre import LayerNormPre -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.factories.activation_function_factory import ( ActivationFunctionFactory, ) @@ -37,6 +37,11 @@ class CanBeUsedAsMLP(nn.Module): # The layer norm component if the activation function is a layer norm ln: Optional[nn.Module] + # MLP weight matrices (Parameter on subclasses; declared here so callers like + # ActivationCache.get_neuron_results get a typed Tensor instead of nn.Module). + W_in: torch.Tensor + W_out: torch.Tensor + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): """The base init for all MLP like components diff --git a/transformer_lens/components/mlps/gated_mlp.py b/transformer_lens/components/mlps/gated_mlp.py index 20ff42209..b19f6e3d5 100644 --- a/transformer_lens/components/mlps/gated_mlp.py +++ b/transformer_lens/components/mlps/gated_mlp.py @@ -12,7 +12,7 @@ from transformers.utils import is_bitsandbytes_available from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.hook_points import HookPoint if is_bitsandbytes_available(): diff --git a/transformer_lens/components/mlps/gated_mlp_4bit.py b/transformer_lens/components/mlps/gated_mlp_4bit.py index a2a1b0451..1eb489f1a 100644 --- a/transformer_lens/components/mlps/gated_mlp_4bit.py +++ b/transformer_lens/components/mlps/gated_mlp_4bit.py @@ -11,7 +11,7 @@ from transformers.utils import is_bitsandbytes_available from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.hook_points import HookPoint if is_bitsandbytes_available(): @@ -30,6 +30,12 @@ class GatedMLP4Bit(CanBeUsedAsMLP): In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out """ + # Narrow base-class W_in/W_out (declared as torch.Tensor) to bnb's Params4bit + # so .quant_state attribute access type-checks. + W_in: "Params4bit" + W_gate: "Params4bit" + W_out: "Params4bit" + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): super().__init__(cfg) self.select_activation_function() diff --git a/transformer_lens/components/mlps/gpt_oss_moe.py b/transformer_lens/components/mlps/gpt_oss_moe.py index 56ccebf83..377f77fb6 100644 --- a/transformer_lens/components/mlps/gpt_oss_moe.py +++ b/transformer_lens/components/mlps/gpt_oss_moe.py @@ -15,7 +15,7 @@ from jaxtyping import Float from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.hook_points import HookPoint GPT_OSS_ALPHA = 1.702 diff --git a/transformer_lens/components/mlps/mlp.py b/transformer_lens/components/mlps/mlp.py index 23772c5a0..9a95f8305 100644 --- a/transformer_lens/components/mlps/mlp.py +++ b/transformer_lens/components/mlps/mlp.py @@ -10,7 +10,7 @@ from jaxtyping import Float from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.hook_points import HookPoint from transformer_lens.utilities.addmm import batch_addmm diff --git a/transformer_lens/components/mlps/moe.py b/transformer_lens/components/mlps/moe.py index 305cdcf2b..63c61a6e9 100644 --- a/transformer_lens/components/mlps/moe.py +++ b/transformer_lens/components/mlps/moe.py @@ -6,7 +6,7 @@ from jaxtyping import Float from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.factories.activation_function_factory import ( ActivationFunctionFactory, ) diff --git a/transformer_lens/components/pos_embed.py b/transformer_lens/components/pos_embed.py index e6319e1b3..d05d28400 100644 --- a/transformer_lens/components/pos_embed.py +++ b/transformer_lens/components/pos_embed.py @@ -10,7 +10,7 @@ import torch.nn as nn from jaxtyping import Float, Int -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.utilities import get_offset_position_ids diff --git a/transformer_lens/components/rms_norm.py b/transformer_lens/components/rms_norm.py index 9d7fc0890..dd4238ab4 100644 --- a/transformer_lens/components/rms_norm.py +++ b/transformer_lens/components/rms_norm.py @@ -9,7 +9,7 @@ import torch.nn as nn from jaxtyping import Float -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.hook_points import HookPoint # RMSNorm operates on the last dimension and supports both 2D and 3D inputs. diff --git a/transformer_lens/components/rms_norm_pre.py b/transformer_lens/components/rms_norm_pre.py index 149f06845..8742f2444 100644 --- a/transformer_lens/components/rms_norm_pre.py +++ b/transformer_lens/components/rms_norm_pre.py @@ -9,7 +9,7 @@ import torch.nn as nn from jaxtyping import Float -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.hook_points import HookPoint diff --git a/transformer_lens/components/t5_attention.py b/transformer_lens/components/t5_attention.py index d15f24bed..0c4a9ffd6 100644 --- a/transformer_lens/components/t5_attention.py +++ b/transformer_lens/components/t5_attention.py @@ -6,7 +6,7 @@ from jaxtyping import Float, Int from transformer_lens.components.abstract_attention import AbstractAttention -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.hook_points import HookPoint diff --git a/transformer_lens/components/t5_block.py b/transformer_lens/components/t5_block.py index e83fc7b7f..88d5467e0 100644 --- a/transformer_lens/components/t5_block.py +++ b/transformer_lens/components/t5_block.py @@ -8,7 +8,7 @@ TransformerLensKeyValueCacheEntry, ) from transformer_lens.components import RMSNorm, T5Attention -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.factories.mlp_factory import MLPFactory from transformer_lens.hook_points import HookPoint from transformer_lens.utilities import repeat_along_head_dimension diff --git a/transformer_lens/components/token_typed_embed.py b/transformer_lens/components/token_typed_embed.py index 0c6ed3db0..1663a58d9 100644 --- a/transformer_lens/components/token_typed_embed.py +++ b/transformer_lens/components/token_typed_embed.py @@ -9,7 +9,7 @@ import torch.nn as nn from jaxtyping import Int -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig class TokenTypeEmbed(nn.Module): diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py index deff95b85..92ac7484a 100644 --- a/transformer_lens/components/transformer_block.py +++ b/transformer_lens/components/transformer_block.py @@ -21,7 +21,7 @@ RMSNormPre, ) from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.factories.mlp_factory import MLPFactory from transformer_lens.hook_points import HookPoint from transformer_lens.utilities import repeat_along_head_dimension diff --git a/transformer_lens/components/unembed.py b/transformer_lens/components/unembed.py index 4808af4e9..4ffa3df49 100644 --- a/transformer_lens/components/unembed.py +++ b/transformer_lens/components/unembed.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from jaxtyping import Float -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.hook_points import HookPoint diff --git a/transformer_lens/config/__init__.py b/transformer_lens/config/__init__.py index 484691fb0..3d6142ebb 100644 --- a/transformer_lens/config/__init__.py +++ b/transformer_lens/config/__init__.py @@ -1,7 +1,7 @@ """Configuration classes for TransformerLens.""" -from .HookedTransformerConfig import HookedTransformerConfig -from .TransformerBridgeConfig import TransformerBridgeConfig -from .TransformerLensConfig import TransformerLensConfig +from .hooked_transformer_config import HookedTransformerConfig +from .transformer_bridge_config import TransformerBridgeConfig +from .transformer_lens_config import TransformerLensConfig __all__ = ["HookedTransformerConfig", "TransformerBridgeConfig", "TransformerLensConfig"] diff --git a/transformer_lens/config/HookedTransformerConfig.py b/transformer_lens/config/hooked_transformer_config.py similarity index 99% rename from transformer_lens/config/HookedTransformerConfig.py rename to transformer_lens/config/hooked_transformer_config.py index 6e4b95150..57ed34745 100644 --- a/transformer_lens/config/HookedTransformerConfig.py +++ b/transformer_lens/config/hooked_transformer_config.py @@ -17,7 +17,7 @@ from transformer_lens.utilities.activation_functions import SUPPORTED_ACTIVATIONS from transformer_lens.utilities.devices import get_device -from .TransformerLensConfig import TransformerLensConfig +from .transformer_lens_config import TransformerLensConfig @dataclass diff --git a/transformer_lens/config/TransformerBridgeConfig.py b/transformer_lens/config/transformer_bridge_config.py similarity index 99% rename from transformer_lens/config/TransformerBridgeConfig.py rename to transformer_lens/config/transformer_bridge_config.py index f83210566..44ea7f055 100644 --- a/transformer_lens/config/TransformerBridgeConfig.py +++ b/transformer_lens/config/transformer_bridge_config.py @@ -4,7 +4,7 @@ import torch -from .TransformerLensConfig import TransformerLensConfig +from .transformer_lens_config import TransformerLensConfig class TransformerBridgeConfig(TransformerLensConfig): diff --git a/transformer_lens/config/TransformerLensConfig.py b/transformer_lens/config/transformer_lens_config.py similarity index 100% rename from transformer_lens/config/TransformerLensConfig.py rename to transformer_lens/config/transformer_lens_config.py diff --git a/transformer_lens/factories/activation_function_factory.py b/transformer_lens/factories/activation_function_factory.py index e5f595a60..511134d98 100644 --- a/transformer_lens/factories/activation_function_factory.py +++ b/transformer_lens/factories/activation_function_factory.py @@ -3,7 +3,7 @@ Centralized location for selection supported activation functions throughout TransformerLens """ -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.utilities.activation_functions import ( SUPPORTED_ACTIVATIONS, XIELU, diff --git a/transformer_lens/factories/mlp_factory.py b/transformer_lens/factories/mlp_factory.py index 858b058a6..6947ccf43 100644 --- a/transformer_lens/factories/mlp_factory.py +++ b/transformer_lens/factories/mlp_factory.py @@ -9,7 +9,7 @@ from transformer_lens.components.mlps.gpt_oss_moe import GptOssMoE from transformer_lens.components.mlps.mlp import MLP from transformer_lens.components.mlps.moe import MoE -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig class MLPFactory: diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index 7ec1ed7c1..9fdb97d44 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -5,21 +5,17 @@ Helpers to access activations in models. """ -import logging -from collections.abc import Callable, Iterable, Sequence -from contextlib import contextmanager +from collections.abc import Callable, Sequence from dataclasses import dataclass from functools import partial from typing import ( Any, Callable, - Iterable, Literal, Optional, Protocol, Sequence, Union, - cast, runtime_checkable, ) @@ -32,7 +28,6 @@ from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( BaseTensorConversion, ) -from transformer_lens.utilities import Slice, SliceInput, warn_if_mps @dataclass @@ -388,7 +383,7 @@ def layer(self): # %% class HookIntrospectionMixin: - """``list_hooks()`` mixin for any class exposing a ``hook_dict``. + """``list_hooks()`` mixins for any class exposing a ``hook_dict``. Accessed via ``getattr`` so subclasses can provide ``hook_dict`` as either an instance attribute (``HookedRootModule``) or a ``@property`` (``TransformerBridge``). @@ -435,536 +430,8 @@ def list_hooks( return out -class HookedRootModule(HookIntrospectionMixin, nn.Module): - """A class building on nn.Module to interface nicely with HookPoints. - - Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, - and run_with_cache to run the model on some input and return a cache of all activations. - - Notes: - - The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the - module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add - the fixed version, the broken one is still there. To solve this, run_with_hooks will remove - hooks at the end by default, and I recommend using the API of this and run_with_cache. If you - want to add hooks into global state, I recommend being intentional about this, and I recommend - using reset_hooks liberally in your code to remove any accidentally remaining global state. - - The main time this goes wrong is when you want to use backward hooks (to cache or intervene on - gradients). In this case, you need to keep the hooks around as global state until you've run - loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks) - """ - - name: Optional[str] - mod_dict: dict[str, nn.Module] - hook_dict: dict[str, HookPoint] - - def __init__(self, *args: Any): - super().__init__() - self.is_caching = False - self.context_level = 0 - - def setup(self): - """ - Sets up model. - - This function must be called in the model's `__init__` method AFTER defining all layers. It - adds a parameter to each module containing its name, and builds a dictionary mapping module - names to the module instances. It also initializes a hook dictionary for modules of type - "HookPoint". - """ - self.mod_dict = {} - self.hook_dict = {} - for name, module in self.named_modules(): - if name == "": - continue - module.name = name - self.mod_dict[name] = module - # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):" - if isinstance(module, HookPoint): - self.hook_dict[name] = module - - def hook_points(self): - return self.hook_dict.values() - - def remove_all_hook_fns( - self, - direction: Literal["fwd", "bwd", "both"] = "both", - including_permanent: bool = False, - level: Optional[int] = None, - ): - for hp in self.hook_points(): - hp.remove_hooks(direction, including_permanent=including_permanent, level=level) - - def clear_contexts(self): - for hp in self.hook_points(): - hp.clear_context() - - def reset_hooks( - self, - clear_contexts: bool = True, - direction: Literal["fwd", "bwd", "both"] = "both", - including_permanent: bool = False, - level: Optional[int] = None, - ): - if clear_contexts: - self.clear_contexts() - self.remove_all_hook_fns(direction, including_permanent, level=level) - self.is_caching = False - - def check_and_add_hook( - self, - hook_point: HookPoint, - hook_point_name: str, - hook: HookFunction, - dir: Literal["fwd", "bwd"] = "fwd", - is_permanent: bool = False, - level: Optional[int] = None, - prepend: bool = False, - ) -> None: - """Runs checks on the hook, and then adds it to the hook point""" - - self.check_hooks_to_add( - hook_point, - hook_point_name, - hook, - dir=dir, - is_permanent=is_permanent, - prepend=prepend, - ) - hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend) - - def check_hooks_to_add( - self, - hook_point: HookPoint, - hook_point_name: str, - hook: HookFunction, - dir: Literal["fwd", "bwd"] = "fwd", - is_permanent: bool = False, - prepend: bool = False, - ) -> None: - """Override this function to add checks on which hooks should be added""" - pass - - def add_hook( - self, - name: Union[str, Callable[[str], bool]], - hook: HookFunction, - dir: Literal["fwd", "bwd"] = "fwd", - is_permanent: bool = False, - level: Optional[int] = None, - prepend: bool = False, - ) -> None: - if isinstance(name, str): - hook_point = self.mod_dict[name] - assert isinstance( - hook_point, HookPoint - ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes. - self.check_and_add_hook( - hook_point, - name, - hook, - dir=dir, - is_permanent=is_permanent, - level=level, - prepend=prepend, - ) - else: - # Otherwise, name is a Boolean function on names - for hook_point_name, hp in self.hook_dict.items(): - if name(hook_point_name): - self.check_and_add_hook( - hp, - hook_point_name, - hook, - dir=dir, - is_permanent=is_permanent, - level=level, - prepend=prepend, - ) - - def add_perma_hook( - self, - name: Union[str, Callable[[str], bool]], - hook: HookFunction, - dir: Literal["fwd", "bwd"] = "fwd", - ) -> None: - self.add_hook(name, hook, dir=dir, is_permanent=True) - - def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]): - """This function takes a key for the mod_dict and enables the related hook for that module - - Args: - name (str): The module name - hook (Callable): The hook to add - dir (Literal["fwd", "bwd"]): The direction for the hook - """ - hook_point_module = self.mod_dict[name] - if not hasattr(hook_point_module, "add_hook"): - raise TypeError(f"Expected a module with add_hook, got {type(hook_point_module)}") - if isinstance(hook_point_module, torch.Tensor): - raise TypeError( - "Module set as Tensor for some reason!" - ) # mypy seems to think these could be tensors after a torch update no idea why, or if this is possible - module_with_hook = cast(HookPoint, hook_point_module) - module_with_hook.add_hook(hook, dir=dir, level=self.context_level) - - def _enable_hooks_for_points( - self, - hook_points: Iterable[tuple[str, HookPoint]], - enabled: Callable, - hook: Callable, - dir: Literal["fwd", "bwd"], - ): - """Enables hooks for a list of points - - Args: - hook_points (Dict[str, HookPoint]): The hook points - enabled (Callable): _description_ - hook (Callable): _description_ - dir (Literal["fwd", "bwd"]): _description_ - """ - for hook_name, hook_point in hook_points: - if enabled(hook_name): - hook_point.add_hook(hook, dir=dir, level=self.context_level) - - def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]): - """Enables an individual hook on a hook point - - Args: - name (str): The name of the hook - hook (Callable): The actual hook - dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd". - """ - if isinstance(name, str): - self._enable_hook_with_name(name=name, hook=hook, dir=dir) - else: - self._enable_hooks_for_points( - hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir - ) - - @contextmanager - def hooks( - self, - fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], - bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], - reset_hooks_end: bool = True, - clear_contexts: bool = False, - ): - """ - A context manager for adding temporary hooks to the model. - - Args: - fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a - Boolean function on hook names and hook is the function to add to that hook point. - bwd_hooks: Same as fwd_hooks, but for the backward pass. - reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits. - clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. - - Example: - - .. code-block:: python - - with model.hooks(fwd_hooks=my_hooks): - hooked_loss = model(text, return_type="loss") - """ - try: - self.context_level += 1 - - for name, hook in fwd_hooks: - self._enable_hook(name=name, hook=hook, dir="fwd") - for name, hook in bwd_hooks: - self._enable_hook(name=name, hook=hook, dir="bwd") - yield self - finally: - if reset_hooks_end: - self.reset_hooks( - clear_contexts, including_permanent=False, level=self.context_level - ) - self.context_level -= 1 - - def run_with_hooks( - self, - *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific? - fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], - bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], - reset_hooks_end: bool = True, - clear_contexts: bool = False, - **model_kwargs: Any, - ): - """ - Runs the model with specified forward and backward hooks. - - Args: - fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is - either the name of a hook point or a boolean function on hook names, and hook is the - function to add to that hook point. Hooks with names that evaluate to True are added - respectively. - bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the - backward pass. - reset_hooks_end (bool): If True, all hooks are removed at the end, including those added - during this run. Default is True. - clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is - False. - *model_args: Positional arguments for the model. - **model_kwargs: Keyword arguments for the model's forward function. See your related - models forward pass for details as to what sort of arguments you can pass through. - - Note: - If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks - remain active. This function only runs a forward pass. - """ - if len(bwd_hooks) > 0 and reset_hooks_end: - logging.warning( - "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur." - ) - - with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model: - return hooked_model.forward(*model_args, **model_kwargs) - - def add_caching_hooks( - self, - names_filter: NamesFilter = None, - incl_bwd: bool = False, - device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not? - remove_batch_dim: bool = False, - cache: Optional[dict] = None, - ) -> dict: - """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately. - - Args: - names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. - incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. - device (_type_, optional): The device to store on. Defaults to same device as model. - remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. - cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. - - Returns: - cache (dict): The cache where activations will be stored. - """ - if device is not None: - warn_if_mps(device) - if cache is None: - cache = {} - - if names_filter is None: - names_filter = lambda name: True - elif isinstance(names_filter, str): - filter_str = names_filter - names_filter = lambda name: name == filter_str - elif isinstance(names_filter, list): - filter_list = names_filter - names_filter = lambda name: name in filter_list - - assert callable(names_filter), "names_filter must be a callable" - - self.is_caching = True - - def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool): - assert hook.name is not None - hook_name = hook.name - if is_backward: - hook_name += "_grad" - if remove_batch_dim: - cache[hook_name] = tensor.detach().to(device)[0] - else: - cache[hook_name] = tensor.detach().to(device) - - for name, hp in self.hook_dict.items(): - if names_filter(name): - hp.add_hook(partial(save_hook, is_backward=False), "fwd") - if incl_bwd: - hp.add_hook(partial(save_hook, is_backward=True), "bwd") - return cache - - def run_with_cache( - self, - *model_args: Any, - names_filter: NamesFilter = None, - device: DeviceType = None, - remove_batch_dim: bool = False, - incl_bwd: bool = False, - reset_hooks_end: bool = True, - clear_contexts: bool = False, - pos_slice: Optional[Union[Slice, SliceInput]] = None, - **model_kwargs: Any, - ): - """ - Runs the model and returns the model output and a Cache object. - - Args: - *model_args: Positional arguments for the model. - names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str, - list of str, or a function that takes a string and returns a bool. Defaults to None, which - means cache everything. - device (str or torch.Device, optional): The device to cache activations on. Defaults to the - model device. WARNING: Setting a different device than the one used by the model leads to - significant performance degradation. - remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only - makes sense with batch_size=1 inputs. Defaults to False. - incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients - as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss - functions are not supported. Defaults to False. - reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the - end of the run. Defaults to True. - clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset. - Defaults to False. - pos_slice: - The slice to apply to the cache output. Defaults to None, do nothing. - **model_kwargs: Keyword arguments for the model's forward function. See your related - models forward pass for details as to what sort of arguments you can pass through. - - Returns: - tuple: A tuple containing the model output and a Cache object. - - """ - - pos_slice = Slice.unwrap(pos_slice) - - cache_dict, fwd, bwd = self.get_caching_hooks( - names_filter, - incl_bwd, - device, - remove_batch_dim=remove_batch_dim, - pos_slice=pos_slice, - ) - - with self.hooks( - fwd_hooks=fwd, - bwd_hooks=bwd, - reset_hooks_end=reset_hooks_end, - clear_contexts=clear_contexts, - ): - model_out = self(*model_args, **model_kwargs) - if incl_bwd: - model_out.backward() - - return model_out, cache_dict - - def get_caching_hooks( - self, - names_filter: NamesFilter = None, - incl_bwd: bool = False, - device: DeviceType = None, - remove_batch_dim: bool = False, - cache: Optional[dict] = None, - pos_slice: Optional[Union[Slice, SliceInput]] = None, - ) -> tuple[dict, list, list]: - """Creates hooks to cache activations. Note: It does not add the hooks to the model. - - Args: - names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. - incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. - device (_type_, optional): The device to store on. Keeps on the same device as the layer if None. - remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. - cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. - - Returns: - cache (dict): The cache where activations will be stored. - fwd_hooks (list): The forward hooks. - bwd_hooks (list): The backward hooks. Empty if incl_bwd is False. - """ - if device is not None: - warn_if_mps(device) - if cache is None: - cache = {} - - pos_slice = Slice.unwrap(pos_slice) - - if names_filter is None: - names_filter = lambda name: True - elif isinstance(names_filter, str): - filter_str = names_filter - names_filter = lambda name: name == filter_str - elif isinstance(names_filter, list): - filter_list = names_filter - names_filter = lambda name: name in filter_list - elif callable(names_filter): - names_filter = names_filter - else: - raise ValueError("names_filter must be a string, list of strings, or function") - assert callable(names_filter) # Callable[[str], bool] - - self.is_caching = True - - def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool = False): - # for attention heads the pos dimension is the third from last - if hook.name is None: - raise RuntimeError("Hook should have been provided a name") - - hook_name = hook.name - if is_backward: - hook_name += "_grad" - resid_stream = tensor.detach().to(device) - if remove_batch_dim: - resid_stream = resid_stream[0] - - if ( - hook.name.endswith("hook_q") - or hook.name.endswith("hook_k") - or hook.name.endswith("hook_v") - or hook.name.endswith("hook_z") - or hook.name.endswith("hook_result") - ): - pos_dim = -3 - else: - # for all other components the pos dimension is the second from last - # including the attn scores where the dest token is the second from last - pos_dim = -2 - - if ( - tensor.dim() >= -pos_dim - ): # check if the residual stream has a pos dimension before trying to slice - resid_stream = pos_slice.apply(resid_stream, dim=pos_dim) - cache[hook_name] = resid_stream - - fwd_hooks = [] - bwd_hooks = [] - for name, _ in self.hook_dict.items(): - if names_filter(name): - fwd_hooks.append((name, partial(save_hook, is_backward=False))) - if incl_bwd: - bwd_hooks.append((name, partial(save_hook, is_backward=True))) - - return cache, fwd_hooks, bwd_hooks - - def cache_all( - self, - cache: Optional[dict], - incl_bwd: bool = False, - device: DeviceType = None, - remove_batch_dim: bool = False, - ): - logging.warning( - "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" - ) - self.add_caching_hooks( - names_filter=lambda name: True, - cache=cache, - incl_bwd=incl_bwd, - device=device, - remove_batch_dim=remove_batch_dim, - ) - - def cache_some( - self, - cache: Optional[dict], - names: Callable[[str], bool], - incl_bwd: bool = False, - device: DeviceType = None, - remove_batch_dim: bool = False, - ): - """Cache a list of hook provided by names, Boolean function on names""" - logging.warning( - "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" - ) - self.add_caching_hooks( - names_filter=names, - cache=cache, - incl_bwd=incl_bwd, - device=device, - remove_batch_dim=remove_batch_dim, - ) +# HookedRootModule moved to transformer_lens.HookedRootModule (3.0). Import it from +# its dedicated module — there is no re-export here. # %% diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index c1ab50538..a916864b5 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -25,7 +25,7 @@ ) import transformer_lens.utilities as utils -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig from transformer_lens.pretrained.weight_conversions import ( convert_apertus_weights, convert_bert_weights, diff --git a/transformer_lens/pretrained/weight_conversions/apertus.py b/transformer_lens/pretrained/weight_conversions/apertus.py index 28b03e9ee..304a49da1 100644 --- a/transformer_lens/pretrained/weight_conversions/apertus.py +++ b/transformer_lens/pretrained/weight_conversions/apertus.py @@ -11,7 +11,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig logger = logging.getLogger(__name__) diff --git a/transformer_lens/pretrained/weight_conversions/bert.py b/transformer_lens/pretrained/weight_conversions/bert.py index 4a072e995..38f4843ae 100644 --- a/transformer_lens/pretrained/weight_conversions/bert.py +++ b/transformer_lens/pretrained/weight_conversions/bert.py @@ -1,6 +1,6 @@ import einops -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_bert_weights(bert, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/bloom.py b/transformer_lens/pretrained/weight_conversions/bloom.py index 23b7d54c8..fe3e41278 100644 --- a/transformer_lens/pretrained/weight_conversions/bloom.py +++ b/transformer_lens/pretrained/weight_conversions/bloom.py @@ -1,6 +1,6 @@ import einops -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_bloom_weights(bloom, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/coder.py b/transformer_lens/pretrained/weight_conversions/coder.py index b6aeb930b..1264ca9e1 100644 --- a/transformer_lens/pretrained/weight_conversions/coder.py +++ b/transformer_lens/pretrained/weight_conversions/coder.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_coder_weights(model, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/gemma.py b/transformer_lens/pretrained/weight_conversions/gemma.py index baa7705fd..78717b2db 100644 --- a/transformer_lens/pretrained/weight_conversions/gemma.py +++ b/transformer_lens/pretrained/weight_conversions/gemma.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_gemma_weights(gemma, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/gpt2.py b/transformer_lens/pretrained/weight_conversions/gpt2.py index 4ebe030d9..0c9e62377 100644 --- a/transformer_lens/pretrained/weight_conversions/gpt2.py +++ b/transformer_lens/pretrained/weight_conversions/gpt2.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_gpt2_weights(gpt2, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/gptj.py b/transformer_lens/pretrained/weight_conversions/gptj.py index 51041779d..cead3f7c7 100644 --- a/transformer_lens/pretrained/weight_conversions/gptj.py +++ b/transformer_lens/pretrained/weight_conversions/gptj.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_gptj_weights(gptj, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/hubert.py b/transformer_lens/pretrained/weight_conversions/hubert.py index 605391f63..54e0e8567 100644 --- a/transformer_lens/pretrained/weight_conversions/hubert.py +++ b/transformer_lens/pretrained/weight_conversions/hubert.py @@ -1,6 +1,6 @@ import einops -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_hubert_weights(hf_model, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/llama.py b/transformer_lens/pretrained/weight_conversions/llama.py index d299b48e1..6c1dd0b61 100644 --- a/transformer_lens/pretrained/weight_conversions/llama.py +++ b/transformer_lens/pretrained/weight_conversions/llama.py @@ -3,7 +3,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_llama_weights(llama, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/mingpt.py b/transformer_lens/pretrained/weight_conversions/mingpt.py index a3713cf68..5b16a84dd 100644 --- a/transformer_lens/pretrained/weight_conversions/mingpt.py +++ b/transformer_lens/pretrained/weight_conversions/mingpt.py @@ -1,6 +1,6 @@ import einops -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/mistral.py b/transformer_lens/pretrained/weight_conversions/mistral.py index a70fad68c..6d443b198 100644 --- a/transformer_lens/pretrained/weight_conversions/mistral.py +++ b/transformer_lens/pretrained/weight_conversions/mistral.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_mistral_weights(mistral, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/mixtral.py b/transformer_lens/pretrained/weight_conversions/mixtral.py index 00903b738..2dfe05888 100644 --- a/transformer_lens/pretrained/weight_conversions/mixtral.py +++ b/transformer_lens/pretrained/weight_conversions/mixtral.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_mixtral_weights(mixtral, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/nanogpt.py b/transformer_lens/pretrained/weight_conversions/nanogpt.py index 235575861..103276329 100644 --- a/transformer_lens/pretrained/weight_conversions/nanogpt.py +++ b/transformer_lens/pretrained/weight_conversions/nanogpt.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/neel_solu_old.py b/transformer_lens/pretrained/weight_conversions/neel_solu_old.py index 995686faf..383f81a9e 100644 --- a/transformer_lens/pretrained/weight_conversions/neel_solu_old.py +++ b/transformer_lens/pretrained/weight_conversions/neel_solu_old.py @@ -1,4 +1,4 @@ -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_neel_solu_old_weights(state_dict: dict, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/neo.py b/transformer_lens/pretrained/weight_conversions/neo.py index 13b4f6b71..76febea48 100644 --- a/transformer_lens/pretrained/weight_conversions/neo.py +++ b/transformer_lens/pretrained/weight_conversions/neo.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_neo_weights(neo, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/neox.py b/transformer_lens/pretrained/weight_conversions/neox.py index 9d4fda6e3..ff84b5b0d 100644 --- a/transformer_lens/pretrained/weight_conversions/neox.py +++ b/transformer_lens/pretrained/weight_conversions/neox.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_neox_weights(neox, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/olmo.py b/transformer_lens/pretrained/weight_conversions/olmo.py index afbbb6263..81ccfba6b 100644 --- a/transformer_lens/pretrained/weight_conversions/olmo.py +++ b/transformer_lens/pretrained/weight_conversions/olmo.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_olmo_weights(olmo, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/olmo2.py b/transformer_lens/pretrained/weight_conversions/olmo2.py index 4b441dbe6..9057f2f34 100644 --- a/transformer_lens/pretrained/weight_conversions/olmo2.py +++ b/transformer_lens/pretrained/weight_conversions/olmo2.py @@ -2,7 +2,7 @@ import torch from transformers.models.olmo2.modeling_olmo2 import Olmo2DecoderLayer -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_olmo2_weights(olmo2, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/olmo3.py b/transformer_lens/pretrained/weight_conversions/olmo3.py index 33d25b6f1..1b89a1d2d 100644 --- a/transformer_lens/pretrained/weight_conversions/olmo3.py +++ b/transformer_lens/pretrained/weight_conversions/olmo3.py @@ -15,7 +15,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_olmo3_weights(olmo3, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/olmoe.py b/transformer_lens/pretrained/weight_conversions/olmoe.py index 16efe3dc3..a38ea758f 100644 --- a/transformer_lens/pretrained/weight_conversions/olmoe.py +++ b/transformer_lens/pretrained/weight_conversions/olmoe.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_olmoe_weights(olmoe, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/openai.py b/transformer_lens/pretrained/weight_conversions/openai.py index d440f15d3..dfe0b7d3f 100644 --- a/transformer_lens/pretrained/weight_conversions/openai.py +++ b/transformer_lens/pretrained/weight_conversions/openai.py @@ -10,7 +10,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_gpt_oss_weights(gpt_oss, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/opt.py b/transformer_lens/pretrained/weight_conversions/opt.py index 2c197fbce..109414f3c 100644 --- a/transformer_lens/pretrained/weight_conversions/opt.py +++ b/transformer_lens/pretrained/weight_conversions/opt.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_opt_weights(opt, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/phi.py b/transformer_lens/pretrained/weight_conversions/phi.py index 88e52d845..6c7cf8bd3 100644 --- a/transformer_lens/pretrained/weight_conversions/phi.py +++ b/transformer_lens/pretrained/weight_conversions/phi.py @@ -1,6 +1,6 @@ import einops -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_phi_weights(phi, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/phi3.py b/transformer_lens/pretrained/weight_conversions/phi3.py index 4f6351f41..6c01c00bd 100644 --- a/transformer_lens/pretrained/weight_conversions/phi3.py +++ b/transformer_lens/pretrained/weight_conversions/phi3.py @@ -3,7 +3,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_phi3_weights(phi, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/qwen.py b/transformer_lens/pretrained/weight_conversions/qwen.py index 187a2b7f6..276d25af5 100644 --- a/transformer_lens/pretrained/weight_conversions/qwen.py +++ b/transformer_lens/pretrained/weight_conversions/qwen.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_qwen_weights(qwen, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/qwen2.py b/transformer_lens/pretrained/weight_conversions/qwen2.py index 0e9021e27..94a5ced5c 100644 --- a/transformer_lens/pretrained/weight_conversions/qwen2.py +++ b/transformer_lens/pretrained/weight_conversions/qwen2.py @@ -1,7 +1,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_qwen2_weights(qwen, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/qwen3.py b/transformer_lens/pretrained/weight_conversions/qwen3.py index ad6a0b8af..c2848213c 100644 --- a/transformer_lens/pretrained/weight_conversions/qwen3.py +++ b/transformer_lens/pretrained/weight_conversions/qwen3.py @@ -3,7 +3,7 @@ import einops import torch -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_qwen3_weights(qwen: Any, cfg: HookedTransformerConfig): diff --git a/transformer_lens/pretrained/weight_conversions/t5.py b/transformer_lens/pretrained/weight_conversions/t5.py index 29946fff2..efa2e0d68 100644 --- a/transformer_lens/pretrained/weight_conversions/t5.py +++ b/transformer_lens/pretrained/weight_conversions/t5.py @@ -1,6 +1,6 @@ import einops -from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig def convert_t5_weights(t5, cfg: HookedTransformerConfig): diff --git a/transformer_lens/utilities/__init__.py b/transformer_lens/utilities/__init__.py index 56975f90d..3da03052c 100644 --- a/transformer_lens/utilities/__init__.py +++ b/transformer_lens/utilities/__init__.py @@ -41,7 +41,7 @@ ) from .library_utils import is_library_available from .lm_utils import lm_accuracy, lm_cross_entropy_loss -from .logits_utils import sample_logits +from .logits_utils import logits_to_df, sample_logits from .matrix import ( composition_scores, get_matrix_corner, diff --git a/transformer_lens/utilities/logits_utils.py b/transformer_lens/utilities/logits_utils.py index 34fbce4e7..2bbc7a78f 100644 --- a/transformer_lens/utilities/logits_utils.py +++ b/transformer_lens/utilities/logits_utils.py @@ -5,12 +5,49 @@ from __future__ import annotations -from typing import Optional +from typing import Any, Optional import torch from jaxtyping import Float, Int +def logits_to_df( + logits: Float[torch.Tensor, "d_vocab"], + tokenizer: Optional[Any] = None, + top_k: Optional[int] = None, +) -> Any: # pandas.DataFrame; left as Any so beartype doesn't resolve a lazy import at runtime. + """Convert a 1-D logit vector into a sortable DataFrame for inspection. + + Returns a frame with columns ``token_index``, ``token_string`` (when + ``tokenizer`` is given), ``logit``, ``log_prob``, ``probability``, sorted by + descending probability. ``top_k`` truncates to the highest-probability rows. + + Args: + logits: 1-D tensor of shape [d_vocab]; raw model logits for one position. + tokenizer: Optional HF tokenizer used to materialise ``token_string``; + when ``None``, the column is omitted. + top_k: Optional cap on the number of returned rows. + """ + # Lazy import — keeps `import transformer_lens` free of pandas's + # warnings unless logits_to_df is actually called. + import pandas as pd + + log_probs = torch.log_softmax(logits.float(), dim=-1) + probs = log_probs.exp() + order = torch.argsort(probs, descending=True) + if top_k is not None: + order = order[:top_k] + + indices = order.cpu().tolist() + data: dict = {"token_index": indices} + if tokenizer is not None: + data["token_string"] = [tokenizer.decode([i]) for i in indices] + data["logit"] = logits[order].detach().cpu().tolist() + data["log_prob"] = log_probs[order].detach().cpu().tolist() + data["probability"] = probs[order].detach().cpu().tolist() + return pd.DataFrame(data) + + def _apply_repetition_penalty( logits: Float[torch.Tensor, "batch d_vocab"], tokens: Int[torch.Tensor, "batch pos"], diff --git a/transformer_lens/utilities/multi_gpu.py b/transformer_lens/utilities/multi_gpu.py index 0584604af..f957877de 100644 --- a/transformer_lens/utilities/multi_gpu.py +++ b/transformer_lens/utilities/multi_gpu.py @@ -10,7 +10,7 @@ import torch if TYPE_CHECKING: - from transformer_lens.config.HookedTransformerConfig import ( + from transformer_lens.config.hooked_transformer_config import ( HookedTransformerConfig as ConfigType, ) else: diff --git a/transformer_lens/weight_processing.py b/transformer_lens/weight_processing.py index 4431e112a..2db0e2b88 100644 --- a/transformer_lens/weight_processing.py +++ b/transformer_lens/weight_processing.py @@ -12,7 +12,7 @@ import torch import transformer_lens.utilities as utils -from transformer_lens.config.TransformerLensConfig import TransformerLensConfig +from transformer_lens.config.transformer_lens_config import TransformerLensConfig from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.utilities import filter_dict_by_prefix