diff --git a/docs/source/content/adapter_development/external-adapter-registration.md b/docs/source/content/adapter_development/external-adapter-registration.md new file mode 100644 index 000000000..9eca643c9 --- /dev/null +++ b/docs/source/content/adapter_development/external-adapter-registration.md @@ -0,0 +1,89 @@ +# External Architecture Adapter Registration + +TransformerLens supports loading custom architecture adapters from **external packages** — no fork required. You can write your adapter, register it, and use it with `boot_transformers()` without modifying TransformerLens source code. + +## Two ways to register + +### 1. Runtime registration + +Call `register_adapter()` in your startup code: + +```python +from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, +) + +ArchitectureAdapterFactory.register_adapter( + "MyModelForCausalLM", + MyArchitectureAdapter, +) + +# Now this works: +bridge = TransformerBridge.boot_transformers("my-org/my-model") +``` + +> **Important:** The architecture name you register (e.g. `"MyModelForCausalLM"`) must match the `architectures` field in the model's HuggingFace `config.json`. TransformerLens reads this field to look up the adapter. + +### 2. Entry-point registration (recommended for packages) + +Declare your adapter in your package's `pyproject.toml`: + +```toml +[project.entry-points."transformer_lens.architectures"] +"MyModelForCausalLM" = "my_package.adapters:MyArchitectureAdapter" +``` + +TransformerLens discovers these automatically on first use. Users just install your package alongside TransformerLens and `boot_transformers()` finds it. + +## How it works + +When `boot_transformers()` is called: + +1. It reads the model's HuggingFace `config.json` to extract the `architectures` field. This field lists the architecture class name (e.g. `"MyModelForCausalLM"`). **This is the name you must use in your registration.** +2. `select_architecture_adapter()` checks the registry for that architecture name. +3. On first call, `discover_entry_points()` scans all installed packages for adapters declared via the `transformer_lens.architectures` entry-point group. +4. The matching adapter class is instantiated and used to build the bridge. + +## Writing an adapter + +Follow the [Architecture Adapter Creation Guide](adapter-creation-guide.md) to build your adapter class. Once written, use either registration method above to plug it into TransformerLens. + +## Example package layout + +``` +my_transformer_plugin/ +├── pyproject.toml # declares the entry point +└── my_transformer_plugin/ + ├── __init__.py + └── adapters.py # contains MyArchitectureAdapter +``` + +**pyproject.toml:** + +```toml +[project] +name = "my-transformer-plugin" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = ["transformer-lens>=3.0"] + +[project.entry-points."transformer_lens.architectures"] +"MyModelForCausalLM" = "my_transformer_plugin.adapters:MyArchitectureAdapter" +``` + +**adapters.py:** + +```python +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + # ... import the bridge components you need +) + +class MyArchitectureAdapter(ArchitectureAdapter): + def __init__(self, cfg): + super().__init__(cfg) + # Set config, weight processing, component mapping + # See the Adapter Creation Guide for details +``` diff --git a/docs/source/content/contributing.md b/docs/source/content/contributing.md index 370a87361..87a28f6b0 100644 --- a/docs/source/content/contributing.md +++ b/docs/source/content/contributing.md @@ -178,4 +178,5 @@ Adapters live in `transformer_lens/model_bridge/supported_architectures/ TransformerBridgeConfig: + defaults = dict( + d_model=64, + d_head=16, + n_layers=2, + n_ctx=64, + n_heads=4, + d_vocab=100, + d_mlp=256, + default_prepend_bos=True, + ) + defaults.update(overrides) + return TransformerBridgeConfig(**defaults) + + +class TestSupportedArchitectures: + """Verify all existing hardcoded entries in SUPPORTED_ARCHITECTURES.""" + + def test_has_common_architectures(self): + common = [ + "GPT2LMHeadModel", + "LlamaForCausalLM", + "MistralForCausalLM", + "Gemma2ForCausalLM", + "Qwen2ForCausalLM", + "BloomForCausalLM", + "FalconForCausalLM", + ] + for arch in common: + assert arch in SUPPORTED_ARCHITECTURES, f"Missing: {arch}" + + +class TestRegisterAdapter: + """Verify runtime adapter registration.""" + + def test_register_adds_to_adapters(self): + key = "TestMockForCausalLM" + ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) + assert key in ArchitectureAdapterFactory._adapters + + def test_register_overwrites_existing(self): + key = "TestOverwriteForCausalLM" + ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) + assert ArchitectureAdapterFactory._adapters[key] is MockArchitectureAdapter + ArchitectureAdapterFactory.register_adapter(key, OtherMockArchitectureAdapter) + assert ArchitectureAdapterFactory._adapters[key] is OtherMockArchitectureAdapter + + def test_select_returns_registered_adapter(self): + key = "TestSelectForCausalLM" + ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) + cfg = _make_cfg(architecture=key) + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, MockArchitectureAdapter) + + +class TestSelectErrors: + """Verify error handling in select_architecture_adapter.""" + + def test_unknown_architecture_raises(self): + cfg = _make_cfg(architecture="NonExistentForCausalLM") + with pytest.raises(ValueError, match="Unsupported architecture"): + ArchitectureAdapterFactory.select_architecture_adapter(cfg) + + def test_none_architecture_raises(self): + cfg = _make_cfg(architecture=None) + with pytest.raises(ValueError, match="must have architecture set"): + ArchitectureAdapterFactory.select_architecture_adapter(cfg) + + +class TestDiscoverEntryPoints: + """Verify entry-point discovery behavior.""" + + def test_discover_is_idempotent(self): + ArchitectureAdapterFactory._entry_points_discovered = False + ArchitectureAdapterFactory.discover_entry_points() + first_run = ArchitectureAdapterFactory._entry_points_discovered + ArchitectureAdapterFactory.discover_entry_points() + assert ArchitectureAdapterFactory._entry_points_discovered is first_run is True + + def test_discover_does_not_remove_existing(self): + key = "TestPreserveForCausalLM" + ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) + ArchitectureAdapterFactory._entry_points_discovered = False + ArchitectureAdapterFactory.discover_entry_points() + assert key in ArchitectureAdapterFactory._adapters diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 02129b856..0aafb43ac 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -1,8 +1,12 @@ """Architecture adapter factory. -This module provides a factory for creating architecture adapters. +This module provides a factory for creating architecture adapters, including +support for external registration and entry-point discovery. """ +import warnings +from importlib.metadata import entry_points + from transformer_lens.config import TransformerBridgeConfig from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.supported_architectures import ( @@ -126,9 +130,68 @@ class ArchitectureAdapterFactory: - """Factory for creating architecture adapters.""" + """Factory for creating architecture adapters. + + Supports external registration via `register_adapter()` and automatic + discovery of adapters from installed packages via entry points. + """ - _adapters = SUPPORTED_ARCHITECTURES + _adapters = dict(SUPPORTED_ARCHITECTURES) + _entry_points_discovered = False + + @classmethod + def register_adapter( + cls, architecture_name: str, adapter_class: type["ArchitectureAdapter"] + ) -> None: + """Register a custom architecture adapter at runtime. + + This allows users to add their own architecture adapters without + modifying TransformerLens source code. + + Args: + architecture_name: The HuggingFace architecture class name + (e.g. ``"Qwen3ForCausalLM"``). + adapter_class: The adapter class to register. + + Example: + >>> from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter + >>> from transformer_lens.factories.architecture_adapter_factory import ArchitectureAdapterFactory + >>> class MyAdapter(ArchitectureAdapter): + ... def __init__(self, cfg): + ... super().__init__(cfg) + >>> ArchitectureAdapterFactory.register_adapter("MyModelForCausalLM", MyAdapter) + >>> "MyModelForCausalLM" in ArchitectureAdapterFactory._adapters + True + """ + cls._adapters[architecture_name] = adapter_class + + @classmethod + def discover_entry_points(cls) -> None: + """Discover and register architecture adapters from installed packages. + + Packages can declare adapters in their ``pyproject.toml``: + ```toml + [project.entry-points."transformer_lens.architectures"] + "MyModelForCausalLM" = "my_package.adapters:MyArchitectureAdapter" + ``` + """ + if cls._entry_points_discovered: + return + try: + eps = entry_points(group="transformer_lens.architectures") + except Exception as e: + warnings.warn( + f"Failed to discover entry points: {e}. " f"External adapters may not be available." + ) + else: + for ep in eps: + try: + cls._adapters[ep.name] = ep.load() + except Exception as e: + warnings.warn( + f"Failed to load entry point '{ep.name}': {e}. " f"Skipping this adapter." + ) + cls._entry_points_discovered = True @classmethod def select_architecture_adapter(cls, cfg: TransformerBridgeConfig) -> ArchitectureAdapter: @@ -143,11 +206,11 @@ def select_architecture_adapter(cls, cfg: TransformerBridgeConfig) -> Architectu Raises: ValueError: If no adapter is found for the given config. """ + cls.discover_entry_points() if cfg.architecture is not None: if cfg.architecture in cls._adapters: return cls._adapters[cfg.architecture](cfg) else: raise ValueError(f"Unsupported architecture: {cfg.architecture}") - # If architecture is None, this is an error since TransformerBridgeConfig should always have it set raise ValueError(f"TransformerBridgeConfig must have architecture set, got: {cfg}")