Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# External Architecture Adapter Registration

TransformerLens supports loading custom architecture adapters from **external packages** — no fork required. You can write your adapter, register it, and use it with `boot_transformers()` without modifying TransformerLens source code.

## Two ways to register

### 1. Runtime registration

Call `register_adapter()` in your startup code:

```python
from transformer_lens.factories.architecture_adapter_factory import (
ArchitectureAdapterFactory,
)

ArchitectureAdapterFactory.register_adapter(
"MyModelForCausalLM",
MyArchitectureAdapter,
)

# Now this works:
bridge = TransformerBridge.boot_transformers("my-org/my-model")
```

> **Important:** The architecture name you register (e.g. `"MyModelForCausalLM"`) must match the `architectures` field in the model's HuggingFace `config.json`. TransformerLens reads this field to look up the adapter.

### 2. Entry-point registration (recommended for packages)

Declare your adapter in your package's `pyproject.toml`:

```toml
[project.entry-points."transformer_lens.architectures"]
"MyModelForCausalLM" = "my_package.adapters:MyArchitectureAdapter"
```

TransformerLens discovers these automatically on first use. Users just install your package alongside TransformerLens and `boot_transformers()` finds it.

## How it works

When `boot_transformers()` is called:

1. It reads the model's HuggingFace `config.json` to extract the `architectures` field. This field lists the architecture class name (e.g. `"MyModelForCausalLM"`). **This is the name you must use in your registration.**
2. `select_architecture_adapter()` checks the registry for that architecture name.
3. On first call, `discover_entry_points()` scans all installed packages for adapters declared via the `transformer_lens.architectures` entry-point group.
4. The matching adapter class is instantiated and used to build the bridge.

## Writing an adapter

Follow the [Architecture Adapter Creation Guide](adapter-creation-guide.md) to build your adapter class. Once written, use either registration method above to plug it into TransformerLens.

## Example package layout

```
my_transformer_plugin/
├── pyproject.toml # declares the entry point
└── my_transformer_plugin/
├── __init__.py
└── adapters.py # contains MyArchitectureAdapter
```

**pyproject.toml:**

```toml
[project]
name = "my-transformer-plugin"
version = "0.1.0"
requires-python = ">=3.10"
dependencies = ["transformer-lens>=3.0"]

[project.entry-points."transformer_lens.architectures"]
"MyModelForCausalLM" = "my_transformer_plugin.adapters:MyArchitectureAdapter"
```

**adapters.py:**

```python
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
from transformer_lens.model_bridge.generalized_components import (
BlockBridge,
EmbeddingBridge,
# ... import the bridge components you need
)

class MyArchitectureAdapter(ArchitectureAdapter):
def __init__(self, cfg):
super().__init__(cfg)
# Set config, weight processing, component mapping
# See the Adapter Creation Guide for details
```
1 change: 1 addition & 0 deletions docs/source/content/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,5 @@ Adapters live in `transformer_lens/model_bridge/supported_architectures/<model_n

adapter_development/adapter-creation-guide
adapter_development/hf-model-analysis-guide
adapter_development/external-adapter-registration
```
126 changes: 126 additions & 0 deletions tests/unit/model_bridge/test_architecture_adapter_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Unit tests for ArchitectureAdapterFactory — external registration and entry-point discovery."""

import pytest

from tests.mocks.architecture_adapter import MockArchitectureAdapter
from transformer_lens.config import TransformerBridgeConfig
from transformer_lens.factories.architecture_adapter_factory import (
SUPPORTED_ARCHITECTURES,
ArchitectureAdapterFactory,
)
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter


class OtherMockArchitectureAdapter(ArchitectureAdapter):
"""A second mock adapter class used to verify overwrite behaviour."""

def __init__(self, cfg=None):
if cfg is None:
cfg = TransformerBridgeConfig(
d_model=512,
d_head=64,
n_layers=2,
n_ctx=1024,
d_vocab=1000,
d_mlp=2048,
default_prepend_bos=True,
architecture="GPT2LMHeadModel",
)
super().__init__(cfg)


@pytest.fixture(autouse=True)
def _isolate_factory_state():
"""Save and restore factory state so tests don't leak into each other."""
saved_adapters = dict(ArchitectureAdapterFactory._adapters)
saved_discovered = ArchitectureAdapterFactory._entry_points_discovered
yield
ArchitectureAdapterFactory._adapters = saved_adapters
ArchitectureAdapterFactory._entry_points_discovered = saved_discovered


def _make_cfg(**overrides) -> TransformerBridgeConfig:
defaults = dict(
d_model=64,
d_head=16,
n_layers=2,
n_ctx=64,
n_heads=4,
d_vocab=100,
d_mlp=256,
default_prepend_bos=True,
)
defaults.update(overrides)
return TransformerBridgeConfig(**defaults)


class TestSupportedArchitectures:
"""Verify all existing hardcoded entries in SUPPORTED_ARCHITECTURES."""

def test_has_common_architectures(self):
common = [
"GPT2LMHeadModel",
"LlamaForCausalLM",
"MistralForCausalLM",
"Gemma2ForCausalLM",
"Qwen2ForCausalLM",
"BloomForCausalLM",
"FalconForCausalLM",
]
for arch in common:
assert arch in SUPPORTED_ARCHITECTURES, f"Missing: {arch}"


class TestRegisterAdapter:
"""Verify runtime adapter registration."""

def test_register_adds_to_adapters(self):
key = "TestMockForCausalLM"
ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter)
assert key in ArchitectureAdapterFactory._adapters

def test_register_overwrites_existing(self):
key = "TestOverwriteForCausalLM"
ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter)
assert ArchitectureAdapterFactory._adapters[key] is MockArchitectureAdapter
ArchitectureAdapterFactory.register_adapter(key, OtherMockArchitectureAdapter)
assert ArchitectureAdapterFactory._adapters[key] is OtherMockArchitectureAdapter

def test_select_returns_registered_adapter(self):
key = "TestSelectForCausalLM"
ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter)
cfg = _make_cfg(architecture=key)
adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg)
assert isinstance(adapter, MockArchitectureAdapter)


class TestSelectErrors:
"""Verify error handling in select_architecture_adapter."""

def test_unknown_architecture_raises(self):
cfg = _make_cfg(architecture="NonExistentForCausalLM")
with pytest.raises(ValueError, match="Unsupported architecture"):
ArchitectureAdapterFactory.select_architecture_adapter(cfg)

def test_none_architecture_raises(self):
cfg = _make_cfg(architecture=None)
with pytest.raises(ValueError, match="must have architecture set"):
ArchitectureAdapterFactory.select_architecture_adapter(cfg)


class TestDiscoverEntryPoints:
"""Verify entry-point discovery behavior."""

def test_discover_is_idempotent(self):
ArchitectureAdapterFactory._entry_points_discovered = False
ArchitectureAdapterFactory.discover_entry_points()
first_run = ArchitectureAdapterFactory._entry_points_discovered
ArchitectureAdapterFactory.discover_entry_points()
assert ArchitectureAdapterFactory._entry_points_discovered is first_run is True

def test_discover_does_not_remove_existing(self):
key = "TestPreserveForCausalLM"
ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter)
ArchitectureAdapterFactory._entry_points_discovered = False
ArchitectureAdapterFactory.discover_entry_points()
assert key in ArchitectureAdapterFactory._adapters
Comment on lines +74 to +126
71 changes: 67 additions & 4 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Architecture adapter factory.

This module provides a factory for creating architecture adapters.
This module provides a factory for creating architecture adapters, including
support for external registration and entry-point discovery.
"""

import warnings
from importlib.metadata import entry_points

from transformer_lens.config import TransformerBridgeConfig
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
from transformer_lens.model_bridge.supported_architectures import (
Expand Down Expand Up @@ -126,9 +130,68 @@


class ArchitectureAdapterFactory:
"""Factory for creating architecture adapters."""
"""Factory for creating architecture adapters.

Supports external registration via `register_adapter()` and automatic
discovery of adapters from installed packages via entry points.
"""

_adapters = SUPPORTED_ARCHITECTURES
_adapters = dict(SUPPORTED_ARCHITECTURES)
_entry_points_discovered = False

@classmethod
def register_adapter(
cls, architecture_name: str, adapter_class: type["ArchitectureAdapter"]
) -> None:
"""Register a custom architecture adapter at runtime.

This allows users to add their own architecture adapters without
modifying TransformerLens source code.

Args:
architecture_name: The HuggingFace architecture class name
(e.g. ``"Qwen3ForCausalLM"``).
adapter_class: The adapter class to register.

Example:
>>> from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
>>> from transformer_lens.factories.architecture_adapter_factory import ArchitectureAdapterFactory
>>> class MyAdapter(ArchitectureAdapter):
... def __init__(self, cfg):
... super().__init__(cfg)
>>> ArchitectureAdapterFactory.register_adapter("MyModelForCausalLM", MyAdapter)
>>> "MyModelForCausalLM" in ArchitectureAdapterFactory._adapters
True
"""
cls._adapters[architecture_name] = adapter_class

@classmethod
def discover_entry_points(cls) -> None:
"""Discover and register architecture adapters from installed packages.

Packages can declare adapters in their ``pyproject.toml``:
```toml
[project.entry-points."transformer_lens.architectures"]
"MyModelForCausalLM" = "my_package.adapters:MyArchitectureAdapter"
```
"""
if cls._entry_points_discovered:
return
try:
eps = entry_points(group="transformer_lens.architectures")
except Exception as e:
warnings.warn(
f"Failed to discover entry points: {e}. " f"External adapters may not be available."
)
else:
for ep in eps:
try:
cls._adapters[ep.name] = ep.load()
except Exception as e:
warnings.warn(
f"Failed to load entry point '{ep.name}': {e}. " f"Skipping this adapter."
)
cls._entry_points_discovered = True

@classmethod
def select_architecture_adapter(cls, cfg: TransformerBridgeConfig) -> ArchitectureAdapter:
Expand All @@ -143,11 +206,11 @@ def select_architecture_adapter(cls, cfg: TransformerBridgeConfig) -> Architectu
Raises:
ValueError: If no adapter is found for the given config.
"""
cls.discover_entry_points()
if cfg.architecture is not None:
if cfg.architecture in cls._adapters:
return cls._adapters[cfg.architecture](cfg)
else:
raise ValueError(f"Unsupported architecture: {cfg.architecture}")

# If architecture is None, this is an error since TransformerBridgeConfig should always have it set
raise ValueError(f"TransformerBridgeConfig must have architecture set, got: {cfg}")
Loading