Skip to content

Commit df3df1c

Browse files
committed
refactor(plugins): make PluginRegistry private with init_plugin invocation
- Rename PluginRegistry to _PluginRegistry (private implementation detail) - Store agent reference in __init__ instead of passing to add_plugin - add_plugin now calls init_plugin synchronously - Add add_plugin_async for async init_plugin implementations - Remove get_plugin, has_plugin, list_plugins (not needed for now) - Update tests for new API and sync/async behavior
1 parent 0a256b8 commit df3df1c

3 files changed

Lines changed: 123 additions & 89 deletions

File tree

src/strands/plugins/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
Example Usage:
77
```python
8-
from strands.plugins import Plugin, PluginRegistry
8+
from strands.plugins import Plugin
99
1010
class LoggingPlugin:
1111
name = "logging"
@@ -15,18 +15,11 @@ def init_plugin(self, agent: Agent) -> None:
1515
1616
def on_model_call(self, event: BeforeModelCallEvent) -> None:
1717
print(f"Model called for {event.agent.name}")
18-
19-
# Use with registry
20-
registry = PluginRegistry()
21-
plugin = LoggingPlugin()
22-
registry.add_plugin(plugin, agent)
2318
```
2419
"""
2520

2621
from .plugin import Plugin
27-
from .registry import PluginRegistry
2822

2923
__all__ = [
3024
"Plugin",
31-
"PluginRegistry",
3225
]

src/strands/plugins/registry.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Plugin registry for managing plugins attached to an agent.
22
3-
This module provides the PluginRegistry class for tracking and managing
3+
This module provides the _PluginRegistry class for tracking and managing
44
plugins that have been initialized with an agent instance.
55
"""
66

7+
import inspect
78
import logging
89
from typing import TYPE_CHECKING
910

@@ -15,15 +16,15 @@
1516
logger = logging.getLogger(__name__)
1617

1718

18-
class PluginRegistry:
19+
class _PluginRegistry:
1920
"""Registry for managing plugins attached to an agent.
2021
21-
The PluginRegistry tracks plugins that have been initialized with an agent,
22-
providing methods to add, retrieve, and check for plugins by name.
22+
The _PluginRegistry tracks plugins that have been initialized with an agent,
23+
providing methods to add plugins and invoke their initialization.
2324
2425
Example:
2526
```python
26-
registry = PluginRegistry()
27+
registry = _PluginRegistry(agent)
2728
2829
class MyPlugin:
2930
name = "my-plugin"
@@ -32,60 +33,64 @@ def init_plugin(self, agent: Agent) -> None:
3233
pass
3334
3435
plugin = MyPlugin()
35-
registry.add_plugin(plugin, agent)
36-
37-
# Check if plugin is registered
38-
if registry.has_plugin("my-plugin"):
39-
retrieved = registry.get_plugin("my-plugin")
36+
registry.add_plugin(plugin)
4037
```
4138
"""
4239

43-
def __init__(self) -> None:
44-
"""Initialize an empty plugin registry."""
40+
def __init__(self, agent: "Agent") -> None:
41+
"""Initialize a plugin registry with an agent reference.
42+
43+
Args:
44+
agent: The agent instance that plugins will be initialized with.
45+
"""
46+
self._agent = agent
4547
self._plugins: dict[str, Plugin] = {}
4648

47-
def add_plugin(self, plugin: Plugin, agent: "Agent") -> None:
48-
"""Add and initialize a plugin with the given agent.
49+
def add_plugin(self, plugin: Plugin) -> None:
50+
"""Add and initialize a plugin with the agent.
51+
52+
This method registers the plugin and calls its init_plugin method
53+
synchronously. For async init_plugin implementations, use add_plugin_async.
4954
5055
Args:
5156
plugin: The plugin to add and initialize.
52-
agent: The agent instance to initialize the plugin with.
5357
5458
Raises:
5559
ValueError: If a plugin with the same name is already registered.
60+
RuntimeError: If the plugin's init_plugin is async (use add_plugin_async instead).
5661
"""
5762
if plugin.name in self._plugins:
5863
raise ValueError(f"plugin_name=<{plugin.name}> | plugin already registered")
5964

60-
logger.debug("plugin_name=<%s> | registering plugin", plugin.name)
61-
self._plugins[plugin.name] = plugin
62-
63-
def get_plugin(self, name: str) -> Plugin | None:
64-
"""Get a plugin by name.
65+
if inspect.iscoroutinefunction(plugin.init_plugin):
66+
raise RuntimeError(
67+
f"plugin_name=<{plugin.name}> | plugin has async init_plugin, use add_plugin_async instead"
68+
)
6569

66-
Args:
67-
name: The name of the plugin to retrieve.
70+
logger.debug("plugin_name=<%s> | registering and initializing plugin", plugin.name)
71+
self._plugins[plugin.name] = plugin
72+
plugin.init_plugin(self._agent)
6873

69-
Returns:
70-
The plugin if found, None otherwise.
71-
"""
72-
return self._plugins.get(name)
74+
async def add_plugin_async(self, plugin: Plugin) -> None:
75+
"""Add and initialize a plugin with the agent asynchronously.
7376
74-
def has_plugin(self, name: str) -> bool:
75-
"""Check if a plugin with the given name is registered.
77+
This method registers the plugin and calls its init_plugin method,
78+
supporting both sync and async implementations.
7679
7780
Args:
78-
name: The name of the plugin to check.
81+
plugin: The plugin to add and initialize.
7982
80-
Returns:
81-
True if the plugin is registered, False otherwise.
83+
Raises:
84+
ValueError: If a plugin with the same name is already registered.
8285
"""
83-
return name in self._plugins
86+
if plugin.name in self._plugins:
87+
raise ValueError(f"plugin_name=<{plugin.name}> | plugin already registered")
88+
89+
logger.debug("plugin_name=<%s> | registering and initializing plugin", plugin.name)
90+
self._plugins[plugin.name] = plugin
8491

85-
def list_plugins(self) -> list[str]:
86-
"""Get a list of all registered plugin names.
92+
if inspect.iscoroutinefunction(plugin.init_plugin):
93+
await plugin.init_plugin(self._agent)
94+
else:
95+
plugin.init_plugin(self._agent)
8796

88-
Returns:
89-
A list of plugin names in registration order.
90-
"""
91-
return list(self._plugins.keys())

tests/strands/plugins/test_plugins.py

Lines changed: 79 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import pytest
66

7-
from strands.plugins import Plugin, PluginRegistry
7+
from strands.plugins import Plugin
8+
from strands.plugins.registry import _PluginRegistry
89

910
# Plugin Protocol Tests
1011

@@ -118,13 +119,7 @@ def init_plugin(self, agent):
118119
assert plugin.name == "property-plugin"
119120

120121

121-
# PluginRegistry Tests
122-
123-
124-
@pytest.fixture
125-
def registry():
126-
"""Create a fresh PluginRegistry for each test."""
127-
return PluginRegistry()
122+
# _PluginRegistry Tests
128123

129124

130125
@pytest.fixture
@@ -133,20 +128,30 @@ def mock_agent():
133128
return unittest.mock.Mock()
134129

135130

136-
def test_plugin_registry_add_plugin(registry, mock_agent):
137-
"""Test adding a plugin to the registry."""
131+
@pytest.fixture
132+
def registry(mock_agent):
133+
"""Create a fresh _PluginRegistry for each test."""
134+
return _PluginRegistry(mock_agent)
135+
136+
137+
def test_plugin_registry_add_plugin_calls_init_plugin(registry, mock_agent):
138+
"""Test adding a plugin calls its init_plugin method."""
138139

139140
class TestPlugin:
140141
name = "test-plugin"
141142

143+
def __init__(self):
144+
self.initialized = False
145+
142146
def init_plugin(self, agent):
143-
pass
147+
self.initialized = True
148+
agent.plugin_initialized = True
144149

145150
plugin = TestPlugin()
146-
registry.add_plugin(plugin, mock_agent)
151+
registry.add_plugin(plugin)
147152

148-
assert registry.has_plugin("test-plugin")
149-
assert registry.get_plugin("test-plugin") is plugin
153+
assert plugin.initialized
154+
assert mock_agent.plugin_initialized
150155

151156

152157
def test_plugin_registry_add_duplicate_raises_error(registry, mock_agent):
@@ -161,51 +166,82 @@ def init_plugin(self, agent):
161166
plugin1 = TestPlugin()
162167
plugin2 = TestPlugin()
163168

164-
registry.add_plugin(plugin1, mock_agent)
169+
registry.add_plugin(plugin1)
165170

166171
with pytest.raises(ValueError, match="plugin_name=<test-plugin> | plugin already registered"):
167-
registry.add_plugin(plugin2, mock_agent)
172+
registry.add_plugin(plugin2)
168173

169174

170-
def test_plugin_registry_get_plugin_not_found(registry):
171-
"""Test getting a plugin that doesn't exist returns None."""
172-
assert registry.get_plugin("nonexistent") is None
175+
def test_plugin_registry_add_plugin_async_raises_runtime_error(registry):
176+
"""Test that add_plugin raises RuntimeError for async plugins."""
173177

178+
class AsyncPlugin:
179+
name = "async-plugin"
174180

175-
def test_plugin_registry_has_plugin_false(registry):
176-
"""Test has_plugin returns False for unregistered plugins."""
177-
assert not registry.has_plugin("nonexistent")
181+
async def init_plugin(self, agent):
182+
pass
178183

184+
plugin = AsyncPlugin()
179185

180-
def test_plugin_registry_list_plugins(registry, mock_agent):
181-
"""Test listing all registered plugins."""
186+
with pytest.raises(RuntimeError, match="use add_plugin_async instead"):
187+
registry.add_plugin(plugin)
182188

183-
class Plugin1:
184-
name = "plugin-1"
185189

186-
def init_plugin(self, agent):
187-
pass
190+
@pytest.mark.asyncio
191+
async def test_plugin_registry_add_plugin_async_with_sync_plugin(mock_agent):
192+
"""Test add_plugin_async works with sync plugins."""
193+
registry = _PluginRegistry(mock_agent)
188194

189-
class Plugin2:
190-
name = "plugin-2"
195+
class SyncPlugin:
196+
name = "sync-plugin"
197+
198+
def __init__(self):
199+
self.initialized = False
191200

192201
def init_plugin(self, agent):
193-
pass
202+
self.initialized = True
194203

195-
class Plugin3:
196-
name = "plugin-3"
204+
plugin = SyncPlugin()
205+
await registry.add_plugin_async(plugin)
197206

198-
def init_plugin(self, agent):
199-
pass
207+
assert plugin.initialized
200208

201-
registry.add_plugin(Plugin1(), mock_agent)
202-
registry.add_plugin(Plugin2(), mock_agent)
203-
registry.add_plugin(Plugin3(), mock_agent)
204209

205-
plugin_names = registry.list_plugins()
206-
assert plugin_names == ["plugin-1", "plugin-2", "plugin-3"]
210+
@pytest.mark.asyncio
211+
async def test_plugin_registry_add_plugin_async_with_async_plugin(mock_agent):
212+
"""Test add_plugin_async works with async plugins."""
213+
registry = _PluginRegistry(mock_agent)
214+
215+
class AsyncPlugin:
216+
name = "async-plugin"
217+
218+
def __init__(self):
219+
self.initialized = False
220+
221+
async def init_plugin(self, agent):
222+
self.initialized = True
223+
224+
plugin = AsyncPlugin()
225+
await registry.add_plugin_async(plugin)
207226

227+
assert plugin.initialized
228+
229+
230+
@pytest.mark.asyncio
231+
async def test_plugin_registry_add_plugin_async_duplicate_raises_error(mock_agent):
232+
"""Test that add_plugin_async raises error for duplicate plugins."""
233+
registry = _PluginRegistry(mock_agent)
208234

209-
def test_plugin_registry_list_plugins_empty(registry):
210-
"""Test listing plugins when registry is empty."""
211-
assert registry.list_plugins() == []
235+
class TestPlugin:
236+
name = "test-plugin"
237+
238+
async def init_plugin(self, agent):
239+
pass
240+
241+
plugin1 = TestPlugin()
242+
plugin2 = TestPlugin()
243+
244+
await registry.add_plugin_async(plugin1)
245+
246+
with pytest.raises(ValueError, match="plugin_name=<test-plugin> | plugin already registered"):
247+
await registry.add_plugin_async(plugin2)

0 commit comments

Comments
 (0)