From 6af217c0514541cd4cf0f94793e2f0c03122c5b7 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Tue, 12 May 2026 08:56:41 +0200 Subject: [PATCH 1/7] feat: External architecture adapter registration and entry-point discovery --- .../factories/architecture_adapter_factory.py | 57 ++++++++++++++++++- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 02129b856..f13d399af 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -1,8 +1,11 @@ """Architecture adapter factory. -This module provides a factory for creating architecture adapters. +This module provides a factory for creating architecture adapters, including +support for external registration and entry-point discovery. """ +from importlib.metadata import entry_points + from transformer_lens.config import TransformerBridgeConfig from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.supported_architectures import ( @@ -126,9 +129,57 @@ class ArchitectureAdapterFactory: - """Factory for creating architecture adapters.""" + """Factory for creating architecture adapters. + + Supports external registration via `register_adapter()` and automatic + discovery of adapters from installed packages via entry points. + """ _adapters = SUPPORTED_ARCHITECTURES + _entry_points_discovered = False + + @classmethod + def register_adapter( + cls, architecture_name: str, adapter_class: type["ArchitectureAdapter"] + ) -> None: + """Register a custom architecture adapter at runtime. + + This allows users to add their own architecture adapters without + modifying TransformerLens source code. + + Args: + architecture_name: The HuggingFace architecture class name + (e.g. ``"Qwen3ForCausalLM"``). + adapter_class: The adapter class to register. + + Example: + >>> from transformer_lens.factories import ArchitectureAdapterFactory + >>> ArchitectureAdapterFactory.register_adapter( + ... "MyModelForCausalLM", + ... MyArchitectureAdapter, + ... ) + """ + cls._adapters[architecture_name] = adapter_class + + @classmethod + def discover_entry_points(cls) -> None: + """Discover and register architecture adapters from installed packages. + + Packages can declare adapters in their ``pyproject.toml``: + ```toml + [project.entry-points."transformer_lens.architectures"] + "MyModelForCausalLM" = "my_package.adapters:MyArchitectureAdapter" + ``` + """ + if cls._entry_points_discovered: + return + try: + eps = entry_points(group="transformer_lens.architectures") + for ep in eps: + cls._adapters[ep.name] = ep.load() + except Exception: + pass + cls._entry_points_discovered = True @classmethod def select_architecture_adapter(cls, cfg: TransformerBridgeConfig) -> ArchitectureAdapter: @@ -143,11 +194,11 @@ def select_architecture_adapter(cls, cfg: TransformerBridgeConfig) -> Architectu Raises: ValueError: If no adapter is found for the given config. """ + cls.discover_entry_points() if cfg.architecture is not None: if cfg.architecture in cls._adapters: return cls._adapters[cfg.architecture](cfg) else: raise ValueError(f"Unsupported architecture: {cfg.architecture}") - # If architecture is None, this is an error since TransformerBridgeConfig should always have it set raise ValueError(f"TransformerBridgeConfig must have architecture set, got: {cfg}") From a68a198a4715f1ec9f43d2502757f7c9b4097264 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Tue, 12 May 2026 09:07:24 +0200 Subject: [PATCH 2/7] docs: External adapter registration guide with examples --- .../external-adapter-registration.md | 87 +++++++++++++++++++ docs/source/content/contributing.md | 1 + 2 files changed, 88 insertions(+) create mode 100644 docs/source/content/adapter_development/external-adapter-registration.md diff --git a/docs/source/content/adapter_development/external-adapter-registration.md b/docs/source/content/adapter_development/external-adapter-registration.md new file mode 100644 index 000000000..2ce2aec57 --- /dev/null +++ b/docs/source/content/adapter_development/external-adapter-registration.md @@ -0,0 +1,87 @@ +# External Architecture Adapter Registration + +TransformerLens supports loading custom architecture adapters from **external packages** — no fork required. You can write your adapter, register it, and use it with `boot_transformers()` without modifying TransformerLens source code. + +## Two ways to register + +### 1. Runtime registration + +Call `register_adapter()` in your startup code: + +```python +from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, +) + +ArchitectureAdapterFactory.register_adapter( + "MyModelForCausalLM", + MyArchitectureAdapter, +) + +# Now this works: +bridge = TransformerBridge.boot_transformers("my-org/my-model") +``` + +### 2. Entry-point registration (recommended for packages) + +Declare your adapter in your package's `pyproject.toml`: + +```toml +[project.entry-points."transformer_lens.architectures"] +"MyModelForCausalLM" = "my_package.adapters:MyArchitectureAdapter" +``` + +TransformerLens discovers these automatically on first use. Users just install your package alongside TransformerLens and `boot_transformers()` finds it. + +## How it works + +When `boot_transformers()` is called: + +1. It reads the model's HuggingFace config to determine the architecture class name (e.g. `"MyModelForCausalLM"`). +2. `select_architecture_adapter()` checks the registry. +3. On first call, `discover_entry_points()` scans all installed packages for adapters declared via the `transformer_lens.architectures` entry-point group. +4. The matching adapter class is instantiated and used to build the bridge. + +## Writing an adapter + +Follow the [Architecture Adapter Creation Guide](adapter-creation-guide.md) to build your adapter class. Once written, use either registration method above to plug it into TransformerLens. + +## Example package layout + +``` +my_transformer_plugin/ +├── pyproject.toml # declares the entry point +└── my_transformer_plugin/ + ├── __init__.py + └── adapters.py # contains MyArchitectureAdapter +``` + +**pyproject.toml:** + +```toml +[project] +name = "my-transformer-plugin" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = ["transformer-lens>=3.0"] + +[project.entry-points."transformer_lens.architectures"] +"MyModelForCausalLM" = "my_transformer_plugin.adapters:MyArchitectureAdapter" +``` + +**adapters.py:** + +```python +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + # ... import the bridge components you need +) + +class MyArchitectureAdapter(ArchitectureAdapter): + def __init__(self, cfg): + super().__init__(cfg) + # Set config, weight processing, component mapping + # See the Adapter Creation Guide for details +``` diff --git a/docs/source/content/contributing.md b/docs/source/content/contributing.md index 370a87361..87a28f6b0 100644 --- a/docs/source/content/contributing.md +++ b/docs/source/content/contributing.md @@ -178,4 +178,5 @@ Adapters live in `transformer_lens/model_bridge/supported_architectures/ Date: Tue, 12 May 2026 09:22:20 +0200 Subject: [PATCH 3/7] test: ArchitectureAdapterFactory registration, selection, and entry-point tests --- .../test_architecture_adapter_factory.py | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 tests/unit/model_bridge/test_architecture_adapter_factory.py diff --git a/tests/unit/model_bridge/test_architecture_adapter_factory.py b/tests/unit/model_bridge/test_architecture_adapter_factory.py new file mode 100644 index 000000000..63eff18ab --- /dev/null +++ b/tests/unit/model_bridge/test_architecture_adapter_factory.py @@ -0,0 +1,97 @@ +"""Unit tests for ArchitectureAdapterFactory — external registration and entry-point discovery.""" + +import pytest + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ArchitectureAdapterFactory, +) +from tests.mocks.architecture_adapter import MockArchitectureAdapter + + +def _make_cfg(**overrides) -> TransformerBridgeConfig: + defaults = dict( + d_model=64, + d_head=16, + n_layers=2, + n_ctx=64, + n_heads=4, + d_vocab=100, + d_mlp=256, + default_prepend_bos=True, + ) + defaults.update(overrides) + return TransformerBridgeConfig(**defaults) + + +class TestSupportedArchitectures: + """Verify all existing hardcoded entries in SUPPORTED_ARCHITECTURES.""" + + def test_has_common_architectures(self): + common = [ + "GPT2LMHeadModel", + "LlamaForCausalLM", + "MistralForCausalLM", + "Gemma2ForCausalLM", + "Qwen2ForCausalLM", + "BloomForCausalLM", + "FalconForCausalLM", + ] + for arch in common: + assert arch in SUPPORTED_ARCHITECTURES, f"Missing: {arch}" + + +class TestRegisterAdapter: + """Verify runtime adapter registration.""" + + def test_register_adds_to_adapters(self): + key = "TestMockForCausalLM" + ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) + assert key in ArchitectureAdapterFactory._adapters + + def test_register_overwrites_existing(self): + key = "TestOverwriteForCausalLM" + ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) + first = ArchitectureAdapterFactory._adapters[key] + ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) + assert ArchitectureAdapterFactory._adapters[key] is first + + def test_select_returns_registered_adapter(self): + key = "TestSelectForCausalLM" + ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) + cfg = _make_cfg(architecture=key) + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, MockArchitectureAdapter) + + +class TestSelectErrors: + """Verify error handling in select_architecture_adapter.""" + + def test_unknown_architecture_raises(self): + cfg = _make_cfg(architecture="NonExistentForCausalLM") + with pytest.raises(ValueError, match="Unsupported architecture"): + ArchitectureAdapterFactory.select_architecture_adapter(cfg) + + def test_none_architecture_raises(self): + cfg = _make_cfg(architecture=None) + with pytest.raises(ValueError, match="must have architecture set"): + ArchitectureAdapterFactory.select_architecture_adapter(cfg) + + +class TestDiscoverEntryPoints: + """Verify entry-point discovery behavior.""" + + def test_discover_is_idempotent(self): + ArchitectureAdapterFactory._entry_points_discovered = False + ArchitectureAdapterFactory.discover_entry_points() + first_run = ArchitectureAdapterFactory._entry_points_discovered + ArchitectureAdapterFactory.discover_entry_points() + assert ArchitectureAdapterFactory._entry_points_discovered is first_run is True + + def test_discover_does_not_remove_existing(self): + key = "TestPreserveForCausalLM" + ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) + ArchitectureAdapterFactory._entry_points_discovered = False + ArchitectureAdapterFactory.discover_entry_points() + assert key in ArchitectureAdapterFactory._adapters From 93c37f71c0bf00d40f5b2cf7486d1389bd0afe66 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 16 May 2026 13:08:44 +0200 Subject: [PATCH 4/7] fix: Clarify matching requirement and add warnings for failed discovery --- .../adapter_development/external-adapter-registration.md | 6 ++++-- .../unit/model_bridge/test_architecture_adapter_factory.py | 2 +- transformer_lens/factories/architecture_adapter_factory.py | 7 +++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/source/content/adapter_development/external-adapter-registration.md b/docs/source/content/adapter_development/external-adapter-registration.md index 2ce2aec57..9eca643c9 100644 --- a/docs/source/content/adapter_development/external-adapter-registration.md +++ b/docs/source/content/adapter_development/external-adapter-registration.md @@ -22,6 +22,8 @@ ArchitectureAdapterFactory.register_adapter( bridge = TransformerBridge.boot_transformers("my-org/my-model") ``` +> **Important:** The architecture name you register (e.g. `"MyModelForCausalLM"`) must match the `architectures` field in the model's HuggingFace `config.json`. TransformerLens reads this field to look up the adapter. + ### 2. Entry-point registration (recommended for packages) Declare your adapter in your package's `pyproject.toml`: @@ -37,8 +39,8 @@ TransformerLens discovers these automatically on first use. Users just install y When `boot_transformers()` is called: -1. It reads the model's HuggingFace config to determine the architecture class name (e.g. `"MyModelForCausalLM"`). -2. `select_architecture_adapter()` checks the registry. +1. It reads the model's HuggingFace `config.json` to extract the `architectures` field. This field lists the architecture class name (e.g. `"MyModelForCausalLM"`). **This is the name you must use in your registration.** +2. `select_architecture_adapter()` checks the registry for that architecture name. 3. On first call, `discover_entry_points()` scans all installed packages for adapters declared via the `transformer_lens.architectures` entry-point group. 4. The matching adapter class is instantiated and used to build the bridge. diff --git a/tests/unit/model_bridge/test_architecture_adapter_factory.py b/tests/unit/model_bridge/test_architecture_adapter_factory.py index 63eff18ab..431c87764 100644 --- a/tests/unit/model_bridge/test_architecture_adapter_factory.py +++ b/tests/unit/model_bridge/test_architecture_adapter_factory.py @@ -2,12 +2,12 @@ import pytest +from tests.mocks.architecture_adapter import MockArchitectureAdapter from transformer_lens.config import TransformerBridgeConfig from transformer_lens.factories.architecture_adapter_factory import ( SUPPORTED_ARCHITECTURES, ArchitectureAdapterFactory, ) -from tests.mocks.architecture_adapter import MockArchitectureAdapter def _make_cfg(**overrides) -> TransformerBridgeConfig: diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index f13d399af..2264b05f9 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -4,6 +4,7 @@ support for external registration and entry-point discovery. """ +import warnings from importlib.metadata import entry_points from transformer_lens.config import TransformerBridgeConfig @@ -177,8 +178,10 @@ def discover_entry_points(cls) -> None: eps = entry_points(group="transformer_lens.architectures") for ep in eps: cls._adapters[ep.name] = ep.load() - except Exception: - pass + except Exception as e: + warnings.warn( + f"Failed to discover entry points: {e}. " f"External adapters may not be available." + ) cls._entry_points_discovered = True @classmethod From 2ef316102988fba41b01b0033c17e7e77e0a3ef7 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 16 May 2026 13:14:55 +0200 Subject: [PATCH 5/7] fix: Make register_adapter doctest self-contained with inline adapter class --- .../factories/architecture_adapter_factory.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 2264b05f9..6f4c26aff 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -154,11 +154,14 @@ def register_adapter( adapter_class: The adapter class to register. Example: - >>> from transformer_lens.factories import ArchitectureAdapterFactory - >>> ArchitectureAdapterFactory.register_adapter( - ... "MyModelForCausalLM", - ... MyArchitectureAdapter, - ... ) + >>> from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter + >>> from transformer_lens.factories.architecture_adapter_factory import ArchitectureAdapterFactory + >>> class MyAdapter(ArchitectureAdapter): + ... def __init__(self, cfg): + ... super().__init__(cfg) + >>> ArchitectureAdapterFactory.register_adapter("MyModelForCausalLM", MyAdapter) + >>> "MyModelForCausalLM" in ArchitectureAdapterFactory._adapters + True """ cls._adapters[architecture_name] = adapter_class From 856c1f224e98d01951c86b4b6d7046486bb98df6 Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 16 May 2026 13:20:02 +0200 Subject: [PATCH 6/7] fix: Address Copilot review - alias, loop isolation, test state leak --- .../test_architecture_adapter_factory.py | 35 +++++++++++++++++-- .../factories/architecture_adapter_factory.py | 16 ++++++--- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/tests/unit/model_bridge/test_architecture_adapter_factory.py b/tests/unit/model_bridge/test_architecture_adapter_factory.py index 431c87764..c83504637 100644 --- a/tests/unit/model_bridge/test_architecture_adapter_factory.py +++ b/tests/unit/model_bridge/test_architecture_adapter_factory.py @@ -8,6 +8,35 @@ SUPPORTED_ARCHITECTURES, ArchitectureAdapterFactory, ) +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter + + +class OtherMockArchitectureAdapter(ArchitectureAdapter): + """A second mock adapter class used to verify overwrite behaviour.""" + + def __init__(self, cfg=None): + if cfg is None: + cfg = TransformerBridgeConfig( + d_model=512, + d_head=64, + n_layers=2, + n_ctx=1024, + d_vocab=1000, + d_mlp=2048, + default_prepend_bos=True, + architecture="GPT2LMHeadModel", + ) + super().__init__(cfg) + + +@pytest.fixture(autouse=True) +def _isolate_factory_state(): + """Save and restore factory state so tests don't leak into each other.""" + saved_adapters = dict(ArchitectureAdapterFactory._adapters) + saved_discovered = ArchitectureAdapterFactory._entry_points_discovered + yield + ArchitectureAdapterFactory._adapters = saved_adapters + ArchitectureAdapterFactory._entry_points_discovered = saved_discovered def _make_cfg(**overrides) -> TransformerBridgeConfig: @@ -53,9 +82,9 @@ def test_register_adds_to_adapters(self): def test_register_overwrites_existing(self): key = "TestOverwriteForCausalLM" ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) - first = ArchitectureAdapterFactory._adapters[key] - ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) - assert ArchitectureAdapterFactory._adapters[key] is first + assert ArchitectureAdapterFactory._adapters[key] is MockArchitectureAdapter + ArchitectureAdapterFactory.register_adapter(key, OtherMockArchitectureAdapter) + assert ArchitectureAdapterFactory._adapters[key] is OtherMockArchitectureAdapter def test_select_returns_registered_adapter(self): key = "TestSelectForCausalLM" diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 6f4c26aff..ef387b2c6 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -136,7 +136,7 @@ class ArchitectureAdapterFactory: discovery of adapters from installed packages via entry points. """ - _adapters = SUPPORTED_ARCHITECTURES + _adapters = dict(SUPPORTED_ARCHITECTURES) _entry_points_discovered = False @classmethod @@ -179,12 +179,20 @@ def discover_entry_points(cls) -> None: return try: eps = entry_points(group="transformer_lens.architectures") - for ep in eps: - cls._adapters[ep.name] = ep.load() except Exception as e: warnings.warn( - f"Failed to discover entry points: {e}. " f"External adapters may not be available." + f"Failed to discover entry points: {e}. " + f"External adapters may not be available." ) + else: + for ep in eps: + try: + cls._adapters[ep.name] = ep.load() + except Exception as e: + warnings.warn( + f"Failed to load entry point '{ep.name}': {e}. " + f"Skipping this adapter." + ) cls._entry_points_discovered = True @classmethod From 12fd1fab4bd5f2fee83ed4f4eef4b876dd6359af Mon Sep 17 00:00:00 2001 From: huseyincavusbi Date: Sat, 16 May 2026 13:29:01 +0200 Subject: [PATCH 7/7] fix: Apply black formatting to factory file --- transformer_lens/factories/architecture_adapter_factory.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index ef387b2c6..0aafb43ac 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -181,8 +181,7 @@ def discover_entry_points(cls) -> None: eps = entry_points(group="transformer_lens.architectures") except Exception as e: warnings.warn( - f"Failed to discover entry points: {e}. " - f"External adapters may not be available." + f"Failed to discover entry points: {e}. " f"External adapters may not be available." ) else: for ep in eps: @@ -190,8 +189,7 @@ def discover_entry_points(cls) -> None: cls._adapters[ep.name] = ep.load() except Exception as e: warnings.warn( - f"Failed to load entry point '{ep.name}': {e}. " - f"Skipping this adapter." + f"Failed to load entry point '{ep.name}': {e}. " f"Skipping this adapter." ) cls._entry_points_discovered = True