From 2cbd61771aca20bbd3bca807fd89bef5e668827f Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 2 Jul 2026 14:42:21 -0700 Subject: [PATCH 01/11] Migrate InitializerRegistry onto unified Registry base Move InitializerRegistry from the transitional BaseClassRegistry to the unified Registry base, mirroring the ScenarioRegistry migration (PR #2115). The registry now lives in pyrit/registry/components/ and uses the Registry class catalog (register_class / create_instance / get_class_names / get_all_registered_class_metadata) instead of the old ClassEntry storage. - Add create_and_configure lifecycle method (parallel to scenario's create_and_initialize_async); route configuration_loader through it. - Override _discover for the filesystem scan; _identifier_type returns None. - Inline unregister_and_cleanup since Registry has no unregister. - Update backend initializer_service and configuration_loader consumers. - Migrate tests and the class-registry doc (.py + .ipynb) to the new API. Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- doc/code/registry/1_class_registry.ipynb | 10 +- doc/code/registry/1_class_registry.py | 10 +- pyrit/backend/services/initializer_service.py | 4 +- pyrit/registry/__init__.py | 4 +- pyrit/registry/class_registries/__init__.py | 11 +- pyrit/registry/components/__init__.py | 6 + .../initializer_registry.py | 124 +++++++++++++----- pyrit/setup/configuration_loader.py | 16 +-- .../unit/backend/test_initializer_service.py | 18 +-- .../registry/test_initializer_registry.py | 12 +- tests/unit/setup/test_configuration_loader.py | 15 +-- .../test_scenario_techniques_initializer.py | 2 +- 12 files changed, 139 insertions(+), 93 deletions(-) rename pyrit/registry/{class_registries => components}/initializer_registry.py (72%) diff --git a/doc/code/registry/1_class_registry.ipynb b/doc/code/registry/1_class_registry.ipynb index 0eb278b032..d96e84c3b0 100644 --- a/doc/code/registry/1_class_registry.ipynb +++ b/doc/code/registry/1_class_registry.ipynb @@ -7,7 +7,7 @@ "source": [ "# Listing Available Classes\n", "\n", - "Use `get_names()` to see what's available, or `list_metadata()` for detailed information." + "Use `get_class_names()` to see what's available, or `get_all_registered_class_metadata()` for detailed information." ] }, { @@ -36,11 +36,11 @@ "registry = ScenarioRegistry.get_registry_singleton()\n", "\n", "# Get all registered names\n", - "names = registry.get_names()\n", + "names = registry.get_class_names()\n", "print(f\"Available scenarios: {names[:5]}...\") # Show first 5\n", "\n", "# Get detailed metadata\n", - "metadata = registry.list_metadata()\n", + "metadata = registry.get_all_registered_class_metadata()\n", "for item in metadata[:2]: # Show first 2\n", " print(f\"\\n{item.class_name}:\")\n", " print(f\" Description: {item.class_description[:80]}...\")" @@ -226,11 +226,11 @@ "initializer_registry = InitializerRegistry.get_registry_singleton()\n", "\n", "# Get all registered names\n", - "initializer_names = initializer_registry.get_names()\n", + "initializer_names = initializer_registry.get_class_names()\n", "print(f\"Available initializers: {initializer_names[:5]}...\") # Show first 5\n", "\n", "# Get detailed metadata\n", - "for init_item in initializer_registry.list_metadata()[:2]: # Show first 2\n", + "for init_item in initializer_registry.get_all_registered_class_metadata()[:2]: # Show first 2\n", " print(f\"\\n{init_item.registry_name}:\")\n", " print(f\" Class: {init_item.class_name}\")\n", " print(f\" Description: {init_item.class_description[:80]}...\")" diff --git a/doc/code/registry/1_class_registry.py b/doc/code/registry/1_class_registry.py index ecfd82a89d..0c6d1e6d3d 100644 --- a/doc/code/registry/1_class_registry.py +++ b/doc/code/registry/1_class_registry.py @@ -11,7 +11,7 @@ # %% [markdown] # # Listing Available Classes # -# Use `get_names()` to see what's available, or `list_metadata()` for detailed information. +# Use `get_class_names()` to see what's available, or `get_all_registered_class_metadata()` for detailed information. # %% from pyrit.registry import ScenarioRegistry @@ -19,11 +19,11 @@ registry = ScenarioRegistry.get_registry_singleton() # Get all registered names -names = registry.get_names() +names = registry.get_class_names() print(f"Available scenarios: {names[:5]}...") # Show first 5 # Get detailed metadata -metadata = registry.list_metadata() +metadata = registry.get_all_registered_class_metadata() for item in metadata[:2]: # Show first 2 print(f"\n{item.class_name}:") print(f" Description: {item.class_description[:80]}...") @@ -92,11 +92,11 @@ initializer_registry = InitializerRegistry.get_registry_singleton() # Get all registered names -initializer_names = initializer_registry.get_names() +initializer_names = initializer_registry.get_class_names() print(f"Available initializers: {initializer_names[:5]}...") # Show first 5 # Get detailed metadata -for init_item in initializer_registry.list_metadata()[:2]: # Show first 2 +for init_item in initializer_registry.get_all_registered_class_metadata()[:2]: # Show first 2 print(f"\n{init_item.registry_name}:") print(f" Class: {init_item.class_name}") print(f" Description: {init_item.class_description[:80]}...") diff --git a/pyrit/backend/services/initializer_service.py b/pyrit/backend/services/initializer_service.py index 8e5f86dd54..e8476a076e 100644 --- a/pyrit/backend/services/initializer_service.py +++ b/pyrit/backend/services/initializer_service.py @@ -69,7 +69,7 @@ async def list_initializers_async( Returns: ListRegisteredInitializersResponse with paginated initializer summaries. """ - all_metadata = self._registry.list_metadata() + all_metadata = self._registry.get_all_registered_class_metadata() all_summaries = [_metadata_to_registered_initializer(m) for m in all_metadata] page, has_more = self._paginate(items=all_summaries, cursor=cursor, limit=limit) @@ -90,7 +90,7 @@ async def get_initializer_async(self, *, initializer_name: str) -> RegisteredIni Returns: RegisteredInitializer if found, None otherwise. """ - all_metadata = self._registry.list_metadata() + all_metadata = self._registry.get_all_registered_class_metadata() for metadata in all_metadata: if metadata.registry_name == initializer_name: return _metadata_to_registered_initializer(metadata) diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 906b23ba7b..1696936afd 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -7,14 +7,14 @@ from pyrit.registry.class_registries import ( BaseClassRegistry, ClassEntry, - InitializerMetadata, - InitializerRegistry, ) from pyrit.registry.components import ( AttackTechniqueMetadata, AttackTechniqueRegistry, ConverterMetadata, ConverterRegistry, + InitializerMetadata, + InitializerRegistry, ScenarioMetadata, ScenarioRegistry, ScorerMetadata, diff --git a/pyrit/registry/class_registries/__init__.py b/pyrit/registry/class_registries/__init__.py index 1314305b96..1156520842 100644 --- a/pyrit/registry/class_registries/__init__.py +++ b/pyrit/registry/class_registries/__init__.py @@ -4,8 +4,9 @@ """ Class registries package. -This package contains registries that store classes (type[T]) which can be -instantiated on demand. Examples include ScenarioRegistry and InitializerRegistry. +This package contains the transitional ``BaseClassRegistry`` base that predates +the unified ``Registry``. It survives only until the remaining domains migrate; +new registries should extend ``pyrit.registry.registry.Registry`` instead. For registries that store pre-configured instances, see object_registries/. """ @@ -14,14 +15,8 @@ BaseClassRegistry, ClassEntry, ) -from pyrit.registry.class_registries.initializer_registry import ( - InitializerMetadata, - InitializerRegistry, -) __all__ = [ "BaseClassRegistry", "ClassEntry", - "InitializerRegistry", - "InitializerMetadata", ] diff --git a/pyrit/registry/components/__init__.py b/pyrit/registry/components/__init__.py index 3f4e98f3f2..b767860053 100644 --- a/pyrit/registry/components/__init__.py +++ b/pyrit/registry/components/__init__.py @@ -22,6 +22,10 @@ ConverterMetadata, ConverterRegistry, ) +from pyrit.registry.components.initializer_registry import ( + InitializerMetadata, + InitializerRegistry, +) from pyrit.registry.components.scenario_registry import ( ScenarioMetadata, ScenarioRegistry, @@ -40,6 +44,8 @@ "AttackTechniqueMetadata", "ConverterRegistry", "ConverterMetadata", + "InitializerRegistry", + "InitializerMetadata", "ScorerRegistry", "ScorerMetadata", "ScenarioRegistry", diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/components/initializer_registry.py similarity index 72% rename from pyrit/registry/class_registries/initializer_registry.py rename to pyrit/registry/components/initializer_registry.py index 6d3a9f27cc..9b6117131d 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/components/initializer_registry.py @@ -4,8 +4,15 @@ """ Initializer registry for discovering and cataloging PyRIT initializers. -This module provides a unified registry for discovering all available -PyRITInitializer subclasses from the pyrit/setup/initializers directory structure. +A ``Registry`` for ``PyRITInitializer`` classes that discovers all available +subclasses from the ``pyrit/setup/initializers`` directory structure and from +uploaded custom scripts. Like ``ScenarioRegistry`` it is a class-only unified +``Registry``: it owns a validated class catalog and builds instances via +``create_instance``. Unlike the component registries it does not hold instances +(no ``.instances`` property) and has no ``ComponentIdentifier`` — its declared, +YAML-style inputs live on ``PyRITInitializer.supported_parameters`` and are +applied post-construction via ``set_params_from_args``. Because discovery is a +filesystem scan rather than a package import, ``_discover`` is overridden. """ from __future__ import annotations @@ -15,15 +22,12 @@ import logging from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from pyrit.models import class_name_to_snake_case, validate_registry_name from pyrit.registry.base import ClassRegistryEntry -from pyrit.registry.class_registries.base_class_registry import ( - BaseClassRegistry, - ClassEntry, -) from pyrit.registry.discovery import discover_in_directory +from pyrit.registry.registry import Registry # Compute PYRIT_PATH directly to avoid importing pyrit package # (which triggers heavy imports from __init__.py) @@ -31,6 +35,8 @@ if TYPE_CHECKING: from pyrit.models import Parameter + from pyrit.models.identifiers.component_identifier import \ + ComponentIdentifier from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer logger = logging.getLogger(__name__) @@ -51,15 +57,16 @@ class InitializerMetadata(ClassRegistryEntry): supported_parameters: tuple[Parameter, ...] = field(kw_only=True, default=()) -class InitializerRegistry(BaseClassRegistry["PyRITInitializer", InitializerMetadata]): +class InitializerRegistry(Registry["PyRITInitializer", InitializerMetadata]): """ Registry for discovering and managing available initializers. - This class discovers all PyRITInitializer subclasses from the - pyrit/setup/initializers directory structure. - - Initializers are identified by their filename (e.g., "objective_target", "simple"). - The directory structure is used for organization but not exposed to users. + Discovers all ``PyRITInitializer`` subclasses from the + ``pyrit/setup/initializers`` directory structure via a filesystem scan (so + ``_discover`` is overridden rather than supplying ``_base_type`` / + ``_discovery_package``). Initializers are identified by their suffix-stripped + snake_case class name (e.g., ``"objective_target"``, ``"simple"``); the + directory structure is used for organization but not exposed to users. """ def __init__(self, *, discovery_path: Path | None = None, lazy_discovery: bool = False) -> None: @@ -86,6 +93,14 @@ def __init__(self, *, discovery_path: Path | None = None, lazy_discovery: bool = self._builtin_names: set[str] = set() super().__init__(lazy_discovery=lazy_discovery) + def _metadata_class(self) -> type[InitializerMetadata]: + """Return the concrete metadata dataclass this registry builds.""" + return InitializerMetadata + + def _identifier_type(self) -> type[ComponentIdentifier] | None: + """Return ``None`` since initializers have no ``ComponentIdentifier``; declared params are their contract.""" + return None + def is_builtin(self, name: str) -> bool: """Return True if *name* was registered during built-in discovery.""" self._ensure_discovered() @@ -166,19 +181,18 @@ def _register_initializer( """ try: # Convert class name to snake_case for registry name - registry_name = class_name_to_snake_case(initializer_class.__name__, suffix="Initializer") + registry_name = self._get_registry_name(initializer_class) # Check for registry key collision - if registry_name in self._class_entries: + if registry_name in self._classes: logger.warning( f"Initializer registry name collision: '{registry_name}' " f"conflicts with an already-registered initializer. Original " - f"initializer is kept: {self._class_entries[registry_name].registered_class.__name__}" + f"initializer is kept: {self._classes[registry_name].__name__}" ) return - entry = ClassEntry(registered_class=initializer_class) - self._class_entries[registry_name] = entry + self.register_class(initializer_class, name=registry_name) if builtin: self._builtin_names.add(registry_name) logger.debug(f"Registered initializer: {registry_name} ({initializer_class.__name__})") @@ -186,26 +200,39 @@ def _register_initializer( except Exception as e: logger.warning(f"Failed to register initializer {initializer_class.__name__}: {e}") - def _build_metadata(self, name: str, entry: ClassEntry[PyRITInitializer]) -> InitializerMetadata: + def _get_registry_name(self, cls: type[PyRITInitializer]) -> str: + """ + Key initializers by their suffix-stripped snake_case class name. + + Args: + cls (type[PyRITInitializer]): The initializer class. + + Returns: + str: The registry name (e.g. ``"objective_target"``). + """ + return class_name_to_snake_case(cls.__name__, suffix="Initializer") + + def _build_metadata(self, name: str, cls: type[PyRITInitializer]) -> InitializerMetadata: """ Build metadata for an initializer class. + Instantiates the initializer with no arguments and reads its + ``required_env_vars`` / ``supported_parameters`` off the instance. + Args: name: The registry name of the initializer. - entry: The ClassEntry containing the initializer class. + cls: The initializer class to describe. Returns: InitializerMetadata describing the initializer class. """ - initializer_class = entry.registered_class - - description = entry.get_description(fallback="No description available") + description = ClassRegistryEntry.description_from_docstring(cls, fallback="No description available") try: - instance = initializer_class() + instance = cls() return InitializerMetadata( - class_name=initializer_class.__name__, - class_module=initializer_class.__module__, + class_name=cls.__name__, + class_module=cls.__module__, class_description=description, registry_name=name, required_env_vars=tuple(instance.required_env_vars), @@ -214,13 +241,42 @@ def _build_metadata(self, name: str, entry: ClassEntry[PyRITInitializer]) -> Ini except Exception as e: logger.warning(f"Failed to get metadata for {name}: {e}") return InitializerMetadata( - class_name=initializer_class.__name__, - class_module=initializer_class.__module__, + class_name=cls.__name__, + class_module=cls.__module__, class_description="Error loading initializer metadata", registry_name=name, required_env_vars=(), ) + def create_and_configure(self, name: str, *, args: dict[str, Any] | None = None) -> PyRITInitializer: + """ + Build and parameterize an initializer in one call. + + Mirrors ``ScenarioRegistry.create_and_initialize_async``: the registry — + not the caller — owns the build → set-params → validate lifecycle. Unlike + scenarios, ``initialize_async`` is invoked later by the PyRIT init flow, so + this returns a *configured, not-yet-initialized* instance. + + Args: + name (str): The registry name of the initializer (e.g. ``"objective_target"``). + args (dict[str, Any] | None): Declared parameters to set before + initialization. Coerced to ``self.params`` via + ``set_params_from_args`` and validated against + ``supported_parameters``. Defaults to no parameters. + + Returns: + PyRITInitializer: The configured initializer, ready for ``initialize_async``. + + Raises: + KeyError: If the name is not registered. + ValueError: If the configured parameters are invalid. + """ + instance = self.create_instance(name) + if args: + instance.set_params_from_args(args=args) + instance._validate_params(params=instance.params) + return instance + def register_from_content(self, *, name: str, script_content: str) -> str: """ Register an initializer from uploaded Python source code. @@ -251,7 +307,7 @@ def register_from_content(self, *, name: str, script_content: str) -> str: validate_registry_name(name) - if name in self._class_entries: + if name in self._classes: raise ValueError(f"Initializer '{name}' is already registered. Unregister it first to replace it.") # Deferred: importing pyrit.setup triggers heavy __init__.py chain @@ -295,9 +351,7 @@ def register_from_content(self, *, name: str, script_content: str) -> str: script_path.unlink(missing_ok=True) raise ValueError(f"Failed to load initializer script '{name}': {e}") from e - entry = ClassEntry(registered_class=discovered) - self._class_entries[name] = entry - self._metadata_cache = None + self.register_class(discovered, name=name) logger.info(f"Registered custom initializer: {name} ({discovered.__name__})") return name @@ -319,7 +373,11 @@ def unregister_and_cleanup(self, name: str) -> None: self._ensure_discovered() if name in self._builtin_names: raise ValueError(f"Cannot remove built-in initializer '{name}'.") - self.unregister(name) + if name not in self._classes: + available = ", ".join(self.get_class_names()) + raise KeyError(f"'{name}' not found in registry. Available: {available}") + del self._classes[name] + self._metadata_cache = None script_path = self._get_custom_scripts_dir() / f"{name}.py" script_path.unlink(missing_ok=True) diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index a47f874315..e0e6b7bc36 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -478,19 +478,13 @@ def resolve_initializers(self) -> Sequence["PyRITInitializer"]: logging.getLogger(__name__).info("Running %d initializer(s)...", len(self._initializer_configs)) for config in self._initializer_configs: - initializer_class = registry.get_class(config.name) - if initializer_class is None: - available = ", ".join(sorted(registry.get_names())) + try: + instance = registry.create_and_configure(config.name, args=config.args) + except KeyError as exc: + available = ", ".join(sorted(registry.get_class_names())) raise ValueError( f"Initializer '{config.name}' not found in registry.\nAvailable initializers: {available}" - ) - - # Instantiate and set params if provided - instance = initializer_class() - if config.args: - instance.set_params_from_args(args=config.args) - # Validate params early against supported_parameters to fail fast - instance._validate_params(params=instance.params) + ) from exc resolved.append(instance) diff --git a/tests/unit/backend/test_initializer_service.py b/tests/unit/backend/test_initializer_service.py index c4ffd7a589..55d631a1a9 100644 --- a/tests/unit/backend/test_initializer_service.py +++ b/tests/unit/backend/test_initializer_service.py @@ -77,7 +77,7 @@ async def test_list_initializers_returns_empty_when_no_initializers(self) -> Non with patch.object(InitializerService, "__init__", lambda self: None): service = InitializerService() service._registry = MagicMock() - service._registry.list_metadata.return_value = [] + service._registry.get_all_registered_class_metadata.return_value = [] result = await service.list_initializers_async() @@ -90,7 +90,7 @@ async def test_list_initializers_returns_initializers_from_registry(self) -> Non with patch.object(InitializerService, "__init__", lambda self: None): service = InitializerService() service._registry = MagicMock() - service._registry.list_metadata.return_value = [metadata] + service._registry.get_all_registered_class_metadata.return_value = [metadata] result = await service.list_initializers_async() @@ -111,7 +111,7 @@ async def test_list_initializers_paginates_with_limit(self) -> None: with patch.object(InitializerService, "__init__", lambda self: None): service = InitializerService() service._registry = MagicMock() - service._registry.list_metadata.return_value = metadata_list + service._registry.get_all_registered_class_metadata.return_value = metadata_list result = await service.list_initializers_async(limit=3) @@ -125,7 +125,7 @@ async def test_list_initializers_paginates_with_cursor(self) -> None: with patch.object(InitializerService, "__init__", lambda self: None): service = InitializerService() service._registry = MagicMock() - service._registry.list_metadata.return_value = metadata_list + service._registry.get_all_registered_class_metadata.return_value = metadata_list result = await service.list_initializers_async(limit=2, cursor="init_1") @@ -140,7 +140,7 @@ async def test_list_initializers_last_page_has_more_false(self) -> None: with patch.object(InitializerService, "__init__", lambda self: None): service = InitializerService() service._registry = MagicMock() - service._registry.list_metadata.return_value = metadata_list + service._registry.get_all_registered_class_metadata.return_value = metadata_list result = await service.list_initializers_async(limit=5) @@ -154,7 +154,7 @@ async def test_list_initializers_with_no_env_vars(self) -> None: with patch.object(InitializerService, "__init__", lambda self: None): service = InitializerService() service._registry = MagicMock() - service._registry.list_metadata.return_value = [metadata] + service._registry.get_all_registered_class_metadata.return_value = [metadata] result = await service.list_initializers_async() @@ -171,7 +171,7 @@ async def test_get_initializer_returns_matching_initializer(self) -> None: with patch.object(InitializerService, "__init__", lambda self: None): service = InitializerService() service._registry = MagicMock() - service._registry.list_metadata.return_value = [metadata] + service._registry.get_all_registered_class_metadata.return_value = [metadata] result = await service.get_initializer_async(initializer_name="target") @@ -182,7 +182,7 @@ async def test_get_initializer_returns_none_for_missing(self) -> None: with patch.object(InitializerService, "__init__", lambda self: None): service = InitializerService() service._registry = MagicMock() - service._registry.list_metadata.return_value = [] + service._registry.get_all_registered_class_metadata.return_value = [] result = await service.get_initializer_async(initializer_name="nonexistent") @@ -316,7 +316,7 @@ async def test_register_initializer_calls_registry(self) -> None: service = InitializerService() mock_registry = MagicMock() mock_registry.register_from_content.return_value = "my_custom" - mock_registry.list_metadata.return_value = [ + mock_registry.get_all_registered_class_metadata.return_value = [ _make_initializer_metadata(registry_name="my_custom", class_name="MyCustomInitializer") ] service._registry = mock_registry diff --git a/tests/unit/registry/test_initializer_registry.py b/tests/unit/registry/test_initializer_registry.py index 6396d5caad..02eb1eca32 100644 --- a/tests/unit/registry/test_initializer_registry.py +++ b/tests/unit/registry/test_initializer_registry.py @@ -7,8 +7,7 @@ import pytest -from pyrit.registry.class_registries.base_class_registry import ClassEntry -from pyrit.registry.class_registries.initializer_registry import ( +from pyrit.registry.components.initializer_registry import ( PYRIT_PATH, InitializerRegistry, ) @@ -47,8 +46,7 @@ async def initialize_async(self) -> None: pass registry = InitializerRegistry(lazy_discovery=True) - entry = ClassEntry(registered_class=FakeInitializer) - metadata = registry._build_metadata("fake", entry) + metadata = registry._build_metadata("fake", FakeInitializer) assert metadata.class_description == "A fake initializer for testing." assert metadata.class_name == "FakeInitializer" @@ -182,8 +180,7 @@ class BuiltinInit(PyRITInitializer): async def initialize_async(self) -> None: pass - entry = ClassEntry(registered_class=BuiltinInit) - lazy_registry._class_entries["builtin_test"] = entry + lazy_registry._classes["builtin_test"] = BuiltinInit lazy_registry._builtin_names.add("builtin_test") with pytest.raises(ValueError, match="Cannot remove built-in"): @@ -199,8 +196,7 @@ class FakeInit(PyRITInitializer): async def initialize_async(self) -> None: pass - entry = ClassEntry(registered_class=FakeInit) - lazy_registry._class_entries["fake"] = entry + lazy_registry._classes["fake"] = FakeInit lazy_registry._builtin_names.add("fake") assert lazy_registry.is_builtin("fake") is True diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py index dab15726c8..ee9b04693c 100644 --- a/tests/unit/setup/test_configuration_loader.py +++ b/tests/unit/setup/test_configuration_loader.py @@ -360,11 +360,9 @@ async def test_initialize_pyrit_async_with_initializers(self, mock_registry_cls, mock_registry = mock.MagicMock() mock_registry_cls.return_value = mock_registry - # Mock an initializer class - mock_initializer_class = mock.MagicMock() + # Mock the configured initializer instance produced by the registry mock_initializer_instance = mock.MagicMock() - mock_initializer_class.return_value = mock_initializer_instance - mock_registry.get_class.return_value = mock_initializer_class + mock_registry.create_and_configure.return_value = mock_initializer_instance config = ConfigurationLoader( memory_db_type="in_memory", @@ -372,9 +370,8 @@ async def test_initialize_pyrit_async_with_initializers(self, mock_registry_cls, ) await config.initialize_pyrit_async() - # Verify registry was used to resolve initializer - mock_registry.get_class.assert_called_once_with("simple") - mock_initializer_class.assert_called_once_with() + # Verify registry was used to resolve and configure the initializer + mock_registry.create_and_configure.assert_called_once_with("simple", args=None) # Verify initialize was called with resolved initializers mock_init.assert_called_once() @@ -386,8 +383,8 @@ async def test_initialize_pyrit_async_unknown_initializer_raises_error(self, moc """Test that unknown initializer name raises ValueError.""" mock_registry = mock.MagicMock() mock_registry_cls.return_value = mock_registry - mock_registry.get_class.return_value = None - mock_registry.get_names.return_value = ["simple", "airt"] + mock_registry.create_and_configure.side_effect = KeyError("unknown_initializer") + mock_registry.get_class_names.return_value = ["simple", "airt"] config = ConfigurationLoader( memory_db_type="in_memory", diff --git a/tests/unit/setup/test_scenario_techniques_initializer.py b/tests/unit/setup/test_scenario_techniques_initializer.py index 878f990437..27a0168622 100644 --- a/tests/unit/setup/test_scenario_techniques_initializer.py +++ b/tests/unit/setup/test_scenario_techniques_initializer.py @@ -330,5 +330,5 @@ def test_initializer_is_discovered(self): from pyrit.registry import InitializerRegistry registry = InitializerRegistry() - names = set(registry.get_names()) + names = set(registry.get_class_names()) assert "scenario_technique" in names From 7429ce307be2d3b1653f9d583d3c953d256ca26f Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 2 Jul 2026 16:48:07 -0700 Subject: [PATCH 02/11] Route scenario_run_service initializers through create_and_configure _run_initializers_async fetched the class via get_class() and did the instantiate + set_params_from_args dance inline, diverging from how the same service builds scenarios (create_and_initialize_async). Route it through the registry-owned create_and_configure lifecycle instead, which also adds the parameter validation it was skipping. Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/services/scenario_run_service.py | 6 ++---- tests/unit/backend/test_scenario_run_service.py | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index b770ac9722..2769ea4d18 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -227,13 +227,11 @@ async def _run_initializers_async(self, *, request: RunScenarioRequest) -> None: initializer_registry = InitializerRegistry.get_registry_singleton() for initializer_name in request.initializers: + args = (request.initializer_args or {}).get(initializer_name) try: - initializer_class = initializer_registry.get_class(initializer_name) + instance = initializer_registry.create_and_configure(initializer_name, args=args) except KeyError as e: raise ValueError(f"Initializer not found: {e}") from None - instance = initializer_class() - if request.initializer_args and initializer_name in request.initializer_args: - instance.set_params_from_args(args=request.initializer_args[initializer_name]) await instance.initialize_async() def _resolve_target(self, *, request: RunScenarioRequest) -> "PromptTarget": diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 89180bfd80..98f27517d3 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -139,7 +139,7 @@ def mock_all_registries(mock_memory): mock_tr.instances.get_names.return_value = ["my_target"] mock_ir = MagicMock() - mock_ir.get_class.return_value = MagicMock(return_value=MagicMock(initialize_async=AsyncMock())) + mock_ir.create_and_configure.return_value = MagicMock(initialize_async=AsyncMock()) # By default, return a matching DB result for get_run / list_runs queries db_result = _make_db_scenario_result() @@ -215,7 +215,7 @@ async def test_start_run_invalid_initializer_raises_value_error(self, mock_memor mock_sr.get_class.return_value = MagicMock() mock_ir = MagicMock() - mock_ir.get_class.side_effect = KeyError("'bad_init' not found") + mock_ir.create_and_configure.side_effect = KeyError("'bad_init' not found") with ( patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), @@ -468,7 +468,7 @@ async def test_start_run_runs_initializers(self, mock_all_registries) -> None: """Test that initializers are run during start_run_async.""" service = ScenarioRunService() mock_ir = mock_all_registries["initializer_registry"] - mock_init_instance = mock_ir.get_class.return_value.return_value + mock_init_instance = mock_ir.create_and_configure.return_value response = await service.start_run_async( request=_make_request(initializers=["target", "load_default_datasets"]) From 511a2d5bd84833c95a0fa368d64065054f4948fa Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 2 Jul 2026 16:58:38 -0700 Subject: [PATCH 03/11] Make InitializerRegistry own external initializer script loading Move the external .py script-loading logic out of setup/initialization.py and into InitializerRegistry.create_from_script_paths, so the registry is the single owner of turning scripts into initializer instances. Share the module-import and module-defined-subclass discovery helpers with register_from_content. Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- .../components/initializer_registry.py | 137 +++++++++++++++--- pyrit/setup/initialization.py | 108 ++------------ .../registry/test_initializer_registry.py | 62 +++++++- tests/unit/setup/test_initialization.py | 22 +-- 4 files changed, 194 insertions(+), 135 deletions(-) diff --git a/pyrit/registry/components/initializer_registry.py b/pyrit/registry/components/initializer_registry.py index 9b6117131d..e57a6e8483 100644 --- a/pyrit/registry/components/initializer_registry.py +++ b/pyrit/registry/components/initializer_registry.py @@ -34,9 +34,11 @@ PYRIT_PATH = Path(__file__).parent.parent.parent.resolve() if TYPE_CHECKING: + from collections.abc import Sequence + from types import ModuleType + from pyrit.models import Parameter - from pyrit.models.identifiers.component_identifier import \ - ComponentIdentifier + from pyrit.models.identifiers.component_identifier import ComponentIdentifier from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer logger = logging.getLogger(__name__) @@ -277,6 +279,113 @@ def create_and_configure(self, name: str, *, args: dict[str, Any] | None = None) instance._validate_params(params=instance.params) return instance + @staticmethod + def _load_module_from_path(*, file_path: Path, module_name: str) -> ModuleType: + """ + Import a Python file as an anonymous module. + + Args: + file_path: Path to the ``.py`` file to import. + module_name: The synthetic module name to load it under. + + Returns: + ModuleType: The executed module. + + Raises: + ValueError: If an import spec could not be created for the file. + """ + spec = importlib.util.spec_from_file_location(module_name, file_path) + if not spec or not spec.loader: + raise ValueError(f"Could not load initializer script: {file_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + @staticmethod + def _module_defined_initializers(*, module: ModuleType, base_class: type) -> list[type]: + """ + Find concrete ``PyRITInitializer`` subclasses defined in *module*. + + Only classes whose ``__module__`` is *module* are returned, so classes + merely imported into the script are ignored. + + Args: + module: The imported module to scan. + base_class: The ``PyRITInitializer`` base class. + + Returns: + list[type]: Concrete initializer classes defined in the module. + """ + return [ + attr + for attr_name in dir(module) + if ( + inspect.isclass(attr := getattr(module, attr_name)) + and issubclass(attr, base_class) + and attr is not base_class + and not inspect.isabstract(attr) + and attr.__module__ == module.__name__ + ) + ] + + def create_from_script_paths(self, *, script_paths: Sequence[str | Path]) -> list[PyRITInitializer]: + """ + Load initializer instances from external Python script files. + + The registry owns turning script files into initializers: each ``.py`` + file is imported and every ``PyRITInitializer`` subclass *defined in that + file* (imported ones are ignored) is instantiated. Instances are returned + in load order, ready for the caller to validate and initialize; they are + not added to the class catalog. + + Args: + script_paths (Sequence[str | Path]): Python (.py) file paths to load + initializers from. Relative paths resolve against the current + working directory. + + Returns: + list[PyRITInitializer]: Instantiated initializers, in load order. + + Raises: + FileNotFoundError: If a script path does not exist. + ValueError: If a path is not a ``.py`` file or defines no initializer. + """ + from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + + resolved = self.resolve_script_paths(script_paths=[str(p) for p in script_paths]) + + instances: list[PyRITInitializer] = [] + for script_path in resolved: + if script_path.suffix != ".py": + raise ValueError(f"Initialization script must be a Python file (.py): {script_path}") + + logger.info(f"Loading initializers from script: {script_path}") + try: + module = self._load_module_from_path( + file_path=script_path, module_name=f"init_script_{script_path.stem}" + ) + file_instances: list[PyRITInitializer] = [] + for init_cls in self._module_defined_initializers(module=module, base_class=PyRITInitializer): + try: + file_instances.append(init_cls()) + logger.debug(f"Found and instantiated {init_cls.__name__} in {script_path.name}") + except Exception as e: + logger.warning(f"Could not instantiate {init_cls.__name__} from {script_path.name}: {e}") + + if not file_instances: + raise ValueError( + f"Initialization script {script_path} must contain at least one PyRITInitializer subclass. " + f"Define a class that inherits from PyRITInitializer." + ) + + instances.extend(file_instances) + logger.debug(f"Loaded {len(file_instances)} initializer(s) from {script_path.name}") + except Exception as e: + logger.error(f"Error loading initializers from script {script_path}: {e}") + raise + + return instances + def register_from_content(self, *, name: str, script_content: str) -> str: """ Register an initializer from uploaded Python source code. @@ -322,28 +431,12 @@ def register_from_content(self, *, name: str, script_content: str) -> str: raise ValueError(f"Failed to write initializer script: {e}") from e try: - spec = importlib.util.spec_from_file_location(f"custom_initializer.{name}", script_path) - if not spec or not spec.loader: - raise ValueError(f"Could not load initializer script for '{name}'") - - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - discovered: type[PyRITInitializer] | None = None - for attr_name in dir(module): - attr = getattr(module, attr_name) - if ( - inspect.isclass(attr) - and issubclass(attr, PyRITInitializer) - and attr is not PyRITInitializer - and not inspect.isabstract(attr) - and attr.__module__ == module.__name__ - ): - discovered = attr - break + module = self._load_module_from_path(file_path=script_path, module_name=f"custom_initializer.{name}") - if discovered is None: + discovered_classes = self._module_defined_initializers(module=module, base_class=PyRITInitializer) + if not discovered_classes: raise ValueError(f"Uploaded script for '{name}' does not contain a concrete PyRITInitializer subclass.") + discovered = discovered_classes[0] except ValueError: script_path.unlink(missing_ok=True) raise diff --git a/pyrit/setup/initialization.py b/pyrit/setup/initialization.py index 8679a828b9..ec913ec1cd 100644 --- a/pyrit/setup/initialization.py +++ b/pyrit/setup/initialization.py @@ -10,12 +10,7 @@ from pyrit.common import path from pyrit.common.apply_defaults import reset_default_values -from pyrit.memory import ( - AzureSQLMemory, - CentralMemory, - MemoryInterface, - SQLiteMemory, -) +from pyrit.memory import AzureSQLMemory, CentralMemory, MemoryInterface, SQLiteMemory if TYPE_CHECKING: from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer @@ -96,95 +91,6 @@ def _print_msg(message: str, quiet: bool, log: bool) -> None: logger.info(message) -def _load_initializers_from_scripts(*, script_paths: Sequence[str | pathlib.Path]) -> Sequence["PyRITInitializer"]: - """ - Load PyRITInitializer instances from external Python files. - - Each script file should contain one or more PyRITInitializer classes. All classes - that inherit from PyRITInitializer will be automatically discovered and instantiated. - - Args: - script_paths (Sequence[str | pathlib.Path]): Sequence of file paths to Python scripts to load. - - Returns: - Sequence[PyRITInitializer]: List of PyRITInitializer instances loaded from the scripts. - - Raises: - FileNotFoundError: If a script path does not exist. - ValueError: If a script path is not a Python file or doesn't contain valid initializers. - - Example: - Script content should be a subclass of PyRITInitializer e.g. like SimpleInitializer - """ - # Import here to avoid circular imports - from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer - - loaded_initializers = [] - - for script_path in script_paths: - # Convert to Path object if string - script = pathlib.Path(script_path) - - # Validate the script exists - if not script.exists(): - raise FileNotFoundError(f"Initialization script not found: {script}") - - # Validate it's a Python file - if script.suffix != ".py": - raise ValueError(f"Initialization script must be a Python file (.py): {script}") - - logger.info(f"Loading initializers from script: {script}") - - # Load the script as a module - try: - import importlib.util - - spec = importlib.util.spec_from_file_location(f"init_script_{script.stem}", script) - if spec is None or spec.loader is None: - raise ValueError(f"Could not load initialization script: {script}") - - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Auto-discover PyRITInitializer subclasses in the module - script_initializers = [] - - # Look for all PyRITInitializer subclasses defined in the module - for name in dir(module): - obj = getattr(module, name) - # Check if it's a class, is a subclass of PyRITInitializer, - # and is not the base class itself - if ( - isinstance(obj, type) - and issubclass(obj, PyRITInitializer) - and obj is not PyRITInitializer - and obj.__module__ == module.__name__ - ): - try: - # Instantiate the initializer class - initializer = obj() - script_initializers.append(initializer) - logger.debug(f"Found and instantiated {name} in {script.name}") - except Exception as e: - logger.warning(f"Could not instantiate {name} from {script.name}: {e}") - # Continue to try other classes rather than failing completely - - if not script_initializers: - raise ValueError( - f"Initialization script {script} must contain at least one PyRITInitializer subclass. " - f"Define a class that inherits from PyRITInitializer." - ) - - loaded_initializers.extend(script_initializers) - logger.debug(f"Loaded {len(script_initializers)} initializer(s) from {script.name}") - - except Exception as e: - logger.error(f"Error loading initializers from script {script}: {e}") - raise - - return loaded_initializers - - def _parse_akv_secret_url(secret_url: str) -> tuple[str, str, str | None]: """ Parse an AKV secret URL into vault URL, secret name, and optional version. @@ -306,8 +212,8 @@ async def initialize_pyrit_async( memory_db_type (MemoryDatabaseType): The MemoryDatabaseType string literal which indicates the memory instance to use for central memory. Options include "InMemory", "SQLite", and "AzureSQL". initialization_scripts (Sequence[str | pathlib.Path] | None): Optional sequence of Python script paths - that contain PyRITInitializer classes. Each script must define either a get_initializers() function - or an 'initializers' variable that returns/contains a list of PyRITInitializer instances. + that define PyRITInitializer subclasses. Every initializer subclass defined in each file is + loaded and executed. Loading is handled by the InitializerRegistry. initializers (Sequence[PyRITInitializer] | None): Optional sequence of PyRITInitializer instances to execute directly. These provide type-safe, validated configuration with clear documentation. env_files (Sequence[pathlib.Path] | None): Optional sequence of environment file paths to load @@ -356,9 +262,13 @@ async def initialize_pyrit_async( # Combine directly provided initializers with those loaded from scripts all_initializers = list(initializers) if initializers else [] - # Load additional initializers from scripts + # Load additional initializers from scripts — the registry owns turning + # external script files into initializer instances. if initialization_scripts: - script_initializers = _load_initializers_from_scripts(script_paths=initialization_scripts) + from pyrit.registry import InitializerRegistry + + registry = InitializerRegistry.get_registry_singleton() + script_initializers = registry.create_from_script_paths(script_paths=initialization_scripts) all_initializers.extend(script_initializers) # Execute all initializers in order diff --git a/tests/unit/registry/test_initializer_registry.py b/tests/unit/registry/test_initializer_registry.py index 02eb1eca32..a9ba35d961 100644 --- a/tests/unit/registry/test_initializer_registry.py +++ b/tests/unit/registry/test_initializer_registry.py @@ -7,10 +7,7 @@ import pytest -from pyrit.registry.components.initializer_registry import ( - PYRIT_PATH, - InitializerRegistry, -) +from pyrit.registry.components.initializer_registry import PYRIT_PATH, InitializerRegistry from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer @@ -209,3 +206,60 @@ def test_is_builtin_returns_false_for_custom_initializers(lazy_registry): lazy_registry.register_from_content(name="custom", script_content=_VALID_SCRIPT) assert lazy_registry.is_builtin("custom") is False + + +# ============================================================================ +# create_from_script_paths Tests +# ============================================================================ + + +def _write_initializer_script(directory: Path, filename: str, *class_names: str) -> Path: + """Write a script defining one or more PyRITInitializer subclasses.""" + body = "from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer\n\n" + for class_name in class_names: + body += ( + f"class {class_name}(PyRITInitializer):\n" + f" async def initialize_async(self) -> None:\n" + f" pass\n\n" + ) + script_path = directory / filename + script_path.write_text(body) + return script_path + + +def test_create_from_script_paths_loads_multiple_classes(lazy_registry): + """Test that all initializer subclasses defined in a file are instantiated.""" + with tempfile.TemporaryDirectory() as temp_dir: + script_path = _write_initializer_script(Path(temp_dir), "multi.py", "FirstInit", "SecondInit") + + instances = lazy_registry.create_from_script_paths(script_paths=[script_path]) + + assert {type(i).__name__ for i in instances} == {"FirstInit", "SecondInit"} + # Loading does not add the classes to the catalog. + assert lazy_registry.get_class_names() == [] + + +def test_create_from_script_paths_rejects_non_python_file(lazy_registry): + """Test that a non-.py path raises ValueError.""" + with tempfile.TemporaryDirectory() as temp_dir: + bad_path = Path(temp_dir) / "not_python.txt" + bad_path.write_text("hello") + + with pytest.raises(ValueError, match="must be a Python file"): + lazy_registry.create_from_script_paths(script_paths=[bad_path]) + + +def test_create_from_script_paths_no_subclass_raises_value_error(lazy_registry): + """Test that a file defining no initializer subclass raises ValueError.""" + with tempfile.TemporaryDirectory() as temp_dir: + empty_path = Path(temp_dir) / "empty.py" + empty_path.write_text("x = 1\n") + + with pytest.raises(ValueError, match="must contain at least one"): + lazy_registry.create_from_script_paths(script_paths=[empty_path]) + + +def test_create_from_script_paths_missing_file_raises(lazy_registry): + """Test that a missing script path raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError): + lazy_registry.create_from_script_paths(script_paths=["definitely_missing_script.py"]) diff --git a/tests/unit/setup/test_initialization.py b/tests/unit/setup/test_initialization.py index c7ba4d8ca1..b919df4338 100644 --- a/tests/unit/setup/test_initialization.py +++ b/tests/unit/setup/test_initialization.py @@ -12,17 +12,13 @@ from pyrit.common.apply_defaults import reset_default_values from pyrit.common.singleton import Singleton +from pyrit.registry import InitializerRegistry from pyrit.setup import IN_MEMORY, initialize_pyrit_async -from pyrit.setup.initialization import ( - _load_env_from_akv_async, - _load_environment_files, - _load_initializers_from_scripts, - _parse_akv_secret_url, -) +from pyrit.setup.initialization import _load_env_from_akv_async, _load_environment_files, _parse_akv_secret_url class TestLoadInitializersFromScripts: - """Tests for _load_initializers_from_scripts function.""" + """Tests for InitializerRegistry.create_from_script_paths.""" def test_load_initializer_from_script(self): """Test loading an initializer from a Python script.""" @@ -47,7 +43,9 @@ async def initialize_async(self) -> None: script_path = f.name try: - initializers = _load_initializers_from_scripts(script_paths=[script_path]) + initializers = InitializerRegistry.get_registry_singleton().create_from_script_paths( + script_paths=[script_path] + ) assert len(initializers) == 1 assert initializers[0].name == "Test Initializer" finally: @@ -56,7 +54,9 @@ async def initialize_async(self) -> None: def test_script_not_found_raises_error(self): """Test that FileNotFoundError is raised for non-existent script.""" with pytest.raises(FileNotFoundError): - _load_initializers_from_scripts(script_paths=["nonexistent_script.py"]) + InitializerRegistry.get_registry_singleton().create_from_script_paths( + script_paths=["nonexistent_script.py"] + ) def test_ignores_imported_initializer_classes(self): """Test that imported initializer classes are not instantiated from the script.""" @@ -106,7 +106,9 @@ async def initialize_async(self) -> None: """ ) - initializers = _load_initializers_from_scripts(script_paths=[script_path]) + initializers = InitializerRegistry.get_registry_singleton().create_from_script_paths( + script_paths=[script_path] + ) assert len(initializers) == 1 assert initializers[0].name == "Local" From 3158f89e63fbd976cd274c5d29e19b4594549447 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 2 Jul 2026 17:13:28 -0700 Subject: [PATCH 04/11] Standardize create_and_configure param to initializer_params Rename the create_and_configure argument from 'args' to 'initializer_params' so it parallels ScenarioRegistry.create_and_initialize_async's 'scenario_params'. The method name intentionally stays create_and_configure (it stops before initialize_async, which the PyRIT init flow runs in order). Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- .../backend/services/scenario_run_service.py | 6 ++++-- .../components/initializer_registry.py | 21 +++++++++++-------- pyrit/setup/configuration_loader.py | 2 +- tests/unit/setup/test_configuration_loader.py | 2 +- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 2769ea4d18..87aa2419e3 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -227,9 +227,11 @@ async def _run_initializers_async(self, *, request: RunScenarioRequest) -> None: initializer_registry = InitializerRegistry.get_registry_singleton() for initializer_name in request.initializers: - args = (request.initializer_args or {}).get(initializer_name) + initializer_params = (request.initializer_args or {}).get(initializer_name) try: - instance = initializer_registry.create_and_configure(initializer_name, args=args) + instance = initializer_registry.create_and_configure( + initializer_name, initializer_params=initializer_params + ) except KeyError as e: raise ValueError(f"Initializer not found: {e}") from None await instance.initialize_async() diff --git a/pyrit/registry/components/initializer_registry.py b/pyrit/registry/components/initializer_registry.py index e57a6e8483..4bfaaebb3e 100644 --- a/pyrit/registry/components/initializer_registry.py +++ b/pyrit/registry/components/initializer_registry.py @@ -250,19 +250,22 @@ def _build_metadata(self, name: str, cls: type[PyRITInitializer]) -> Initializer required_env_vars=(), ) - def create_and_configure(self, name: str, *, args: dict[str, Any] | None = None) -> PyRITInitializer: + def create_and_configure( + self, name: str, *, initializer_params: dict[str, Any] | None = None + ) -> PyRITInitializer: """ Build and parameterize an initializer in one call. - Mirrors ``ScenarioRegistry.create_and_initialize_async``: the registry — - not the caller — owns the build → set-params → validate lifecycle. Unlike - scenarios, ``initialize_async`` is invoked later by the PyRIT init flow, so - this returns a *configured, not-yet-initialized* instance. + Parallels ``ScenarioRegistry.create_and_initialize_async`` (which takes + ``scenario_params``): the registry — not the caller — owns the + build → set-params → validate lifecycle. Unlike scenarios, + ``initialize_async`` is invoked later by the PyRIT init flow, so this stops + at ``configure`` and returns a *configured, not-yet-initialized* instance. Args: name (str): The registry name of the initializer (e.g. ``"objective_target"``). - args (dict[str, Any] | None): Declared parameters to set before - initialization. Coerced to ``self.params`` via + initializer_params (dict[str, Any] | None): Declared parameters to set + before initialization. Coerced to ``self.params`` via ``set_params_from_args`` and validated against ``supported_parameters``. Defaults to no parameters. @@ -274,8 +277,8 @@ def create_and_configure(self, name: str, *, args: dict[str, Any] | None = None) ValueError: If the configured parameters are invalid. """ instance = self.create_instance(name) - if args: - instance.set_params_from_args(args=args) + if initializer_params: + instance.set_params_from_args(args=initializer_params) instance._validate_params(params=instance.params) return instance diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index e0e6b7bc36..d2edadf4f7 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -479,7 +479,7 @@ def resolve_initializers(self) -> Sequence["PyRITInitializer"]: for config in self._initializer_configs: try: - instance = registry.create_and_configure(config.name, args=config.args) + instance = registry.create_and_configure(config.name, initializer_params=config.args) except KeyError as exc: available = ", ".join(sorted(registry.get_class_names())) raise ValueError( diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py index ee9b04693c..3c33fac5dc 100644 --- a/tests/unit/setup/test_configuration_loader.py +++ b/tests/unit/setup/test_configuration_loader.py @@ -371,7 +371,7 @@ async def test_initialize_pyrit_async_with_initializers(self, mock_registry_cls, await config.initialize_pyrit_async() # Verify registry was used to resolve and configure the initializer - mock_registry.create_and_configure.assert_called_once_with("simple", args=None) + mock_registry.create_and_configure.assert_called_once_with("simple", initializer_params=None) # Verify initialize was called with resolved initializers mock_init.assert_called_once() From a1b489dfa2134197e4e39b753d4175716da93073 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 2 Jul 2026 17:18:48 -0700 Subject: [PATCH 05/11] Add composable TagQuery filtering to instance registries Wire the existing registry-agnostic TagQuery into DefaultInstanceRegistry (and the InstanceRegistry protocol) via query_by_tags, so callers can filter held instances with arbitrary AND/OR/exclude tag predicates instead of only the single-key get_by_tag. Matching is on tag keys, consistent with TagQuery. Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- pyrit/registry/instance_registry.py | 23 +++++++++++++++ tests/unit/registry/test_instance_registry.py | 29 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/pyrit/registry/instance_registry.py b/pyrit/registry/instance_registry.py index b5c9ed8b98..08b74cd44f 100644 --- a/pyrit/registry/instance_registry.py +++ b/pyrit/registry/instance_registry.py @@ -29,6 +29,8 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterator + from pyrit.registry.tag_query import TagQuery + T = TypeVar("T", bound=Identifiable) # The type of items stored @@ -96,6 +98,10 @@ def get_by_tag(self, *, tag: str, value: str | None = None) -> list[RegistryEntr """Return entries carrying ``tag`` (optionally matching ``value``), sorted by name.""" ... + def query_by_tags(self, *, query: TagQuery) -> list[RegistryEntry[T]]: + """Return entries whose tag keys satisfy the composable ``TagQuery``, sorted by name.""" + ... + def add_tags(self, *, name: str, tags: dict[str, str] | list[str]) -> None: """Add tags to an existing entry.""" ... @@ -318,6 +324,23 @@ def get_by_tag(self, *, tag: str, value: str | None = None) -> list[RegistryEntr results.append(entry) return results + def query_by_tags(self, *, query: TagQuery) -> list[RegistryEntry[T]]: + """ + Get entries whose tag keys satisfy a composable ``TagQuery``. + + Where ``get_by_tag`` matches a single key (optionally a value), this + evaluates an arbitrary AND / OR / exclude predicate built with ``TagQuery`` + (e.g. ``TagQuery.all("core") & TagQuery.any_of("fast", "cheap")``). Matching + is on the tag *keys* only; tag values are not considered. + + Args: + query (TagQuery): The predicate to evaluate against each entry's tag keys. + + Returns: + list[RegistryEntry[T]]: Matching entries sorted by name. + """ + return [entry for entry in self.get_all_instances() if query.matches(set(entry.tags))] + def add_tags(self, *, name: str, tags: dict[str, str] | list[str]) -> None: """ Add tags to an existing entry. diff --git a/tests/unit/registry/test_instance_registry.py b/tests/unit/registry/test_instance_registry.py index df66cc4506..ec607517e4 100644 --- a/tests/unit/registry/test_instance_registry.py +++ b/tests/unit/registry/test_instance_registry.py @@ -16,6 +16,7 @@ RegistryEntry, SupportsInstances, ) +from pyrit.registry.tag_query import TagQuery class _TestItem(Identifiable): @@ -258,6 +259,34 @@ def test_get_by_tag_no_match(self, registry: DefaultInstanceRegistry[_TestItem]) registry.register(_item("v1"), name="n1", tags=["fast"]) assert registry.get_by_tag(tag="missing") == [] + def test_query_by_tags_and_predicate(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1", tags=["core", "fast"]) + registry.register(_item("v2"), name="n2", tags=["core", "slow"]) + registry.register(_item("v3"), name="n3", tags=["fast"]) + query = TagQuery.all("core") & TagQuery.any_of("fast", "cheap") + assert [e.name for e in registry.query_by_tags(query=query)] == ["n1"] + + def test_query_by_tags_exclude(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1", tags=["core"]) + registry.register(_item("v2"), name="n2", tags=["core", "deprecated"]) + query = TagQuery.all("core") & TagQuery.none_of("deprecated") + assert [e.name for e in registry.query_by_tags(query=query)] == ["n1"] + + def test_query_by_tags_matches_keys_ignoring_values(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1", tags={"speed": "fast"}) + registry.register(_item("v2"), name="n2", tags={"speed": "slow"}) + assert [e.name for e in registry.query_by_tags(query=TagQuery.all("speed"))] == ["n1", "n2"] + + def test_query_by_tags_returns_sorted_by_name(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="zeta", tags=["t"]) + registry.register(_item("v2"), name="alpha", tags=["t"]) + assert [e.name for e in registry.query_by_tags(query=TagQuery.any_of("t"))] == ["alpha", "zeta"] + + def test_query_by_tags_empty_query_returns_all(self, registry: DefaultInstanceRegistry[_TestItem]): + registry.register(_item("v1"), name="n1", tags=["a"]) + registry.register(_item("v2"), name="n2", tags=["b"]) + assert [e.name for e in registry.query_by_tags(query=TagQuery())] == ["n1", "n2"] + def test_normalize_tags_none(self, registry: DefaultInstanceRegistry[_TestItem]): assert registry._normalize_tags(None) == {} From 0576ad2cc835037e4e4325e2650faf756f5ea132 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 2 Jul 2026 17:28:37 -0700 Subject: [PATCH 06/11] Delete legacy class_registries and object_registries scaffolding Remove the now-unused BaseClassRegistry/ClassEntry and BaseInstanceRegistry legacy stacks that the unified Registry migration made obsolete. RegistryEntry is re-pointed to its canonical home in instance_registry, and the RegistryProtocol docstring plus a test docstring are updated to drop references to the deleted classes. test_base.py keeps its ClassRegistryEntry and _matches_filters coverage. Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- pyrit/registry/__init__.py | 12 +- pyrit/registry/base.py | 5 +- pyrit/registry/class_registries/__init__.py | 22 - .../class_registries/base_class_registry.py | 414 ---------- pyrit/registry/object_registries/__init__.py | 23 - .../base_instance_registry.py | 342 -------- tests/unit/registry/test_base.py | 130 --- .../registry/test_base_instance_registry.py | 762 ------------------ tests/unit/setup/test_targets_initializer.py | 2 +- 9 files changed, 4 insertions(+), 1708 deletions(-) delete mode 100644 pyrit/registry/class_registries/__init__.py delete mode 100644 pyrit/registry/class_registries/base_class_registry.py delete mode 100644 pyrit/registry/object_registries/__init__.py delete mode 100644 pyrit/registry/object_registries/base_instance_registry.py delete mode 100644 tests/unit/registry/test_base_instance_registry.py diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 1696936afd..0cabdbd0e1 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -4,10 +4,6 @@ """Registry module for PyRIT class and object registries.""" from pyrit.registry.base import RegistryProtocol -from pyrit.registry.class_registries import ( - BaseClassRegistry, - ClassEntry, -) from pyrit.registry.components import ( AttackTechniqueMetadata, AttackTechniqueRegistry, @@ -30,11 +26,8 @@ from pyrit.registry.instance_registry import ( DefaultInstanceRegistry, InstanceRegistry, - SupportsInstances, -) -from pyrit.registry.object_registries import ( - BaseInstanceRegistry, RegistryEntry, + SupportsInstances, ) from pyrit.registry.registry import Registry from pyrit.registry.tag_query import TagQuery @@ -42,15 +35,12 @@ __all__ = [ "AttackTechniqueRegistry", "AttackTechniqueMetadata", - "BaseClassRegistry", - "BaseInstanceRegistry", "ConverterRegistry", "ConverterMetadata", "DefaultInstanceRegistry", "InstanceRegistry", "Registry", "SupportsInstances", - "ClassEntry", "discover_in_directory", "discover_in_package", "discover_subclasses_in_loaded_modules", diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py index f39e30e33c..7cbd5bfd32 100644 --- a/pyrit/registry/base.py +++ b/pyrit/registry/base.py @@ -98,9 +98,8 @@ class RegistryProtocol(Protocol[MetadataT]): """ Protocol defining the common interface for all registries. - Both class registries (BaseClassRegistry) and object registries - (BaseInstanceRegistry) implement this interface, enabling code that - works with either registry type. + Registries implement this interface, enabling code that works with any + registry type. Type Parameters: MetadataT: The metadata dataclass type (e.g., ScenarioMetadata). diff --git a/pyrit/registry/class_registries/__init__.py b/pyrit/registry/class_registries/__init__.py deleted file mode 100644 index 1156520842..0000000000 --- a/pyrit/registry/class_registries/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Class registries package. - -This package contains the transitional ``BaseClassRegistry`` base that predates -the unified ``Registry``. It survives only until the remaining domains migrate; -new registries should extend ``pyrit.registry.registry.Registry`` instead. - -For registries that store pre-configured instances, see object_registries/. -""" - -from pyrit.registry.class_registries.base_class_registry import ( - BaseClassRegistry, - ClassEntry, -) - -__all__ = [ - "BaseClassRegistry", - "ClassEntry", -] diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py deleted file mode 100644 index b1668f6177..0000000000 --- a/pyrit/registry/class_registries/base_class_registry.py +++ /dev/null @@ -1,414 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Base class registry for PyRIT. - -This module provides the abstract base class for registries that store classes (type[T]). -These registries allow on-demand instantiation of registered classes. - -For registries that store pre-configured instances, see object_registries/. - -Terminology: -- **Metadata**: A TypedDict describing a registered class (e.g., ScenarioMetadata) -- **Class**: The actual Python class (type[T]) that can be instantiated -- **Instance**: A created object of that class -- **ClassEntry**: Internal wrapper holding a class plus optional factory/defaults -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar - -if TYPE_CHECKING: - from collections.abc import Callable, Iterator - - from typing_extensions import Self - -from pyrit.models import class_name_to_snake_case -from pyrit.registry.base import ClassRegistryEntry, RegistryProtocol - -# Type variable for the registered class type -T = TypeVar("T") -# Type variable for the metadata TypedDict -MetadataT = TypeVar("MetadataT") - - -class ClassEntry(Generic[T]): - """ - Internal wrapper for a registered class. - - This holds the class itself (type[T]) along with optional factory - and default parameters for creating instances. - - Note: This is an internal implementation detail. Users interact with - registries via get_class(), create_instance(), and list_metadata(). - - Attributes: - registered_class: The actual Python class (type[T]). - factory: Optional callable to create instances with custom logic. - default_kwargs: Default keyword arguments for instance creation. - """ - - def __init__( - self, - *, - registered_class: type[T], - factory: Callable[..., T] | None = None, - default_kwargs: dict[str, object] | None = None, - ) -> None: - """ - Initialize a class entry. - - Args: - registered_class: The actual Python class (type[T]). - factory: Optional callable that creates an instance. - default_kwargs: Default keyword arguments for instantiation. - """ - self.registered_class = registered_class - self.factory = factory - self.default_kwargs = default_kwargs or {} - - def get_description(self, *, fallback: str = "") -> str: - """ - Resolve description from docstring, falling back to provided default. - - Returns: - str: The resolved description string. - """ - return ClassRegistryEntry.description_from_docstring(self.registered_class, fallback=fallback) - - def create_instance(self, **kwargs: object) -> T: - """ - Create an instance of the registered class. - - Args: - **kwargs: Additional keyword arguments. These override default_kwargs. - - Returns: - An instance of type T. - """ - merged_kwargs = {**self.default_kwargs, **kwargs} - - if self.factory is not None: - return self.factory(**merged_kwargs) - return self.registered_class(**merged_kwargs) - - -class BaseClassRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]): - """ - Abstract base class for registries that store classes (type[T]). - - This class implements RegistryProtocol and provides the common infrastructure - for class registries including: - - Lazy discovery of classes - - Registration of classes or factory callables - - Metadata caching - - Consistent API: get_class(), get_names(), list_metadata(), create_instance() - - Singleton pattern support via get_registry_singleton() - - Subclasses must implement: - - _discover(): Populate the registry with discovered classes - - _build_metadata(): Build a metadata TypedDict for a class - - Type Parameters: - T: The type of classes being registered (e.g., Scenario, PromptTarget). - MetadataT: The TypedDict type for metadata (e.g., ScenarioMetadata). - """ - - # Class-level singleton instances, keyed by registry class - _instances: dict[type, BaseClassRegistry[object, object]] = {} - - def __init__(self, *, lazy_discovery: bool = True) -> None: - """ - Initialize the registry. - - Args: - lazy_discovery: If True, discovery is deferred until first access. - If False, discovery runs immediately in constructor. - """ - # Maps registry names to ClassEntry wrappers - self._class_entries: dict[str, ClassEntry[T]] = {} - self._metadata_cache: list[MetadataT] | None = None - self._discovered = False - self._lazy_discovery = lazy_discovery - - if not lazy_discovery: - self._discover() - self._discovered = True - - @classmethod - def get_registry_singleton(cls) -> Self: - """ - Get the singleton instance of this registry. - - Creates the instance on first call with default parameters. - - Returns: - The singleton instance of this registry class. - """ - if cls not in cls._instances: - cls._instances[cls] = cls() # type: ignore[ty:invalid-assignment] - return cls._instances[cls] # type: ignore[ty:invalid-return-type] - - @classmethod - def reset_instance(cls) -> None: - """ - Reset the singleton instance. - - Useful for testing or when re-discovery is needed. - """ - if cls in cls._instances: - del cls._instances[cls] - - def _ensure_discovered(self) -> None: - """Ensure discovery has been performed. Runs discovery on first access.""" - if not self._discovered: - self._discover() - self._discovered = True - - @abstractmethod - def _discover(self) -> None: - """ - Perform discovery of registry classes. - - Subclasses implement this to populate self._class_entries with discovered classes. - """ - - @abstractmethod - def _build_metadata(self, name: str, entry: ClassEntry[T]) -> MetadataT: - """ - Build metadata dictionary for a registered class. - - Subclasses must implement this to provide registry-specific metadata. - - Args: - name: The registry name (snake_case identifier). - entry: The ClassEntry containing the registered class. - - Returns: - A metadata dataclass with descriptive information about the registered class. - """ - - def _require_entry(self, name: str) -> ClassEntry[T]: - """ - Resolve a registered ``ClassEntry`` by name or raise. - - Shared lookup used by ``get_class`` and ``create_instance`` so the - "not found" behavior (and its error message listing the class catalog) - lives in one place. - - Args: - name: The registry name (snake_case identifier). - - Returns: - The registered ``ClassEntry``. - - Raises: - KeyError: If the name is not registered. - """ - self._ensure_discovered() - entry = self._class_entries.get(name) - if entry is None: - available = ", ".join(self.get_names()) - raise KeyError(f"'{name}' not found in registry. Available: {available}") - return entry - - def get_class(self, name: str) -> type[T]: - """ - Get a registered class by name. - - Args: - name: The registry name (snake_case identifier). - - Returns: - The registered class (type[T]). - Note: This returns the class itself, not an instance. - - Raises: - KeyError: If the name is not registered. - """ - return self._require_entry(name).registered_class - - def get_entry(self, name: str) -> ClassEntry[T] | None: - """ - Get the full ClassEntry for a registered class. - - This is useful when you need access to factory or default_kwargs. - - Args: - name: The registry name. - - Returns: - The ClassEntry containing class, factory, and defaults, or None if not found. - """ - self._ensure_discovered() - return self._class_entries.get(name) - - def get_names(self) -> list[str]: - """ - Get a sorted list of all registered names. - - These are the snake_case registry keys (e.g., "encoding", "self_ask_refusal"), - not the actual class names (e.g., "EncodingScenario", "SelfAskRefusalScorer"). - - Returns: - Sorted list of registry names. - """ - self._ensure_discovered() - return sorted(self._class_entries.keys()) - - def list_metadata( - self, - *, - include_filters: dict[str, object] | None = None, - exclude_filters: dict[str, object] | None = None, - ) -> list[MetadataT]: - """ - List metadata for all registered classes, optionally filtered. - - Supports filtering on any metadata property: - - Simple types (str, int, bool): exact match - - List types: checks if filter value is in the list - - Args: - include_filters: Optional dict of filters that items must match. - Keys are metadata property names, values are the filter criteria. - All filters must match (AND logic). - exclude_filters: Optional dict of filters that items must NOT match. - Keys are metadata property names, values are the filter criteria. - Any matching filter excludes the item. - - Returns: - List of metadata dictionaries (TypedDict) describing each registered class. - Note: This returns descriptive info, not the classes themselves. - """ - from pyrit.registry.base import _matches_filters - - self._ensure_discovered() - - if self._metadata_cache is None: - self._metadata_cache = [ - self._build_metadata(name, entry) for name, entry in sorted(self._class_entries.items()) - ] - - if not include_filters and not exclude_filters: - return self._metadata_cache - - return [ - m - for m in self._metadata_cache - if _matches_filters(m, include_filters=include_filters, exclude_filters=exclude_filters) - ] - - def register( - self, - cls: type[T], - *, - name: str | None = None, - factory: Callable[..., T] | None = None, - default_kwargs: dict[str, object] | None = None, - ) -> None: - """ - Register a class with the registry. - - Args: - cls: The class to register (type[T], not an instance). - name: Optional custom registry name. If not provided, derived from class name. - factory: Optional callable for creating instances with custom logic. - default_kwargs: Default keyword arguments for instance creation. - """ - if name is None: - name = self._get_registry_name(cls) - - entry = ClassEntry( - registered_class=cls, - factory=factory, - default_kwargs=default_kwargs, - ) - self._class_entries[name] = entry - self._metadata_cache = None - - def unregister(self, name: str) -> None: - """ - Remove a registered class from the registry. - - Args: - name: The registry name of the class to remove. - - Raises: - KeyError: If the name is not registered. - """ - self._ensure_discovered() - if name not in self._class_entries: - available = ", ".join(self.get_names()) - raise KeyError(f"'{name}' not found in registry. Available: {available}") - del self._class_entries[name] - self._metadata_cache = None - - def create_instance(self, name: str, **kwargs: object) -> T: - """ - Create an instance of a registered class. - - Args: - name: The registry name of the class. - **kwargs: Keyword arguments to pass to the factory or constructor. - - Returns: - A new instance of type T. - - Raises: - KeyError: If the name is not registered. - """ - self._ensure_discovered() - entry = self._class_entries.get(name) - if entry is None: - available = ", ".join(self.get_names()) - raise KeyError(f"'{name}' not found in registry. Available: {available}") - return entry.create_instance(**kwargs) - - def _get_registry_name(self, cls: type[T]) -> str: - """ - Get the registry name for a class. - - Subclasses can override this to customize name derivation. - Default implementation converts CamelCase to snake_case. - - Args: - cls: The class to get a name for. - - Returns: - The registry name (snake_case identifier). - """ - return class_name_to_snake_case(cls.__name__) - - def __contains__(self, name: str) -> bool: - """ - Check if a name is registered. - - Returns: - True if the name is registered, False otherwise. - """ - self._ensure_discovered() - return name in self._class_entries - - def __len__(self) -> int: - """ - Get the count of registered classes. - - Returns: - The number of registered classes. - """ - self._ensure_discovered() - return len(self._class_entries) - - def __iter__(self) -> Iterator[str]: - """ - Iterate over registered names. - - Returns: - An iterator over sorted registered names. - """ - self._ensure_discovered() - return iter(sorted(self._class_entries.keys())) diff --git a/pyrit/registry/object_registries/__init__.py b/pyrit/registry/object_registries/__init__.py deleted file mode 100644 index 16de1a1fd2..0000000000 --- a/pyrit/registry/object_registries/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Object registries package. - -This package contains the legacy instance-only registry stack still used by -``AttackTechniqueRegistry``. Component registries that hold pre-configured -instances (converters, scorers, targets) now live in ``registry/components/`` as -``Registry`` subclasses that expose their instances via the ``.instances`` -property. -""" - -from pyrit.registry.object_registries.base_instance_registry import ( - BaseInstanceRegistry, - RegistryEntry, -) - -__all__ = [ - # Base classes - "BaseInstanceRegistry", - "RegistryEntry", -] diff --git a/pyrit/registry/object_registries/base_instance_registry.py b/pyrit/registry/object_registries/base_instance_registry.py deleted file mode 100644 index 39b53ba1c5..0000000000 --- a/pyrit/registry/object_registries/base_instance_registry.py +++ /dev/null @@ -1,342 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Base instance registry for PyRIT. - -.. note:: - - **Legacy stack — do not build new registries on this.** New component - registries should subclass ``Registry`` (a class catalog that can - build instances by name) and hold pre-configured instances via the - ``.instances`` property (a ``DefaultInstanceRegistry``). See - ``ConverterRegistry`` for the target shape. No production registry - subclasses this anymore; it is retained only for backward compatibility - and is removed once external dependents migrate. - -This module provides ``BaseInstanceRegistry``, the shared infrastructure for -registries that store ``Identifiable`` objects (not classes): singleton -lifecycle, registration, tags, metadata, container protocol. - -Subclass directly for registries that store factories or other -non-retrievable items. For registries where callers retrieve stored objects -directly, use ``Registry`` + the ``.instances`` property -(``DefaultInstanceRegistry``) instead. - -For registries that store classes (type[T]), see ``class_registries/``. -""" - -from __future__ import annotations - -from abc import ABC -from typing import TYPE_CHECKING, Any, Generic, TypeVar - -from pyrit.models import ComponentIdentifier, Identifiable -from pyrit.registry.base import RegistryProtocol -from pyrit.registry.instance_registry import RegistryEntry - -if TYPE_CHECKING: - from collections.abc import Iterator - - from typing_extensions import Self - -# Re-exported for back-compat; the canonical definition now lives in -# ``pyrit.registry.instance_registry`` alongside the new instance-registry capability. -__all__ = ["BaseInstanceRegistry", "RegistryEntry"] - -T = TypeVar("T", bound=Identifiable) # The type of items stored - - -class BaseInstanceRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T]): - """ - Abstract base class providing shared registry infrastructure. - - .. note:: - - **Legacy — do not subclass for new registries.** New component - registries subclass ``Registry`` and expose retained instances - via the ``.instances`` property (``DefaultInstanceRegistry``), which - carries this same surface (``register``/``get``/``get_by_tag``/ - ``add_tags``/``find_dependents_of_tag``/``list_metadata``). This class - is no longer subclassed by any production registry and is retained - only for backward compatibility. - - Provides singleton lifecycle, registration, tag-based lookup, metadata - filtering, and the standard container protocol (``__contains__``, - ``__len__``, ``__iter__``). - - Subclass directly when stored items should not be retrievable via - ``get()`` (e.g., factory registries). For registries that expose - direct item retrieval, use ``Registry`` + the ``.instances`` property - (``DefaultInstanceRegistry``) instead. - - All stored items must implement ``Identifiable``, which provides - ``get_identifier()`` for metadata generation. - - Type Parameters: - T: The type of items stored in the registry (must be Identifiable). - """ - - # Class-level singleton instances, keyed by registry class - _instances: dict[type, BaseInstanceRegistry[Any]] = {} - - @classmethod - def get_registry_singleton(cls) -> Self: - """ - Get the singleton instance of this registry. - - Creates the instance on first call with default parameters. - - Returns: - The singleton instance of this registry class. - """ - if cls not in cls._instances: - cls._instances[cls] = cls() - return cls._instances[cls] # type: ignore[ty:invalid-return-type] - - @classmethod - def reset_instance(cls) -> None: - """ - Reset the singleton instance. - - Useful for testing or reinitializing the registry. - """ - if cls in cls._instances: - del cls._instances[cls] - - @staticmethod - def _normalize_tags( - tags: dict[str, str] | list[str] | None = None, - ) -> dict[str, str]: - """ - Normalize tags into a ``dict[str, str]``. - - Args: - tags: Tags as a dict, a list of string keys (values default to ``""``), - or ``None`` (returns empty dict). - - Returns: - A ``dict[str, str]`` of normalised tags. - """ - if tags is None: - return {} - if isinstance(tags, list): - return dict.fromkeys(tags, "") - return dict(tags) - - def __init__(self) -> None: - """Initialize the registry.""" - # Maps registry names to registry entries - self._registry_items: dict[str, RegistryEntry[T]] = {} - self._metadata_cache: list[ComponentIdentifier] | None = None - - def register( - self, - instance: T, - *, - name: str, - tags: dict[str, str] | list[str] | None = None, - metadata: dict[str, Any] | None = None, - ) -> None: - """ - Register an item. - - Args: - instance: The item to register. - name: The registry name for this item. - tags: Optional tags for categorisation. Accepts a ``dict[str, str]`` - or a ``list[str]`` (each string becomes a key with value ``""``). - metadata: Optional metadata dict for capability flags or other - per-entry data that should not appear in tags. - """ - normalized = self._normalize_tags(tags) - self._registry_items[name] = RegistryEntry( - name=name, - instance=instance, - tags=normalized, - metadata=metadata or {}, - ) - self._metadata_cache = None - - def get_names(self) -> list[str]: - """ - Get a sorted list of all registered names. - - Returns: - Sorted list of registry names (keys). - """ - return sorted(self._registry_items.keys()) - - def get_by_tag( - self, - *, - tag: str, - value: str | None = None, - ) -> list[RegistryEntry[T]]: - """ - Get all entries that have a given tag, optionally matching a specific value. - - Args: - tag: The tag key to match. - value: If provided, only entries whose tag value equals this are returned. - If ``None``, any entry that has the tag key is returned regardless of value. - - Returns: - List of matching RegistryEntry objects sorted by name. - """ - results: list[RegistryEntry[T]] = [] - for name in sorted(self._registry_items.keys()): - entry = self._registry_items[name] - if tag in entry.tags and (value is None or entry.tags[tag] == value): - results.append(entry) - return results - - def add_tags( - self, - *, - name: str, - tags: dict[str, str] | list[str], - ) -> None: - """ - Add tags to an existing registry entry. - - Args: - name: The registry name of the entry to tag. - tags: Tags to add. Accepts a ``dict[str, str]`` - or a ``list[str]`` (each string becomes a key with value ``""``). - - Raises: - KeyError: If no entry with the given name exists. - """ - entry = self._registry_items.get(name) - if entry is None: - raise KeyError(f"No entry named '{name}' in registry.") - entry.tags.update(self._normalize_tags(tags)) - self._metadata_cache = None - - def find_dependents_of_tag(self, *, tag: str) -> list[RegistryEntry[T]]: - """ - Find entries whose children depend on entries with the given tag. - - Scans each registry entry's ``ComponentIdentifier`` tree and checks - whether any child's ``eval_hash`` matches the ``eval_hash`` of an - entry that carries *tag*. Entries that themselves carry *tag* are - excluded from the results. - - This enables automatic dependency detection: for example, tagging - base refusal scorers with ``"refusal"`` lets you discover all - wrapper scorers (inverters, composites) that embed a refusal scorer - without any explicit ``depends_on`` declaration. - - Args: - tag: The tag key that identifies the "base" entries. - - Returns: - List of ``RegistryEntry`` objects that depend on tagged entries, - sorted by name. - """ - # Collect eval_hashes of all tagged entries - tagged_hashes: set[str] = set() - tagged_names: set[str] = set() - for entry in self.get_by_tag(tag=tag): - tagged_names.add(entry.name) - identifier = self._build_metadata(entry.name, entry.instance) - if identifier.eval_hash: - tagged_hashes.add(identifier.eval_hash) - - if not tagged_hashes: - return [] - - # Find non-tagged entries whose children reference a tagged eval_hash - dependents: list[RegistryEntry[T]] = [] - for name in sorted(self._registry_items.keys()): - if name in tagged_names: - continue - entry = self._registry_items[name] - identifier = self._build_metadata(name, entry.instance) - child_hashes = identifier._collect_child_eval_hashes() - if child_hashes & tagged_hashes: - dependents.append(entry) - return dependents - - def list_metadata( - self, - *, - include_filters: dict[str, object] | None = None, - exclude_filters: dict[str, object] | None = None, - ) -> list[ComponentIdentifier]: - """ - List metadata for all registered items, optionally filtered. - - Supports filtering on any metadata property: - - Simple types (str, int, bool): exact match - - List types: checks if filter value is in the list - - Args: - include_filters: Optional dict of filters that items must match. - Keys are metadata property names, values are the filter criteria. - All filters must match (AND logic). - exclude_filters: Optional dict of filters that items must NOT match. - Keys are metadata property names, values are the filter criteria. - Any matching filter excludes the item. - - Returns: - List of ComponentIdentifier metadata for each registered item. - """ - from pyrit.registry.base import _matches_filters - - if self._metadata_cache is None: - items = [] - for name in sorted(self._registry_items.keys()): - entry = self._registry_items[name] - items.append(self._build_metadata(name, entry.instance)) - self._metadata_cache = items - - if not include_filters and not exclude_filters: - return self._metadata_cache - - return [ - m - for m in self._metadata_cache - if _matches_filters(m, include_filters=include_filters, exclude_filters=exclude_filters) - ] - - def _build_metadata(self, name: str, instance: T) -> ComponentIdentifier: - """ - Build metadata for an item via its ``Identifiable`` interface. - - Args: - name: The registry name of the item. - instance: The item. - - Returns: - The item's ComponentIdentifier. - """ - return instance.get_identifier() - - def __contains__(self, name: str) -> bool: - """ - Check if a name is registered. - - Returns: - True if the name is registered, False otherwise. - """ - return name in self._registry_items - - def __len__(self) -> int: - """ - Get the count of registered items. - - Returns: - The number of registered items. - """ - return len(self._registry_items) - - def __iter__(self) -> Iterator[str]: - """ - Iterate over registered names. - - Returns: - An iterator over sorted registered names. - """ - return iter(sorted(self._registry_items.keys())) diff --git a/tests/unit/registry/test_base.py b/tests/unit/registry/test_base.py index 380718b554..30bb73b806 100644 --- a/tests/unit/registry/test_base.py +++ b/tests/unit/registry/test_base.py @@ -3,10 +3,7 @@ from dataclasses import dataclass, field -import pytest - from pyrit.registry.base import ClassRegistryEntry, _matches_filters -from pyrit.registry.class_registries.base_class_registry import BaseClassRegistry, ClassEntry @dataclass(frozen=True) @@ -16,21 +13,6 @@ class MetadataWithTags(ClassRegistryEntry): tags: tuple[str, ...] = field(kw_only=True) -class _TestRegistry(BaseClassRegistry[object, ClassRegistryEntry]): - """Minimal concrete registry for testing BaseClassRegistry methods.""" - - def _discover(self) -> None: - pass - - def _build_metadata(self, name: str, entry: ClassEntry[object]) -> ClassRegistryEntry: - return ClassRegistryEntry( - class_name=entry.registered_class.__name__, - class_module=entry.registered_class.__module__, - class_description=entry.get_description(fallback=""), - registry_name=name, - ) - - class TestDescriptionFromDocstring: """Tests for ClassRegistryEntry.description_from_docstring.""" @@ -63,24 +45,6 @@ class NoDoc: assert result == "" -class TestClassEntryGetDescription: - """Tests for ClassEntry.get_description.""" - - def test_returns_docstring_description(self): - class Documented: - """A documented class.""" - - entry = ClassEntry(registered_class=Documented) - assert entry.get_description() == "A documented class." - - def test_returns_fallback_when_no_docstring(self): - class Undocumented: - pass - - entry = ClassEntry(registered_class=Undocumented) - assert entry.get_description(fallback="No description available") == "No description available" - - class TestMatchesFilters: """Tests for the _matches_filters function.""" @@ -226,97 +190,3 @@ def test_matches_filters_combined_include_and_exclude(self): ) is False ) - - -# ============================================================================ -# BaseClassRegistry.unregister Tests -# ============================================================================ - - -class _DummyClass: - """A dummy class for registry testing.""" - - -class _AnotherClass: - """Another dummy class.""" - - -def test_unregister_removes_entry(): - """Test that unregister removes a registered entry.""" - registry = _TestRegistry(lazy_discovery=True) - registry.register(_DummyClass, name="dummy") - assert "dummy" in registry - - registry.unregister("dummy") - assert "dummy" not in registry - assert len(registry) == 0 - - -def test_unregister_raises_key_error_for_missing(): - """Test that unregister raises KeyError when name is not registered.""" - registry = _TestRegistry(lazy_discovery=True) - - with pytest.raises(KeyError, match="not_here"): - registry.unregister("not_here") - - -def test_unregister_key_error_lists_available_names(): - """Test that the KeyError message includes available names.""" - registry = _TestRegistry(lazy_discovery=True) - registry.register(_DummyClass, name="alpha") - registry.register(_AnotherClass, name="beta") - - with pytest.raises(KeyError, match="alpha"): - registry.unregister("missing") - - -def test_unregister_invalidates_metadata_cache(): - """Test that unregister clears the metadata cache.""" - registry = _TestRegistry(lazy_discovery=True) - registry.register(_DummyClass, name="cached") - - registry.list_metadata() - assert registry._metadata_cache is not None - - registry.unregister("cached") - assert registry._metadata_cache is None - - -def test_unregister_does_not_affect_other_entries(): - """Test that unregistering one entry leaves others intact.""" - registry = _TestRegistry(lazy_discovery=True) - registry.register(_DummyClass, name="keep") - registry.register(_AnotherClass, name="remove") - - registry.unregister("remove") - - assert "keep" in registry - assert "remove" not in registry - assert registry.get_class("keep") is _DummyClass - - -def test_unregister_then_re_register(): - """Test that an entry can be re-registered after being unregistered.""" - registry = _TestRegistry(lazy_discovery=True) - registry.register(_DummyClass, name="reuse") - - registry.unregister("reuse") - assert "reuse" not in registry - - registry.register(_AnotherClass, name="reuse") - assert registry.get_class("reuse") is _AnotherClass - - -def test_unregister_makes_metadata_reflect_removal(): - """Test that list_metadata no longer includes the unregistered entry.""" - registry = _TestRegistry(lazy_discovery=True) - registry.register(_DummyClass, name="alpha") - registry.register(_AnotherClass, name="beta") - - assert len(registry.list_metadata()) == 2 - - registry.unregister("alpha") - metadata = registry.list_metadata() - - assert len(metadata) == 1 - assert metadata[0].registry_name == "beta" diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py deleted file mode 100644 index 533ee95e1a..0000000000 --- a/tests/unit/registry/test_base_instance_registry.py +++ /dev/null @@ -1,762 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import pytest - -from pyrit.models import ComponentIdentifier, Identifiable -from pyrit.registry.object_registries.base_instance_registry import ( - BaseInstanceRegistry, - RegistryEntry, -) - - -class _TestItem(Identifiable): - """Minimal Identifiable stub wrapping a string value for testing.""" - - def __init__(self, value: str) -> None: - self.value = value - - def _build_identifier(self) -> ComponentIdentifier: - return ComponentIdentifier( - class_name="_TestItem", - class_module="test", - params={"category": "test" if "test" in self.value.lower() else "other"}, - ) - - def __eq__(self, other: object) -> bool: - if isinstance(other, _TestItem): - return self.value == other.value - if isinstance(other, str): - return self.value == other - return NotImplemented - - def __hash__(self) -> int: - return hash(self.value) - - def __repr__(self) -> str: - return f"_TestItem({self.value!r})" - - -def _item(value: str) -> _TestItem: - """Shorthand factory for _TestItem.""" - return _TestItem(value) - - -class ConcreteTestRegistry(BaseInstanceRegistry["_TestItem"]): - """Concrete instance-holding registry (legacy base) used as a test double. - - Defines the direct-retrieval helpers (``get``/``get_entry``/ - ``get_all_instances``) so the shared ``BaseInstanceRegistry`` infrastructure - can be exercised through a retrievable surface. The canonical retrievable - implementation now lives on ``DefaultInstanceRegistry`` (the ``.instances`` - property); this double mirrors it only for these legacy-base tests. - """ - - def get(self, name: str) -> "_TestItem | None": - entry = self._registry_items.get(name) - return None if entry is None else entry.instance - - def get_entry(self, name: str) -> "RegistryEntry[_TestItem] | None": - return self._registry_items.get(name) - - def get_all_instances(self) -> "list[RegistryEntry[_TestItem]]": - return [self._registry_items[name] for name in sorted(self._registry_items.keys())] - - -class TestConcreteInstanceRegistrySingleton: - """Tests for the singleton pattern in the concrete instance registry.""" - - def setup_method(self): - """Reset the singleton before each test.""" - ConcreteTestRegistry.reset_instance() - - def teardown_method(self): - """Reset the singleton after each test.""" - ConcreteTestRegistry.reset_instance() - - def test_get_registry_singleton_returns_same_instance(self): - """Test that get_registry_singleton returns the same singleton each time.""" - instance1 = ConcreteTestRegistry.get_registry_singleton() - instance2 = ConcreteTestRegistry.get_registry_singleton() - - assert instance1 is instance2 - - def test_reset_instance_clears_singleton(self): - """Test that reset_instance clears the singleton.""" - instance1 = ConcreteTestRegistry.get_registry_singleton() - ConcreteTestRegistry.reset_instance() - instance2 = ConcreteTestRegistry.get_registry_singleton() - - assert instance1 is not instance2 - - def test_reset_instance_when_not_exists_does_not_raise(self): - """Test that reset_instance works even when no instance exists.""" - # Should not raise any exception - ConcreteTestRegistry.reset_instance() - ConcreteTestRegistry.reset_instance() - - -class TestConcreteInstanceRegistryRegistration: - """Tests for registration functionality in the concrete instance registry.""" - - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConcreteTestRegistry.reset_instance() - self.registry = ConcreteTestRegistry.get_registry_singleton() - - def teardown_method(self): - """Reset the singleton after each test.""" - ConcreteTestRegistry.reset_instance() - - def test_register_adds_instance(self): - """Test that register adds an instance to the registry.""" - self.registry.register(_item("test_value"), name="test_name") - - assert "test_name" in self.registry - assert self.registry.get("test_name") == "test_value" - - def test_register_multiple_instances(self): - """Test registering multiple instances.""" - self.registry.register(_item("value1"), name="name1") - self.registry.register(_item("value2"), name="name2") - self.registry.register(_item("value3"), name="name3") - - assert len(self.registry) == 3 - assert self.registry.get("name1") == "value1" - assert self.registry.get("name2") == "value2" - assert self.registry.get("name3") == "value3" - - def test_register_overwrites_existing(self): - """Test that registering with the same name overwrites the existing instance.""" - self.registry.register(_item("original"), name="name") - self.registry.register(_item("updated"), name="name") - - assert len(self.registry) == 1 - assert self.registry.get("name") == "updated" - - def test_register_invalidates_metadata_cache(self): - """Test that registering a new instance invalidates the metadata cache.""" - self.registry.register(_item("value1"), name="name1") - # Build cache by calling list_metadata - metadata1 = self.registry.list_metadata() - assert len(metadata1) == 1 - - # Register new instance - should invalidate cache - self.registry.register(_item("value2"), name="name2") - metadata2 = self.registry.list_metadata() - - assert len(metadata2) == 2 - - -class TestConcreteInstanceRegistryGet: - """Tests for get functionality in the concrete instance registry.""" - - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConcreteTestRegistry.reset_instance() - self.registry = ConcreteTestRegistry.get_registry_singleton() - self.registry.register(_item("test_value"), name="test_name") - - def teardown_method(self): - """Reset the singleton after each test.""" - ConcreteTestRegistry.reset_instance() - - def test_get_existing_instance(self): - """Test getting an existing instance by name.""" - result = self.registry.get("test_name") - assert result == "test_value" - - def test_get_nonexistent_returns_none(self): - """Test that getting a non-existent instance returns None.""" - result = self.registry.get("nonexistent") - assert result is None - - -class TestConcreteInstanceRegistryGetEntry: - """Tests for get_entry functionality in the concrete instance registry.""" - - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConcreteTestRegistry.reset_instance() - self.registry = ConcreteTestRegistry.get_registry_singleton() - self.registry.register(_item("test_value"), name="test_name", tags={"role": "scorer"}) - - def teardown_method(self): - """Reset the singleton after each test.""" - ConcreteTestRegistry.reset_instance() - - def test_get_entry_returns_registry_entry(self): - """Test that get_entry returns a RegistryEntry with correct fields.""" - entry = self.registry.get_entry("test_name") - assert entry is not None - assert isinstance(entry, RegistryEntry) - assert entry.name == "test_name" - assert entry.instance == "test_value" - assert entry.tags == {"role": "scorer"} - - def test_get_entry_nonexistent_returns_none(self): - """Test that get_entry returns None for a non-existent name.""" - result = self.registry.get_entry("nonexistent") - assert result is None - - -class TestConcreteInstanceRegistryGetNames: - """Tests for get_names functionality in the concrete instance registry.""" - - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConcreteTestRegistry.reset_instance() - self.registry = ConcreteTestRegistry.get_registry_singleton() - - def teardown_method(self): - """Reset the singleton after each test.""" - ConcreteTestRegistry.reset_instance() - - def test_get_names_empty_registry(self): - """Test get_names on an empty registry.""" - names = self.registry.get_names() - assert names == [] - - def test_get_names_returns_sorted_list(self): - """Test that get_names returns a sorted list of names.""" - self.registry.register(_item("value3"), name="zeta") - self.registry.register(_item("value1"), name="alpha") - self.registry.register(_item("value2"), name="beta") - - names = self.registry.get_names() - assert names == ["alpha", "beta", "zeta"] - - -class TestConcreteInstanceRegistryGetAllInstances: - """Tests for get_all_instances functionality in the concrete instance registry.""" - - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConcreteTestRegistry.reset_instance() - self.registry = ConcreteTestRegistry.get_registry_singleton() - - def teardown_method(self): - """Reset the singleton after each test.""" - ConcreteTestRegistry.reset_instance() - - def test_get_all_instances_returns_list_of_registry_entries(self): - """Test that get_all_instances returns a list of RegistryEntry objects.""" - self.registry.register(_item("value1"), name="name1") - self.registry.register(_item("value2"), name="name2") - - result = self.registry.get_all_instances() - assert isinstance(result, list) - assert len(result) == 2 - assert all(isinstance(entry, RegistryEntry) for entry in result) - - def test_get_all_instances_sorted_by_name(self): - """Test that get_all_instances returns entries sorted by name.""" - self.registry.register(_item("value_z"), name="zeta") - self.registry.register(_item("value_a"), name="alpha") - self.registry.register(_item("value_b"), name="beta") - - result = self.registry.get_all_instances() - assert [e.name for e in result] == ["alpha", "beta", "zeta"] - - def test_get_all_instances_preserves_tags(self): - """Test that get_all_instances preserves tags on entries.""" - self.registry.register(_item("value1"), name="name1", tags={"role": "scorer"}) - self.registry.register(_item("value2"), name="name2", tags=["fast"]) - - result = self.registry.get_all_instances() - entry_map = {e.name: e for e in result} - assert entry_map["name1"].tags == {"role": "scorer"} - assert entry_map["name2"].tags == {"fast": ""} - - def test_get_all_instances_empty_registry(self): - """Test that get_all_instances returns empty list on empty registry.""" - result = self.registry.get_all_instances() - assert result == [] - - -class TestConcreteInstanceRegistryListMetadata: - """Tests for list_metadata functionality in the concrete instance registry.""" - - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConcreteTestRegistry.reset_instance() - self.registry = ConcreteTestRegistry.get_registry_singleton() - self.registry.register(_item("test_item_1"), name="item1") - self.registry.register(_item("other_item_2"), name="item2") - self.registry.register(_item("test_item_3"), name="item3") - - def teardown_method(self): - """Reset the singleton after each test.""" - ConcreteTestRegistry.reset_instance() - - def test_list_metadata_returns_all_items(self): - """Test that list_metadata returns metadata for all items.""" - metadata = self.registry.list_metadata() - assert len(metadata) == 3 - - def test_list_metadata_sorted_by_name(self): - """Test that metadata is sorted by registry key order.""" - metadata = self.registry.list_metadata() - # Since unique_name is auto-computed, we verify we get 3 items in order - # The actual unique_name field is auto-computed from class_name::hash - assert len(metadata) == 3 - # All should have "_TestItem" in the unique_name since class_name is "_TestItem" - for m in metadata: - assert "_TestItem" in m.unique_name - - def test_list_metadata_with_filter(self): - """Test filtering metadata by a field.""" - metadata = self.registry.list_metadata(include_filters={"category": "test"}) - assert len(metadata) == 2 - assert all(m.params["category"] == "test" for m in metadata) - - def test_list_metadata_filter_no_match(self): - """Test filtering with no matches returns empty list.""" - metadata = self.registry.list_metadata(include_filters={"category": "nonexistent"}) - assert metadata == [] - - def test_list_metadata_with_exclude_filter(self): - """Test excluding metadata by a field.""" - metadata = self.registry.list_metadata(exclude_filters={"category": "test"}) - assert len(metadata) == 1 - assert all(m.params["category"] == "other" for m in metadata) - - def test_list_metadata_combined_include_and_exclude(self): - """Test combined include and exclude filters.""" - # Add another test item to have more variety - self.registry.register(_item("another_test_item"), name="item4") - - # Get items with category "test" but exclude by class_name "str" - # Since all have class_name="str", excluding by class_name would exclude all - # Instead, test with category filters - metadata = self.registry.list_metadata(include_filters={"category": "test"}) - assert len(metadata) == 3 # item1, item3, item4 (all have "test" in value) - assert all(m.params["category"] == "test" for m in metadata) - - def test_list_metadata_caching(self): - """Test that metadata is cached after first call.""" - # First call builds cache - metadata1 = self.registry.list_metadata() - # Second call uses cache - metadata2 = self.registry.list_metadata() - - # Should be the same list object (cached) - assert metadata1 is metadata2 - assert len(metadata1) == 3 - - -class TestConcreteInstanceRegistryTags: - """Tests for tag registration and retrieval in the concrete instance registry.""" - - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConcreteTestRegistry.reset_instance() - self.registry = ConcreteTestRegistry.get_registry_singleton() - - def teardown_method(self): - """Reset the singleton after each test.""" - ConcreteTestRegistry.reset_instance() - - def test_register_with_dict_tags(self): - """Test that dict tags are stored correctly.""" - self.registry.register(_item("value"), name="name1", tags={"role": "scorer", "provider": "azure"}) - - entry = self.registry.get_entry("name1") - assert entry is not None - assert entry.tags == {"role": "scorer", "provider": "azure"} - - def test_register_with_list_tags(self): - """Test that list tags are normalized to dict with empty string values.""" - self.registry.register(_item("value"), name="name1", tags=["fast", "default"]) - - entry = self.registry.get_entry("name1") - assert entry is not None - assert entry.tags == {"fast": "", "default": ""} - - def test_register_without_tags(self): - """Test that registering without tags defaults to empty dict.""" - self.registry.register(_item("value"), name="name1") - - entry = self.registry.get_entry("name1") - assert entry is not None - assert entry.tags == {} - - def test_get_by_tag_key_only(self): - """Test get_by_tag matching by key only (any value).""" - self.registry.register(_item("v1"), name="n1", tags={"role": "scorer"}) - self.registry.register(_item("v2"), name="n2", tags={"role": "target"}) - self.registry.register(_item("v3"), name="n3", tags={"provider": "azure"}) - - results = self.registry.get_by_tag(tag="role") - assert len(results) == 2 - assert {e.name for e in results} == {"n1", "n2"} - - def test_get_by_tag_key_and_value(self): - """Test get_by_tag matching by key and specific value.""" - self.registry.register(_item("v1"), name="n1", tags={"role": "scorer"}) - self.registry.register(_item("v2"), name="n2", tags={"role": "target"}) - self.registry.register(_item("v3"), name="n3", tags={"role": "scorer"}) - - results = self.registry.get_by_tag(tag="role", value="scorer") - assert len(results) == 2 - assert {e.name for e in results} == {"n1", "n3"} - - def test_get_by_tag_no_match(self): - """Test get_by_tag returns empty list when no entries match.""" - self.registry.register(_item("v1"), name="n1", tags={"role": "scorer"}) - - results = self.registry.get_by_tag(tag="nonexistent") - assert results == [] - - def test_get_by_tag_value_no_match(self): - """Test get_by_tag returns empty when key exists but value does not match.""" - self.registry.register(_item("v1"), name="n1", tags={"role": "scorer"}) - - results = self.registry.get_by_tag(tag="role", value="nonexistent") - assert results == [] - - def test_get_by_tag_returns_sorted_by_name(self): - """Test that get_by_tag results are sorted by name.""" - self.registry.register(_item("v3"), name="zeta", tags=["shared"]) - self.registry.register(_item("v1"), name="alpha", tags=["shared"]) - self.registry.register(_item("v2"), name="beta", tags=["shared"]) - - results = self.registry.get_by_tag(tag="shared") - assert [e.name for e in results] == ["alpha", "beta", "zeta"] - - def test_get_by_tag_with_list_tags(self): - """Test get_by_tag works with list-style tags (normalized to empty string values).""" - self.registry.register(_item("v1"), name="n1", tags=["fast", "default"]) - self.registry.register(_item("v2"), name="n2", tags=["slow"]) - - results = self.registry.get_by_tag(tag="fast") - assert len(results) == 1 - assert results[0].name == "n1" - - def test_get_by_tag_with_list_tags_value_empty_string(self): - """Test get_by_tag with explicit empty string value matches list-style tags.""" - self.registry.register(_item("v1"), name="n1", tags=["fast"]) - - results = self.registry.get_by_tag(tag="fast", value="") - assert len(results) == 1 - assert results[0].name == "n1" - - def test_normalize_tags_none(self): - """Test _normalize_tags returns empty dict for None.""" - assert BaseInstanceRegistry._normalize_tags(None) == {} - - def test_normalize_tags_list(self): - """Test _normalize_tags converts list to dict with empty values.""" - assert BaseInstanceRegistry._normalize_tags(["a", "b"]) == {"a": "", "b": ""} - - def test_normalize_tags_dict(self): - """Test _normalize_tags returns a copy of the dict.""" - original = {"key": "val"} - result = BaseInstanceRegistry._normalize_tags(original) - assert result == {"key": "val"} - assert result is not original - - -class TestConcreteInstanceRegistryDunderMethods: - """Tests for dunder methods (__contains__, __len__, __iter__) in the concrete instance registry.""" - - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConcreteTestRegistry.reset_instance() - self.registry = ConcreteTestRegistry.get_registry_singleton() - self.registry.register(_item("value1"), name="name1") - self.registry.register(_item("value2"), name="name2") - - def teardown_method(self): - """Reset the singleton after each test.""" - ConcreteTestRegistry.reset_instance() - - def test_contains_existing_name(self): - """Test __contains__ returns True for existing name.""" - assert "name1" in self.registry - assert "name2" in self.registry - - def test_contains_nonexistent_name(self): - """Test __contains__ returns False for non-existent name.""" - assert "nonexistent" not in self.registry - - def test_len_returns_count(self): - """Test __len__ returns the correct count.""" - assert len(self.registry) == 2 - - def test_len_empty_registry(self): - """Test __len__ returns 0 for empty registry.""" - ConcreteTestRegistry.reset_instance() - empty_registry = ConcreteTestRegistry.get_registry_singleton() - assert len(empty_registry) == 0 - - def test_iter_returns_sorted_names(self): - """Test __iter__ returns names in sorted order.""" - names = list(self.registry) - assert names == ["name1", "name2"] - - def test_iter_allows_for_loop(self): - """Test that the registry can be used in a for loop.""" - collected = list(self.registry) - assert collected == ["name1", "name2"] - - -class _ItemOnlyRegistry(BaseInstanceRegistry["_TestItem"]): - """Concrete BaseInstanceRegistry subclass — should NOT have get/get_entry/get_all_instances.""" - - -class TestBaseInstanceRegistryDoesNotExposeInstanceMethods: - """Verify that BaseInstanceRegistry subclasses lack instance-retrieval methods.""" - - def test_item_registry_has_no_get(self): - """BaseInstanceRegistry subclasses should not have a get() method.""" - assert not hasattr(_ItemOnlyRegistry, "get") - - def test_item_registry_has_no_get_entry(self): - """BaseInstanceRegistry subclasses should not have a get_entry() method.""" - assert not hasattr(_ItemOnlyRegistry, "get_entry") - - def test_item_registry_has_no_get_all_instances(self): - """BaseInstanceRegistry subclasses should not have a get_all_instances() method.""" - assert not hasattr(_ItemOnlyRegistry, "get_all_instances") - - def test_item_registry_shares_common_methods(self): - """BaseInstanceRegistry subclasses should have shared registry methods.""" - for method in ( - "register", - "get_names", - "get_by_tag", - "add_tags", - "list_metadata", - "find_dependents_of_tag", - "get_registry_singleton", - "reset_instance", - ): - assert hasattr(_ItemOnlyRegistry, method), f"Missing method: {method}" - - -class TestConcreteInstanceRegistryAddTags: - """Tests for add_tags functionality in the concrete instance registry.""" - - def setup_method(self): - """Reset and get a fresh registry for each test.""" - ConcreteTestRegistry.reset_instance() - self.registry = ConcreteTestRegistry.get_registry_singleton() - - def teardown_method(self): - """Reset the singleton after each test.""" - ConcreteTestRegistry.reset_instance() - - def test_add_tags_with_list(self): - """Test adding list-style tags to an existing entry.""" - self.registry.register(_item("value"), name="entry1") - self.registry.add_tags(name="entry1", tags=["fast", "default"]) - - entry = self.registry.get_entry("entry1") - assert entry is not None - assert entry.tags == {"fast": "", "default": ""} - - def test_add_tags_with_dict(self): - """Test adding dict-style tags to an existing entry.""" - self.registry.register(_item("value"), name="entry1") - self.registry.add_tags(name="entry1", tags={"role": "scorer"}) - - entry = self.registry.get_entry("entry1") - assert entry is not None - assert entry.tags == {"role": "scorer"} - - def test_add_tags_merges_with_existing(self): - """Test that add_tags merges new tags with existing ones.""" - self.registry.register(_item("value"), name="entry1", tags={"existing": "yes"}) - self.registry.add_tags(name="entry1", tags=["new_tag"]) - - entry = self.registry.get_entry("entry1") - assert entry is not None - assert entry.tags == {"existing": "yes", "new_tag": ""} - - def test_add_tags_raises_for_missing_entry(self): - """Test that add_tags raises KeyError for a non-existent entry.""" - with pytest.raises(KeyError, match="No entry named 'missing'"): - self.registry.add_tags(name="missing", tags=["tag"]) - - def test_add_tags_invalidates_metadata_cache(self): - """Test that add_tags invalidates the metadata cache.""" - self.registry.register(_item("value"), name="entry1") - self.registry.list_metadata() # Build cache - - self.registry.add_tags(name="entry1", tags=["new"]) - - # Cache should be invalidated (None), next call rebuilds - assert self.registry._metadata_cache is None - - def test_add_tags_entries_findable_by_get_by_tag(self): - """Test that entries are findable via get_by_tag after add_tags.""" - self.registry.register(_item("value"), name="entry1") - self.registry.add_tags(name="entry1", tags=["best_scorer"]) - - results = self.registry.get_by_tag(tag="best_scorer") - assert len(results) == 1 - assert results[0].name == "entry1" - - -class _IdentifiableStub(Identifiable): - """A minimal stub that holds a ComponentIdentifier for dependency tests.""" - - def __init__(self, identifier: ComponentIdentifier) -> None: - self._stored_identifier = identifier - - def _build_identifier(self) -> ComponentIdentifier: - return self._stored_identifier - - -class IdentifierTestRegistry(BaseInstanceRegistry["_IdentifiableStub"]): - """Registry for testing dependency-related functionality with ComponentIdentifier trees.""" - - -class TestFindDependentsOfTag: - """Tests for BaseInstanceRegistry.find_dependents_of_tag.""" - - def setup_method(self) -> None: - IdentifierTestRegistry.reset_instance() - self.registry = IdentifierTestRegistry.get_registry_singleton() - - def teardown_method(self) -> None: - IdentifierTestRegistry.reset_instance() - - def test_no_tagged_entries_returns_empty(self) -> None: - """Test that when no entries have the tag, an empty list is returned.""" - stub = _IdentifiableStub(ComponentIdentifier(class_name="A", class_module="mod")) - self.registry.register(stub, name="a") - assert self.registry.find_dependents_of_tag(tag="refusal") == [] - - def test_tagged_entry_not_returned_as_dependent(self) -> None: - """Test that an entry tagged with the tag is not returned as a dependent of itself.""" - stub = _IdentifiableStub(ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="r1")) - self.registry.register(stub, name="refusal_scorer", tags=["refusal"]) - assert self.registry.find_dependents_of_tag(tag="refusal") == [] - - def test_dependent_found_by_child_eval_hash(self) -> None: - """Test that an entry whose child matches a tagged entry's eval_hash is found.""" - # Base scorer (tagged) - base_id = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="r_hash") - self.registry.register(_IdentifiableStub(base_id), name="refusal_scorer", tags=["refusal"]) - - # Wrapper scorer (child references the base scorer) - child_id = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="r_hash") - wrapper_id = ComponentIdentifier( - class_name="Inverter", - class_module="mod", - eval_hash="w_hash", - children={"sub_scorers": [child_id]}, - ) - self.registry.register(_IdentifiableStub(wrapper_id), name="inverter") - - dependents = self.registry.find_dependents_of_tag(tag="refusal") - assert len(dependents) == 1 - assert dependents[0].name == "inverter" - - def test_non_dependent_not_returned(self) -> None: - """Test that entries without matching child eval_hash are not returned.""" - base_id = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="r_hash") - self.registry.register(_IdentifiableStub(base_id), name="refusal_scorer", tags=["refusal"]) - - # Unrelated scorer (no children matching r_hash) - unrelated_id = ComponentIdentifier(class_name="Likert", class_module="mod", eval_hash="l_hash") - self.registry.register(_IdentifiableStub(unrelated_id), name="likert") - - assert self.registry.find_dependents_of_tag(tag="refusal") == [] - - def test_deeply_nested_dependency_found(self) -> None: - """Test that a deeply nested child eval_hash still triggers a match.""" - base_id = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="deep_r") - self.registry.register(_IdentifiableStub(base_id), name="refusal_scorer", tags=["refusal"]) - - # Composite with nested child - inner_child = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="deep_r") - inverter = ComponentIdentifier( - class_name="Inverter", - class_module="mod", - children={"sub_scorers": [inner_child]}, - ) - composite_id = ComponentIdentifier( - class_name="Composite", - class_module="mod", - children={"sub_scorers": [inverter]}, - ) - self.registry.register(_IdentifiableStub(composite_id), name="composite") - - dependents = self.registry.find_dependents_of_tag(tag="refusal") - assert len(dependents) == 1 - assert dependents[0].name == "composite" - - def test_multiple_dependents_returned_sorted(self) -> None: - """Test that multiple dependents are returned sorted by name.""" - base_id = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="r1") - self.registry.register(_IdentifiableStub(base_id), name="refusal_scorer", tags=["refusal"]) - - child = ComponentIdentifier(class_name="Refusal", class_module="mod", eval_hash="r1") - for wrapper_name in ["z_wrapper", "a_wrapper", "m_wrapper"]: - wrapper_id = ComponentIdentifier( - class_name="Wrapper", - class_module="mod", - children={"sub_scorers": [child]}, - ) - self.registry.register(_IdentifiableStub(wrapper_id), name=wrapper_name) - - dependents = self.registry.find_dependents_of_tag(tag="refusal") - assert [d.name for d in dependents] == ["a_wrapper", "m_wrapper", "z_wrapper"] - - def test_tagged_entries_without_eval_hash_returns_empty(self) -> None: - """Test that tagged entries without eval_hash yield no dependents.""" - base_id = ComponentIdentifier(class_name="Refusal", class_module="mod") - self.registry.register(_IdentifiableStub(base_id), name="refusal_scorer", tags=["refusal"]) - - child = ComponentIdentifier(class_name="Refusal", class_module="mod") - wrapper_id = ComponentIdentifier( - class_name="Wrapper", - class_module="mod", - children={"sub_scorers": [child]}, - ) - self.registry.register(_IdentifiableStub(wrapper_id), name="wrapper") - - assert self.registry.find_dependents_of_tag(tag="refusal") == [] - - -class TestConcreteInstanceRegistryMetadataField: - """Tests for the metadata field on RegistryEntry.""" - - def setup_method(self): - """Get a fresh registry instance.""" - ConcreteTestRegistry.reset_instance() - self.registry = ConcreteTestRegistry.get_registry_singleton() - - def teardown_method(self): - """Reset the singleton after each test.""" - ConcreteTestRegistry.reset_instance() - - def test_register_with_metadata_stores_it(self): - """Test that metadata dict is stored on the entry.""" - self.registry.register(_item("v1"), name="n1", metadata={"accepts_scorer_override": False, "priority": 5}) - - entry = self.registry.get_entry("n1") - assert entry is not None - assert entry.metadata == {"accepts_scorer_override": False, "priority": 5} - - def test_register_without_metadata_defaults_to_empty_dict(self): - """Test that registering without metadata defaults to empty dict.""" - self.registry.register(_item("v1"), name="n1") - - entry = self.registry.get_entry("n1") - assert entry is not None - assert entry.metadata == {} - - def test_metadata_does_not_affect_tags(self): - """Metadata and tags are independent.""" - self.registry.register(_item("v1"), name="n1", tags=["fast"], metadata={"key": "value"}) - - entry = self.registry.get_entry("n1") - assert entry is not None - assert entry.tags == {"fast": ""} - assert entry.metadata == {"key": "value"} - # Metadata keys don't appear in tag queries - assert self.registry.get_by_tag(tag="key") == [] diff --git a/tests/unit/setup/test_targets_initializer.py b/tests/unit/setup/test_targets_initializer.py index f7dc4ddf78..b81e1851d0 100644 --- a/tests/unit/setup/test_targets_initializer.py +++ b/tests/unit/setup/test_targets_initializer.py @@ -216,7 +216,7 @@ def test_target_configs_have_unique_registry_names(self): """Guard against typos: every ``registry_name`` in ``ENV_TARGET_CONFIGS`` must be unique. Duplicate names would silently overwrite each other when - ``TargetInitializer`` registers them (per ``BaseInstanceRegistry.register`` + ``TargetInitializer`` registers them (per instance-registry ``register`` semantics, characterized in ``test_target_registry.py``). Only the second entry would survive in the registry, which breaks downstream scenarios that resolve targets by name (e.g. ``AdversarialBenchmark``'s From 779ed21c4e127b7f7e1d3a43b6e22f868d40af00 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 2 Jul 2026 17:44:04 -0700 Subject: [PATCH 07/11] Rename ClassRegistryEntry to RegistryMetadata and delete base.py Move the shared registry metadata base out of base.py into its own registry_metadata.py module, renamed from ClassRegistryEntry to the clearer RegistryMetadata. All six component metadata classes plus scenario/initializer description helpers now extend/reference it from the new module. With that moved, base.py held only dead weight: RegistryProtocol had no production implementers after the legacy-registry deletion, and _matches_filters/ _get_metadata_value were byte-identical duplicates of the copies in registry.py. instance_registry now uses registry.py's _matches_filters, RegistryProtocol is dropped from the public API (and the registry doc updated), and base.py is deleted. Also trim discovery.py: discover_in_directory is still used by the initializer registry, but discover_in_package and discover_subclasses_in_loaded_modules were unused exports and are removed. Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- doc/code/registry/0_registry.md | 13 +- pyrit/registry/__init__.py | 12 +- pyrit/registry/base.py | 239 ------------------ .../components/attack_technique_registry.py | 6 +- .../registry/components/converter_registry.py | 4 +- .../components/initializer_registry.py | 6 +- .../registry/components/scenario_registry.py | 6 +- pyrit/registry/components/scorer_registry.py | 4 +- pyrit/registry/components/target_registry.py | 4 +- pyrit/registry/discovery.py | 111 +------- pyrit/registry/instance_registry.py | 2 +- pyrit/registry/registry.py | 8 +- pyrit/registry/registry_metadata.py | 88 +++++++ pyrit/scenario/core/scenario.py | 4 +- pyrit/setup/initializers/pyrit_initializer.py | 4 +- tests/unit/registry/test_registry.py | 18 +- ...test_base.py => test_registry_metadata.py} | 33 +-- 17 files changed, 150 insertions(+), 412 deletions(-) delete mode 100644 pyrit/registry/base.py create mode 100644 pyrit/registry/registry_metadata.py rename tests/unit/registry/{test_base.py => test_registry_metadata.py} (88%) diff --git a/doc/code/registry/0_registry.md b/doc/code/registry/0_registry.md index b29e0fa5fe..58998a56aa 100644 --- a/doc/code/registry/0_registry.md +++ b/doc/code/registry/0_registry.md @@ -18,9 +18,9 @@ PyRIT has two registry patterns for different use cases: | **Class Registry** | Classes (type[T]) | Components instantiated with user-provided parameters | | **Instance Registry** | Pre-configured instances | Components requiring complex setup before use | -## Common API (RegistryProtocol) +## Common API -Both registry types implement `RegistryProtocol`, sharing a consistent interface: +Registries share a consistent interface for discovery and introspection: | Method | Description | |--------|-------------| @@ -29,14 +29,17 @@ Both registry types implement `RegistryProtocol`, sharing a consistent interface | `list_metadata()` | Get descriptive metadata for all items | | `reset_instance()` | Reset the singleton (useful for testing) | -This protocol enables writing code that works with any registry type: +This makes it easy to write code that inspects any registry: ```python -from pyrit.registry import RegistryProtocol +from pyrit.registry import ScenarioRegistry -def show_registry_contents(registry: RegistryProtocol) -> None: +def show_registry_contents(registry) -> None: for name in registry.get_names(): print(name) + + +show_registry_contents(ScenarioRegistry.get_registry_singleton()) ``` diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 0cabdbd0e1..7c3ff3e7fc 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -3,7 +3,6 @@ """Registry module for PyRIT class and object registries.""" -from pyrit.registry.base import RegistryProtocol from pyrit.registry.components import ( AttackTechniqueMetadata, AttackTechniqueRegistry, @@ -18,11 +17,7 @@ TargetMetadata, TargetRegistry, ) -from pyrit.registry.discovery import ( - discover_in_directory, - discover_in_package, - discover_subclasses_in_loaded_modules, -) +from pyrit.registry.discovery import discover_in_directory from pyrit.registry.instance_registry import ( DefaultInstanceRegistry, InstanceRegistry, @@ -30,6 +25,7 @@ SupportsInstances, ) from pyrit.registry.registry import Registry +from pyrit.registry.registry_metadata import RegistryMetadata from pyrit.registry.tag_query import TagQuery __all__ = [ @@ -40,14 +36,12 @@ "DefaultInstanceRegistry", "InstanceRegistry", "Registry", + "RegistryMetadata", "SupportsInstances", "discover_in_directory", - "discover_in_package", - "discover_subclasses_in_loaded_modules", "InitializerMetadata", "InitializerRegistry", "RegistryEntry", - "RegistryProtocol", "ScenarioMetadata", "ScenarioRegistry", "ScorerRegistry", diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py deleted file mode 100644 index 7cbd5bfd32..0000000000 --- a/pyrit/registry/base.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Shared base types for PyRIT registries. - -This module contains types shared between class registries (which store type[T]) -and object registries (which store T instances). -""" - -from __future__ import annotations - -import inspect -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable - -if TYPE_CHECKING: - from collections.abc import Iterator, Mapping - - from typing_extensions import Self - - from pyrit.models.parameter import Parameter - -# Type variable for metadata (invariant for Protocol compatibility) -MetadataT = TypeVar("MetadataT") - - -@dataclass(frozen=True) -class ClassRegistryEntry: - """ - Minimal base for class-level registry metadata. - - Provides the common fields every registry metadata type needs for display, - lookup, and filtering in class registries. - - Attributes: - class_name (str): Python class name (e.g., "ContentHarmsScenario"). - class_module (str): Full module path (e.g., "pyrit.scenario.scenarios.content_harms"). - class_description (str): Human-readable description, typically from the class docstring. - registry_name (str): The suffix-stripped snake_case key used in the registry - (e.g., "content_harms" for ContentHarmsScenario). - parameters (tuple[Parameter, ...]): The derived build contract for the class. - Buildable registries (e.g. converters) populate this from the constructor - signature; scenarios/initializers use their own ``supported_parameters`` - today and will migrate to this unified shape. - class_attributes (Mapping[str, Any]): Values sourced from class attributes - (declared on the identifier via ``Param.ClassAttr``), letting the entry - describe class-level facts — e.g. a converter's supported input/output - types — without constructing an instance. Empty for entries with none. - """ - - class_name: str - class_module: str - class_description: str = "" - registry_name: str = "" - parameters: tuple[Parameter, ...] = field(kw_only=True, default=()) - class_attributes: Mapping[str, Any] = field(kw_only=True, default_factory=dict) - - @staticmethod - def description_from_docstring(cls: type, *, fallback: str = "") -> str: - """ - Extract a normalized description from a class docstring. - - Collapses all whitespace into single spaces. Returns fallback if - no docstring is present or the docstring is empty after cleaning. - - Returns: - str: The cleaned docstring or the fallback value. - """ - doc = cls.__doc__ or "" - cleaned = " ".join(doc.split()) - return cleaned or fallback - - @staticmethod - def summary_from_docstring(cls: type) -> str: - """ - Extract a short summary from the first paragraph of a class docstring. - - Uses the class's own docstring only (never an inherited one), normalizes - indentation, and collapses the first paragraph's whitespace onto one line. - Empty when the class has no docstring. This is the catalog-display - counterpart to ``description_from_docstring`` (which collapses the whole - docstring); buildable registries populate ``class_description`` from this - first-paragraph form. - - Returns: - str: The first-paragraph summary, or "" when there is no docstring. - """ - raw = cls.__doc__ - if not raw: - return "" - first_paragraph = inspect.cleandoc(raw).split("\n\n", 1)[0] - return " ".join(first_paragraph.split()) - - -@runtime_checkable -class RegistryProtocol(Protocol[MetadataT]): - """ - Protocol defining the common interface for all registries. - - Registries implement this interface, enabling code that works with any - registry type. - - Type Parameters: - MetadataT: The metadata dataclass type (e.g., ScenarioMetadata). - """ - - @classmethod - def get_registry_singleton(cls) -> Self: - """Get the singleton instance of this registry.""" - ... - - @classmethod - def reset_instance(cls) -> None: - """Reset the singleton instance.""" - ... - - def get_names(self) -> list[str]: - """Get a sorted list of all registered names.""" - ... - - def list_metadata( - self, - *, - include_filters: dict[str, Any] | None = None, - exclude_filters: dict[str, Any] | None = None, - ) -> list[MetadataT]: - """ - List metadata for all registered items, optionally filtered. - - Args: - include_filters: Optional dict of filters that items must match. - Keys are metadata property names, values are the filter criteria. - All filters must match (AND logic). - exclude_filters: Optional dict of filters that items must NOT match. - Keys are metadata property names, values are the filter criteria. - Any matching filter excludes the item. - - Returns: - List of metadata describing each registered item. - """ - ... - - def __contains__(self, name: str) -> bool: - """Check if a name is registered.""" - ... - - def __len__(self) -> int: - """Get the count of registered items.""" - ... - - def __iter__(self) -> Iterator[str]: - """Iterate over registered names.""" - ... - - -def _get_metadata_value(metadata: Any, key: str) -> tuple[bool, Any]: - """ - Get a value from a metadata object by key. - - Checks direct attributes first, then falls back to the ``params`` dict - (used by ComponentIdentifier). Returns a (found, value) tuple. - - Args: - metadata: The metadata object to look up. - key (str): The attribute or params key to find. - - Returns: - tuple: (True, value) if found, (False, None) otherwise. - """ - if hasattr(metadata, key): - return True, getattr(metadata, key) - - # Fall back to params dict (for ComponentIdentifier) - params = getattr(metadata, "params", None) - if isinstance(params, dict) and key in params: - return True, params[key] - - return False, None - - -def _matches_filters( - metadata: Any, - *, - include_filters: dict[str, Any] | None = None, - exclude_filters: dict[str, Any] | None = None, -) -> bool: - """ - Check if a metadata object matches all provided filters. - - Supports filtering on any property of the metadata dataclass or on keys - inside the ``params`` dict (for ComponentIdentifier metadata): - - For simple types (str, int, bool): exact match comparison - - For sequence types (list, tuple): checks if filter value is contained in the sequence - - Items must match ALL include_filters (AND logic) and must NOT match ANY exclude_filters. - - Args: - metadata: The metadata dataclass instance to check. - include_filters: Optional dict of filters that must ALL match. - Keys are metadata property names or params keys, values are the filter criteria. - exclude_filters: Optional dict of filters that must ALL NOT match. - Keys are metadata property names or params keys, values are the filter criteria. - - Returns: - True if all include_filters match and no exclude_filters match, False otherwise. - """ - # Check include filters - all must match - if include_filters: - for key, filter_value in include_filters.items(): - found, actual_value = _get_metadata_value(metadata, key) - if not found: - return False - - # Handle sequence types - check if filter value is in the sequence - if isinstance(actual_value, (list, tuple)): - if filter_value not in actual_value: - return False - # Simple exact match for other types - elif actual_value != filter_value: - return False - - # Check exclude filters - none must match - if exclude_filters: - for key, filter_value in exclude_filters.items(): - found, actual_value = _get_metadata_value(metadata, key) - if not found: - # If the key doesn't exist, it can't match the exclude filter - continue - - # Handle sequence types - exclude if filter value is in the sequence - if isinstance(actual_value, (list, tuple)): - if filter_value in actual_value: - return False - # Simple exact match for other types - exclude if it matches - elif actual_value == filter_value: - return False - - return True diff --git a/pyrit/registry/components/attack_technique_registry.py b/pyrit/registry/components/attack_technique_registry.py index 97f9ce8a04..affcd2cfc9 100644 --- a/pyrit/registry/components/attack_technique_registry.py +++ b/pyrit/registry/components/attack_technique_registry.py @@ -24,9 +24,9 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from pyrit.registry.base import ClassRegistryEntry from pyrit.registry.instance_registry import DefaultInstanceRegistry, InstanceRegistry from pyrit.registry.registry import Registry +from pyrit.registry.registry_metadata import RegistryMetadata if TYPE_CHECKING: from pyrit.registry.tag_query import TagQuery @@ -55,13 +55,13 @@ def _attack_technique_factory_type() -> type[AttackTechniqueFactory]: @dataclass(frozen=True) -class AttackTechniqueMetadata(ClassRegistryEntry): +class AttackTechniqueMetadata(RegistryMetadata): """ Metadata describing a registered attack-technique class. Placeholder for the buildable catalog, which is intentionally empty until the factory is decoupled into a buildable component. It carries only the common - ``ClassRegistryEntry`` fields today; technique-specific fields are added when + ``RegistryMetadata`` fields today; technique-specific fields are added when the catalog is lit up. """ diff --git a/pyrit/registry/components/converter_registry.py b/pyrit/registry/components/converter_registry.py index eb8bf68964..e3598e6e29 100644 --- a/pyrit/registry/components/converter_registry.py +++ b/pyrit/registry/components/converter_registry.py @@ -28,9 +28,9 @@ from pyrit.models.identifiers import ConverterIdentifier from pyrit.models.parameter import ComponentType -from pyrit.registry.base import ClassRegistryEntry from pyrit.registry.instance_registry import DefaultInstanceRegistry, InstanceRegistry from pyrit.registry.registry import Registry +from pyrit.registry.registry_metadata import RegistryMetadata if TYPE_CHECKING: from types import ModuleType @@ -39,7 +39,7 @@ @dataclass(frozen=True) -class ConverterMetadata(ClassRegistryEntry): +class ConverterMetadata(RegistryMetadata): """ Metadata describing a registered ``PromptConverter`` class. diff --git a/pyrit/registry/components/initializer_registry.py b/pyrit/registry/components/initializer_registry.py index 4bfaaebb3e..37f1a1c906 100644 --- a/pyrit/registry/components/initializer_registry.py +++ b/pyrit/registry/components/initializer_registry.py @@ -25,9 +25,9 @@ from typing import TYPE_CHECKING, Any from pyrit.models import class_name_to_snake_case, validate_registry_name -from pyrit.registry.base import ClassRegistryEntry from pyrit.registry.discovery import discover_in_directory from pyrit.registry.registry import Registry +from pyrit.registry.registry_metadata import RegistryMetadata # Compute PYRIT_PATH directly to avoid importing pyrit package # (which triggers heavy imports from __init__.py) @@ -45,7 +45,7 @@ @dataclass(frozen=True) -class InitializerMetadata(ClassRegistryEntry): +class InitializerMetadata(RegistryMetadata): """ Metadata describing a registered PyRITInitializer class. @@ -228,7 +228,7 @@ def _build_metadata(self, name: str, cls: type[PyRITInitializer]) -> Initializer Returns: InitializerMetadata describing the initializer class. """ - description = ClassRegistryEntry.description_from_docstring(cls, fallback="No description available") + description = RegistryMetadata.description_from_docstring(cls, fallback="No description available") try: instance = cls() diff --git a/pyrit/registry/components/scenario_registry.py b/pyrit/registry/components/scenario_registry.py index 2f28b87fe0..062cb56e05 100644 --- a/pyrit/registry/components/scenario_registry.py +++ b/pyrit/registry/components/scenario_registry.py @@ -19,8 +19,8 @@ from pyrit.models import class_name_to_snake_case from pyrit.models.identifiers.scenario_identifier import ScenarioIdentifier -from pyrit.registry.base import ClassRegistryEntry from pyrit.registry.registry import Registry +from pyrit.registry.registry_metadata import RegistryMetadata if TYPE_CHECKING: from types import ModuleType @@ -31,7 +31,7 @@ @dataclass(frozen=True) -class ScenarioMetadata(ClassRegistryEntry): +class ScenarioMetadata(RegistryMetadata): """ Metadata describing a registered Scenario class. @@ -128,7 +128,7 @@ def _build_metadata(self, name: str, cls: type[Scenario]) -> ScenarioMetadata: Raises: TypeError: If ``cls()`` cannot be called with no arguments. """ - description = ClassRegistryEntry.description_from_docstring(cls, fallback="No description available") + description = RegistryMetadata.description_from_docstring(cls, fallback="No description available") supported_parameters = tuple(cls.supported_parameters()) diff --git a/pyrit/registry/components/scorer_registry.py b/pyrit/registry/components/scorer_registry.py index 84f195091e..f6277da954 100644 --- a/pyrit/registry/components/scorer_registry.py +++ b/pyrit/registry/components/scorer_registry.py @@ -29,9 +29,9 @@ from pyrit.models.identifiers import ScorerIdentifier from pyrit.models.parameter import ComponentType -from pyrit.registry.base import ClassRegistryEntry from pyrit.registry.instance_registry import DefaultInstanceRegistry, InstanceRegistry from pyrit.registry.registry import Registry +from pyrit.registry.registry_metadata import RegistryMetadata if TYPE_CHECKING: from types import ModuleType @@ -40,7 +40,7 @@ @dataclass(frozen=True) -class ScorerMetadata(ClassRegistryEntry): +class ScorerMetadata(RegistryMetadata): """ Metadata describing a registered ``Scorer`` class. diff --git a/pyrit/registry/components/target_registry.py b/pyrit/registry/components/target_registry.py index 01e3c4e4dd..8233770127 100644 --- a/pyrit/registry/components/target_registry.py +++ b/pyrit/registry/components/target_registry.py @@ -27,9 +27,9 @@ from typing import TYPE_CHECKING from pyrit.models.identifiers import TargetIdentifier -from pyrit.registry.base import ClassRegistryEntry from pyrit.registry.instance_registry import DefaultInstanceRegistry, InstanceRegistry from pyrit.registry.registry import Registry +from pyrit.registry.registry_metadata import RegistryMetadata if TYPE_CHECKING: from types import ModuleType @@ -38,7 +38,7 @@ @dataclass(frozen=True) -class TargetMetadata(ClassRegistryEntry): +class TargetMetadata(RegistryMetadata): """ Metadata describing a registered ``PromptTarget`` class. diff --git a/pyrit/registry/discovery.py b/pyrit/registry/discovery.py index 34c1562bc3..91c5015272 100644 --- a/pyrit/registry/discovery.py +++ b/pyrit/registry/discovery.py @@ -8,12 +8,10 @@ used by registries to find and register items automatically. """ -import importlib import importlib.util import inspect import logging -import pkgutil -from collections.abc import Callable, Iterator +from collections.abc import Iterator from pathlib import Path from typing import TypeVar @@ -84,110 +82,3 @@ def _process_file(*, file_path: Path, base_class: type[T]) -> Iterator[tuple[str except Exception as e: logger.warning(f"Failed to load module from {file_path}: {e}") - - -def discover_in_package( - *, - package_path: Path, - package_name: str, - base_class: type[T], - recursive: bool = True, - name_builder: Callable[[str, str], str] | None = None, - _prefix: str = "", -) -> Iterator[tuple[str, type[T]]]: - """ - Discover all subclasses using pkgutil.iter_modules on a package. - - This function uses Python's package infrastructure to discover modules, - making it suitable for discovering classes in installed packages. - - Args: - package_path: The filesystem path to the package directory. - package_name: The dotted module name of the package (e.g., "pyrit.scenario.scenarios"). - base_class: The base class to filter subclasses of. - recursive: Whether to recursively search subpackages. Defaults to True. - name_builder: Optional callable to build the registry name from (prefix, module_name). - Defaults to returning just the module_name. - _prefix: Internal parameter to track the current subdirectory prefix. - - Yields: - Tuples of (registry_name, class) for each discovered subclass. - """ - if name_builder is None: - - def name_builder(prefix: str, name: str) -> str: - return name if not prefix else f"{prefix}.{name}" - - for _, module_name, is_pkg in pkgutil.iter_modules([str(package_path)]): - if module_name.startswith("_"): - continue - - full_module_name = f"{package_name}.{module_name}" - - try: - module = importlib.import_module(full_module_name) - - # For non-package modules, find and yield subclasses - if not is_pkg: - for _name, obj in inspect.getmembers(module, inspect.isclass): - if issubclass(obj, base_class) and obj is not base_class and not inspect.isabstract(obj): - # Build the registry name including any prefix - registry_name = name_builder(_prefix, module_name) - yield (registry_name, obj) - - # Recursively discover in subpackages - if recursive and is_pkg: - sub_path = package_path / module_name - # Pass the current module_name as part of the prefix for nested discoveries - new_prefix = name_builder(_prefix, module_name) - yield from discover_in_package( - package_path=sub_path, - package_name=full_module_name, - base_class=base_class, - recursive=True, - name_builder=name_builder, - _prefix=new_prefix, - ) - - except Exception as e: - logger.warning(f"Failed to load package module {full_module_name}: {e}") - - -def discover_subclasses_in_loaded_modules( - *, - base_class: type[T], - exclude_module_prefixes: tuple[str, ...] | None = None, -) -> Iterator[tuple[str, type[T]]]: - """ - Discover subclasses of a base class from already-loaded modules. - - This is useful for discovering user-defined classes that were loaded - via initialization scripts or dynamic imports. - - Args: - base_class: The base class to filter subclasses of. - exclude_module_prefixes: Module prefixes to exclude from search. - Defaults to common system modules. - - Yields: - Tuples of (module_name, class) for each discovered subclass. - """ - import sys - - if exclude_module_prefixes is None: - exclude_module_prefixes = ("builtins", "_", "sys", "os", "importlib") - - # Create a snapshot to avoid dictionary changed size during iteration - modules_snapshot = list(sys.modules.items()) - - for module_name, module in modules_snapshot: - if module is None or not hasattr(module, "__dict__"): - continue - - # Skip excluded modules - if any(module_name.startswith(prefix) for prefix in exclude_module_prefixes): - continue - - for _name, obj in inspect.getmembers(module, inspect.isclass): - if issubclass(obj, base_class) and obj is not base_class and not inspect.isabstract(obj): - yield (module_name, obj) diff --git a/pyrit/registry/instance_registry.py b/pyrit/registry/instance_registry.py index 08b74cd44f..404189e2d1 100644 --- a/pyrit/registry/instance_registry.py +++ b/pyrit/registry/instance_registry.py @@ -416,7 +416,7 @@ def list_metadata( Returns: list[ComponentIdentifier]: The identifier metadata for each instance. """ - from pyrit.registry.base import _matches_filters + from pyrit.registry.registry import _matches_filters if self._metadata_cache is None: self._metadata_cache = [ diff --git a/pyrit/registry/registry.py b/pyrit/registry/registry.py index c5d1154a80..3a6ce28135 100644 --- a/pyrit/registry/registry.py +++ b/pyrit/registry/registry.py @@ -31,7 +31,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Generic, TypeVar -from pyrit.registry.base import ClassRegistryEntry +from pyrit.registry.registry_metadata import RegistryMetadata from pyrit.registry.resolution import ( derive_parameters, resolve_constructor_args, @@ -49,7 +49,7 @@ logger = logging.getLogger(__name__) T = TypeVar("T") -MetadataT = TypeVar("MetadataT", bound=ClassRegistryEntry) +MetadataT = TypeVar("MetadataT", bound=RegistryMetadata) def _get_metadata_value(metadata: Any, key: str) -> tuple[bool, Any]: @@ -325,7 +325,7 @@ def _metadata_class(self) -> type[MetadataT]: Return the concrete metadata dataclass this registry builds. The base ``_build_metadata`` constructs this type from the common - ``ClassRegistryEntry`` fields. Subclasses whose metadata carries extra + ``RegistryMetadata`` fields. Subclasses whose metadata carries extra fields beyond the common shape override ``_build_metadata`` instead. Returns: @@ -336,7 +336,7 @@ def _build_metadata(self, name: str, cls: type[T]) -> MetadataT: """ Build the metadata descriptor for a registered class. - Populates the common ``ClassRegistryEntry`` fields — name/module, a + Populates the common ``RegistryMetadata`` fields — name/module, a first-paragraph description, the derived ``Parameter`` build contract, and any ``Param.ClassAttr`` class attributes — into the registry's ``_metadata_class``. Subclasses needing extra fields override this. diff --git a/pyrit/registry/registry_metadata.py b/pyrit/registry/registry_metadata.py new file mode 100644 index 0000000000..ec472dc96d --- /dev/null +++ b/pyrit/registry/registry_metadata.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Shared metadata base type for PyRIT registries. + +``RegistryMetadata`` is the minimal base every registry metadata dataclass +extends, carrying the common fields used for display, lookup, and filtering. +""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Mapping + + from pyrit.models.parameter import Parameter + + +@dataclass(frozen=True) +class RegistryMetadata: + """ + Minimal base for class-level registry metadata. + + Provides the common fields every registry metadata type needs for display, + lookup, and filtering in class registries. + + Attributes: + class_name (str): Python class name (e.g., "ContentHarmsScenario"). + class_module (str): Full module path (e.g., "pyrit.scenario.scenarios.content_harms"). + class_description (str): Human-readable description, typically from the class docstring. + registry_name (str): The suffix-stripped snake_case key used in the registry + (e.g., "content_harms" for ContentHarmsScenario). + parameters (tuple[Parameter, ...]): The derived build contract for the class. + Buildable registries (e.g. converters) populate this from the constructor + signature; scenarios/initializers use their own ``supported_parameters`` + today and will migrate to this unified shape. + class_attributes (Mapping[str, Any]): Values sourced from class attributes + (declared on the identifier via ``Param.ClassAttr``), letting the entry + describe class-level facts — e.g. a converter's supported input/output + types — without constructing an instance. Empty for entries with none. + """ + + class_name: str + class_module: str + class_description: str = "" + registry_name: str = "" + parameters: tuple[Parameter, ...] = field(kw_only=True, default=()) + class_attributes: Mapping[str, Any] = field(kw_only=True, default_factory=dict) + + @staticmethod + def description_from_docstring(cls: type, *, fallback: str = "") -> str: + """ + Extract a normalized description from a class docstring. + + Collapses all whitespace into single spaces. Returns fallback if + no docstring is present or the docstring is empty after cleaning. + + Returns: + str: The cleaned docstring or the fallback value. + """ + doc = cls.__doc__ or "" + cleaned = " ".join(doc.split()) + return cleaned or fallback + + @staticmethod + def summary_from_docstring(cls: type) -> str: + """ + Extract a short summary from the first paragraph of a class docstring. + + Uses the class's own docstring only (never an inherited one), normalizes + indentation, and collapses the first paragraph's whitespace onto one line. + Empty when the class has no docstring. This is the catalog-display + counterpart to ``description_from_docstring`` (which collapses the whole + docstring); buildable registries populate ``class_description`` from this + first-paragraph form. + + Returns: + str: The first-paragraph summary, or "" when there is no docstring. + """ + raw = cls.__doc__ + if not raw: + return "" + first_paragraph = inspect.cleandoc(raw).split("\n\n", 1)[0] + return " ".join(first_paragraph.split()) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index de59d2c004..778f7fbac3 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -194,9 +194,9 @@ def __init__( The scenario description is automatically extracted from the class's docstring (__doc__) with whitespace normalized for display. """ - from pyrit.registry.base import ClassRegistryEntry + from pyrit.registry.registry_metadata import RegistryMetadata - description = ClassRegistryEntry.description_from_docstring(self.__class__) + description = RegistryMetadata.description_from_docstring(self.__class__) # The scenario identifier is the canonical per-run identity: the scenario # registry produces it and it is persisted on the ScenarioResult (carrying diff --git a/pyrit/setup/initializers/pyrit_initializer.py b/pyrit/setup/initializers/pyrit_initializer.py index adbedaed74..94ba4bac4a 100644 --- a/pyrit/setup/initializers/pyrit_initializer.py +++ b/pyrit/setup/initializers/pyrit_initializer.py @@ -87,9 +87,9 @@ def description(self) -> str: Returns: str: A description of the configuration changes this initializer makes. """ - from pyrit.registry.base import ClassRegistryEntry + from pyrit.registry.registry_metadata import RegistryMetadata - return ClassRegistryEntry.description_from_docstring(self.__class__, fallback=type(self).__name__) + return RegistryMetadata.description_from_docstring(self.__class__, fallback=type(self).__name__) @property def required_env_vars(self) -> list[str]: diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 2a59870eb8..8f0aef0388 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -17,8 +17,8 @@ import pytest -from pyrit.registry.base import ClassRegistryEntry from pyrit.registry.registry import Registry, _get_metadata_value, _matches_filters +from pyrit.registry.registry_metadata import RegistryMetadata class SampleWidget: @@ -43,7 +43,7 @@ def __init__(self, *, size: int = 1) -> None: self.size = size -class WidgetRegistry(Registry[object, ClassRegistryEntry]): +class WidgetRegistry(Registry[object, RegistryMetadata]): """Minimal Registry subclass that keeps every base default.""" def __init__(self, *, lazy_discovery: bool = True) -> None: @@ -55,12 +55,12 @@ def _discover(self) -> None: self.register_class(SampleWidget) self.register_class(UndocumentedWidget) - def _metadata_class(self) -> type[ClassRegistryEntry]: - return ClassRegistryEntry + def _metadata_class(self) -> type[RegistryMetadata]: + return RegistryMetadata @dataclass(frozen=True) -class _TaggedMetadata(ClassRegistryEntry): +class _TaggedMetadata(RegistryMetadata): tags: tuple[str, ...] = field(kw_only=True, default=()) @@ -195,7 +195,7 @@ def test_matches_filters_list_containment(): def test_matches_filters_unknown_include_key_fails(): - meta = ClassRegistryEntry(class_name="X", class_module="m") + meta = RegistryMetadata(class_name="X", class_module="m") assert not _matches_filters(meta, include_filters={"nope": "x"}) @@ -222,7 +222,7 @@ class _ConcreteWidget(_WidgetBase): """A concrete widget.""" -class _PackageDrivenRegistry(Registry[object, ClassRegistryEntry]): +class _PackageDrivenRegistry(Registry[object, RegistryMetadata]): """Registry that uses the base's default ``_discover`` over a supplied package.""" def __init__(self, *, package: ModuleType) -> None: @@ -235,8 +235,8 @@ def _base_type(self) -> type: def _discovery_package(self) -> ModuleType: return self._package - def _metadata_class(self) -> type[ClassRegistryEntry]: - return ClassRegistryEntry + def _metadata_class(self) -> type[RegistryMetadata]: + return RegistryMetadata def test_discover_skips_spec_type_mock_exports(): diff --git a/tests/unit/registry/test_base.py b/tests/unit/registry/test_registry_metadata.py similarity index 88% rename from tests/unit/registry/test_base.py rename to tests/unit/registry/test_registry_metadata.py index 30bb73b806..a5599a8b18 100644 --- a/tests/unit/registry/test_base.py +++ b/tests/unit/registry/test_registry_metadata.py @@ -3,45 +3,46 @@ from dataclasses import dataclass, field -from pyrit.registry.base import ClassRegistryEntry, _matches_filters +from pyrit.registry.registry import _matches_filters +from pyrit.registry.registry_metadata import RegistryMetadata @dataclass(frozen=True) -class MetadataWithTags(ClassRegistryEntry): +class MetadataWithTags(RegistryMetadata): """Test metadata with a tags field for list filtering tests.""" tags: tuple[str, ...] = field(kw_only=True) class TestDescriptionFromDocstring: - """Tests for ClassRegistryEntry.description_from_docstring.""" + """Tests for RegistryMetadata.description_from_docstring.""" def test_extracts_docstring_and_normalizes_whitespace(self): class MyClass: """This is\n a docstring.""" - result = ClassRegistryEntry.description_from_docstring(MyClass) + result = RegistryMetadata.description_from_docstring(MyClass) assert result == "This is a docstring." def test_returns_fallback_when_no_docstring(self): class NoDoc: pass - result = ClassRegistryEntry.description_from_docstring(NoDoc, fallback="default") + result = RegistryMetadata.description_from_docstring(NoDoc, fallback="default") assert result == "default" def test_returns_fallback_when_empty_docstring(self): class EmptyDoc: """ """ - result = ClassRegistryEntry.description_from_docstring(EmptyDoc, fallback="fallback") + result = RegistryMetadata.description_from_docstring(EmptyDoc, fallback="fallback") assert result == "fallback" def test_returns_empty_string_when_no_docstring_and_no_fallback(self): class NoDoc: pass - result = ClassRegistryEntry.description_from_docstring(NoDoc) + result = RegistryMetadata.description_from_docstring(NoDoc) assert result == "" @@ -50,7 +51,7 @@ class TestMatchesFilters: def test_matches_filters_exact_match_string(self): """Test that exact string matches work.""" - metadata = ClassRegistryEntry( + metadata = RegistryMetadata( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -60,7 +61,7 @@ def test_matches_filters_exact_match_string(self): def test_matches_filters_no_match_string(self): """Test that non-matching strings return False.""" - metadata = ClassRegistryEntry( + metadata = RegistryMetadata( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -70,7 +71,7 @@ def test_matches_filters_no_match_string(self): def test_matches_filters_multiple_filters_all_match(self): """Test that all filters must match.""" - metadata = ClassRegistryEntry( + metadata = RegistryMetadata( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -82,7 +83,7 @@ def test_matches_filters_multiple_filters_all_match(self): def test_matches_filters_multiple_filters_partial_match(self): """Test that partial matches return False when not all filters match.""" - metadata = ClassRegistryEntry( + metadata = RegistryMetadata( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -94,7 +95,7 @@ def test_matches_filters_multiple_filters_partial_match(self): def test_matches_filters_key_not_in_metadata(self): """Test that filtering on a non-existent key returns False.""" - metadata = ClassRegistryEntry( + metadata = RegistryMetadata( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -103,7 +104,7 @@ def test_matches_filters_key_not_in_metadata(self): def test_matches_filters_empty_filters(self): """Test that empty filters return True.""" - metadata = ClassRegistryEntry( + metadata = RegistryMetadata( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -133,7 +134,7 @@ def test_matches_filters_list_value_not_contains_filter(self): def test_matches_filters_exclude_exact_match(self): """Test that exclude filters work for exact matches.""" - metadata = ClassRegistryEntry( + metadata = RegistryMetadata( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -154,7 +155,7 @@ def test_matches_filters_exclude_list_value(self): def test_matches_filters_exclude_nonexistent_key(self): """Test that exclude filters for non-existent keys don't exclude the item.""" - metadata = ClassRegistryEntry( + metadata = RegistryMetadata( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -164,7 +165,7 @@ def test_matches_filters_exclude_nonexistent_key(self): def test_matches_filters_combined_include_and_exclude(self): """Test combined include and exclude filters.""" - metadata = ClassRegistryEntry( + metadata = RegistryMetadata( class_name="TestClass", class_module="test.module", class_description="A test item", From e9ff5f044efe9a864f5e8c58c4468d3589fb89a9 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 2 Jul 2026 18:15:28 -0700 Subject: [PATCH 08/11] Address review: public validate_params, dedup scanners, dead-guard, tests - Add public PyRITInitializer.validate_params() so the registry no longer reaches into the private _validate_params; validate() delegates to it. - Collapse redundant _discovery_path None-guards in InitializerRegistry.__init__ and drop the now-unnecessary assert in _discover. - Rewrite _process_file to reuse _load_module_from_path + _module_defined_initializers, removing a duplicate module scan loop. - Add direct unit tests for create_and_configure (build, param set, unknown param ValueError, unknown name KeyError). Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- .../components/initializer_registry.py | 47 ++++------------- pyrit/setup/initializers/pyrit_initializer.py | 26 ++++------ .../registry/test_initializer_registry.py | 51 +++++++++++++++++++ tests/unit/setup/test_pyrit_initializer.py | 6 ++- 4 files changed, 76 insertions(+), 54 deletions(-) diff --git a/pyrit/registry/components/initializer_registry.py b/pyrit/registry/components/initializer_registry.py index 37f1a1c906..9f541eb9a9 100644 --- a/pyrit/registry/components/initializer_registry.py +++ b/pyrit/registry/components/initializer_registry.py @@ -80,17 +80,10 @@ def __init__(self, *, discovery_path: Path | None = None, lazy_discovery: bool = If None, defaults to pyrit/setup/initializers (discovers all). lazy_discovery: If True, discovery is deferred until first access. Defaults to False for backwards compatibility. - - Raises: - ValueError: If the discovery path could not be resolved. """ - self._discovery_path = discovery_path - if self._discovery_path is None: - self._discovery_path = Path(PYRIT_PATH) / "setup" / "initializers" - - # At this point _discovery_path is guaranteed to be a Path - if self._discovery_path is None: - raise ValueError("self._discovery_path is not initialized") + self._discovery_path: Path = ( + discovery_path if discovery_path is not None else Path(PYRIT_PATH) / "setup" / "initializers" + ) self._builtin_names: set[str] = set() super().__init__(lazy_discovery=lazy_discovery) @@ -111,7 +104,6 @@ def is_builtin(self, name: str) -> bool: def _discover(self) -> None: """Discover all initializers from the specified discovery path.""" discovery_path = self._discovery_path - assert discovery_path is not None # Set in __init__ if not discovery_path.exists(): logger.warning(f"Initializers directory not found: {discovery_path}") @@ -135,38 +127,21 @@ def _discover(self) -> None: def _process_file(self, *, file_path: Path, base_class: type, builtin: bool = False) -> None: """ - Process a Python file to extract initializer subclasses. + Load a single Python file and register the initializers it defines. Args: file_path: Path to the Python file to process. base_class: The PyRITInitializer base class. builtin: Whether discovered classes should be marked as built-in. """ - short_name = file_path.stem - try: - spec = importlib.util.spec_from_file_location(f"initializer.{short_name}", file_path) - if not spec or not spec.loader: - return - - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - for attr_name in dir(module): - attr = getattr(module, attr_name) - if ( - inspect.isclass(attr) - and issubclass(attr, base_class) - and attr is not base_class - and not inspect.isabstract(attr) - ): - self._register_initializer( - initializer_class=attr, - builtin=builtin, - ) - + module = self._load_module_from_path(file_path=file_path, module_name=f"initializer.{file_path.stem}") except Exception as e: - logger.warning(f"Failed to load initializer module {short_name}: {e}") + logger.warning(f"Failed to load initializer module {file_path.stem}: {e}") + return + + for initializer_class in self._module_defined_initializers(module=module, base_class=base_class): + self._register_initializer(initializer_class=initializer_class, builtin=builtin) def _register_initializer( self, @@ -279,7 +254,7 @@ def create_and_configure( instance = self.create_instance(name) if initializer_params: instance.set_params_from_args(args=initializer_params) - instance._validate_params(params=instance.params) + instance.validate_params() return instance @staticmethod diff --git a/pyrit/setup/initializers/pyrit_initializer.py b/pyrit/setup/initializers/pyrit_initializer.py index 94ba4bac4a..5ef8c6be35 100644 --- a/pyrit/setup/initializers/pyrit_initializer.py +++ b/pyrit/setup/initializers/pyrit_initializer.py @@ -171,28 +171,22 @@ def validate(self) -> None: f"{', '.join(missing_vars)}" ) - # Validate configured params - if self.params: - self._validate_params(params=self.params) + self.validate_params() - def _validate_params(self, *, params: dict[str, list[str]]) -> None: + def validate_params(self) -> None: """ - Validate parameters against supported_parameters. + Validate the configured parameters against supported_parameters. - Checks that all provided params are declared in supported_parameters - and that all required params are present. - - Args: - params: The parameters to validate. + Checks that every parameter in ``self.params`` is declared in + ``supported_parameters``. Unlike ``validate()``, this does not check + required environment variables, so it can be used to fail fast on + parameter shape at configuration time — before the environment is set up. Raises: - ValueError: If unknown parameters are provided or required parameters are missing. + ValueError: If unknown parameters are provided. """ - supported = {p.name: p for p in self.supported_parameters} - supported_names = set(supported.keys()) - - # Check for unknown params - unknown = set(params.keys()) - supported_names + supported_names = {p.name for p in self.supported_parameters} + unknown = set(self.params.keys()) - supported_names if unknown: raise ValueError( f"Initializer '{type(self).__name__}' received unknown parameter(s): {', '.join(sorted(unknown))}. " diff --git a/tests/unit/registry/test_initializer_registry.py b/tests/unit/registry/test_initializer_registry.py index a9ba35d961..9e2c3dc52e 100644 --- a/tests/unit/registry/test_initializer_registry.py +++ b/tests/unit/registry/test_initializer_registry.py @@ -7,6 +7,7 @@ import pytest +from pyrit.models.parameter import Parameter from pyrit.registry.components.initializer_registry import PYRIT_PATH, InitializerRegistry from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer @@ -263,3 +264,53 @@ def test_create_from_script_paths_missing_file_raises(lazy_registry): """Test that a missing script path raises FileNotFoundError.""" with pytest.raises(FileNotFoundError): lazy_registry.create_from_script_paths(script_paths=["definitely_missing_script.py"]) + + +# ============================================================================ +# create_and_configure Tests +# ============================================================================ + + +class _ParamInitializer(PyRITInitializer): + """An initializer that accepts a single declared parameter.""" + + @property + def supported_parameters(self) -> list[Parameter]: + return [Parameter(name="mode", description="Operation mode", default="fast")] + + async def initialize_async(self) -> None: + pass + + +def test_create_and_configure_builds_and_sets_params(lazy_registry): + """Test that create_and_configure returns a configured instance with params set.""" + lazy_registry.register_class(_ParamInitializer, name="param_init") + + instance = lazy_registry.create_and_configure("param_init", initializer_params={"mode": "slow"}) + + assert isinstance(instance, _ParamInitializer) + assert instance.params == {"mode": ["slow"]} + + +def test_create_and_configure_without_params_leaves_instance_unconfigured(lazy_registry): + """Test that create_and_configure returns an unconfigured instance when no params are given.""" + lazy_registry.register_class(_ParamInitializer, name="param_init") + + instance = lazy_registry.create_and_configure("param_init") + + assert isinstance(instance, _ParamInitializer) + assert instance.params == {} + + +def test_create_and_configure_unknown_param_raises_value_error(lazy_registry): + """Test that an unknown parameter raises ValueError during configuration.""" + lazy_registry.register_class(_ParamInitializer, name="param_init") + + with pytest.raises(ValueError, match="unknown parameter"): + lazy_registry.create_and_configure("param_init", initializer_params={"bogus": "x"}) + + +def test_create_and_configure_unknown_name_raises_key_error(lazy_registry): + """Test that an unregistered name raises KeyError.""" + with pytest.raises(KeyError): + lazy_registry.create_and_configure("does_not_exist") diff --git a/tests/unit/setup/test_pyrit_initializer.py b/tests/unit/setup/test_pyrit_initializer.py index 7c5a3cd971..ac72bef4fe 100644 --- a/tests/unit/setup/test_pyrit_initializer.py +++ b/tests/unit/setup/test_pyrit_initializer.py @@ -529,8 +529,9 @@ async def initialize_async(self) -> None: pass init = StrictInit() + init.params = {"bogus": ["value"]} with pytest.raises(ValueError, match="unknown parameter"): - init._validate_params(params={"bogus": ["value"]}) + init.validate_params() def test_validate_params_accepts_valid(self) -> None: """Test that valid params pass validation.""" @@ -549,8 +550,9 @@ async def initialize_async(self) -> None: pass init = ValidInit() + init.params = {"key": ["abc"], "mode": ["slow"]} # Should not raise - init._validate_params(params={"key": ["abc"], "mode": ["slow"]}) + init.validate_params() def test_validate_checks_params_on_instance(self) -> None: """Test that validate() checks self.params.""" From 6d4d9198dfcb252483ea51b8a2b1ab3cf39a8197 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 2 Jul 2026 19:15:08 -0700 Subject: [PATCH 09/11] Fix flaky network PDF test and apply ruff format - test_filename_extension_existing_pdf downloaded fake_CV.pdf from raw.githubusercontent.com, so it failed in CI without network. Read the PDF from the local datasets directory instead (no network in unit tests). - Apply ruff format (collapse over-wrapped lines) to satisfy pre-commit. Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- pyrit/registry/components/initializer_registry.py | 4 +--- tests/unit/prompt_converter/test_pdf_converter.py | 8 ++++---- tests/unit/registry/test_initializer_registry.py | 4 +--- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/pyrit/registry/components/initializer_registry.py b/pyrit/registry/components/initializer_registry.py index 9f541eb9a9..18e955651f 100644 --- a/pyrit/registry/components/initializer_registry.py +++ b/pyrit/registry/components/initializer_registry.py @@ -225,9 +225,7 @@ def _build_metadata(self, name: str, cls: type[PyRITInitializer]) -> Initializer required_env_vars=(), ) - def create_and_configure( - self, name: str, *, initializer_params: dict[str, Any] | None = None - ) -> PyRITInitializer: + def create_and_configure(self, name: str, *, initializer_params: dict[str, Any] | None = None) -> PyRITInitializer: """ Build and parameterize an initializer in one call. diff --git a/tests/unit/prompt_converter/test_pdf_converter.py b/tests/unit/prompt_converter/test_pdf_converter.py index 9ba93066d4..541e57dfe6 100644 --- a/tests/unit/prompt_converter/test_pdf_converter.py +++ b/tests/unit/prompt_converter/test_pdf_converter.py @@ -469,14 +469,14 @@ async def test_filename_extension_default(sqlite_instance): async def test_filename_extension_existing_pdf(sqlite_instance): + import shutil import tempfile - import requests + from pyrit.common.path import DATASETS_PATH - url = "https://raw.githubusercontent.com/microsoft/PyRIT/main/pyrit/datasets/prompt_converters/pdf_converters/fake_CV.pdf" + source_pdf = DATASETS_PATH / "prompt_converters" / "pdf_converters" / "fake_CV.pdf" with tempfile.NamedTemporaryFile(delete=False, suffix=".tmp") as tmp_file: - response = requests.get(url) - tmp_file.write(response.content) + shutil.copyfile(source_pdf, tmp_file.name) cv_pdf_path = Path(tmp_file.name) diff --git a/tests/unit/registry/test_initializer_registry.py b/tests/unit/registry/test_initializer_registry.py index 9e2c3dc52e..ae09aeedb5 100644 --- a/tests/unit/registry/test_initializer_registry.py +++ b/tests/unit/registry/test_initializer_registry.py @@ -219,9 +219,7 @@ def _write_initializer_script(directory: Path, filename: str, *class_names: str) body = "from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer\n\n" for class_name in class_names: body += ( - f"class {class_name}(PyRITInitializer):\n" - f" async def initialize_async(self) -> None:\n" - f" pass\n\n" + f"class {class_name}(PyRITInitializer):\n async def initialize_async(self) -> None:\n pass\n\n" ) script_path = directory / filename script_path.write_text(body) From a119a29a2975384ff185453f9c514eca906b221e Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 2 Jul 2026 22:20:35 -0700 Subject: [PATCH 10/11] TEST: raise initializer registry coverage to 99% Add unit tests for discovery, registration, and metadata edge cases. Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- .../registry/test_initializer_registry.py | 152 ++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/tests/unit/registry/test_initializer_registry.py b/tests/unit/registry/test_initializer_registry.py index ae09aeedb5..416e32d1b2 100644 --- a/tests/unit/registry/test_initializer_registry.py +++ b/tests/unit/registry/test_initializer_registry.py @@ -312,3 +312,155 @@ def test_create_and_configure_unknown_name_raises_key_error(lazy_registry): """Test that an unregistered name raises KeyError.""" with pytest.raises(KeyError): lazy_registry.create_and_configure("does_not_exist") + + +# ============================================================================ +# Discovery / registration edge-case Tests +# ============================================================================ + +_SOLO_SCRIPT = ( + "from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer\n\n" + "class SoloInitializer(PyRITInitializer):\n" + " async def initialize_async(self) -> None:\n" + " pass\n" +) + + +def test_discover_directory_registers_and_lists_metadata(): + """Test that a directory discovery path scans, registers, and builds metadata for initializers.""" + with tempfile.TemporaryDirectory() as temp_dir: + (Path(temp_dir) / "solo.py").write_text(_SOLO_SCRIPT) + + registry = InitializerRegistry(discovery_path=Path(temp_dir), lazy_discovery=False) + + assert "solo" in registry + metadata = registry.get_all_registered_class_metadata() + assert any(m.registry_name == "solo" for m in metadata) + + +def test_discover_single_file_registers_builtin(): + """Test that a discovery path pointing at a single file registers it as built-in.""" + with tempfile.TemporaryDirectory() as temp_dir: + script = Path(temp_dir) / "solo.py" + script.write_text(_SOLO_SCRIPT) + + registry = InitializerRegistry(discovery_path=script, lazy_discovery=False) + + assert "solo" in registry + assert registry.is_builtin("solo") is True + assert registry.is_builtin("not_registered") is False + + +def test_discover_missing_path_registers_nothing(): + """Test that a non-existent discovery path logs a warning and registers nothing.""" + missing = Path(tempfile.gettempdir()) / "pyrit_missing_initializers_dir_xyz" + registry = InitializerRegistry(discovery_path=missing, lazy_discovery=False) + + assert registry.get_class_names() == [] + + +def test_discover_single_file_load_failure_registers_nothing(): + """Test that a file that fails to import is skipped without raising.""" + with tempfile.TemporaryDirectory() as temp_dir: + bad = Path(temp_dir) / "bad.py" + bad.write_text("def bad syntax(:\n") + + registry = InitializerRegistry(discovery_path=bad, lazy_discovery=False) + + assert registry.get_class_names() == [] + + +def test_register_initializer_collision_keeps_first(lazy_registry): + """Test that a registry-name collision keeps the first registration.""" + + class DupInitializer(PyRITInitializer): + async def initialize_async(self) -> None: + pass + + lazy_registry._register_initializer(initializer_class=DupInitializer, builtin=True) + lazy_registry._register_initializer(initializer_class=DupInitializer) + + assert lazy_registry.get_class("dup") is DupInitializer + + +def test_register_initializer_swallows_registration_errors(lazy_registry): + """Test that a failure inside register_class is logged and swallowed.""" + + class BadInitializer(PyRITInitializer): + async def initialize_async(self) -> None: + pass + + with patch.object(lazy_registry, "register_class", side_effect=RuntimeError("boom")): + lazy_registry._register_initializer(initializer_class=BadInitializer) + + assert "bad" not in lazy_registry + + +def test_build_metadata_instantiation_failure_returns_fallback(lazy_registry): + """Test that _build_metadata falls back when the initializer cannot be instantiated.""" + + class ExplodingInitializer(PyRITInitializer): + """Exploding.""" + + def __init__(self) -> None: + raise RuntimeError("cannot construct") + + async def initialize_async(self) -> None: + pass + + metadata = lazy_registry._build_metadata("exploding", ExplodingInitializer) + + assert metadata.class_description == "Error loading initializer metadata" + assert metadata.required_env_vars == () + + +def test_load_module_from_path_no_spec_raises(): + """Test that _load_module_from_path raises when an import spec cannot be created.""" + with patch( + "pyrit.registry.components.initializer_registry.importlib.util.spec_from_file_location", + return_value=None, + ): + with pytest.raises(ValueError, match="Could not load initializer script"): + InitializerRegistry._load_module_from_path(file_path=Path("nope.py"), module_name="nope") + + +def test_create_from_script_paths_instantiation_failure_raises(lazy_registry): + """Test that a script whose only initializer fails to instantiate raises ValueError.""" + script = ( + "from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer\n\n" + "class BoomInitializer(PyRITInitializer):\n" + " def __init__(self):\n" + " raise RuntimeError('boom')\n" + " async def initialize_async(self) -> None:\n" + " pass\n" + ) + with tempfile.TemporaryDirectory() as temp_dir: + path = Path(temp_dir) / "boom.py" + path.write_text(script) + + with pytest.raises(ValueError, match="must contain at least one"): + lazy_registry.create_from_script_paths(script_paths=[path]) + + +def test_register_from_content_write_failure_raises(lazy_registry): + """Test that an OSError while writing the script surfaces as ValueError.""" + with patch.object(InitializerRegistry, "_get_custom_scripts_dir", return_value=Path(tempfile.mkdtemp())): + with patch("pathlib.Path.write_text", side_effect=OSError("disk full")): + with pytest.raises(ValueError, match="Failed to write initializer script"): + lazy_registry.register_from_content(name="write_fail", script_content=_VALID_SCRIPT) + + +def test_unregister_and_cleanup_unknown_name_raises(lazy_registry): + """Test that unregistering an unknown, non-built-in name raises KeyError.""" + with pytest.raises(KeyError, match="not found in registry"): + lazy_registry.unregister_and_cleanup("nonexistent") + + +def test_get_custom_scripts_dir_creates_directory(): + """Test that _get_custom_scripts_dir returns and creates the managed directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + with patch("pyrit.common.path.CONFIGURATION_DIRECTORY_PATH", Path(temp_dir)): + result = InitializerRegistry._get_custom_scripts_dir() + + assert result == Path(temp_dir) / "custom_initializers" + assert result.exists() From 2d204547ef258741bea7cc4201b0c6f2ac2dc63e Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 2 Jul 2026 22:28:12 -0700 Subject: [PATCH 11/11] MAINT: fix ty type errors in initializer discovery helpers Type the discovery helper base_class/return as type[PyRITInitializer]. Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- pyrit/registry/components/initializer_registry.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyrit/registry/components/initializer_registry.py b/pyrit/registry/components/initializer_registry.py index 18e955651f..4cc9cfa5cb 100644 --- a/pyrit/registry/components/initializer_registry.py +++ b/pyrit/registry/components/initializer_registry.py @@ -125,7 +125,7 @@ def _discover(self) -> None: builtin=True, ) - def _process_file(self, *, file_path: Path, base_class: type, builtin: bool = False) -> None: + def _process_file(self, *, file_path: Path, base_class: type[PyRITInitializer], builtin: bool = False) -> None: """ Load a single Python file and register the initializers it defines. @@ -278,7 +278,9 @@ def _load_module_from_path(*, file_path: Path, module_name: str) -> ModuleType: return module @staticmethod - def _module_defined_initializers(*, module: ModuleType, base_class: type) -> list[type]: + def _module_defined_initializers( + *, module: ModuleType, base_class: type[PyRITInitializer] + ) -> list[type[PyRITInitializer]]: """ Find concrete ``PyRITInitializer`` subclasses defined in *module*.