Skip to content
125 changes: 125 additions & 0 deletions tests/integration/model_bridge/test_bridge_vs_hf_eager_parity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Asserts ``TransformerBridge`` reproduces ``AutoModelForCausalLM`` eager-attention logits.

Issue #385 reported drift between bridge and HF for rotary models like Pythia. The drift
was an attention-implementation mismatch — bridge always uses eager, default HF loads use
SDPA, which reorders ops in a fused kernel. Bridge vs HF *eager* matches to fp32-noise.
"""

from typing import Callable

import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from transformer_lens.model_bridge import TransformerBridge

MODEL_NAME = "EleutherAI/pythia-70m"

# Op-reorder noise floor for fp32 transformer forward passes. We currently
# measure 0.0 on this model, but allow a small epsilon so harmless refactors
# (intermediate allocations, equivalent op reorderings) don't break the test.
FP32_NOISE_TOL = 1e-5


@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained(MODEL_NAME)


@pytest.fixture(scope="module")
def bridge():
return TransformerBridge.boot_transformers(MODEL_NAME, device="cpu", dtype=torch.float32)


@pytest.fixture(scope="module")
def hf_eager():
"""HF model loaded independently of the bridge's wrapped instance."""
return AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype=torch.float32, attn_implementation="eager"
).eval()


@pytest.fixture
def tokenize(tokenizer) -> Callable[[str], torch.Tensor]:
def _tok(prompt: str) -> torch.Tensor:
return tokenizer(prompt, return_tensors="pt").input_ids

return _tok


@pytest.mark.parametrize("prompt", ["Hello, world!", "The quick brown fox jumps"])
def test_bridge_logits_match_hf_eager(bridge, hf_eager, tokenize, prompt):
tokens = tokenize(prompt)
with torch.inference_mode():
bridge_logits = bridge(tokens)
hf_logits = hf_eager(tokens).logits
max_diff = (bridge_logits - hf_logits).abs().max().item()
assert max_diff < FP32_NOISE_TOL, (
f"{MODEL_NAME!r} bridge vs HF eager drift={max_diff:.2e} on {prompt!r} "
f"exceeds fp32-noise tolerance {FP32_NOISE_TOL:.0e} — bridge's "
f"_reconstruct_attention may have regressed (see issue #385)."
)


def test_bridge_residual_stream_matches_hf_eager(bridge, hf_eager, tokenize):
"""Per-layer parity catches compensating errors that wash out at the final logits."""
tokens = tokenize("Hello, world!")
n_layers = len(hf_eager.gpt_neox.layers)

hf_layer_out: dict[int, torch.Tensor] = {}

def _make_hf_hook(idx):
def _h(_m, _i, o):
hf_layer_out[idx] = (o[0] if isinstance(o, tuple) else o).detach()

return _h

handles = [
layer.register_forward_hook(_make_hf_hook(i))
for i, layer in enumerate(hf_eager.gpt_neox.layers)
]
try:
with torch.inference_mode():
hf_eager(tokens)
finally:
for h in handles:
h.remove()

bridge_layer_out: dict[int, torch.Tensor] = {}
fwd_hooks = [
(
f"blocks.{i}.hook_resid_post",
lambda v, hook, idx=i: bridge_layer_out.__setitem__(idx, v.detach()),
)
for i in range(n_layers)
]
with torch.inference_mode():
bridge.run_with_hooks(tokens, fwd_hooks=fwd_hooks)

for i in range(n_layers):
d = (hf_layer_out[i] - bridge_layer_out[i]).abs().max().item()
assert d < FP32_NOISE_TOL, (
f"layer {i} residual drift={d:.2e} exceeds fp32-noise tolerance "
f"{FP32_NOISE_TOL:.0e} — bridge layer output diverges from HF eager."
)


def test_bridge_attention_reconstruction_actually_runs(bridge, tokenize):
"""Guard against tautology: prove bridge's custom attention path executes.

If a future refactor made the bridge delegate to HF directly, the previous
parity tests would pass trivially. This one fails fast in that case by
asserting bridge-specific hooks fire during forward.
"""
tokens = tokenize("Hello, world!")
attn_scores_fired: list[bool] = []
bridge.run_with_hooks(
tokens,
fwd_hooks=[
("blocks.0.attn.hook_attn_scores", lambda v, hook: attn_scores_fired.append(True)),
],
)
assert attn_scores_fired, (
"blocks.0.attn.hook_attn_scores did not fire — bridge no longer runs its "
"own attention reconstruction, making the parity tests tautological."
)
59 changes: 41 additions & 18 deletions tests/unit/factored_matrix/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,52 @@ def test_transpose_property(self, factored_matrices):

def test_svd_property(self, factored_matrices):
for factored_matrix in factored_matrices:
U, S, Vh = factored_matrix.svd()
assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.T, atol=1e-5)
# test that U and Vh are unitary
U, S, V = factored_matrix.svd()
assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ V.T, atol=1e-5)
# test that U and V are unitary
assert torch.allclose(U.T @ U, torch.eye(U.shape[-1]), atol=1e-5)
assert torch.allclose(Vh.T @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5)
assert torch.allclose(V.T @ V, torch.eye(V.shape[-1]), atol=1e-5)

def test_svd_property_leading_ones(self, factored_matrices_leading_ones):
for factored_matrix in factored_matrices_leading_ones:
U, S, Vh = factored_matrix.svd()
assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.mT, atol=1e-5)
# test that U and Vh are unitary
U, S, V = factored_matrix.svd()
assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ V.mT, atol=1e-5)
# test that U and V are unitary
assert torch.allclose(U.mT @ U, torch.eye(U.shape[-1]), atol=1e-5)
assert torch.allclose(Vh.mT @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5)
assert torch.allclose(V.mT @ V, torch.eye(V.shape[-1]), atol=1e-5)

def test_V_and_Vh_alias_match(self, factored_matrices):
import warnings

for factored_matrix in factored_matrices:
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
vh_value = factored_matrix.Vh
assert any(issubclass(w.category, DeprecationWarning) for w in caught)
assert torch.equal(vh_value, factored_matrix.V)

def test_svd_caches_per_instance(self):
"""svd() should cache its result on the instance — repeated calls return the same tensors."""
m = FactoredMatrix(randn(4, 3), randn(3, 4))
first_U, first_S, first_V = m.svd()
second_U, second_S, second_V = m.svd()
assert first_U is second_U
assert first_S is second_S
assert first_V is second_V

def test_svd_does_not_prevent_gc(self):
"""svd's cache must not hold a strong reference that prevents the instance from being GC'd"""
import gc
import weakref

m = FactoredMatrix(randn(4, 3), randn(3, 4))
_ = m.svd() # populate the cache
ref = weakref.ref(m)
del m
gc.collect()
assert (
ref() is None
), "FactoredMatrix instance survived deletion — svd cache is leaking references."

def test_eigenvalues_property(self, factored_matrices):
for factored_matrix in factored_matrices:
Expand Down Expand Up @@ -141,16 +174,6 @@ def test_collapse_l(self, factored_matrices):
expected = factored_matrix.S[..., :, None] * utils.transpose(factored_matrix.V)
assert torch.allclose(result, expected)

def test_V_and_Vh_alias_match(self, factored_matrices):
import warnings

for factored_matrix in factored_matrices:
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
vh_value = factored_matrix.Vh
assert any(issubclass(w.category, DeprecationWarning) for w in caught)
assert torch.equal(vh_value, factored_matrix.V)

def test_collapse_r(self, factored_matrices):
for factored_matrix in factored_matrices:
result = factored_matrix.collapse_r()
Expand Down
Loading
Loading