From f5f7f4207a8cca1efba791a2bdb27fc2e5139817 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 2 Jul 2026 15:53:04 -0700 Subject: [PATCH 1/8] =?UTF-8?q?FEAT:=20Phase=204.5=20=E2=80=94=20route=20t?= =?UTF-8?q?arget=20service=20and=20frontend=20through=20TargetRegistry?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirror the Phase 3.5 converter work for targets: push target-specific auth/endpoint-validation out of the backend target_service onto the target classes, add a registry-driven GET /targets/catalog, and wire the create-target dialog to it. - Move endpoint-trust checks onto the targets via shared pyrit.auth helpers (is_azure_openai_endpoint / is_azure_ml_endpoint). OpenAITarget's auto-Entra fallback now uses the strict recognized-host allowlist instead of a loose "azure" substring check (behavior change: unrecognized "azure"-ish endpoints now raise instead of silently minting a token). AzureMLChatTarget gains the Entra auto-fallback the service used to special-case. - Add declarative auth facts to PromptTarget (supported_auth_modes, get_api_key_environment_variable) for the catalog. - RoundRobinTarget owns its own multi-target dedup (object identity). - create_target_async is now generic: resolve class via the registry catalog and build with create_instance; for entra it just omits the api_key. No endpoint or per-class auth branching remains in the service. - Add TargetCatalogEntry/Response + list_target_catalog_async + /targets/catalog. - TargetInstance embeds an additive serialized identifier. - Frontend: targetsApi.listTargetCatalog(); dialog derives its type list + Entra support from the catalog and sends RoundRobin inner targets as `targets`. Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- .../Config/CreateTargetDialog.test.tsx | 21 + .../components/Config/CreateTargetDialog.tsx | 106 +++- frontend/src/services/api.ts | 6 + frontend/src/types/index.ts | 12 + pyrit/auth/__init__.py | 4 + pyrit/auth/azure_auth.py | 49 ++ pyrit/backend/mappers/target_mappers.py | 1 + pyrit/backend/models/targets.py | 32 ++ pyrit/backend/routes/targets.py | 32 +- pyrit/backend/services/target_service.py | 377 ++++----------- pyrit/models/catalog/target.py | 9 + pyrit/prompt_target/azure_ml_chat_target.py | 54 ++- pyrit/prompt_target/common/prompt_target.py | 32 +- pyrit/prompt_target/openai/openai_target.py | 42 +- pyrit/prompt_target/round_robin_target.py | 77 ++- tests/unit/backend/test_api_routes.py | 25 + tests/unit/backend/test_target_service.py | 452 +++++++----------- .../target/test_azure_ml_chat_target.py | 29 ++ .../target/test_openai_chat_target.py | 38 ++ 19 files changed, 776 insertions(+), 622 deletions(-) diff --git a/frontend/src/components/Config/CreateTargetDialog.test.tsx b/frontend/src/components/Config/CreateTargetDialog.test.tsx index c825c32ac5..f177769302 100644 --- a/frontend/src/components/Config/CreateTargetDialog.test.tsx +++ b/frontend/src/components/Config/CreateTargetDialog.test.tsx @@ -8,11 +8,24 @@ import { targetsApi } from "@/services/api"; jest.mock("@/services/api", () => ({ targetsApi: { createTarget: jest.fn(), + listTargetCatalog: jest.fn(), + listTargets: jest.fn(), }, })); const mockedTargetsApi = targetsApi as jest.Mocked; +// Representative target catalog covering the types the dialog renders. Mirrors +// the shape returned by GET /targets/catalog. +const TARGET_CATALOG = { + items: [ + { target_type: "OpenAIChatTarget", parameters: [], supported_auth_modes: ["api_key", "entra"], api_key_env_var: "OPENAI_CHAT_KEY" }, + { target_type: "OpenAIResponseTarget", parameters: [], supported_auth_modes: ["api_key", "entra"], api_key_env_var: "OPENAI_RESPONSES_KEY" }, + { target_type: "AzureMLChatTarget", parameters: [], supported_auth_modes: ["api_key", "entra"], api_key_env_var: "AZURE_ML_KEY" }, + { target_type: "RoundRobinTarget", parameters: [], supported_auth_modes: ["api_key"], api_key_env_var: null }, + ], +}; + const TestWrapper: React.FC<{ children: React.ReactNode }> = ({ children, }) => {children}; @@ -104,6 +117,13 @@ describe("CreateTargetDialog", () => { beforeEach(() => { jest.clearAllMocks(); + mockedTargetsApi.listTargetCatalog.mockResolvedValue( + TARGET_CATALOG as unknown as Awaited>, + ); + mockedTargetsApi.listTargets.mockResolvedValue({ + items: [], + pagination: { limit: 200, has_more: false, next_cursor: null, prev_cursor: null }, + } as unknown as Awaited>); }); it("should render dialog when open", () => { @@ -1199,6 +1219,7 @@ describe("CreateTargetDialog", () => { ); const call = mockedTargetsApi.createTarget.mock.calls[0][0]; expect(call.type).toBe("RoundRobinTarget"); + expect(call.params?.targets).toEqual(["a", "b"]); expect(call.params?.weights).toEqual([7, 42]); }, 30000); diff --git a/frontend/src/components/Config/CreateTargetDialog.tsx b/frontend/src/components/Config/CreateTargetDialog.tsx index 2c39e1666a..e236a83c7d 100644 --- a/frontend/src/components/Config/CreateTargetDialog.tsx +++ b/frontend/src/components/Config/CreateTargetDialog.tsx @@ -23,30 +23,48 @@ import { import { DeleteRegular } from '@fluentui/react-icons' import { targetsApi } from '@/services/api' import { toApiError } from '@/services/errors' -import type { TargetInstance } from '@/types' +import type { TargetInstance, TargetCatalogEntry } from '@/types' import { useCreateTargetDialogStyles } from './CreateTargetDialog.styles' import { MAX_WEIGHT, parseWeight } from './weightValidation' -interface TargetTypeConfig { - readonly kind: 'openai' | 'azureml' | 'roundrobin' - readonly supportsEntra: boolean -} - -const TARGET_TYPE_CONFIG: Record = { - OpenAIChatTarget: { kind: 'openai', supportsEntra: true }, - OpenAICompletionTarget: { kind: 'openai', supportsEntra: true }, - OpenAIImageTarget: { kind: 'openai', supportsEntra: true }, - OpenAIVideoTarget: { kind: 'openai', supportsEntra: true }, - OpenAITTSTarget: { kind: 'openai', supportsEntra: true }, - OpenAIResponseTarget: { kind: 'openai', supportsEntra: true }, - AzureMLChatTarget: { kind: 'azureml', supportsEntra: true }, - RoundRobinTarget: { kind: 'roundrobin', supportsEntra: false }, +/** + * Form shape for each target type the dialog knows how to render. + * + * The dialog renders bespoke, type-specific forms (endpoint/model for OpenAI, + * extra sampling params for Azure ML, an inner-target picker for RoundRobin), + * so this map declares *which* types are renderable and *how*. The list of + * available types and their auth flags come from the backend catalog + * (`/targets/catalog`); this map only governs the form layout. Types the + * backend offers but that aren't in this map are simply not shown, and types in + * this map that the backend doesn't offer fall back to being listed anyway + * (e.g. when the catalog fetch fails). + */ +type TargetFormShape = 'openai' | 'azureml' | 'roundrobin' + +const TARGET_FORM_SHAPES: Record = { + OpenAIChatTarget: 'openai', + OpenAICompletionTarget: 'openai', + OpenAIImageTarget: 'openai', + OpenAIVideoTarget: 'openai', + OpenAITTSTarget: 'openai', + OpenAIResponseTarget: 'openai', + AzureMLChatTarget: 'azureml', + RoundRobinTarget: 'roundrobin', } -const SUPPORTED_TARGET_TYPES = Object.keys(TARGET_TYPE_CONFIG) +const RENDERABLE_TARGET_TYPES = Object.keys(TARGET_FORM_SHAPES) type AuthMode = 'api_key' | 'entra' +/** + * Fallback for whether a target type supports Entra auth when the backend + * catalog hasn't loaded (or the fetch failed). Once the catalog is available it + * is authoritative; this only keeps the form usable offline / mid-load. + */ +function defaultSupportsEntra(shape: TargetFormShape | undefined): boolean { + return shape === 'openai' || shape === 'azureml' +} + // Mirrors backend's hostname-suffix check (list in target_service.py). // The backend still does the check and will reject unsupported endpoints, but this allows us to show a warning in the UI if the user selects Microsoft Entra authentication with a non-Azure OpenAI endpoint. const AZURE_OPENAI_HOSTNAME_SUFFIXES = [ @@ -161,11 +179,47 @@ export default function CreateTargetDialog({ open, onClose, onCreated, existingT // Targets the user has picked for the RoundRobinTarget, with their weights. const [selectedInnerTargets, setSelectedInnerTargets] = useState([]) - const targetConfig = TARGET_TYPE_CONFIG[targetType] - const isRoundRobin = targetConfig?.kind === 'roundrobin' - const isAzureML = targetConfig?.kind === 'azureml' - const isOpenAi = targetConfig?.kind === 'openai' - const supportsEntra = targetConfig?.supportsEntra ?? false + // --- Catalog state --- + // Available target types + their auth facts, fetched from the backend registry. + const [catalogEntries, setCatalogEntries] = useState([]) + const catalogByType = useMemo( + () => new Map(catalogEntries.map((entry) => [entry.target_type, entry])), + [catalogEntries], + ) + + // Fetch the target catalog once when the dialog opens. The backend is the + // authority on which types exist and which auth modes they support. + useEffect(() => { + if (!open) return + let cancelled = false + targetsApi.listTargetCatalog() + .then((res) => { + if (!cancelled) setCatalogEntries(res.items) + }) + .catch(() => { + // Ignore fetch errors — fall back to the locally-known renderable types. + }) + return () => { cancelled = true } + }, [open]) + + // The types offered in the dropdown: catalog types the dialog can render, + // preserving catalog order. Fall back to the locally-known types when the + // catalog hasn't loaded (or the fetch failed) so the form stays usable. + const targetTypeOptions = useMemo(() => { + const fromCatalog = catalogEntries + .map((entry) => entry.target_type) + .filter((type) => type in TARGET_FORM_SHAPES) + return fromCatalog.length > 0 ? fromCatalog : RENDERABLE_TARGET_TYPES + }, [catalogEntries]) + + const formShape = TARGET_FORM_SHAPES[targetType] + const isRoundRobin = formShape === 'roundrobin' + const isAzureML = formShape === 'azureml' + const isOpenAi = formShape === 'openai' + const catalogEntry = catalogByType.get(targetType) + const supportsEntra = catalogEntry + ? catalogEntry.supported_auth_modes.includes('entra') + : defaultSupportsEntra(formShape) const showAuthField = targetType !== '' && supportsEntra const isEntra = showAuthField && authMode === 'entra' const entraEndpointError: string | null = (() => { @@ -301,7 +355,7 @@ export default function CreateTargetDialog({ open, onClose, onCreated, existingT await targetsApi.createTarget({ type: 'RoundRobinTarget', params: { - target_registry_names: selectedInnerTargets.map((t) => t.registryName), + targets: selectedInnerTargets.map((t) => t.registryName), weights: parsedWeights, }, }) @@ -390,13 +444,17 @@ export default function CreateTargetDialog({ open, onClose, onCreated, existingT onChange={(_, data) => { const next = data.value setTargetType(next) - if (!(TARGET_TYPE_CONFIG[next]?.supportsEntra ?? false)) { + const nextEntry = catalogByType.get(next) + const nextSupportsEntra = nextEntry + ? nextEntry.supported_auth_modes.includes('entra') + : defaultSupportsEntra(TARGET_FORM_SHAPES[next]) + if (!nextSupportsEntra) { setAuthMode('api_key') } }} > - {SUPPORTED_TARGET_TYPES.map((type) => ( + {targetTypeOptions.map((type) => ( ))} diff --git a/frontend/src/services/api.ts b/frontend/src/services/api.ts index 3c04828cb0..5dfb44d8f0 100644 --- a/frontend/src/services/api.ts +++ b/frontend/src/services/api.ts @@ -5,6 +5,7 @@ import { getApiScopes } from '../auth/msalConfig' import type { TargetInstance, TargetListResponse, + TargetCatalogResponse, ConverterCatalogResponse, ConverterInstance, ConverterListResponse, @@ -146,6 +147,11 @@ export const versionApi = { } export const targetsApi = { + listTargetCatalog: async (): Promise => { + const response = await apiClient.get('/targets/catalog') + return response.data + }, + listTargets: async (limit = 50, cursor?: string): Promise => { const params: Record = { limit } if (cursor) params.cursor = cursor diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 236a5e49d9..db521405ee 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -143,6 +143,18 @@ export interface ConverterCatalogResponse { items: ConverterCatalogEntry[] } +export interface TargetCatalogEntry { + target_type: string + parameters: Parameter[] + supported_auth_modes: ('api_key' | 'entra')[] + api_key_env_var?: string | null + description?: string | null +} + +export interface TargetCatalogResponse { + items: TargetCatalogEntry[] +} + // --- Attacks --- export interface TargetInfo { diff --git a/pyrit/auth/__init__.py b/pyrit/auth/__init__.py index 3b17fcef61..ac37e89abf 100644 --- a/pyrit/auth/__init__.py +++ b/pyrit/auth/__init__.py @@ -15,6 +15,8 @@ get_azure_openai_auth, get_azure_token_provider, get_default_azure_scope, + is_azure_ml_endpoint, + is_azure_openai_endpoint, ) from pyrit.auth.azure_storage_auth import AzureStorageAuth from pyrit.auth.copilot_authenticator import CopilotAuthenticator @@ -33,4 +35,6 @@ "get_azure_async_token_provider", "get_default_azure_scope", "get_azure_openai_auth", + "is_azure_ml_endpoint", + "is_azure_openai_endpoint", ] diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index 8a6131d3a4..411f17504b 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -7,6 +7,7 @@ import logging import time from typing import TYPE_CHECKING, Any, cast +from urllib.parse import urlparse import msal from azure.core.credentials import AccessToken @@ -32,6 +33,54 @@ logger = logging.getLogger(__name__) +# Recognised Azure OpenAI / AI Foundry hostname suffixes. Used for strict +# endpoint validation before an Entra ID bearer token is minted, so a token is +# only ever issued for a known Microsoft-operated endpoint (a substring check +# such as ``"azure" in endpoint`` is not sufficient — anyone can host a domain +# that merely contains "azure"). +_AZURE_OPENAI_HOSTNAME_SUFFIXES = ( + ".openai.azure.com", + ".ai.azure.com", + ".services.ai.azure.com", + ".cognitiveservices.azure.com", +) + +# Recognised Azure Machine Learning managed online endpoint hostname suffixes. +_AZURE_ML_HOSTNAME_SUFFIXES = (".inference.ml.azure.com",) + + +def is_azure_openai_endpoint(endpoint: str | None) -> bool: + """ + Return True if ``endpoint`` resolves to a known Azure OpenAI / AI Foundry host. + + Uses a strict hostname-suffix check (not a substring search) so an Entra ID + token is only minted for a Microsoft-operated endpoint. + + Args: + endpoint (str | None): The endpoint URL to validate. + + Returns: + bool: True if the endpoint's hostname ends with a recognised Azure suffix. + """ + hostname = (urlparse(endpoint or "").hostname or "").lower() + return any(hostname.endswith(suffix) for suffix in _AZURE_OPENAI_HOSTNAME_SUFFIXES) + + +def is_azure_ml_endpoint(endpoint: str | None) -> bool: + """ + Return True if ``endpoint`` resolves to a known AML managed online host. + + Uses a strict hostname-suffix check (not a substring search). + + Args: + endpoint (str | None): The endpoint URL to validate. + + Returns: + bool: True if the endpoint's hostname ends with a recognised AML suffix. + """ + hostname = (urlparse(endpoint or "").hostname or "").lower() + return any(hostname.endswith(suffix) for suffix in _AZURE_ML_HOSTNAME_SUFFIXES) + class TokenProviderCredential: """ diff --git a/pyrit/backend/mappers/target_mappers.py b/pyrit/backend/mappers/target_mappers.py index efe89e13a4..1309dfbf3d 100644 --- a/pyrit/backend/mappers/target_mappers.py +++ b/pyrit/backend/mappers/target_mappers.py @@ -117,6 +117,7 @@ def target_object_to_instance(target_registry_name: str, target_obj: PromptTarge target_specific_params=combined_specific, inner_targets=inner_targets, identifier_hash=target_identifier.hash, + identifier=target_identifier, ) diff --git a/pyrit/backend/models/targets.py b/pyrit/backend/models/targets.py index e31d1c362a..359c373415 100644 --- a/pyrit/backend/models/targets.py +++ b/pyrit/backend/models/targets.py @@ -14,14 +14,46 @@ from pydantic import BaseModel, Field from pyrit.backend.models.common import PaginationInfo +from pyrit.models import Parameter from pyrit.models.catalog.target import TargetInstance __all__ = [ "CreateTargetRequest", + "TargetCatalogEntry", + "TargetCatalogResponse", "TargetListResponse", ] +def _default_auth_modes() -> list[Literal["api_key", "entra"]]: + return ["api_key"] + + +class TargetCatalogEntry(BaseModel): + """A target type available from the backend registry.""" + + target_type: str = Field(..., description="Target class name (e.g., 'OpenAIChatTarget')") + parameters: list[Parameter] = Field( + default_factory=list, + description="Constructor parameters for dynamic form generation", + ) + supported_auth_modes: list[Literal["api_key", "entra"]] = Field( + default_factory=_default_auth_modes, + description="Authentication modes this target type supports", + ) + api_key_env_var: str | None = Field( + None, + description="Environment variable name that supplies this target's API key, if any", + ) + description: str | None = Field(None, description="Short description of the target from its docstring") + + +class TargetCatalogResponse(BaseModel): + """Response for listing available target types from the registry.""" + + items: list[TargetCatalogEntry] = Field(..., description="List of available target types") + + class TargetListResponse(BaseModel): """Response for listing target instances.""" diff --git a/pyrit/backend/routes/targets.py b/pyrit/backend/routes/targets.py index c55a6d9ff2..3bac8a23b7 100644 --- a/pyrit/backend/routes/targets.py +++ b/pyrit/backend/routes/targets.py @@ -13,6 +13,7 @@ from pyrit.backend.models.common import ProblemDetail from pyrit.backend.models.targets import ( CreateTargetRequest, + TargetCatalogResponse, TargetListResponse, ) from pyrit.backend.services.target_service import get_target_service @@ -44,15 +45,38 @@ async def list_targets( # pyrit-async-suffix-exempt return await service.list_targets_async(limit=limit, cursor=cursor) +@router.get( + "/catalog", + response_model=TargetCatalogResponse, + responses={ + 500: {"model": ProblemDetail, "description": "Internal server error"}, + }, +) +async def list_target_catalog() -> TargetCatalogResponse: # pyrit-async-suffix-exempt + """ + List all available target types from the backend target registry. + + Returns: + TargetCatalogResponse: List of available target types. + """ + service = get_target_service() + return await service.list_target_catalog_async() + + @router.post( "", response_model=TargetInstance, status_code=status.HTTP_201_CREATED, responses={ - 400: {"model": ProblemDetail, "description": "Invalid target type or parameters"}, + 400: { + "model": ProblemDetail, + "description": "Invalid target type or parameters", + }, }, ) -async def create_target(request: CreateTargetRequest) -> TargetInstance: # pyrit-async-suffix-exempt +async def create_target( + request: CreateTargetRequest, +) -> TargetInstance: # pyrit-async-suffix-exempt """ Create a new target instance. @@ -87,7 +111,9 @@ async def create_target(request: CreateTargetRequest) -> TargetInstance: # pyri 404: {"model": ProblemDetail, "description": "Target not found"}, }, ) -async def get_target(target_registry_name: str) -> TargetInstance: # pyrit-async-suffix-exempt +async def get_target( + target_registry_name: str, +) -> TargetInstance: # pyrit-async-suffix-exempt """ Get a target instance by registry name. diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index 5e23bd9bad..e1819c7e00 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -15,156 +15,37 @@ import logging import os from functools import lru_cache -from typing import Any, ClassVar -from urllib.parse import urlparse +from typing import Any -from pyrit import prompt_target -from pyrit.auth import get_azure_async_token_provider, get_azure_openai_auth from pyrit.backend.mappers.target_mappers import target_object_to_instance from pyrit.backend.models.common import PaginationInfo from pyrit.backend.models.targets import ( CreateTargetRequest, + TargetCatalogEntry, + TargetCatalogResponse, TargetListResponse, ) from pyrit.models.catalog.target import TargetInstance from pyrit.prompt_target import PromptTarget -from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget -from pyrit.prompt_target.openai.openai_target import OpenAITarget -from pyrit.prompt_target.round_robin_target import RoundRobinTarget from pyrit.registry import TargetRegistry logger = logging.getLogger(__name__) -# Recognised Azure OpenAI / AI Foundry hostname suffixes. Used for strict -# endpoint validation when Entra ID auth is requested, so a bearer token is -# only ever issued for a known Microsoft-operated endpoint. -_AZURE_OPENAI_HOSTNAME_SUFFIXES = ( - ".openai.azure.com", - ".ai.azure.com", - ".services.ai.azure.com", - ".cognitiveservices.azure.com", -) - -# Recognised Azure Machine Learning managed online endpoint hostname suffixes. -# Used for the same strict endpoint validation when issuing Entra ID tokens -# against an AML scope. -_AZURE_ML_HOSTNAME_SUFFIXES = (".inference.ml.azure.com",) - - -def _is_azure_openai_endpoint(endpoint: str) -> bool: - """ - Return True if ``endpoint`` resolves to a known Azure OpenAI / AI Foundry host. - Uses a strict hostname-suffix check (not a substring search). - - Args: - endpoint (str): The endpoint URL to validate. - - Returns: - bool: True if the endpoint's hostname ends with a recognised Azure suffix; - False otherwise - """ - hostname = (urlparse(endpoint).hostname or "").lower() - return any(hostname.endswith(suffix) for suffix in _AZURE_OPENAI_HOSTNAME_SUFFIXES) - - -def _is_azure_ml_endpoint(endpoint: str) -> bool: - """ - Return True if ``endpoint`` resolves to a known AML managed host. - Uses a strict hostname-suffix check (not a substring search). - - Args: - endpoint (str): The endpoint URL to validate. - - Returns: - bool: True if the endpoint's hostname ends with a recognised AML suffix; - False otherwise. - """ - hostname = (urlparse(endpoint).hostname or "").lower() - return any(hostname.endswith(suffix) for suffix in _AZURE_ML_HOSTNAME_SUFFIXES) - - -def _resolve_api_key_env_var(target_class: type) -> str | None: - """ - Return the api_key environment variable name for a target class. - - Args: - target_class (type): The target class to inspect. - - Returns: - str | None: The env var name, or None if the class does not declare one. - """ - if issubclass(target_class, AzureMLChatTarget): - env_var = getattr(target_class, "api_key_environment_variable", None) - return env_var if isinstance(env_var, str) and env_var else None - if issubclass(target_class, OpenAITarget): - try: - instance = target_class.__new__(target_class) - instance._set_openai_env_configuration_vars() - except Exception: - return None - env_var = getattr(instance, "api_key_environment_variable", None) - return env_var if isinstance(env_var, str) and env_var else None - return None - - -def _build_target_class_registry() -> dict[str, type]: - """ - Build a registry mapping target class names to their classes. - - Uses the prompt_target module's __all__ to discover all available targets. - - Returns: - Dict mapping class name (str) to class (type). - """ - registry: dict[str, type] = {} - for name in prompt_target.__all__: - cls = getattr(prompt_target, name, None) - if cls is not None and isinstance(cls, type) and issubclass(cls, PromptTarget): - registry[name] = cls - return registry - - -# Module-level class registry (built once on import) -_TARGET_CLASS_REGISTRY: dict[str, type] = _build_target_class_registry() - class TargetService: """ Service for managing target instances. - Uses TargetRegistry as the sole source of truth. - API metadata is derived from the target objects' identifiers. + Uses TargetRegistry as the sole source of truth. Class discovery, + construction (incl. param coercion and reference resolution), and endpoint + validation are all owned by the registry and the target classes; this + service only orchestrates the request → registry hand-off. """ - # Scope for Azure Machine Learning managed online endpoints. - _AZURE_ML_SCOPE: ClassVar[str] = "https://ml.azure.com/.default" - def __init__(self) -> None: """Initialize the target service.""" self._registry = TargetRegistry.get_registry_singleton() - def _get_target_class(self, *, target_type: str) -> type: - """ - Get the target class for a given type name. - - Looks up the class in the module-level target class registry. - - Args: - target_type: The exact class name of the target (e.g., 'TextTarget'). - - Returns: - The target class. - - Raises: - ValueError: If the target type is not found. - """ - cls = _TARGET_CLASS_REGISTRY.get(target_type) - if cls is None: - raise ValueError( - f"Target type '{target_type}' not found. Available types: {sorted(_TARGET_CLASS_REGISTRY.keys())}" - ) - return cls - def _build_instance_from_object(self, *, target_registry_name: str, target_obj: Any) -> TargetInstance: """ Build a TargetInstance from a registry object. @@ -198,7 +79,12 @@ async def list_targets_async( next_cursor = page[-1].target_registry_name if has_more and page else None return TargetListResponse( items=page, - pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor), + pagination=PaginationInfo( + limit=limit, + has_more=has_more, + next_cursor=next_cursor, + prev_cursor=cursor, + ), ) @staticmethod @@ -241,12 +127,49 @@ def get_target_object(self, *, target_registry_name: str) -> Any | None: """ return self._registry.instances.get(target_registry_name) + async def list_target_catalog_async(self) -> TargetCatalogResponse: + """ + List all available target types from the target class registry. + + Returns every constructible target with its derived constructor + parameters and declarative auth facts (which auth modes it supports and + its api-key env var). Deciding which entries to surface to a user is a + presentation concern owned by the caller (e.g. the frontend), not this + service. + + Returns: + TargetCatalogResponse containing all available target classes. + """ + items: list[TargetCatalogEntry] = [] + for metadata in self._registry.get_all_registered_class_metadata(): + target_cls = self._registry.get_class(metadata.class_name) + items.append( + TargetCatalogEntry( + target_type=metadata.class_name, + parameters=[p for p in metadata.parameters if p.is_string_coercible], + supported_auth_modes=list(target_cls.supported_auth_modes), + api_key_env_var=target_cls.get_api_key_environment_variable(), + description=metadata.class_description or None, + ) + ) + return TargetCatalogResponse(items=items) + async def create_target_async(self, *, request: CreateTargetRequest) -> TargetInstance: """ Create a new target instance from API request. - Instantiates the target with the given type and params, - then registers it in the registry under its registry name. + Class discovery is owned by the ``TargetRegistry``. Targets whose build + contract references other registry instances (e.g. ``RoundRobinTarget``'s + ``targets``) are constructed via ``registry.create_instance`` so the + resolver turns registry names into live objects; all other targets carry + their base configuration (``endpoint`` / ``model_name`` / ``api_key``) + through ``**kwargs``, which is not part of the registry's derived + parameter contract, so they are constructed directly from the registry + class. Endpoint trust and Entra token minting are owned by the target + classes themselves. This service only enforces the request-level auth + contract: for ``entra`` it confirms the target supports it and omits the + api_key so the target validates its own endpoint and mints the token; for + ``api_key`` it confirms a key is available. Args: request: The create target request with type, params, and auth_mode. @@ -255,187 +178,79 @@ async def create_target_async(self, *, request: CreateTargetRequest) -> TargetIn TargetInstance with the new target's details. Raises: - ValueError: if any of the following occur: - - Target type in request is not found in the class registry; - - Entra ID auth is requested but the target type does not support it; - - Entra ID auth is requested for an OpenAI target or AzureMLChatTarget - but the endpoint is not valid (not managed by correct hosts); - - If auth_mode='api_key' is set for a target but no key is supplied; - - For RoundRobinTarget: if target_registry_names are missing, any name - is not found, or inner targets fail compatibility checks. + ValueError: If the target type is not registered, Entra auth is + requested but unsupported by the target type, or api_key auth is + requested but no key is available. Construction errors (unknown + params, incompatible inner targets, unrecognized Entra endpoints) + are raised by the registry / target classes. """ - target_class = self._get_target_class(target_type=request.type) + if request.type not in self._registry: + raise ValueError( + f"Target type '{request.type}' not found. Available types: {self._registry.get_class_names()}" + ) - # RoundRobinTarget needs special handling: the user passes registry names - # of existing targets, and we resolve them to live objects. - if request.type == "RoundRobinTarget": - target_obj = self._create_round_robin_target(params=dict(request.params)) - else: - # Copy params so we can modify values (eg api_key) without changing request.params. - params: dict[str, Any] = dict(request.params) + target_cls = self._registry.get_class(request.type) + params: dict[str, Any] = dict(request.params) - if request.auth_mode == "entra": - params = self._apply_entra_auth(target_class=target_class, target_type=request.type, params=params) - else: - self._validate_api_key_auth(target_class=target_class, params=params) + if request.auth_mode == "entra": + if "entra" not in target_cls.supported_auth_modes: + raise ValueError( + f"Target type '{request.type}' does not support Entra ID authentication. " + "Supported types are OpenAI-family targets and AzureMLChatTarget." + ) + # Omit any api_key so the target validates its own endpoint and mints the token. + params.pop("api_key", None) + else: + self._validate_api_key_present(target_cls=target_cls, params=params) - target_obj = target_class(**params) + if self._has_reference_params(target_type=request.type): + # e.g. RoundRobinTarget: `targets` is a list of registry names the + # resolver turns into live target objects. + target_obj = self._registry.create_instance(request.type, **params) + else: + target_obj = target_cls(**params) self._registry.instances.register(target_obj) target_registry_name = target_obj.get_identifier().unique_name return self._build_instance_from_object(target_registry_name=target_registry_name, target_obj=target_obj) - def _create_round_robin_target(self, *, params: dict[str, Any]) -> RoundRobinTarget: - """ - Resolve registry names to target objects and create a RoundRobinTarget. - - Targets resolving to the same ``ComponentIdentifier.hash`` are deduplicated - before construction (mirroring ``TargetInitializer._auto_group_targets``) - so duplicate registry aliases for the same underlying endpoint do not - produce a rotation that hits one target twice. If fewer than 2 distinct - targets remain after dedup, a ``ValueError`` is raised. - - The RoundRobinTarget constructor validates all compatibility requirements - (same class, same configuration, same behavioral params, ≥2 targets). - - Args: - params: Must contain ``target_registry_names`` (list of registry name - strings). May contain ``weights`` (list of positive ints) of the - same length as ``target_registry_names``; weights for deduped - entries are dropped along with their target. - - Returns: - A new RoundRobinTarget wrapping the resolved (deduped) targets. - - Raises: - ValueError: If fewer than 2 names are supplied, a name is not found - in the registry, weights length does not match, dedup leaves - fewer than 2 distinct targets, or the RoundRobinTarget - constructor rejects the combination. + def _has_reference_params(self, *, target_type: str) -> bool: """ - registry_names: list[str] = params.get("target_registry_names", []) - if len(registry_names) < 2: - raise ValueError("RoundRobinTarget requires at least 2 target_registry_names in params.") - - raw_weights: list[int] | None = params.get("weights") or None - if raw_weights is not None and len(raw_weights) != len(registry_names): - raise ValueError( - f"weights length ({len(raw_weights)}) must match target_registry_names length ({len(registry_names)})." - ) - - # Deduplicate by ComponentIdentifier hash: two registry entries that - # resolve to the same identifier (same endpoint, model, api_version, etc.) - # would just hit the same target twice in the rotation. This mirrors the - # dedup in TargetInitializer._auto_group_targets so user-driven and - # auto-grouped flows behave the same. - seen_hashes: set[str | None] = set() - resolved_targets: list[PromptTarget] = [] - resolved_weights: list[int] = [] - duplicates: list[str] = [] - for idx, name in enumerate(registry_names): - target_obj = self._registry.instances.get(name) - if target_obj is None: - raise ValueError(f"Target '{name}' not found in the registry.") - target_hash = target_obj.get_identifier().hash - if target_hash in seen_hashes: - duplicates.append(name) - logger.debug(f"Skipping duplicate target '{name}' (hash {target_hash}) in RoundRobinTarget creation") - continue - seen_hashes.add(target_hash) - resolved_targets.append(target_obj) - if raw_weights is not None: - resolved_weights.append(raw_weights[idx]) - - if len(resolved_targets) < 2: - raise ValueError( - f"RoundRobinTarget requires at least 2 distinct targets, but the provided names " - f"resolved to {len(resolved_targets)} unique target(s) after deduplication. " - f"Duplicate names skipped: {duplicates}. Please select targets with different " - f"endpoints or configurations." - ) - - weights = resolved_weights if raw_weights is not None else None - - # The constructor validates same-class, same-config, behavioral consistency, etc. - return RoundRobinTarget(targets=resolved_targets, weights=weights) - - @staticmethod - def _apply_entra_auth(*, target_class: type, target_type: str, params: dict[str, Any]) -> dict[str, Any]: - """ - Replace ``api_key`` in ``params`` with an Entra ID token provider for - the given target class. + Return True if the target type's build contract references other registry + instances (so construction must go through the resolver). Args: - target_class (type): The target class being instantiated - target_type (str): The user-facing target type name - params (dict[str, Any]): The target constructor parameters from the request + target_type (str): The registered target class name. Returns: - dict[str, Any]: A new params dict with ``api_key`` replaced by an async - token-provider callable suitable for the target class. - - Raises: - ValueError: If the target type does not support Entra ID, if an - OpenAI target is given a non-Azure endpoint, or if an - AzureMLChatTarget is given a non-AML endpoint. + bool: True if any derived parameter is a registry reference. """ - new_params = dict(params) - if "api_key" in new_params: - logger.debug("Discarding 'api_key' from params because auth_mode='entra'.") - new_params.pop("api_key", None) - - if issubclass(target_class, OpenAITarget): - endpoint = new_params.get("endpoint") - if not isinstance(endpoint, str) or not endpoint: - raise ValueError("Entra ID authentication requires an 'endpoint' in params.") - if not _is_azure_openai_endpoint(endpoint): - raise ValueError( - "Entra ID authentication requires an Azure endpoint " - f"(*.openai.azure.com or *.ai.azure.com). Got: {endpoint}" - ) - new_params["api_key"] = get_azure_openai_auth(endpoint) - return new_params - - if issubclass(target_class, AzureMLChatTarget): - endpoint = new_params.get("endpoint") - if not isinstance(endpoint, str) or not endpoint: - raise ValueError("Entra ID authentication requires an 'endpoint' in params.") - if not _is_azure_ml_endpoint(endpoint): - raise ValueError( - "Entra ID authentication for AzureMLChatTarget requires an AML endpoint " - f"(*.inference.ml.azure.com). Got: {endpoint}" - ) - new_params["api_key"] = get_azure_async_token_provider(TargetService._AZURE_ML_SCOPE) - return new_params - - raise ValueError( - f"Target type '{target_type}' does not support Entra ID authentication. " - "Supported types are OpenAI-family targets and AzureMLChatTarget." - ) + metadata = self._registry.get_registered_class_metadata(target_type) + if metadata is None: + return False + return any(param.reference is not None for param in metadata.parameters) @staticmethod - def _validate_api_key_auth(*, target_class: type, params: dict[str, Any]) -> None: + def _validate_api_key_present(*, target_cls: type[PromptTarget], params: dict[str, Any]) -> None: """ Enforce that ``auth_mode='api_key'`` actually has a usable key. - Targets that do not authenticate via an api_key (e.g. ``TextTarget``) - are skipped since they have no env var and the underlying - constructor does not take any ``api_key`` arguments. + Reads the target class's declarative api-key env var + (``get_api_key_environment_variable``). Targets that do not authenticate + via an api_key (e.g. ``TextTarget``) declare no env var and are skipped. Args: - target_class (type): The target class being instantiated. + target_cls (type[PromptTarget]): The target class being instantiated. params (dict[str, Any]): The constructor parameters from the request. Raises: - ValueError: If no API key is provided in params or in the relevant - environment variable for a target class that authenticates via - an API key. + ValueError: If the target authenticates via an API key but none was + provided in params or the relevant environment variable. """ - env_var = _resolve_api_key_env_var(target_class) + env_var = target_cls.get_api_key_environment_variable() if env_var is None: return - if params.get("api_key"): return if os.environ.get(env_var): diff --git a/pyrit/models/catalog/target.py b/pyrit/models/catalog/target.py index b9251a0d4d..b13b907810 100644 --- a/pyrit/models/catalog/target.py +++ b/pyrit/models/catalog/target.py @@ -21,6 +21,8 @@ from pydantic import BaseModel, Field +from pyrit.models.identifiers.target_identifier import TargetIdentifier + class TargetCapabilitiesInfo(BaseModel): """ @@ -73,3 +75,10 @@ class TargetInstance(BaseModel): None, description="Inner targets for composite targets like RoundRobinTarget" ) identifier_hash: str | None = Field(None, description="ComponentIdentifier content hash for duplicate detection") + identifier: TargetIdentifier | None = Field( + None, + description=( + "The target's typed, lossless TargetIdentifier projection (class, promoted " + "params, and inner targets). Additive to the flattened presentation fields above." + ), + ) diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index 0bb58fa2d1..49b25b9ea6 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -3,11 +3,15 @@ import logging from collections.abc import Awaitable, Callable -from typing import Any, cast +from typing import Any, ClassVar, cast from httpx import HTTPStatusError -from pyrit.auth import ensure_async_token_provider +from pyrit.auth import ( + ensure_async_token_provider, + get_azure_async_token_provider, + is_azure_ml_endpoint, +) from pyrit.common import default_values, net_utility from pyrit.exceptions import ( EmptyResponseException, @@ -21,7 +25,7 @@ Message, construct_response_from_request, ) -from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.prompt_target import AuthMode, PromptTarget from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p @@ -44,6 +48,13 @@ class AzureMLChatTarget(PromptTarget): endpoint_uri_environment_variable: str = "AZURE_ML_MANAGED_ENDPOINT" api_key_environment_variable: str = "AZURE_ML_KEY" + # AML managed online endpoints can authenticate with a Microsoft Entra ID + # token scoped to AML, so this target supports both auth modes. + supported_auth_modes: ClassVar[tuple[AuthMode, ...]] = ("api_key", "entra") + + # Entra ID token scope for Azure Machine Learning managed online endpoints. + _AZURE_ML_SCOPE: ClassVar[str] = "https://ml.azure.com/.default" + _DEFAULT_CONFIGURATION: TargetConfiguration = TargetConfiguration( capabilities=TargetCapabilities( supports_multi_message_pieces=True, @@ -53,6 +64,16 @@ class AzureMLChatTarget(PromptTarget): ) ) + @classmethod + def get_api_key_environment_variable(cls) -> str | None: + """ + Return the api-key env var (``AZURE_ML_KEY``) for this target. + + Returns: + str | None: The api-key env var name. + """ + return cls.api_key_environment_variable + def __init__( self, *, @@ -164,6 +185,11 @@ def _initialize_vars( The API key for accessing the Azure ML endpoint, or a callable which returns a bearer token, or None to fall back to the ``AZURE_ML_KEY`` env variable. + + Raises: + ValueError: If no api_key is supplied (via parameter or environment + variable) and the endpoint is not a recognized Azure ML managed + online endpoint for which Entra ID authentication can be used. """ self._endpoint = default_values.get_required_value( env_var_name=self.endpoint_uri_environment_variable, passed_value=endpoint @@ -176,10 +202,28 @@ def _initialize_vars( self._api_key = "" return - self._api_key_provider = None - self._api_key = default_values.get_required_value( + api_key_value = default_values.get_non_required_value( env_var_name=self.api_key_environment_variable, passed_value=api_key ) + if api_key_value: + self._api_key_provider = None + self._api_key = api_key_value + return + + # No key supplied: fall back to Microsoft Entra ID, but only for a + # recognized AML managed online endpoint so a bearer token is never + # minted for an arbitrary host. + if is_azure_ml_endpoint(self._endpoint): + normalized = ensure_async_token_provider(get_azure_async_token_provider(self._AZURE_ML_SCOPE)) + self._api_key_provider = cast("Callable[[], Awaitable[str]]", normalized) + self._api_key = "" + return + + raise ValueError( + f"Environment variable {self.api_key_environment_variable} is required unless the endpoint is a " + "recognized Azure ML managed online endpoint (*.inference.ml.azure.com), for which Entra ID " + "authentication is used automatically. Pass an api_key or a token provider callable instead." + ) @limit_requests_per_minute async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 65000a6139..511b803a7a 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -3,7 +3,7 @@ import abc import logging -from typing import Any, final +from typing import Any, ClassVar, Literal, final from pyrit.common.deprecation import print_deprecation_message from pyrit.memory import CentralMemory, MemoryInterface @@ -18,6 +18,11 @@ logger = logging.getLogger(__name__) +# Authentication modes a target can expose to the create-target catalog / API. +# ``api_key`` passes a key (from params or the target's env var); ``entra`` omits +# the key so the target mints a Microsoft Entra ID token for its own endpoint. +AuthMode = Literal["api_key", "entra"] + class PromptTarget(Identifiable): """ @@ -46,6 +51,31 @@ class PromptTarget(Identifiable): # constructor parameter, which takes precedence over the class-level value. _DEFAULT_CONFIGURATION: TargetConfiguration = TargetConfiguration(capabilities=TargetCapabilities()) + # Declarative auth facts consumed by the create-target service and catalog + # (kept off ``TargetCapabilities`` / the identifier — auth is a construction + # /credential axis, not a message-handling capability or an identity input). + # + # ``supported_auth_modes`` lists the auth modes the create-target API accepts + # for this type. Base default is api-key only; families that can mint an Entra + # ID token for their own endpoint (e.g. OpenAI, Azure ML) override this to add + # ``"entra"``. ``get_api_key_environment_variable`` names the env var that + # supplies the api_key, or None for targets that do not authenticate with one. + supported_auth_modes: ClassVar[tuple[AuthMode, ...]] = ("api_key",) + + @classmethod + def get_api_key_environment_variable(cls) -> str | None: + """ + Return the name of the environment variable that supplies this target's + API key, or None if the target does not authenticate with an API key. + + Used by the create-target service/catalog to decide whether an api_key is + required (and which env var to hint) without constructing the target. + + Returns: + str | None: The api-key env var name, or None if the target uses none. + """ + return None + def __init_subclass__(cls, **kwargs: object) -> None: """ Validate that subclasses follow the keyword-only ``__init__`` contract. diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index bcd9b48162..cd8b6fe45b 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -6,7 +6,7 @@ import re from abc import abstractmethod from collections.abc import Awaitable, Callable -from typing import Any +from typing import Any, ClassVar from urllib.parse import urlparse from openai import ( @@ -22,14 +22,14 @@ AuthenticationError, ) -from pyrit.auth import ensure_async_token_provider, get_azure_openai_auth +from pyrit.auth import ensure_async_token_provider, get_azure_openai_auth, is_azure_openai_endpoint from pyrit.common import default_values from pyrit.exceptions.exception_classes import ( RateLimitException, handle_bad_request_exception, ) from pyrit.models import Message, MessagePiece -from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.prompt_target import AuthMode, PromptTarget from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.openai.openai_error_handling import ( @@ -57,12 +57,36 @@ class OpenAITarget(PromptTarget): capabilities=TargetCapabilities(supports_multi_message_pieces=True) ) + # OpenAI-family targets can mint an Entra ID token for a recognized Azure + # endpoint (see ``is_azure_openai_endpoint``), so they support both modes. + supported_auth_modes: ClassVar[tuple[AuthMode, ...]] = ("api_key", "entra") + model_name_environment_variable: str endpoint_environment_variable: str api_key_environment_variable: str _async_client: AsyncOpenAI | None = None + @classmethod + def get_api_key_environment_variable(cls) -> str | None: + """ + Return the api-key env var for this concrete OpenAI target subclass. + + The env var name is subclass-specific (e.g. ``OPENAI_CHAT_KEY``) and is + assigned in ``_set_openai_env_configuration_vars``. Resolve it without a + full construction by binding that setter to a bare instance. + + Returns: + str | None: The api-key env var name, or None if it cannot be resolved. + """ + try: + instance = cls.__new__(cls) + instance._set_openai_env_configuration_vars() + except Exception: + return None + env_var = getattr(instance, "api_key_environment_variable", None) + return env_var if isinstance(env_var, str) and env_var else None + @property def _client(self) -> AsyncOpenAI: """ @@ -96,8 +120,8 @@ def __init__( endpoint (str, Optional): The target URL for the OpenAI service. api_key (str | Callable[[], str | Awaitable[str]], Optional): The API key for accessing the OpenAI service, or a callable that returns an access token (sync or async). - For Azure endpoints, if no API key is provided (via parameter or environment variable), - Entra ID authentication is used automatically. + For recognized Azure OpenAI / AI Foundry endpoints, if no API key is provided + (via parameter or environment variable), Entra ID authentication is used automatically. You can also explicitly pass a token provider from pyrit.auth (e.g., get_azure_openai_auth(endpoint) for async, or get_azure_token_provider(scope) for sync). Synchronous token providers are automatically wrapped to work with async clients. @@ -116,7 +140,8 @@ def __init__( this target instance. If None, uses the class-level defaults. Defaults to None. Raises: - ValueError: If no API key is provided and the endpoint is not an Azure endpoint. + ValueError: If no API key is provided (via parameter or environment variable) and the + endpoint is not a recognized Azure OpenAI / AI Foundry endpoint. """ self._headers: dict[str, str] = {} self._httpx_client_kwargs = httpx_client_kwargs or {} @@ -157,12 +182,13 @@ def __init__( ) if api_key_value: resolved_api_key = api_key_value - elif "azure" in endpoint_value.lower(): + elif is_azure_openai_endpoint(endpoint_value): resolved_api_key = get_azure_openai_auth(endpoint_value) else: raise ValueError( f"Environment variable {self.api_key_environment_variable} is required for non-Azure endpoints. " - "For Azure endpoints, Entra ID authentication is used automatically." + "For recognized Azure OpenAI / AI Foundry endpoints, Entra ID authentication is used " + "automatically." ) # Ensure api_key is async-compatible (wrap sync token providers if needed) diff --git a/pyrit/prompt_target/round_robin_target.py b/pyrit/prompt_target/round_robin_target.py index 3086601d72..a0fb8a8edf 100644 --- a/pyrit/prompt_target/round_robin_target.py +++ b/pyrit/prompt_target/round_robin_target.py @@ -5,7 +5,12 @@ import logging from typing import Any -from pyrit.models import TARGET_EVAL_PARAM_FALLBACKS, TARGET_EVAL_PARAMS, ComponentIdentifier, Message +from pyrit.models import ( + TARGET_EVAL_PARAM_FALLBACKS, + TARGET_EVAL_PARAMS, + ComponentIdentifier, + Message, +) from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.target_requirements import CHAT_TARGET_REQUIREMENTS @@ -65,11 +70,11 @@ def __init__( requests to the first target. Defaults to equal weight. Raises: - ValueError: If fewer than 2 targets are provided, targets are - different classes, a nested RoundRobinTarget is detected, - weights length doesn't match, weights contain non-positive - values, inner targets have different configurations, or - targets lack required capabilities. + ValueError: If fewer than 2 targets are provided, fewer than 2 distinct + target instances remain after deduplication, targets are different + classes, a nested RoundRobinTarget is detected, weights length doesn't + match, weights contain non-positive values, inner targets have different + configurations, or targets lack required capabilities. """ if len(targets) < 2: raise ValueError(f"RoundRobinTarget requires at least 2 targets, got {len(targets)}.") @@ -77,6 +82,27 @@ def __init__( if any(isinstance(t, RoundRobinTarget) for t in targets): raise ValueError("Nesting RoundRobinTarget inside another RoundRobinTarget is not supported.") + if weights is not None and len(weights) != len(targets): + raise ValueError(f"weights length ({len(weights)}) must match targets length ({len(targets)}).") + + # Deduplicate the same target *instance* referenced more than once (e.g. the + # build-from-registry-names path resolving two names to the same registered + # object): rotating over the literal same object twice is meaningless. We dedup + # by object identity, NOT by ComponentIdentifier hash, on purpose: the hash + # excludes credentials (api_key), so two genuinely distinct targets that share an + # endpoint/model but use different keys (e.g. round-robining across accounts to + # spread rate limits) must be preserved. Weights for dropped duplicates are + # dropped alongside their target. (Auto-grouping in TargetInitializer still dedups + # by hash upstream — that is a different, pool-discovery intent.) + targets, weights = self._deduplicate_targets(targets=targets, weights=weights) + + if len(targets) < 2: + raise ValueError( + "RoundRobinTarget requires at least 2 distinct targets, but the provided targets " + f"resolved to {len(targets)} unique target instance(s) after deduplication. The same " + "target instance was referenced more than once." + ) + first_type = type(targets[0]) mismatched = [(i, type(t).__name__) for i, t in enumerate(targets[1:], start=1) if type(t) is not first_type] if mismatched: @@ -86,8 +112,6 @@ def __init__( ) weights = weights or [1] * len(targets) - if len(weights) != len(targets): - raise ValueError(f"weights length ({len(weights)}) must match targets length ({len(targets)}).") if any(w <= 0 for w in weights): raise ValueError("All weights must be positive integers.") @@ -117,6 +141,43 @@ def __init__( self._counter: int = 0 + @staticmethod + def _deduplicate_targets( + *, targets: list[PromptTarget], weights: list[int] | None + ) -> tuple[list[PromptTarget], list[int] | None]: + """ + Drop the same target *instance* referenced more than once. + + Keeps the first occurrence of each distinct object (by identity); when + ``weights`` is provided, the weight of a dropped duplicate is dropped + alongside it. Deduplication is by object identity, not by + ``ComponentIdentifier.hash``: the hash excludes credentials, so + config-identical targets that differ only by api_key are intentionally + preserved. + + Args: + targets (list[PromptTarget]): The inner targets, possibly with duplicates. + weights (list[int] | None): Optional weights aligned with ``targets``. + + Returns: + tuple[list[PromptTarget], list[int] | None]: The deduplicated targets and + their aligned weights (or None if no weights were provided). + """ + seen_ids: set[int] = set() + unique_targets: list[PromptTarget] = [] + unique_weights: list[int] = [] + for idx, target in enumerate(targets): + target_id = id(target) + if target_id in seen_ids: + logger.debug("Skipping duplicate target instance in RoundRobinTarget.") + continue + seen_ids.add(target_id) + unique_targets.append(target) + if weights is not None: + unique_weights.append(weights[idx]) + + return unique_targets, (unique_weights if weights is not None else None) + def _next_target(self) -> PromptTarget: """ Return the next inner target in the weighted rotation. diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index e6299313ef..5d9fbd7b67 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -35,6 +35,7 @@ PreviewStep, ) from pyrit.backend.models.targets import ( + TargetCatalogResponse, TargetListResponse, ) from pyrit.backend.routes.labels import get_label_options @@ -795,6 +796,30 @@ def test_list_targets_returns_empty_list(self, client: TestClient) -> None: assert data["items"] == [] assert data["pagination"]["has_more"] is False + def test_list_target_catalog(self, client: TestClient) -> None: + """Test listing available target types from the target catalog.""" + with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_target_catalog_async = AsyncMock( + return_value=TargetCatalogResponse( + items=[ + { + "target_type": "OpenAIChatTarget", + "supported_auth_modes": ["api_key", "entra"], + "api_key_env_var": "OPENAI_CHAT_KEY", + } + ] + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/targets/catalog") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["items"][0]["target_type"] == "OpenAIChatTarget" + assert data["items"][0]["supported_auth_modes"] == ["api_key", "entra"] + def test_create_target_success(self, client: TestClient) -> None: """Test successful target creation.""" with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index f8926b5778..497bba756e 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -15,6 +15,7 @@ from pyrit.models import ComponentIdentifier from pyrit.prompt_target import PromptTarget from pyrit.registry import TargetRegistry +from unit.mocks import MockPromptTarget @pytest.fixture(autouse=True) @@ -228,6 +229,31 @@ def test_get_target_object_returns_object_from_registry(self) -> None: assert result is mock_target +class TestListTargetCatalog: + """Tests for TargetService.list_target_catalog_async method.""" + + async def test_catalog_returns_known_target_types(self) -> None: + """The catalog exposes constructible target classes from the registry.""" + service = TargetService() + + result = await service.list_target_catalog_async() + + target_types = [item.target_type for item in result.items] + assert "OpenAIChatTarget" in target_types + assert "AzureMLChatTarget" in target_types + + async def test_catalog_includes_declarative_auth_facts(self) -> None: + """Catalog entries surface the per-class auth facts the frontend needs.""" + service = TargetService() + + result = await service.list_target_catalog_async() + + openai_entry = next(item for item in result.items if item.target_type == "OpenAIChatTarget") + assert "api_key" in openai_entry.supported_auth_modes + assert "entra" in openai_entry.supported_auth_modes + assert openai_entry.api_key_env_var == "OPENAI_CHAT_KEY" + + class TestCreateTarget: """Tests for TargetService.create_target method.""" @@ -313,189 +339,136 @@ async def test_create_target_with_different_underlying_model(self, sqlite_instan class TestCreateTargetEntraAuth: - """Test that creating targets with Entra auth mode properly authenticates and handles edge cases.""" - - async def test_create_openai_target_with_entra_injects_token_provider(self, sqlite_instance) -> None: - """Entra auth path: api_key is replaced with the authentication callable""" + """Entra auth at the service boundary: the service only omits the api_key and + confirms the target type supports Entra. Endpoint trust + token minting are the + target's job and are covered in the target-level tests (see + tests/unit/prompt_target/target/).""" - with patch( - "pyrit.backend.services.target_service.get_azure_openai_auth", - return_value=_test_token_provider, - ) as mock_get_auth: - service = TargetService() - - request = CreateTargetRequest( - type="OpenAIChatTarget", - params={ - "endpoint": "https://test.openai.azure.com/", - "model_name": "gpt-4o", - }, - auth_mode="entra", - ) + async def test_create_openai_target_with_entra_omits_key_and_target_mints_token(self, sqlite_instance) -> None: + """Entra path: the service omits the api_key so the target mints its own token.""" - result = await service.create_target_async(request=request) - - mock_get_auth.assert_called_once_with("https://test.openai.azure.com/") - target_obj = service.get_target_object(target_registry_name=result.target_registry_name) - assert target_obj is not None - # OpenAI target preserves async callables verbatim through ensure_async_token_provider. - assert target_obj._api_key is _test_token_provider # type: ignore[attr-defined] + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("OPENAI_CHAT_KEY", None) + with patch( + "pyrit.prompt_target.openai.openai_target.get_azure_openai_auth", + return_value=_test_token_provider, + ) as mock_get_auth: + service = TargetService() + + request = CreateTargetRequest( + type="OpenAIChatTarget", + params={ + "endpoint": "https://test.openai.azure.com/", + "model_name": "gpt-4o", + }, + auth_mode="entra", + ) + + result = await service.create_target_async(request=request) + + mock_get_auth.assert_called_once_with("https://test.openai.azure.com/") + target_obj = service.get_target_object(target_registry_name=result.target_registry_name) + assert target_obj is not None + # OpenAI target preserves async callables verbatim through ensure_async_token_provider. + assert target_obj._api_key is _test_token_provider # type: ignore[attr-defined] async def test_create_openai_target_with_entra_drops_user_api_key(self, sqlite_instance) -> None: """Any api_key supplied alongside auth_mode='entra' must be discarded.""" - with patch( - "pyrit.backend.services.target_service.get_azure_openai_auth", - return_value=_test_token_provider, - ): - service = TargetService() - - request = CreateTargetRequest( - type="OpenAIChatTarget", - params={ - "endpoint": "https://test.openai.azure.com/", - "model_name": "gpt-4o", - "api_key": "should-be-ignored", - }, - auth_mode="entra", - ) - - result = await service.create_target_async(request=request) - - target_obj = service.get_target_object(target_registry_name=result.target_registry_name) - assert target_obj is not None - assert target_obj._api_key is _test_token_provider # type: ignore[attr-defined] - # The literal "should-be-ignored" string must never appear. - assert target_obj._api_key != "should-be-ignored" # type: ignore[attr-defined] + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("OPENAI_CHAT_KEY", None) + with patch( + "pyrit.prompt_target.openai.openai_target.get_azure_openai_auth", + return_value=_test_token_provider, + ): + service = TargetService() + + request = CreateTargetRequest( + type="OpenAIChatTarget", + params={ + "endpoint": "https://test.openai.azure.com/", + "model_name": "gpt-4o", + "api_key": "should-be-ignored", + }, + auth_mode="entra", + ) + + result = await service.create_target_async(request=request) + + target_obj = service.get_target_object(target_registry_name=result.target_registry_name) + assert target_obj is not None + assert target_obj._api_key is _test_token_provider # type: ignore[attr-defined] + # The literal "should-be-ignored" string must never appear. + assert target_obj._api_key != "should-be-ignored" # type: ignore[attr-defined] async def test_create_openai_target_with_entra_does_not_mutate_request_params(self, sqlite_instance) -> None: """The CreateTargetRequest.params object must remain unchanged after creation.""" - with patch( - "pyrit.backend.services.target_service.get_azure_openai_auth", - return_value=_test_token_provider, - ): - service = TargetService() - - original_params = { - "endpoint": "https://test.openai.azure.com/", - "model_name": "gpt-4o", - "api_key": "original-key", - } - request = CreateTargetRequest( - type="OpenAIChatTarget", - params=dict(original_params), - auth_mode="entra", - ) - - await service.create_target_async(request=request) - - # The caller's request.params must be unchanged after the call. - assert request.params == original_params - - async def test_create_openai_target_with_entra_non_azure_endpoint_raises(self, sqlite_instance) -> None: - """Entra ID requires a known Azure OpenAI / AI Foundry hostname suffix.""" - service = TargetService() - - request = CreateTargetRequest( - type="OpenAIChatTarget", - params={"endpoint": "https://api.openai.com/"}, - auth_mode="entra", - ) - - with pytest.raises(ValueError, match="Azure endpoint"): - await service.create_target_async(request=request) - - async def test_create_openai_target_with_entra_substring_lookalike_endpoint_raises(self, sqlite_instance) -> None: - """Substring 'azure' in the hostname must not be enough to pass Entra validation.""" - service = TargetService() + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("OPENAI_CHAT_KEY", None) + with patch( + "pyrit.prompt_target.openai.openai_target.get_azure_openai_auth", + return_value=_test_token_provider, + ): + service = TargetService() - request = CreateTargetRequest( - type="OpenAIChatTarget", - # Hostname contains 'azure' but does NOT end with an approved suffix. - params={"endpoint": "https://evil-azure.example.com/"}, - auth_mode="entra", - ) + original_params = { + "endpoint": "https://test.openai.azure.com/", + "model_name": "gpt-4o", + "api_key": "original-key", + } + request = CreateTargetRequest( + type="OpenAIChatTarget", + params=dict(original_params), + auth_mode="entra", + ) - with pytest.raises(ValueError, match="Azure endpoint"): - await service.create_target_async(request=request) + await service.create_target_async(request=request) - async def test_create_openai_target_with_entra_missing_endpoint_raises(self, sqlite_instance) -> None: - """Entra ID for OpenAI must reject a missing endpoint with a clear error.""" - service = TargetService() + # The caller's request.params must be unchanged after the call. + assert request.params == original_params - request = CreateTargetRequest( - type="OpenAIChatTarget", - params={}, - auth_mode="entra", - ) + async def test_create_azureml_target_with_entra_omits_key_and_target_mints_token(self, sqlite_instance) -> None: + """AzureML Entra path: the service omits the key so the target mints the ML scope token.""" - with pytest.raises(ValueError, match="endpoint"): - await service.create_target_async(request=request) + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("AZURE_ML_KEY", None) + with patch( + "pyrit.prompt_target.azure_ml_chat_target.get_azure_async_token_provider", + return_value=_test_token_provider, + ) as mock_get_provider: + service = TargetService() + + request = CreateTargetRequest( + type="AzureMLChatTarget", + params={"endpoint": "https://my-aml.region.inference.ml.azure.com/score"}, + auth_mode="entra", + ) + + result = await service.create_target_async(request=request) + + mock_get_provider.assert_called_once_with("https://ml.azure.com/.default") + target_obj = service.get_target_object(target_registry_name=result.target_registry_name) + assert target_obj is not None + # AzureMLChatTarget stores the provider on _api_key_provider; static _api_key is cleared. + assert target_obj._api_key_provider is _test_token_provider # type: ignore[attr-defined] + assert target_obj._api_key == "" # type: ignore[attr-defined] - async def test_create_azureml_target_with_entra_injects_token_provider(self, sqlite_instance) -> None: - """AzureML Entra path: api_key is replaced with the ML scope token provider.""" + async def test_create_openai_target_with_entra_non_azure_endpoint_raises(self, sqlite_instance) -> None: + """The target (not the service) rejects an unrecognized endpoint under Entra.""" - with patch( - "pyrit.backend.services.target_service.get_azure_async_token_provider", - return_value=_test_token_provider, - ) as mock_get_provider: + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("OPENAI_CHAT_KEY", None) service = TargetService() request = CreateTargetRequest( - type="AzureMLChatTarget", - params={"endpoint": "https://my-aml.region.inference.ml.azure.com/score"}, + type="OpenAIChatTarget", + params={"endpoint": "https://api.openai.com/", "model_name": "gpt-4o"}, auth_mode="entra", ) - result = await service.create_target_async(request=request) - - mock_get_provider.assert_called_once_with("https://ml.azure.com/.default") - target_obj = service.get_target_object(target_registry_name=result.target_registry_name) - assert target_obj is not None - # AzureMLChatTarget stores the provider on _api_key_provider; static _api_key is cleared. - assert target_obj._api_key_provider is _test_token_provider # type: ignore[attr-defined] - assert target_obj._api_key == "" # type: ignore[attr-defined] - - async def test_create_azureml_target_with_entra_non_aml_endpoint_raises(self, sqlite_instance) -> None: - """Entra ID for AzureMLChatTarget requires a known AML hostname suffix.""" - service = TargetService() - - request = CreateTargetRequest( - type="AzureMLChatTarget", - params={"endpoint": "https://example.com/score"}, - auth_mode="entra", - ) - - with pytest.raises(ValueError, match="AML endpoint"): - await service.create_target_async(request=request) - - async def test_create_azureml_target_with_entra_substring_lookalike_endpoint_raises(self, sqlite_instance) -> None: - """Substring 'inference.ml.azure.com' in the hostname must not be enough to pass AML validation.""" - service = TargetService() - - request = CreateTargetRequest( - type="AzureMLChatTarget", - # Hostname contains the AML suffix as a substring but does NOT end with it. - params={"endpoint": "https://evil-inference.ml.azure.com.attacker.com/score"}, - auth_mode="entra", - ) - - with pytest.raises(ValueError, match="AML endpoint"): - await service.create_target_async(request=request) - - async def test_create_azureml_target_with_entra_missing_endpoint_raises(self, sqlite_instance) -> None: - """Entra ID for AzureMLChatTarget must reject a missing endpoint with a clear error.""" - service = TargetService() - - request = CreateTargetRequest( - type="AzureMLChatTarget", - params={}, - auth_mode="entra", - ) - - with pytest.raises(ValueError, match="endpoint"): - await service.create_target_async(request=request) + with pytest.raises(ValueError, match="non-Azure endpoints"): + await service.create_target_async(request=request) async def test_create_target_entra_unsupported_type_raises(self, sqlite_instance) -> None: """Entra ID is only supported for OpenAI-family and AzureMLChatTarget.""" @@ -600,165 +573,60 @@ async def test_create_text_target_api_key_mode_skips_validation(self, sqlite_ins class TestCreateRoundRobinTarget: - """Tests for creating RoundRobinTarget via the service.""" + """Service-level tests for building RoundRobinTarget through the registry. + + The service passes ``targets`` (registry names) to ``registry.create_instance``; + the resolver turns the names into live target objects and RoundRobinTarget owns + its own construction validation (dedup, class/config consistency). Those rules + are covered in tests/unit/prompt_target/test_round_robin_target.py — here we only + exercise the service wiring. + """ async def test_create_round_robin_target_resolves_registry_names(self, sqlite_instance) -> None: """RoundRobinTarget creation resolves registry names to live target objects.""" service = TargetService() - # Register two mock targets in the registry to serve as inner targets. - # We mock the RoundRobinTarget constructor because it does deep validation - # (same class, multi-turn, editable history) that requires real compatible - # targets. The service's job is to resolve registry names and pass them - # through — the constructor validation is tested in RoundRobinTarget's own tests. - mock_a = MagicMock(spec=PromptTarget) - mock_a.get_identifier.return_value = _mock_target_identifier( - class_name="OpenAIChatTarget", endpoint="https://a.openai.azure.com", model_name="gpt-4o" - ) - mock_b = MagicMock(spec=PromptTarget) - mock_b.get_identifier.return_value = _mock_target_identifier( - class_name="OpenAIChatTarget", endpoint="https://b.openai.azure.com", model_name="gpt-4o" - ) - service._registry.instances.register(mock_a, name="target-a") - service._registry.instances.register(mock_b, name="target-b") - - # Patch RoundRobinTarget so the constructor returns a mock that behaves - # like a registered target (has get_identifier, capabilities, etc.) - mock_rr = MagicMock(spec=PromptTarget) - mock_rr.get_identifier.return_value = ComponentIdentifier( - class_name="RoundRobinTarget", - class_module="pyrit.prompt_target.round_robin_target", - params={"weights": [2, 1]}, - ) - mock_rr._targets = [mock_a, mock_b] - - with patch( - "pyrit.backend.services.target_service.RoundRobinTarget", - return_value=mock_rr, - ) as mock_rr_cls: - rr_request = CreateTargetRequest( - type="RoundRobinTarget", - params={ - "target_registry_names": ["target-a", "target-b"], - "weights": [2, 1], - }, - ) - - result = await service.create_target_async(request=rr_request) - - # Verify the constructor was called with the resolved targets and weights - mock_rr_cls.assert_called_once_with(targets=[mock_a, mock_b], weights=[2, 1]) - assert result.target_type == "RoundRobinTarget" - - async def test_create_round_robin_target_fewer_than_2_raises(self, sqlite_instance) -> None: - """RoundRobinTarget with fewer than 2 registry names raises ValueError.""" - service = TargetService() - - rr_request = CreateTargetRequest( - type="RoundRobinTarget", - params={"target_registry_names": ["only-one"]}, - ) - - with pytest.raises(ValueError, match="at least 2"): - await service.create_target_async(request=rr_request) - - async def test_create_round_robin_target_unknown_name_raises(self, sqlite_instance) -> None: - """RoundRobinTarget with a non-existent registry name raises ValueError.""" - service = TargetService() + target_a = MockPromptTarget() + target_b = MockPromptTarget() + service._registry.instances.register(target_a, name="target-a") + service._registry.instances.register(target_b, name="target-b") rr_request = CreateTargetRequest( type="RoundRobinTarget", - params={"target_registry_names": ["does-not-exist-a", "does-not-exist-b"]}, - ) - - with pytest.raises(ValueError, match="not found"): - await service.create_target_async(request=rr_request) - - async def test_create_round_robin_target_deduplicates_identical_targets(self, sqlite_instance) -> None: - """Targets that resolve to the same identifier hash are deduplicated, and - the corresponding weights are dropped alongside them.""" - service = TargetService() - - # mock_a and mock_a_alias share the same identifier params, so their - # ComponentIdentifier.hash is identical — they should dedupe to one entry. - identifier_a = _mock_target_identifier( - class_name="OpenAIChatTarget", endpoint="https://a.openai.azure.com", model_name="gpt-4o" - ) - mock_a = MagicMock(spec=PromptTarget) - mock_a.get_identifier.return_value = identifier_a - mock_a_alias = MagicMock(spec=PromptTarget) - mock_a_alias.get_identifier.return_value = identifier_a - - mock_b = MagicMock(spec=PromptTarget) - mock_b.get_identifier.return_value = _mock_target_identifier( - class_name="OpenAIChatTarget", endpoint="https://b.openai.azure.com", model_name="gpt-4o" + params={"targets": ["target-a", "target-b"], "weights": [2, 1]}, ) - service._registry.instances.register(mock_a, name="target-a") - service._registry.instances.register(mock_a_alias, name="target-a-alias") - service._registry.instances.register(mock_b, name="target-b") - - mock_rr = MagicMock(spec=PromptTarget) - mock_rr.get_identifier.return_value = ComponentIdentifier( - class_name="RoundRobinTarget", - class_module="pyrit.prompt_target.round_robin_target", - params={"weights": [3, 1]}, - ) - mock_rr._targets = [mock_a, mock_b] - - with patch( - "pyrit.backend.services.target_service.RoundRobinTarget", - return_value=mock_rr, - ) as mock_rr_cls: - rr_request = CreateTargetRequest( - type="RoundRobinTarget", - params={ - "target_registry_names": ["target-a", "target-a-alias", "target-b"], - "weights": [3, 2, 1], - }, - ) - - await service.create_target_async(request=rr_request) + result = await service.create_target_async(request=rr_request) - # The duplicate alias and its weight (2) should be dropped. - mock_rr_cls.assert_called_once_with(targets=[mock_a, mock_b], weights=[3, 1]) + assert result.target_type == "RoundRobinTarget" + target_obj = service.get_target_object(target_registry_name=result.target_registry_name) + assert target_obj._targets == [target_a, target_b] + assert target_obj._weights == [2, 1] - async def test_create_round_robin_target_all_duplicates_raises(self, sqlite_instance) -> None: - """If dedup leaves fewer than 2 distinct targets, raise a clear error.""" + async def test_create_round_robin_target_fewer_than_2_raises(self, sqlite_instance) -> None: + """A single inner target bubbles up RoundRobinTarget's own validation error.""" service = TargetService() - identifier = _mock_target_identifier( - class_name="OpenAIChatTarget", endpoint="https://a.openai.azure.com", model_name="gpt-4o" - ) - mock_a = MagicMock(spec=PromptTarget) - mock_a.get_identifier.return_value = identifier - mock_a_alias = MagicMock(spec=PromptTarget) - mock_a_alias.get_identifier.return_value = identifier - - service._registry.instances.register(mock_a, name="target-a") - service._registry.instances.register(mock_a_alias, name="target-a-alias") + service._registry.instances.register(MockPromptTarget(), name="only-one") rr_request = CreateTargetRequest( type="RoundRobinTarget", - params={"target_registry_names": ["target-a", "target-a-alias"]}, + params={"targets": ["only-one"]}, ) - with pytest.raises(ValueError, match="at least 2 distinct targets"): + with pytest.raises(ValueError, match="at least 2 targets"): await service.create_target_async(request=rr_request) - async def test_create_round_robin_target_weights_length_mismatch_raises(self, sqlite_instance) -> None: - """Mismatched weights length raises before any registry lookups.""" + async def test_create_round_robin_target_unknown_name_raises(self, sqlite_instance) -> None: + """A non-existent registry name is rejected by the resolver.""" service = TargetService() rr_request = CreateTargetRequest( type="RoundRobinTarget", - params={ - "target_registry_names": ["a", "b", "c"], - "weights": [1, 2], - }, + params={"targets": ["does-not-exist-a", "does-not-exist-b"]}, ) - with pytest.raises(ValueError, match="weights length"): + with pytest.raises(ValueError, match="not found"): await service.create_target_async(request=rr_request) diff --git a/tests/unit/prompt_target/target/test_azure_ml_chat_target.py b/tests/unit/prompt_target/target/test_azure_ml_chat_target.py index a219a8fab5..1b48070264 100644 --- a/tests/unit/prompt_target/target/test_azure_ml_chat_target.py +++ b/tests/unit/prompt_target/target/test_azure_ml_chat_target.py @@ -54,6 +54,35 @@ def test_initialization_with_no_api_raises(): AzureMLChatTarget(api_key="xxxxx") +def test_no_key_recognized_aml_endpoint_auto_mints_entra(patch_central_database): + """With no key and a recognized *.inference.ml.azure.com endpoint, the target + auto-mints an Entra token provider for the AML scope.""" + + async def _provider() -> str: + return "aml-entra-token" + + with ( + patch.dict(os.environ, {AzureMLChatTarget.api_key_environment_variable: ""}), + patch( + "pyrit.prompt_target.azure_ml_chat_target.get_azure_async_token_provider", + return_value=_provider, + ) as mock_provider, + ): + target = AzureMLChatTarget(endpoint="https://my-aml.region.inference.ml.azure.com/score") + + mock_provider.assert_called_once_with(AzureMLChatTarget._AZURE_ML_SCOPE) + assert target._api_key_provider is _provider + assert target._api_key == "" + + +def test_no_key_non_aml_endpoint_raises(patch_central_database): + """With no key and an endpoint that is not a recognized AML host, the target + refuses to mint a bearer token.""" + with patch.dict(os.environ, {AzureMLChatTarget.api_key_environment_variable: ""}): + with pytest.raises(ValueError, match="recognized Azure ML"): + AzureMLChatTarget(endpoint="https://example.com/score") + + async def test_complete_chat_async(aml_online_chat: AzureMLChatTarget): messages = [ Message(message_pieces=[MessagePiece(role="user", conversation_id="123", original_value="user content")]), diff --git a/tests/unit/prompt_target/target/test_openai_chat_target.py b/tests/unit/prompt_target/target/test_openai_chat_target.py index 452decb0ac..dbf1018a1b 100644 --- a/tests/unit/prompt_target/target/test_openai_chat_target.py +++ b/tests/unit/prompt_target/target/test_openai_chat_target.py @@ -729,6 +729,44 @@ def test_set_auth_with_api_key(patch_central_database): assert target._api_key == "test_api_key_456" +def test_no_key_recognized_azure_endpoint_auto_mints_entra(patch_central_database): + """With no key and a recognized Azure OpenAI endpoint, the target auto-mints + an Entra token provider for that endpoint.""" + + async def _provider() -> str: + return "aoai-entra-token" + + with ( + patch.dict(os.environ, {}, clear=True), + patch( + "pyrit.prompt_target.openai.openai_target.get_azure_openai_auth", + return_value=_provider, + ) as mock_get_auth, + ): + target = OpenAIChatTarget( + model_name="gpt-4", + endpoint="https://test.openai.azure.com/", + ) + + mock_get_auth.assert_called_once_with("https://test.openai.azure.com/") + assert target._api_key is _provider + + +def test_no_key_non_azure_endpoint_raises(patch_central_database): + """With no key and a non-Azure endpoint, the target refuses to mint a token.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="non-Azure endpoints"): + OpenAIChatTarget(model_name="gpt-4", endpoint="https://api.openai.com/") + + +def test_no_key_substring_lookalike_endpoint_raises(patch_central_database): + """A hostname merely containing 'azure' (but not a recognized suffix) must not + trigger auto-Entra minting (loose->strict hardening).""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="non-Azure endpoints"): + OpenAIChatTarget(model_name="gpt-4", endpoint="https://evil-azure.example.com/") + + def test_url_validation_no_warning_for_custom_endpoint(caplog, patch_central_database): """Test that URL validation doesn't warn for custom endpoint paths.""" with patch.dict(os.environ, {}, clear=True), caplog.at_level(logging.WARNING): From 9b97764acad2516ac227effbe79d83929fa84a45 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 2 Jul 2026 17:58:58 -0700 Subject: [PATCH 2/8] MAINT: Phase 4.5 cleanup - slim TargetInstance, drop api-key env machinery, simplify RoundRobin Follow-up cleanup on the Phase 4.5 target-registry migration: - Remove get_api_key_environment_variable classmethod (base + OpenAI/AzureML overrides), the catalog api_key_env_var field, backend _validate_api_key_present pre-check, and the frontend type field. Target constructors already raise on missing credentials, so the create-time pre-check was redundant. The instance attribute api_key_environment_variable (the real auth mechanism) is retained. - Slim TargetInstance to embed a required serialized TargetIdentifier and TargetCapabilities; migrate mappers, backend, frontend, and tests to read through identifier.* instead of flattened scalars. - Simplify RoundRobinTarget: drop the _deduplicate_targets staticmethod and the second length check. Now a single length check plus a raise on duplicate target instances, compared by object identity (the identifier hash excludes api_key, so hash comparison would falsely collapse distinct different-key targets). Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- frontend/src/App.test.tsx | 7 +- frontend/src/App.tsx | 15 +- .../components/Chat/ChatInputArea.test.tsx | 73 ++++----- .../src/components/Chat/ChatWindow.test.tsx | 35 ++--- frontend/src/components/Chat/ChatWindow.tsx | 7 +- .../src/components/Chat/TargetBadge.test.tsx | 20 ++- frontend/src/components/Chat/TargetBadge.tsx | 40 ++--- .../Config/CreateTargetDialog.test.tsx | 93 ++++++------ .../components/Config/CreateTargetDialog.tsx | 29 ++-- .../components/Config/TargetConfig.test.tsx | 17 ++- .../components/Config/TargetTable.test.tsx | 29 ++-- .../src/components/Config/TargetTable.tsx | 35 +++-- frontend/src/components/Home/Home.test.tsx | 5 +- frontend/src/components/Home/Home.tsx | 7 +- frontend/src/test-utils/targetFixtures.ts | 44 ++++++ frontend/src/types/index.ts | 31 ++-- frontend/src/utils/targetIdentity.ts | 51 +++++++ pyrit/backend/mappers/target_mappers.py | 111 +++++--------- pyrit/backend/models/targets.py | 9 +- pyrit/backend/services/target_service.py | 36 ----- pyrit/cli/_output.py | 9 +- pyrit/models/catalog/__init__.py | 2 - pyrit/models/catalog/target.py | 72 ++++----- pyrit/models/target_capabilities.py | 49 +++++- pyrit/prompt_target/azure_ml_chat_target.py | 10 -- pyrit/prompt_target/common/prompt_target.py | 17 +-- pyrit/prompt_target/openai/openai_target.py | 20 --- pyrit/prompt_target/round_robin_target.py | 73 ++------- tests/unit/backend/test_api_routes.py | 37 ++--- tests/unit/backend/test_mappers.py | 81 +++++----- tests/unit/backend/test_target_service.py | 140 ++++-------------- tests/unit/cli/test_api_client.py | 17 +-- tests/unit/cli/test_output.py | 25 ++-- tests/unit/cli/test_pyrit_scan.py | 5 +- .../prompt_target/test_round_robin_target.py | 7 + 35 files changed, 583 insertions(+), 675 deletions(-) create mode 100644 frontend/src/test-utils/targetFixtures.ts create mode 100644 frontend/src/utils/targetIdentity.ts diff --git a/frontend/src/App.test.tsx b/frontend/src/App.test.tsx index 5088924cb2..7663291a4d 100644 --- a/frontend/src/App.test.tsx +++ b/frontend/src/App.test.tsx @@ -167,6 +167,7 @@ jest.mock("./components/Chat/ChatWindow", () => { }); jest.mock("./components/Config/TargetConfig", () => { + const { makeTarget } = jest.requireActual("@/test-utils/targetFixtures") as typeof import("@/test-utils/targetFixtures"); const MockTargetConfig = ({ activeTarget, onSetActiveTarget, @@ -181,12 +182,10 @@ jest.mock("./components/Config/TargetConfig", () => {