Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions demos/Grokking_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,8 @@
"source": [
"import transformer_lens\n",
"import transformer_lens.utilities as utils\n",
"from transformer_lens.hook_points import (\n",
" HookedRootModule,\n",
" HookPoint,\n",
") # Hooking utilities\n",
"from transformer_lens.hook_points import HookPoint\n",
"from transformer_lens.HookedRootModule import HookedRootModule # Hooking utilities\n",
"from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n",
"\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion demos/Main_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1738,7 +1738,7 @@
"metadata": {},
"outputs": [],
"source": [
"from transformer_lens.hook_points import HookedRootModule, HookPoint\n",
"from transformer_lens.hook_points import HookPoint\nfrom transformer_lens.HookedRootModule import HookedRootModule\n",
"\n",
"\n",
"class SquareThenAdd(nn.Module):\n",
Expand Down
10 changes: 4 additions & 6 deletions demos/Othello_GPT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,8 @@
"source": [
"import transformer_lens\n",
"import transformer_lens.utilities as utils\n",
"from transformer_lens.hook_points import (\n",
" HookedRootModule,\n",
" HookPoint,\n",
") # Hooking utilities\n",
"from transformer_lens.hook_points import HookPoint\n",
"from transformer_lens.HookedRootModule import HookedRootModule # Hooking utilities\n",
"from transformer_lens import (\n",
" HookedTransformer,\n",
" HookedTransformerConfig,\n",
Expand Down Expand Up @@ -595,7 +593,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "transformer-lens",
"language": "python",
"name": "python3"
},
Expand All @@ -609,7 +607,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
"version": "3.12.12"
},
"orig_nbformat": 4
},
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/components/mlps/test_gpt_oss_moe_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
GptOssExpert,
GptOssMoE,
)
from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig


def make_moe_cfg():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig
from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig
from transformer_lens.conversion_utils.conversion_steps import (
ArithmeticTensorConversion,
RearrangeTensorConversion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pytest

from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig
from transformer_lens.loading_from_pretrained import get_pretrained_model_config
from transformer_lens.supported_models import OFFICIAL_MODEL_NAMES

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig
from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig
from transformer_lens.conversion_utils.conversion_steps import (
ArithmeticTensorConversion,
RearrangeTensorConversion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig
from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig
from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
from transformer_lens.conversion_utils.param_processing_conversion import (
ParamProcessingConversion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def test_adapter_class_correct(self):

def _make_bridge_cfg(**overrides):
"""Minimal TransformerBridgeConfig for Qwen3_5 adapter tests."""
from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig
from transformer_lens.config.transformer_bridge_config import (
TransformerBridgeConfig,
)

defaults = dict(
d_model=1024,
Expand Down Expand Up @@ -266,7 +268,7 @@ def test_n_key_value_heads_set_when_gqa(self):
assert adapter.cfg.n_key_value_heads == 2

def test_n_key_value_heads_not_set_when_absent(self):
from transformer_lens.config.TransformerBridgeConfig import (
from transformer_lens.config.transformer_bridge_config import (
TransformerBridgeConfig,
)
from transformer_lens.model_bridge.supported_architectures.qwen3_5 import (
Expand Down Expand Up @@ -583,7 +585,9 @@ def _make_tiny_bridge():
"""Build a Qwen3_5 bridge from a tiny HF model."""
from unittest.mock import MagicMock

from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig
from transformer_lens.config.transformer_bridge_config import (
TransformerBridgeConfig,
)
from transformer_lens.model_bridge import TransformerBridge
from transformer_lens.model_bridge.supported_architectures.qwen3_5 import (
Qwen3_5ArchitectureAdapter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def test_adapter_class_correct(self):

def _make_bridge_cfg(**overrides):
"""Minimal TransformerBridgeConfig for Qwen3Next adapter tests."""
from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig
from transformer_lens.config.transformer_bridge_config import (
TransformerBridgeConfig,
)

defaults = dict(
d_model=2048,
Expand Down Expand Up @@ -502,7 +504,9 @@ def _make_tiny_bridge():
"""Build a Qwen3Next bridge from a tiny HF model."""
from unittest.mock import MagicMock

from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig
from transformer_lens.config.transformer_bridge_config import (
TransformerBridgeConfig,
)
from transformer_lens.model_bridge import TransformerBridge
from transformer_lens.model_bridge.supported_architectures.qwen3_next import (
Qwen3NextArchitectureAdapter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from transformer_lens import HookedTransformer
from transformer_lens import utilities as utils
from transformer_lens.config.TransformerBridgeConfig import TransformerBridgeConfig
from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
from transformer_lens.weight_processing import ProcessWeights

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/pretrained_weight_conversions/test_apertus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig
from transformer_lens.pretrained.weight_conversions.apertus import (
convert_apertus_weights,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/pretrained_weight_conversions/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
import torch.nn as nn

from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig
from transformer_lens.pretrained.weight_conversions.gemma import convert_gemma_weights


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import torch

from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig
from transformer_lens.pretrained.weight_conversions.hubert import convert_hubert_weights


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/pretrained_weight_conversions/test_olmo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
import torch.nn as nn

from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig
from transformer_lens.pretrained.weight_conversions.olmo3 import convert_olmo3_weights


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/pretrained_weight_conversions/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig
from transformer_lens.pretrained.weight_conversions.openai import (
convert_gpt_oss_weights,
)
Expand Down
7 changes: 2 additions & 5 deletions tests/unit/test_hook_introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@

from unittest import mock

from transformer_lens.hook_points import (
HookedRootModule,
HookIntrospectionMixin,
HookPoint,
)
from transformer_lens.hook_points import HookIntrospectionMixin, HookPoint
from transformer_lens.HookedRootModule import HookedRootModule


class _ToyModel(HookedRootModule):
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_hooked_root_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import Mock

from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens.hook_points import HookPoint
from transformer_lens.HookedRootModule import HookedRootModule

MODEL_NAME = "solu-2l"

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_key_value_cache_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformer_lens.cache.key_value_cache_entry import (
TransformerLensKeyValueCacheEntry,
)
from transformer_lens.config.TransformerLensConfig import TransformerLensConfig
from transformer_lens.config.transformer_lens_config import TransformerLensConfig


def _make_cfg(dtype: torch.dtype, n_heads: int = 4, d_head: int = 8, n_key_value_heads=None):
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/test_optional_submodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def test_succeeds_on_universal_submodule(self):

class TestRefactorFactoredAttnHybrid:
def test_skips_missing_attn_layers(self):
from transformer_lens.config.TransformerLensConfig import TransformerLensConfig
from transformer_lens.config.transformer_lens_config import (
TransformerLensConfig,
)
from transformer_lens.weight_processing import ProcessWeights

cfg = TransformerLensConfig(
Expand Down Expand Up @@ -255,7 +257,9 @@ def test_skips_missing_attn_layers(self):
assert "blocks.3.attn.W_Q" not in result

def test_raises_on_partial_attn_keys(self):
from transformer_lens.config.TransformerLensConfig import TransformerLensConfig
from transformer_lens.config.transformer_lens_config import (
TransformerLensConfig,
)
from transformer_lens.weight_processing import ProcessWeights

cfg = TransformerLensConfig(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_weight_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest
import torch

from transformer_lens.config.TransformerLensConfig import TransformerLensConfig
from transformer_lens.config.transformer_lens_config import TransformerLensConfig
from transformer_lens.weight_processing import ProcessWeights

# from typing import Dict # Unused import
Expand Down
87 changes: 87 additions & 0 deletions tests/unit/utilities/test_logits_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Unit tests for transformer_lens.utilities.logits_utils."""

import pandas as pd
import pytest
import torch

from transformer_lens.utilities.logits_utils import logits_to_df


class _StubTokenizer:
"""Minimal tokenizer surface used by logits_to_df (decode of single ids)."""

def __init__(self, vocab: list[str]):
self._vocab = vocab

def decode(self, ids: list[int]) -> str:
return "".join(self._vocab[i] for i in ids)


@pytest.fixture(scope="module")
def logits() -> torch.Tensor:
return torch.tensor([1.0, 3.0, 2.0, 0.5])


class TestLogitsToDf:
def test_returns_dataframe(self, logits: torch.Tensor):
df = logits_to_df(logits)
assert isinstance(df, pd.DataFrame)

def test_columns_no_tokenizer(self, logits: torch.Tensor):
df = logits_to_df(logits)
assert list(df.columns) == ["token_index", "logit", "log_prob", "probability"]

def test_columns_with_tokenizer(self, logits: torch.Tensor):
tok = _StubTokenizer(["a", "b", "c", "d"])
df = logits_to_df(logits, tokenizer=tok)
assert list(df.columns) == [
"token_index",
"token_string",
"logit",
"log_prob",
"probability",
]

def test_sorted_by_descending_probability(self, logits: torch.Tensor):
df = logits_to_df(logits)
probs = df["probability"].tolist()
assert probs == sorted(probs, reverse=True)

def test_token_indices_match_logit_argsort(self, logits: torch.Tensor):
df = logits_to_df(logits)
# logits = [1.0, 3.0, 2.0, 0.5] -> argsort desc = [1, 2, 0, 3]
assert df["token_index"].tolist() == [1, 2, 0, 3]

def test_top_k_truncates(self, logits: torch.Tensor):
df = logits_to_df(logits, top_k=2)
assert len(df) == 2
assert df["token_index"].tolist() == [1, 2]

def test_token_string_decoded(self, logits: torch.Tensor):
tok = _StubTokenizer(["a", "b", "c", "d"])
df = logits_to_df(logits, tokenizer=tok, top_k=2)
assert df["token_string"].tolist() == ["b", "c"]

def test_log_prob_and_probability_consistent(self, logits: torch.Tensor):
df = logits_to_df(logits)
assert torch.allclose(
torch.tensor(df["log_prob"].tolist()).exp(),
torch.tensor(df["probability"].tolist()),
atol=1e-6,
)

def test_probabilities_sum_to_one(self, logits: torch.Tensor):
df = logits_to_df(logits)
assert df["probability"].sum() == pytest.approx(1.0, abs=1e-6)

def test_logit_column_preserves_input_values(self, logits: torch.Tensor):
df = logits_to_df(logits)
# Order is argsort-desc; just check membership and float-equality.
assert sorted(df["logit"].tolist()) == sorted(logits.tolist())

def test_rejects_non_1d_input(self):
# Shape constraint enforced by jaxtyping/beartype on Float[Tensor, "d_vocab"].
from beartype.roar import BeartypeCallHintParamViolation

with pytest.raises(BeartypeCallHintParamViolation):
logits_to_df(torch.zeros(3, 4))
Loading
Loading