From 136b57eb12993a7698c888e310a741517a8d33ca Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Jul 2026 13:53:20 -0700 Subject: [PATCH 1/6] allow parameters on dataset configuration when loading from memory --- .../backend/services/scenario_run_service.py | 32 +++-- pyrit/cli/_cli_args.py | 46 +++++++ pyrit/cli/pyrit_scan.py | 10 ++ pyrit/cli/pyrit_shell.py | 2 + pyrit/models/catalog/scenario.py | 7 ++ pyrit/models/seeds/seed_dataset.py | 2 +- pyrit/scenario/core/__init__.py | 2 + pyrit/scenario/core/dataset_configuration.py | 108 +++++++++++++++- .../unit/backend/test_scenario_run_service.py | 52 ++++++++ tests/unit/cli/test_dataset_filter_help.py | 18 +++ tests/unit/cli/test_pyrit_scan.py | 28 +++++ tests/unit/cli/test_pyrit_shell.py | 2 + .../core/test_dataset_configuration.py | 116 ++++++++++++++++++ 13 files changed, 411 insertions(+), 14 deletions(-) create mode 100644 tests/unit/cli/test_dataset_filter_help.py diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 7236ce6ddb..7208b5e686 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -23,7 +23,7 @@ ) from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry from pyrit.scenario import Scenario -from pyrit.scenario.core import DatasetAttackConfiguration +from pyrit.scenario.core import DatasetAttackConfiguration, build_dataset_filters if TYPE_CHECKING: from pyrit.prompt_target import PromptTarget @@ -294,15 +294,21 @@ def _build_init_kwargs( if request.labels: init_kwargs["memory_labels"] = request.labels + dataset_filters = build_dataset_filters(parameters=request.dataset_parameters) + # Resolve strategies and dataset config from a temporary instance of the # scenario. The downstream _initialize_scenario_async builds its own # instance (so scenario_result_id can be passed), so this is a cheap # throwaway used only for introspection. Introspection is required - # whenever the caller wants to override strategies, dataset names, or - # the sample cap, because each of those needs the scenario's own - # strategy enum or dataset-config subclass to be resolved correctly. + # whenever the caller wants to override strategies, dataset names, the + # sample cap, or dataset filters, because each of those needs the + # scenario's own strategy enum or dataset-config subclass to be resolved + # correctly. needs_introspection = ( - bool(request.strategies) or bool(request.dataset_names) or request.max_dataset_size is not None + bool(request.strategies) + or bool(request.dataset_names) + or request.max_dataset_size is not None + or bool(dataset_filters) ) if not needs_introspection: return init_kwargs @@ -329,7 +335,7 @@ def _build_init_kwargs( ) from None init_kwargs["scenario_strategies"] = strategy_enums - if request.dataset_names or request.max_dataset_size is not None: + if request.dataset_names or request.max_dataset_size is not None or dataset_filters: default_config = introspection_instance._default_dataset_config if request.dataset_names: @@ -340,6 +346,7 @@ def _build_init_kwargs( init_kwargs["dataset_config"] = default_config_class( dataset_names=request.dataset_names, max_dataset_size=request.max_dataset_size, + filters=dataset_filters or None, ) except TypeError as exc: # The subclass __init__ takes extra required kwargs we cannot @@ -349,7 +356,7 @@ def _build_init_kwargs( # define a no-extra-required-args constructor or surface the # incompatibility through their own initialize_async validation. logger.warning( - "Cannot construct %s(dataset_names=..., max_dataset_size=...) (%s). " + "Cannot construct %s(dataset_names=..., max_dataset_size=..., filters=...) (%s). " "Falling back to a generic DatasetAttackConfiguration; scenario-specific " "dataset-config behavior may be lost.", default_config_class.__name__, @@ -358,12 +365,17 @@ def _build_init_kwargs( init_kwargs["dataset_config"] = DatasetAttackConfiguration( dataset_names=request.dataset_names, max_dataset_size=request.max_dataset_size, + filters=dataset_filters or None, ) - elif request.max_dataset_size is not None: + else: # Reuse the scenario's default dataset config (preserves subtype + # the scenario's own default dataset names) and override only the - # sample cap. Safe because the introspection instance is throwaway. - default_config.max_dataset_size = request.max_dataset_size + # sample cap and/or filters. Safe because the introspection instance + # is throwaway. + if request.max_dataset_size is not None: + default_config.max_dataset_size = request.max_dataset_size + if dataset_filters: + default_config.update_filters(filters=dataset_filters) init_kwargs["dataset_config"] = default_config return init_kwargs diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index d7877daa88..148f814f16 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -249,9 +249,44 @@ def parse_memory_labels(json_string: str) -> dict[str, str]: return labels +def parse_dataset_parameter(arg: str) -> tuple[str, str]: + """ + Parse a single ``KEY=VALUE`` dataset-parameter token from the CLI. + + Note: The ``arg`` parameter is positional (not keyword-only) so it can be used directly + as an argparse ``type=`` callable and an ``_ArgSpec`` parser. This mirrors + ``_parse_initializer_arg`` and is an intentional exception to the keyword-only style rule + for argparse compatibility. + + Args: + arg (str): The raw ``KEY=VALUE`` token. + + Returns: + tuple[str, str]: The (key, value) pair. The value keeps its raw string form so the + server can coerce and validate it. + + Raises: + ValueError: If the token is not in ``KEY=VALUE`` form or the key is empty. Argparse + converts this into a clean CLI error; the shell catches it directly. + """ + if "=" not in arg: + raise ValueError(f"Dataset parameter must be in KEY=VALUE form, got: {arg!r}") + key, _, value = arg.partition("=") + key = key.strip() + if not key: + raise ValueError(f"Dataset parameter key cannot be empty in: {arg!r}") + return key, value + + # --------------------------------------------------------------------------- # Shared argument help text # --------------------------------------------------------------------------- + +# Dataset-filter keys advertised in --dataset-parameters help. Kept as a static list here (this +# module is on the fast --help path and stays import-light); a unit test asserts it matches the +# authoritative registry in pyrit.scenario.core.dataset_configuration. +_ADVERTISED_DATASET_FILTER_KEYS: tuple[str, ...] = ("harm_categories", "data_types") + ARG_HELP = { "config_file": CONFIG_FILE_HELP, "initializers": ( @@ -272,6 +307,10 @@ def parse_memory_labels(json_string: str) -> dict[str, str]: "Creates a new dataset config; fetches all items unless --max-dataset-size is also specified", "max_dataset_size": "Maximum number of items to use from the dataset (must be >= 1). " "Limits new datasets if --dataset-names provided, otherwise overrides scenario's default limit", + "dataset_parameters": "Dataset seed filters as KEY=VALUE tokens " + "(e.g., harm_categories=cyber data_types=text). Accepted keys: harm_categories, data_types. " + "Keys filter seeds before sizing. " + "List values may be comma-separated (e.g., harm_categories=cyber,violence)", "target": "Name of a registered target from the TargetRegistry to use as the objective target. " "Targets are registered by initializers (e.g., 'target' initializer). " "Use --list-targets to see available target names after initializers have run", @@ -406,6 +445,12 @@ class _ArgSpec: result_key="max_dataset_size", parser=lambda v: validate_integer(v, name="--max-dataset-size", min_value=1), ) +_DATASET_PARAMETERS_ARG = _ArgSpec( + flags=["--dataset-parameters"], + result_key="dataset_parameters", + multi_value=True, + parser=parse_dataset_parameter, +) _TARGET_ARG = _ArgSpec( flags=["--target"], result_key="target", @@ -419,6 +464,7 @@ class _ArgSpec: _MEMORY_LABELS_ARG, _DATASET_NAMES_ARG, _MAX_DATASET_SIZE_ARG, + _DATASET_PARAMETERS_ARG, _TARGET_ARG, ] diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index 12fd5c62a0..a99bfef0c9 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -24,6 +24,7 @@ _parse_initializer_arg, build_parameters_from_api, non_negative_int, + parse_dataset_parameter, positive_int, validate_log_level_argparse, ) @@ -257,6 +258,13 @@ def _build_base_parser(*, add_help: bool = True) -> ArgumentParser: type=positive_int, help=ARG_HELP["max_dataset_size"], ) + run_group.add_argument( + "--dataset-parameters", + type=parse_dataset_parameter, + nargs="+", + metavar="KEY=VALUE", + help=ARG_HELP["dataset_parameters"], + ) return parser @@ -629,6 +637,8 @@ def _build_run_request(*, parsed_args: Namespace, scenario_name: str) -> RunScen kwargs["dataset_names"] = parsed_args.dataset_names if parsed_args.max_dataset_size is not None: kwargs["max_dataset_size"] = parsed_args.max_dataset_size + if parsed_args.dataset_parameters: + kwargs["dataset_parameters"] = dict(parsed_args.dataset_parameters) if parsed_args.memory_labels: kwargs["labels"] = parse_memory_labels(json_string=parsed_args.memory_labels) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index a4ccee96af..44b9212375 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -412,6 +412,8 @@ def do_run(self, line: str) -> None: request_kwargs["dataset_names"] = args["dataset_names"] if args.get("max_dataset_size") is not None: request_kwargs["max_dataset_size"] = args["max_dataset_size"] + if args.get("dataset_parameters"): + request_kwargs["dataset_parameters"] = dict(args["dataset_parameters"]) if args.get("memory_labels"): request_kwargs["labels"] = args["memory_labels"] diff --git a/pyrit/models/catalog/scenario.py b/pyrit/models/catalog/scenario.py index 6268a6bf82..d567732723 100644 --- a/pyrit/models/catalog/scenario.py +++ b/pyrit/models/catalog/scenario.py @@ -51,6 +51,13 @@ class RunScenarioRequest(BaseModel): strategies: list[str] | None = Field(None, description="Strategy names to use (uses scenario default if omitted)") dataset_names: list[str] | None = Field(None, description="Dataset names to use (uses scenario default if omitted)") max_dataset_size: int | None = Field(None, ge=1, description="Maximum items per dataset") + dataset_parameters: dict[str, Any] | None = Field( + None, + description=( + "Dataset seed filters keyed by field, applied before sampling. " + "Accepted keys: harm_categories, data_types." + ), + ) max_concurrency: int = Field(10, ge=1, le=100, description="Maximum concurrent operations") max_retries: int = Field(0, ge=0, le=20, description="Maximum retry attempts on failure") labels: dict[str, str] | None = Field(None, description="Labels to attach to memory entries") diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index de71cf8972..a00a15c0e0 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -99,7 +99,7 @@ class SeedDataset(BaseModel): seed_type: SeedType | None = None # The actual prompts - seeds: list[SeedUnion] + seeds: Sequence[SeedUnion] @model_validator(mode="before") @classmethod diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index 3faa4a4e52..224c3d3eb0 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -15,6 +15,7 @@ DatasetConstraintError, DatasetSourceKind, ResolvedDataset, + build_dataset_filters, require_nonempty, ) from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario @@ -34,6 +35,7 @@ "INLINE_DATASET_NAME", "Parameter", "ResolvedDataset", + "build_dataset_filters", "require_nonempty", "Scenario", "ScenarioCompositeStrategy", diff --git a/pyrit/scenario/core/dataset_configuration.py b/pyrit/scenario/core/dataset_configuration.py index 020a5a6964..390ae25d19 100644 --- a/pyrit/scenario/core/dataset_configuration.py +++ b/pyrit/scenario/core/dataset_configuration.py @@ -31,7 +31,7 @@ from dataclasses import dataclass from enum import Enum from functools import cached_property -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from pyrit.common.deprecation import print_deprecation_message from pyrit.memory import CentralMemory @@ -226,6 +226,60 @@ def _validate(resolved: ResolvedDataset) -> None: return _validate +# Authoritative set of dataset seed filters exposed via ``--dataset-parameters``. Each entry +# is used verbatim as a ``MemoryInterface.get_seeds`` keyword argument, so a filter key IS the +# get_seeds kwarg. Every exposed filter must be a list-valued (Sequence) get_seeds parameter -- +# a guard test enforces this and keeps the CLI's advertised keys in sync. Adding a filterable +# field is a one-line change here. +DATASET_FILTERS: frozenset[str] = frozenset({"harm_categories", "data_types"}) + + +def _coerce_filter_values(value: Any) -> list[str]: + """ + Coerce a single filter value into a list of non-empty strings. + + Accepts a list/tuple as-is (stringified) or splits a comma-separated string. + + Args: + value (Any): The raw filter value (scalar, comma-separated string, or sequence). + + Returns: + list[str]: The cleaned list of values. + """ + if isinstance(value, (list, tuple)): + return [str(item).strip() for item in value if str(item).strip()] + return [part.strip() for part in str(value).split(",") if part.strip()] + + +def build_dataset_filters(*, parameters: dict[str, Any] | None) -> dict[str, Any]: + """ + Map a user-facing dataset-parameter bag onto ``MemoryInterface.get_seeds`` filters. + + Each key must be one of ``DATASET_FILTERS`` (the exposed ``get_seeds`` kwargs); its value is + coerced into a list (a single token or a comma-separated string). Any other key is rejected + loudly. + + Args: + parameters (dict[str, Any] | None): The raw dataset-parameter bag (e.g. from the CLI + ``--dataset-parameters`` flag), keyed by ``get_seeds`` kwarg name. + + Returns: + dict[str, Any]: The filters ready to pass as ``**filters`` to ``get_seeds``. + + Raises: + ValueError: If a key is not present in ``DATASET_FILTERS``. + """ + filters: dict[str, Any] = {} + for key, raw_value in (parameters or {}).items(): + if key not in DATASET_FILTERS: + raise ValueError(f"Unknown dataset parameter '{key}'. Allowed: {', '.join(sorted(DATASET_FILTERS))}.") + values = _coerce_filter_values(raw_value) + if values: + filters[key] = values + + return filters + + def restrict_dataset_names(allowed: set[str]) -> Callable[[ResolvedDataset], None]: """ Build a validator that requires every contributing dataset name to be in ``allowed``. @@ -293,6 +347,7 @@ def __init__( seed_groups: list[SeedGroup] | None = None, dataset_names: list[str] | None = None, max_dataset_size: int | None = None, + filters: dict[str, Any] | None = None, validators: Sequence[Callable[[ResolvedDataset], None]] | None = None, auto_fetch: bool = True, ) -> None: @@ -306,6 +361,9 @@ def __init__( dataset_names (list[str] | None): Names of datasets to load from memory. max_dataset_size (int | None): If set, randomly samples up to this many items from the resolved dataset (without replacement). + filters (dict[str, Any] | None): Filters passed to ``MemoryInterface.get_seeds`` + when resolving named datasets (e.g. ``{"harm_categories": ["cyber"]}``). + Applied before ``max_dataset_size`` sampling; ignored for inline seeds. validators (Sequence[Callable[[ResolvedDataset], None]] | None): Constraint callbacks run against the resolved dataset; each raises on violation. These are appended to the subclass's ``_default_validators``. @@ -331,6 +389,7 @@ def __init__( self._seed_groups = list(seed_groups) if seed_groups is not None else None self._dataset_names = list(dataset_names) if dataset_names is not None else None self.max_dataset_size = max_dataset_size + self._filters: dict[str, Any] = dict(filters or {}) self._validators: list[Callable[[ResolvedDataset], None]] = [ *self._default_validators(), *(list(validators) if validators else []), @@ -389,6 +448,28 @@ def source_kind(self) -> DatasetSourceKind: return DatasetSourceKind.INLINE return DatasetSourceKind.MEMORY + @property + def filters(self) -> dict[str, Any]: + """ + The ``get_seeds`` filters applied when resolving named datasets. + + Returns: + dict[str, Any]: A copy of the configured filters. + """ + return dict(self._filters) + + def update_filters(self, *, filters: dict[str, Any]) -> None: + """ + Merge additional ``get_seeds`` filters into this configuration (run-time override). + + Used when a run overrides dataset selection without rebuilding the configuration -- + the provided filters take precedence over any already configured with the same key. + + Args: + filters (dict[str, Any]): Filters to merge, keyed by ``get_seeds`` kwarg name. + """ + self._filters = {**self._filters, **filters} + # ========================================================================= # Resolution helpers # ========================================================================= @@ -426,7 +507,7 @@ async def _collect_seeds_for_dataset_async(self, *, dataset_name: str) -> list[S DatasetConstraintError: If the dataset yields no seeds even after auto-fetch, or if auto-fetch itself fails (the provider error is chained as the cause). """ - found = list(self._memory.get_seeds(dataset_name=dataset_name)) + found = list(self._memory.get_seeds(dataset_name=dataset_name, **self._filters)) if not found and self._auto_fetch: try: await self._fetch_dataset_async(dataset_name=dataset_name) @@ -434,8 +515,12 @@ async def _collect_seeds_for_dataset_async(self, *, dataset_name: str) -> list[S raise DatasetConstraintError( f"Dataset '{dataset_name}' could not be loaded: auto-fetch from the registered provider failed." ) from exc - found = list(self._memory.get_seeds(dataset_name=dataset_name)) + found = list(self._memory.get_seeds(dataset_name=dataset_name, **self._filters)) if not found: + if self._filters and self._memory.get_seeds(dataset_name=dataset_name): + raise DatasetConstraintError( + f"Dataset '{dataset_name}' has seeds, but none match the configured filters {self._filters}." + ) hint = ( "auto-fetch from the registered provider did not populate it" if self._auto_fetch @@ -867,6 +952,7 @@ def per_dataset( dataset_names: Sequence[str], max_dataset_size: int | None = None, auto_fetch: bool = True, + filters: dict[str, Any] | None = None, validators: Sequence[Callable[[ResolvedDataset], None]] | None = None, ) -> CompoundDatasetAttackConfiguration: """ @@ -879,6 +965,7 @@ def per_dataset( dataset_names (Sequence[str]): The dataset names; one child is built per name. max_dataset_size (int | None): Per-dataset cap applied to each child. auto_fetch (bool): Passed to each child (fetch missing datasets into memory). + filters (dict[str, Any] | None): ``get_seeds`` filters applied to each child. validators (Sequence[Callable[[ResolvedDataset], None]] | None): Applied to each child. Returns: @@ -895,6 +982,7 @@ def per_dataset( dataset_names=[name], max_dataset_size=max_dataset_size, auto_fetch=auto_fetch, + filters=filters, validators=validators, ) for name in dataset_names @@ -928,6 +1016,20 @@ def source_kind(self) -> DatasetSourceKind: return DatasetSourceKind.INLINE return DatasetSourceKind.MEMORY + def update_filters(self, *, filters: dict[str, Any]) -> None: + """ + Merge filters into the compound and propagate them to every child configuration. + + The children run the actual ``get_seeds`` queries, so run-time filter overrides must + reach each child to take effect. + + Args: + filters (dict[str, Any]): Filters to merge, keyed by ``get_seeds`` kwarg name. + """ + super().update_filters(filters=filters) + for child in self._configurations: + child.update_filters(filters=filters) + async def get_seed_attack_groups_async(self) -> list[SeedAttackGroup]: """ Concatenate every child's flat result, then validate and apply the global cap. diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 08ae14f8a5..ae4ed5f489 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -41,6 +41,7 @@ def _make_request( scenario_result_id: str | None = None, dataset_names: list[str] | None = None, max_dataset_size: int | None = None, + dataset_parameters: dict[str, Any] | None = None, ) -> RunScenarioRequest: """Create a RunScenarioRequest for testing.""" return RunScenarioRequest( @@ -51,6 +52,7 @@ def _make_request( scenario_result_id=scenario_result_id, dataset_names=dataset_names, max_dataset_size=max_dataset_size, + dataset_parameters=dataset_parameters, ) @@ -364,6 +366,56 @@ def __init__(self, *, required_extra: str, **kwargs: Any) -> None: for record in caplog.records ) + async def test_start_run_dataset_parameters_filters_new_config(self, mock_all_registries) -> None: + """``dataset_parameters`` with ``dataset_names`` builds a config carrying the filters.""" + + class _MarkerDatasetConfiguration(DatasetConfiguration): + pass + + scenario_instance = mock_all_registries["scenario_instance"] + scenario_instance._default_dataset_config = _MarkerDatasetConfiguration(dataset_names=["original"]) + + service = ScenarioRunService() + await service.start_run_async( + request=_make_request( + dataset_names=["custom"], + max_dataset_size=7, + dataset_parameters={"harm_categories": "cyber"}, + ) + ) + + init_call = scenario_instance.initialize_async.await_args + built_config = init_call.kwargs["dataset_config"] + assert built_config.dataset_names == ["custom"] + assert built_config.max_dataset_size == 7 + assert built_config.filters == {"harm_categories": ["cyber"]} + + async def test_start_run_dataset_parameters_updates_default_config(self, mock_all_registries) -> None: + """``dataset_parameters`` with no ``dataset_names`` merges filters into the default config.""" + default_config = DatasetAttackConfiguration(dataset_names=["original"]) + scenario_instance = mock_all_registries["scenario_instance"] + scenario_instance._default_dataset_config = default_config + + service = ScenarioRunService() + await service.start_run_async(request=_make_request(dataset_parameters={"harm_categories": "cyber"})) + + init_call = scenario_instance.initialize_async.await_args + built_config = init_call.kwargs["dataset_config"] + assert built_config is default_config + assert built_config.filters == {"harm_categories": ["cyber"]} + + async def test_start_run_dataset_parameters_rejects_max_dataset_size(self, mock_all_registries) -> None: + """``max_dataset_size`` in ``dataset_parameters`` is rejected as an unknown filter key.""" + service = ScenarioRunService() + with pytest.raises(ValueError, match="Unknown dataset parameter 'max_dataset_size'"): + await service.start_run_async(request=_make_request(dataset_parameters={"max_dataset_size": "9"})) + + async def test_start_run_dataset_parameters_unknown_key_raises(self, mock_all_registries) -> None: + """An unknown dataset parameter surfaces as a validation error.""" + service = ScenarioRunService() + with pytest.raises(ValueError, match="Unknown dataset parameter 'bogus'"): + await service.start_run_async(request=_make_request(dataset_parameters={"bogus": "x"})) + async def test_start_run_dataset_names_introspection_failure_raises(self, mock_memory) -> None: """Passing ``dataset_names`` against a non-no-arg-instantiable scenario fails fast.""" # Mirrors test_start_run_scenario_not_no_arg_instantiable_raises but for the dataset_names path. diff --git a/tests/unit/cli/test_dataset_filter_help.py b/tests/unit/cli/test_dataset_filter_help.py new file mode 100644 index 0000000000..72ce0f6b3c --- /dev/null +++ b/tests/unit/cli/test_dataset_filter_help.py @@ -0,0 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Guard that the CLI's advertised dataset-filter keys stay in sync with the resolver.""" + +from pyrit.cli._cli_args import _ADVERTISED_DATASET_FILTER_KEYS, ARG_HELP +from pyrit.scenario.core.dataset_configuration import DATASET_FILTERS + + +def test_cli_advertised_filters_match_dataset_configuration() -> None: + # The static CLI list must equal the exact filter kwargs the resolver accepts. + assert set(_ADVERTISED_DATASET_FILTER_KEYS) == set(DATASET_FILTERS) + + +def test_help_text_lists_every_advertised_key() -> None: + help_text = ARG_HELP["dataset_parameters"] + for key in _ADVERTISED_DATASET_FILTER_KEYS: + assert key in help_text diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py index 3386e5bdea..34dee384d4 100644 --- a/tests/unit/cli/test_pyrit_scan.py +++ b/tests/unit/cli/test_pyrit_scan.py @@ -83,6 +83,16 @@ def test_parse_args_with_memory_labels(self): args = pyrit_scan.parse_args(["test_scenario", "--memory-labels", '{"key":"value"}']) assert args.memory_labels == '{"key":"value"}' + def test_parse_args_with_dataset_parameters(self): + args = pyrit_scan.parse_args( + ["test_scenario", "--dataset-parameters", "harm_categories=cyber", "data_types=text"] + ) + assert args.dataset_parameters == [("harm_categories", "cyber"), ("data_types", "text")] + + def test_parse_args_dataset_parameter_without_equals_errors(self): + with pytest.raises(SystemExit): + pyrit_scan.parse_args(["test_scenario", "--dataset-parameters", "harm_categories"]) + def test_parse_args_complex_command(self): args = pyrit_scan.parse_args( [ @@ -542,6 +552,7 @@ def test_includes_initializer_args(self): max_retries=None, dataset_names=None, max_dataset_size=None, + dataset_parameters=None, memory_labels=None, ) request = pyrit_scan._build_run_request(parsed_args=parsed, scenario_name="s") @@ -557,6 +568,7 @@ def test_populates_optional_fields(self): max_retries=2, dataset_names=["d1"], max_dataset_size=10, + dataset_parameters=None, memory_labels='{"key":"value"}', ) request = pyrit_scan._build_run_request(parsed_args=parsed, scenario_name="s") @@ -567,6 +579,21 @@ def test_populates_optional_fields(self): assert request.max_dataset_size == 10 assert request.labels == {"key": "value"} + def test_populates_dataset_parameters(self): + parsed = Namespace( + target="t", + initializers=None, + scenario_strategies=None, + max_concurrency=None, + max_retries=None, + dataset_names=None, + max_dataset_size=None, + dataset_parameters=[("harm_categories", "cyber"), ("data_types", "text")], + memory_labels=None, + ) + request = pyrit_scan._build_run_request(parsed_args=parsed, scenario_name="s") + assert request.dataset_parameters == {"harm_categories": "cyber", "data_types": "text"} + def test_includes_scenario_declared_params(self): parsed = Namespace( target=None, @@ -576,6 +603,7 @@ def test_includes_scenario_declared_params(self): max_retries=None, dataset_names=None, max_dataset_size=None, + dataset_parameters=None, memory_labels=None, scenario__max_turns="7", ) diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index d3689292a6..887cbc79c6 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -525,6 +525,7 @@ def test_run_completed_path_with_results(self, shell, capsys): "memory_labels": {"k": "v"}, "dataset_names": ["d1"], "max_dataset_size": 5, + "dataset_parameters": [("harm_categories", "cyber"), ("data_types", "text")], }, ), patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock), @@ -542,6 +543,7 @@ def test_run_completed_path_with_results(self, shell, capsys): assert sent.labels == {"k": "v"} assert sent.dataset_names == ["d1"] assert sent.max_dataset_size == 5 + assert sent.dataset_parameters == {"harm_categories": "cyber", "data_types": "text"} def test_run_failed_status_calls_summary(self, shell): s, client = shell diff --git a/tests/unit/scenario/core/test_dataset_configuration.py b/tests/unit/scenario/core/test_dataset_configuration.py index a1e94f41b1..f0b9535a6e 100644 --- a/tests/unit/scenario/core/test_dataset_configuration.py +++ b/tests/unit/scenario/core/test_dataset_configuration.py @@ -9,6 +9,7 @@ from pyrit.models import SeedAttackGroup, SeedGroup, SeedObjective, SeedPrompt from pyrit.scenario.core.dataset_configuration import ( + DATASET_FILTERS, INLINE_DATASET_NAME, CompoundDatasetAttackConfiguration, DatasetAttackConfiguration, @@ -16,6 +17,7 @@ DatasetConstraintError, DatasetSourceKind, ResolvedDataset, + build_dataset_filters, forbid_inline_seeds, require_harm_categories, require_inline_seeds, @@ -595,3 +597,117 @@ async def test_inline_children_combine(self) -> None: ) groups = await config.get_seed_attack_groups_async() assert sorted(g.objective.value for g in groups) == ["a", "b"] + + +class TestBuildDatasetFilters: + """``build_dataset_filters`` maps the user-facing bag onto ``get_seeds`` filters.""" + + def test_none_returns_empty(self) -> None: + assert build_dataset_filters(parameters=None) == {} + + def test_empty_returns_empty(self) -> None: + assert build_dataset_filters(parameters={}) == {} + + def test_max_dataset_size_is_rejected(self) -> None: + with pytest.raises(ValueError, match="Unknown dataset parameter 'max_dataset_size'"): + build_dataset_filters(parameters={"max_dataset_size": "13"}) + + def test_single_value_maps_to_list(self) -> None: + assert build_dataset_filters(parameters={"harm_categories": "cyber"}) == {"harm_categories": ["cyber"]} + + def test_comma_splits_into_list(self) -> None: + assert build_dataset_filters(parameters={"harm_categories": "cyber,violence"}) == { + "harm_categories": ["cyber", "violence"] + } + + def test_multiple_filters_combine(self) -> None: + assert build_dataset_filters(parameters={"harm_categories": "cyber", "data_types": "text,image_path"}) == { + "harm_categories": ["cyber"], + "data_types": ["text", "image_path"], + } + + def test_unknown_key_raises(self) -> None: + with pytest.raises(ValueError, match="Unknown dataset parameter 'bogus'"): + build_dataset_filters(parameters={"bogus": "x"}) + + def test_unexposed_get_seeds_kwarg_is_rejected(self) -> None: + # ``authors`` is a real get_seeds kwarg but is intentionally NOT exposed as a filter. + with pytest.raises(ValueError, match="Unknown dataset parameter 'authors'"): + build_dataset_filters(parameters={"authors": "jones"}) + + +class TestDatasetFilterRegistry: + """The exposed filter set is intentional and stays in contract with ``get_seeds``.""" + + def test_exposed_filters_are_frozen(self) -> None: + # Adding/removing a filter must be a deliberate edit to this expected set. + assert {"harm_categories", "data_types"} == DATASET_FILTERS + + def test_every_filter_is_a_sequence_get_seeds_param(self) -> None: + # Each exposed key must be a real get_seeds parameter AND list-valued, since + # build_dataset_filters always coerces values into a list. + import typing + from collections.abc import Sequence + + from pyrit.memory.memory_interface import MemoryInterface + + hints = typing.get_type_hints(MemoryInterface.get_seeds) + + def _allows_sequence(annotation: object) -> bool: + for candidate in (annotation, *typing.get_args(annotation)): + origin = typing.get_origin(candidate) or candidate + if isinstance(origin, type) and origin is not str and issubclass(origin, Sequence): + return True + return False + + for name in DATASET_FILTERS: + assert name in hints, f"'{name}' is not a MemoryInterface.get_seeds parameter" + assert _allows_sequence(hints[name]), f"'{name}' must be a Sequence-typed get_seeds parameter" + + +class TestDatasetConfigurationFilters: + """Filters are threaded into ``get_seeds`` and applied before sampling.""" + + async def test_filters_passed_to_get_seeds(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.return_value = make_objectives("a", "b") + config = DatasetAttackConfiguration(dataset_names=["d1"], filters={"harm_categories": ["cyber"]}) + await config.get_seed_attack_groups_async() + mock_memory.get_seeds.assert_called_with(dataset_name="d1", harm_categories=["cyber"]) + + async def test_filter_removing_all_seeds_raises_specific_error(self, mock_memory: MagicMock) -> None: + def _get_seeds(*, dataset_name, **filters): + return [] if filters else make_objectives("a", "b") + + mock_memory.get_seeds.side_effect = _get_seeds + config = DatasetAttackConfiguration( + dataset_names=["d1"], filters={"harm_categories": ["missing"]}, auto_fetch=False + ) + with pytest.raises(DatasetConstraintError, match="none match the configured filters"): + await config.get_seed_attack_groups_async() + + async def test_update_filters_merges(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.return_value = make_objectives("a") + config = DatasetAttackConfiguration(dataset_names=["d1"], filters={"harm_categories": ["a"]}) + config.update_filters(filters={"authors": ["jones"]}) + await config.get_seed_attack_groups_async() + mock_memory.get_seeds.assert_called_with(dataset_name="d1", harm_categories=["a"], authors=["jones"]) + + def test_filters_property_returns_copy(self) -> None: + config = DatasetAttackConfiguration(dataset_names=["d1"], filters={"harm_categories": ["a"]}) + config.filters["authors"] = ["mutated"] + assert config.filters == {"harm_categories": ["a"]} + + async def test_per_dataset_threads_filters_to_children(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.return_value = make_objectives("a") + config = CompoundDatasetAttackConfiguration.per_dataset( + dataset_names=["d1"], filters={"harm_categories": ["cyber"]} + ) + await config.get_seed_attack_groups_async() + mock_memory.get_seeds.assert_called_with(dataset_name="d1", harm_categories=["cyber"]) + + async def test_compound_update_filters_propagates_to_children(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.return_value = make_objectives("a") + config = CompoundDatasetAttackConfiguration.per_dataset(dataset_names=["d1"]) + config.update_filters(filters={"harm_categories": ["cyber"]}) + await config.get_seed_attack_groups_async() + mock_memory.get_seeds.assert_called_with(dataset_name="d1", harm_categories=["cyber"]) From 8e6065f5bd5131591da4ee4251a39ded037ae632 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Jul 2026 13:57:31 -0700 Subject: [PATCH 2/6] fix --- param.diff | 1116 ++++++++++++++++++++ pyrit/cli/_cli_args.py | 2 +- tests/unit/cli/test_dataset_filter_help.py | 8 +- 3 files changed, 1118 insertions(+), 8 deletions(-) create mode 100644 param.diff diff --git a/param.diff b/param.diff new file mode 100644 index 0000000000..8f3a50534e --- /dev/null +++ b/param.diff @@ -0,0 +1,1116 @@ +commit e1285604a2357fddbd29f89e984450140884e15f +Author: Behnam Ousat +Date: Tue Jun 30 14:26:40 2026 -0700 + + add params to datasets + +diff --git a/.pyrit_conf_example b/.pyrit_conf_example +index 17e052d7d..d4966ed0a 100644 +--- a/.pyrit_conf_example ++++ b/.pyrit_conf_example +@@ -127,15 +127,19 @@ operation: op_trash_panda + + # Datasets + # -------- +-# List of seed dataset names to load into memory after initialization completes. +-# Each name must match a registered dataset (run `pyrit_scan --list-datasets` to +-# see available datasets). Loaded datasets are added to CentralMemory and are +-# available to scenarios and attacks. ++# List of seed datasets to load into memory after initialization completes. ++# Each entry is either a dataset name (run `pyrit_scan --list-datasets` to see ++# available datasets) or a dictionary with 'name' and optional 'args' to pass ++# constructor parameters to the dataset loader. Loaded datasets are added to ++# CentralMemory and are available to scenarios and attacks. + # + # Example: + # datasets: + # - airt_illegal + # - airt_malware ++# - name: harmbench ++# args: ++# category: chemical_biological + + # Max Concurrent Scenario Runs + # ---------------------------- +diff --git a/pyrit/backend/models/datasets.py b/pyrit/backend/models/datasets.py +index fbee58be8..367d03d0e 100644 +--- a/pyrit/backend/models/datasets.py ++++ b/pyrit/backend/models/datasets.py +@@ -9,14 +9,29 @@ Datasets are seed prompt/objective collections provided by + listing available datasets and loading them into memory. + """ + ++from typing import Any ++ + from pydantic import BaseModel, Field + + ++class DatasetParameterInfo(BaseModel): ++ """A single user-settable parameter exposed by a dataset loader.""" ++ ++ name: str = Field(..., description="Parameter name (the loader constructor argument)") ++ description: str = Field("", description="Human-readable description of the parameter") ++ required: bool = Field(False, description="Whether the parameter must be supplied") ++ default: Any | None = Field(None, description="Default value used when the parameter is omitted") ++ choices: list[Any] | None = Field(None, description="Allowed values for a constrained parameter, if any") ++ ++ + class DatasetInfo(BaseModel): + """Metadata about a single available dataset.""" + + name: str = Field(..., description="Dataset name (e.g., 'harmbench')") + loaded: bool = Field(False, description="Whether the dataset is already present in memory") ++ parameters: list[DatasetParameterInfo] = Field( ++ default_factory=list, description="User-settable parameters this dataset exposes" ++ ) + + + class DatasetListResponse(BaseModel): +@@ -29,6 +44,10 @@ class LoadDatasetRequest(BaseModel): + """Request to load one or more datasets into memory.""" + + dataset_names: list[str] = Field(..., description="Names of the datasets to load into memory") ++ dataset_parameters: dict[str, dict[str, Any]] | None = Field( ++ None, ++ description="Optional mapping of dataset name to constructor argument values", ++ ) + cache: bool = Field(True, description="Whether to cache fetched remote datasets to disk") + + +diff --git a/pyrit/backend/services/dataset_service.py b/pyrit/backend/services/dataset_service.py +index 045870746..5e2cdb278 100644 +--- a/pyrit/backend/services/dataset_service.py ++++ b/pyrit/backend/services/dataset_service.py +@@ -15,12 +15,16 @@ from functools import lru_cache + from pyrit.backend.models.datasets import ( + DatasetInfo, + DatasetListResponse, ++ DatasetParameterInfo, + LoadDatasetRequest, + LoadDatasetResponse, + LoadedDataset, + ) ++from pyrit.common.apply_defaults import REQUIRED_VALUE + from pyrit.datasets import SeedDatasetProvider + from pyrit.memory import CentralMemory ++from pyrit.models.parameter import Parameter ++from pyrit.registry.resolution import display_choices + + logger = logging.getLogger(__name__) + +@@ -35,16 +39,47 @@ class DatasetService: + List all available datasets and whether they are already in memory. + + Returns: +- DatasetListResponse: Available datasets with their loaded status. ++ DatasetListResponse: Available datasets with their loaded status and parameters. + """ +- available = await SeedDatasetProvider.get_all_dataset_names_async() +- + memory = CentralMemory.get_memory_instance() + loaded = set(memory.get_seed_dataset_names()) + +- items = [DatasetInfo(name=name, loaded=name in loaded) for name in available] ++ items: list[DatasetInfo] = [] ++ for class_name, provider_class in SeedDatasetProvider.get_all_providers().items(): ++ name = provider_class().dataset_name ++ parameters = SeedDatasetProvider.get_dataset_parameters(class_name=class_name) ++ items.append( ++ DatasetInfo( ++ name=name, ++ loaded=name in loaded, ++ parameters=[self._to_parameter_info(param=param) for param in parameters], ++ ) ++ ) ++ ++ items.sort(key=lambda item: item.name) + return DatasetListResponse(items=items) + ++ @staticmethod ++ def _to_parameter_info(*, param: Parameter) -> DatasetParameterInfo: ++ """ ++ Project a derived ``Parameter`` into its serializable API model. ++ ++ Args: ++ param (Parameter): The introspected loader parameter. ++ ++ Returns: ++ DatasetParameterInfo: The wire representation of the parameter. ++ """ ++ required = param.default is REQUIRED_VALUE ++ choices = display_choices(param.param_type) ++ return DatasetParameterInfo( ++ name=param.name, ++ description=param.description, ++ required=required, ++ default=None if required else param.default, ++ choices=list(choices) if choices is not None else None, ++ ) ++ + async def load_datasets_async(self, *, request: LoadDatasetRequest) -> LoadDatasetResponse: + """ + Fetch the requested datasets and add their seeds to memory. +@@ -60,6 +95,7 @@ class DatasetService: + """ + datasets = await SeedDatasetProvider.fetch_datasets_async( + dataset_names=request.dataset_names, ++ dataset_parameters=request.dataset_parameters, + cache=request.cache, + ) + +diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py +index fcc214392..92838854b 100644 +--- a/pyrit/cli/_cli_args.py ++++ b/pyrit/cli/_cli_args.py +@@ -275,6 +275,12 @@ ARG_HELP = { + "target": "Name of a registered target from the TargetRegistry to use as the objective target. " + "Targets are registered by initializers (e.g., 'target' initializer). " + "Use --list-targets to see available target names after initializers have run", ++ "load_dataset": ( ++ "Names of datasets to load into memory and exit. " ++ "Supports optional params with name:key=val syntax " ++ "(e.g., harmbench:category=chemical_biological). " ++ "Use comma-separated values for list parameters (e.g., name:key=a,b)" ++ ), + } + + +@@ -328,6 +334,59 @@ def _parse_initializer_arg(arg: str) -> str | dict[str, Any]: + return name + + ++def _parse_load_dataset_arg(arg: str) -> str | dict[str, Any]: ++ """ ++ Parse a ``--load-dataset`` CLI argument into a string or dict. ++ ++ Supports two formats: ++ - Simple name: "harmbench" → "harmbench" ++ - Name with params: "harmbench:category=chemical_biological" → ++ {"name": "harmbench", "args": {"category": "chemical_biological"}} ++ ++ Values are kept as strings so the server can coerce them to each loader's ++ declared parameter type. A comma-separated value becomes a list, for ++ parameters that declare a list type (e.g., "name:key=a,b" → {"key": ["a", "b"]}). ++ ++ For multiple params on one dataset, separate with semicolons: "name:key1=val1;key2=val2" ++ For multiple datasets with params, space-separate them: "airt_hate harmbench:category=chemical_biological" ++ ++ Args: ++ arg: The CLI argument string. ++ ++ Returns: ++ str | dict[str, Any]: A plain name string, or a dict with 'name' and 'args' keys. ++ ++ Raises: ++ ValueError: If the argument format is invalid. ++ """ ++ if ":" not in arg: ++ return arg ++ ++ name, params_str = arg.split(":", 1) ++ if not name: ++ raise ValueError(f"Invalid dataset argument '{arg}': missing name before ':'") ++ ++ args: dict[str, Any] = {} ++ for pair in params_str.split(";"): ++ pair = pair.strip() ++ if not pair: ++ continue ++ if "=" not in pair: ++ raise ValueError(f"Invalid dataset parameter '{pair}' in '{arg}': expected key=value format") ++ key, value = pair.split("=", 1) ++ key = key.strip() ++ if not key: ++ raise ValueError(f"Invalid dataset parameter in '{arg}': empty key") ++ if "," in value: ++ args[key] = [v.strip() for v in value.split(",")] ++ else: ++ args[key] = value.strip() ++ ++ if args: ++ return {"name": name, "args": args} ++ return name ++ ++ + # --------------------------------------------------------------------------- + # Shell argument specification + # --------------------------------------------------------------------------- +diff --git a/pyrit/cli/_output.py b/pyrit/cli/_output.py +index 8dd6f2fec..133f844c6 100644 +--- a/pyrit/cli/_output.py ++++ b/pyrit/cli/_output.py +@@ -219,6 +219,18 @@ def print_dataset_list(*, items: list[dict[str, Any]]) -> None: + status = "loaded" if loaded else "not loaded" + marker = "*" if loaded else " " + print(f" {marker} {name} ({status})") ++ params = ds.get("parameters") or [] ++ if params: ++ print(" Parameters:") ++ for p in params: ++ required_str = " [required]" if p.get("required") else "" ++ default_str = "" if p.get("required") else f" [default: {p.get('default')!r}]" ++ choices = p.get("choices") ++ choices_display = ", ".join(str(c) for c in choices) if isinstance(choices, list) else choices ++ choices_str = f" [choices: {choices_display}]" if choices_display else "" ++ description = p.get("description") or "" ++ desc_str = f": {description}" if description else "" ++ print(f" - {p.get('name', '?')}{required_str}{default_str}{choices_str}{desc_str}") + print("=" * 80) + loaded_count = sum(1 for ds in items if ds.get("loaded")) + print(f"\nTotal datasets: {len(items)} ({loaded_count} loaded)") +diff --git a/pyrit/cli/api_client.py b/pyrit/cli/api_client.py +index dc008e907..e455187ec 100644 +--- a/pyrit/cli/api_client.py ++++ b/pyrit/cli/api_client.py +@@ -184,7 +184,13 @@ class PyRITApiClient: + """ + return await self._get_json_async(path="/api/datasets") + +- async def load_datasets_async(self, *, dataset_names: list[str], cache: bool = True) -> dict[str, Any]: ++ async def load_datasets_async( ++ self, ++ *, ++ dataset_names: list[str], ++ dataset_parameters: dict[str, dict[str, Any]] | None = None, ++ cache: bool = True, ++ ) -> dict[str, Any]: + """ + Load one or more datasets into memory. + +@@ -194,6 +200,8 @@ class PyRITApiClient: + + Args: + dataset_names: Names of the datasets to load. ++ dataset_parameters: Optional mapping of dataset name to constructor ++ argument values. Datasets absent from the mapping use their defaults. + cache: Whether to cache fetched remote datasets to disk. + + Returns: +@@ -201,10 +209,14 @@ class PyRITApiClient: + """ + import httpx + ++ payload: dict[str, Any] = {"dataset_names": dataset_names, "cache": cache} ++ if dataset_parameters: ++ payload["dataset_parameters"] = dataset_parameters ++ + client = self._get_client() + resp = await client.post( + "/api/datasets/load", +- json={"dataset_names": dataset_names, "cache": cache}, ++ json=payload, + timeout=httpx.Timeout(connect=10.0, read=None, write=30.0, pool=10.0), + ) + self._raise_for_status(resp) +diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py +index d43748ae0..84718f45f 100644 +--- a/pyrit/cli/pyrit_scan.py ++++ b/pyrit/cli/pyrit_scan.py +@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Any, get_args, get_origin + from pyrit.cli._cli_args import ( + ARG_HELP, + _parse_initializer_arg, ++ _parse_load_dataset_arg, + build_parameters_from_api, + non_negative_int, + positive_int, +@@ -196,10 +197,10 @@ def _build_base_parser(*, add_help: bool = True) -> ArgumentParser: + ) + discovery_group.add_argument( + "--load-dataset", +- type=str, ++ type=_parse_load_dataset_arg, + nargs="+", +- metavar="NAME", +- help="Load one or more datasets into memory and exit", ++ metavar="NAME[:key=val]", ++ help=ARG_HELP["load_dataset"], + ) + discovery_group.add_argument( + "--add-initializer", +@@ -543,10 +544,24 @@ async def _handle_load_dataset_async(*, client: Any, parsed_args: Namespace) -> + Returns: + int: Exit code (``0`` on success, ``1`` on failure). + """ +- print(f"\nLoading datasets: {', '.join(parsed_args.load_dataset)} (this can take a few minutes)...") ++ dataset_names: list[str] = [] ++ dataset_parameters: dict[str, dict[str, Any]] = {} ++ for entry in parsed_args.load_dataset: ++ if isinstance(entry, dict): ++ name = entry["name"] ++ dataset_names.append(name) ++ if entry.get("args"): ++ dataset_parameters[name] = entry["args"] ++ else: ++ dataset_names.append(entry) ++ ++ print(f"\nLoading datasets: {', '.join(dataset_names)} (this can take a few minutes)...") + sys.stdout.flush() + try: +- result = await client.load_datasets_async(dataset_names=parsed_args.load_dataset) ++ result = await client.load_datasets_async( ++ dataset_names=dataset_names, ++ dataset_parameters=dataset_parameters or None, ++ ) + except Exception as exc: + print(f"Error loading datasets: {exc}") + return 1 +diff --git a/pyrit/datasets/seed_datasets/dataset_parameter.py b/pyrit/datasets/seed_datasets/dataset_parameter.py +new file mode 100644 +index 000000000..fa5bcee63 +--- /dev/null ++++ b/pyrit/datasets/seed_datasets/dataset_parameter.py +@@ -0,0 +1,40 @@ ++# Copyright (c) Microsoft Corporation. ++# Licensed under the MIT license. ++ ++"""Annotation marker for user-settable seed-dataset constructor parameters.""" ++ ++from __future__ import annotations ++ ++from dataclasses import dataclass ++from typing import Annotated, Any, get_args, get_origin ++ ++ ++@dataclass(frozen=True) ++class DatasetParameter: ++ """ ++ Mark a loader constructor parameter as a user-settable dataset parameter. ++ ++ Attach inside a parameter's ``Annotated[...]`` metadata to opt it in to ++ dataset discovery: ``SeedDatasetProvider`` introspects each loader and ++ surfaces only the parameters marked this way (see ++ ``SeedDatasetProvider.get_dataset_parameters``). ++ ++ Usage:: ++ ++ category: Annotated[str | None, DatasetParameter()] = None ++ """ ++ ++ ++def is_dataset_parameter(annotation: Any) -> bool: ++ """ ++ Return whether an annotation carries a ``DatasetParameter`` marker. ++ ++ Args: ++ annotation (Any): The annotation object read from a constructor parameter. ++ ++ Returns: ++ bool: True when the annotation carries a ``DatasetParameter`` marker. ++ """ ++ if get_origin(annotation) is not Annotated: ++ return False ++ return any(isinstance(meta, DatasetParameter) for meta in get_args(annotation)[1:]) +diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py +index acd54a56e..4eeb8a5f3 100644 +--- a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py ++++ b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py +@@ -1,10 +1,11 @@ + # Copyright (c) Microsoft Corporation. + # Licensed under the MIT license. + +-from typing import Literal ++from typing import Annotated, Literal + + from typing_extensions import override + ++from pyrit.datasets.seed_datasets.dataset_parameter import DatasetParameter + from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( + _RemoteDatasetLoader, + ) +@@ -35,6 +36,7 @@ class _HarmBenchDataset(_RemoteDatasetLoader): + "harmbench_behaviors_text_all.csv" + ), + source_type: Literal["public_url", "file"] = "public_url", ++ category: Annotated[str | None, DatasetParameter()] = None, + ) -> None: + """ + Initialize the HarmBench dataset loader. +@@ -42,9 +44,12 @@ class _HarmBenchDataset(_RemoteDatasetLoader): + Args: + source: URL to the HarmBench CSV file. Defaults to the official repository. + source_type: The type of source ('public_url' or 'file'). ++ category (str | None): Optional SemanticCategory to filter behaviors by. ++ Defaults to None, which keeps all categories. + """ + self.source = source + self.source_type: Literal["public_url", "file"] = source_type ++ self.category = category + + @property + @override +@@ -87,6 +92,10 @@ class _HarmBenchDataset(_RemoteDatasetLoader): + # Extract data + category = example["SemanticCategory"] + ++ # Apply optional category filter ++ if self.category is not None and category != self.category: ++ continue ++ + # Create SeedPrompt + seed_prompt = SeedObjective( + value=example["Behavior"], +@@ -109,5 +118,8 @@ class _HarmBenchDataset(_RemoteDatasetLoader): + ) + seeds.append(seed_prompt) + ++ if not seeds: ++ raise ValueError("SeedDataset cannot be empty. Check your filter criteria.") ++ + # Create and return SeedDataset + return SeedDataset(seeds=seeds, dataset_name=self.dataset_name) +diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py +index 8d23a5af0..fc888872b 100644 +--- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py ++++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py +@@ -12,6 +12,7 @@ from tqdm import tqdm + + from pyrit.common.deprecation import print_deprecation_message + from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetFilter, SeedDatasetLoadTime, SeedDatasetMetadata ++from pyrit.models.parameter import Parameter + from pyrit.models.seeds import SeedDataset + + logger = logging.getLogger(__name__) +@@ -35,6 +36,7 @@ class SeedDatasetProvider(ABC): + """ + + _registry: dict[str, type["SeedDatasetProvider"]] = {} ++ _parameters: dict[str, list[Parameter]] = {} + load_time: SeedDatasetLoadTime = SeedDatasetLoadTime.UNINITIALIZED + + def __init_subclass__(cls, **kwargs: Any) -> None: +@@ -63,8 +65,31 @@ class SeedDatasetProvider(ABC): + ) + if not inspect.isabstract(cls) and getattr(cls, "should_register", True): + SeedDatasetProvider._registry[cls.__name__] = cls ++ SeedDatasetProvider._parameters[cls.__name__] = cls._derive_dataset_parameters() + logger.debug(f"Registered dataset provider: {cls.__name__}") + ++ @classmethod ++ def _derive_dataset_parameters(cls) -> list[Parameter]: ++ """ ++ Derive the user-settable ``Parameter`` list from this loader's constructor. ++ ++ Only constructor parameters whose annotation is wrapped with ++ ``DatasetParameter`` are surfaced; framework plumbing arguments are left ++ out so callers see exactly the knobs a dataset exposes. ++ ++ Returns: ++ list[Parameter]: One ``Parameter`` per ``DatasetParameter``-marked argument. ++ """ ++ from pyrit.datasets.seed_datasets.dataset_parameter import is_dataset_parameter ++ from pyrit.registry.resolution import derive_parameters ++ ++ try: ++ sig = inspect.signature(cls.__init__) ++ except (ValueError, TypeError): ++ return [] ++ marked = {name for name, param in sig.parameters.items() if is_dataset_parameter(param.annotation)} ++ return [param for param in derive_parameters(cls=cls) if param.name in marked] ++ + @property + @abstractmethod + def dataset_name(self) -> str: +@@ -149,6 +174,20 @@ class SeedDatasetProvider(ABC): + """ + return cls._registry.copy() + ++ @classmethod ++ def get_dataset_parameters(cls, *, class_name: str) -> list[Parameter]: ++ """ ++ Get the introspected parameters for a registered provider class. ++ ++ Args: ++ class_name (str): The registered provider class name (e.g. ``"_HarmBenchDataset"``). ++ ++ Returns: ++ list[Parameter]: The provider's ``DatasetParameter``-marked parameters, ++ or an empty list when the class is unknown or exposes none. ++ """ ++ return list(cls._parameters.get(class_name, [])) ++ + @classmethod + async def get_all_dataset_names_async(cls, filters: SeedDatasetFilter | None = None) -> list[str]: + """ +@@ -280,6 +319,7 @@ class SeedDatasetProvider(ABC): + cls, + *, + dataset_names: list[str] | None = None, ++ dataset_parameters: dict[str, dict[str, Any]] | None = None, + cache: bool = True, + max_concurrency: int = 5, + ) -> list[SeedDataset]: +@@ -291,6 +331,10 @@ class SeedDatasetProvider(ABC): + Args: + dataset_names: Optional list of dataset names to fetch. If None, fetches all. + Names should match the dataset_name property of providers. ++ dataset_parameters: Optional mapping of dataset name to a flat dict of ++ constructor argument values. Values are coerced to the ++ loader's declared parameter types before the provider is ++ constructed. Datasets absent from the mapping use their defaults. + cache: Whether to cache the fetched datasets. Defaults to True. + This uses DB_DATA_PATH for caching remote datasets. + max_concurrency: Maximum number of datasets to fetch concurrently. Defaults to 5. +@@ -312,6 +356,10 @@ class SeedDatasetProvider(ABC): + ... dataset_names=["harmbench", "DarkBench"] + ... ) + """ ++ from pyrit.registry.resolution import resolve_constructor_args ++ ++ dataset_parameters = dataset_parameters or {} ++ + # Validate dataset names if specified + if dataset_names is not None: + available_names = await cls.get_all_dataset_names_async() +@@ -328,7 +376,13 @@ class SeedDatasetProvider(ABC): + Returns: + tuple[str, SeedDataset] | None: Tuple of provider name and dataset, or None if filtered. + """ +- provider = provider_class() ++ # Resolve and coerce any caller-supplied constructor parameters by dataset name. ++ raw_args = dataset_parameters.get(provider_class().dataset_name) ++ if raw_args: ++ resolved = resolve_constructor_args(cls=provider_class, raw_args=raw_args) ++ provider = provider_class(**resolved) ++ else: ++ provider = provider_class() + + # Apply dataset name filter if specified + if dataset_names is not None and provider.dataset_name not in dataset_names: +diff --git a/pyrit/registry/resolution.py b/pyrit/registry/resolution.py +index 99b5c61b5..a71f23d3f 100644 +--- a/pyrit/registry/resolution.py ++++ b/pyrit/registry/resolution.py +@@ -32,7 +32,7 @@ import inspect + import re + import types + from enum import Enum +-from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, Union, get_args, get_origin ++from typing import TYPE_CHECKING, Annotated, Any, Literal, Protocol, TypeAlias, Union, get_args, get_origin + + from pyrit.common.apply_defaults import REQUIRED_VALUE, _RequiredValueSentinel + from pyrit.models.parameter import ComponentType, Parameter, RegistryReference +@@ -148,6 +148,10 @@ def derive_parameters(*, cls: type, identifier_type: type[ComponentIdentifier] | + continue + + annotation = param.annotation ++ # Strip any ``Annotated[X, ...]`` marker (e.g. ``DatasetParameter``) so the ++ # contract carries the bare type ``X``. ++ if get_origin(annotation) is Annotated: ++ annotation = get_args(annotation)[0] + component_type = reference_overrides.get(name) + description = descriptions.get(name, "") + default = _default_for(param) +diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py +index dfe399605..c88c4bad9 100644 +--- a/pyrit/setup/configuration_loader.py ++++ b/pyrit/setup/configuration_loader.py +@@ -73,6 +73,20 @@ class ScenarioConfig: + args: dict[str, YamlValue] | None = None + + ++@dataclass ++class DatasetConfig: ++ """ ++ Configuration for a single seed dataset to load into memory. ++ ++ Attributes: ++ name: The dataset name (must match a registered ``SeedDatasetProvider``). ++ args: Optional map of constructor argument values for the dataset loader. ++ """ ++ ++ name: str ++ args: dict[str, YamlValue] | None = None ++ ++ + def _scenario_config_to_dict(config: ScenarioConfig) -> dict[str, Any]: + """ + Serialize a ``ScenarioConfig`` back to the YAML-style dict shape. +@@ -103,7 +117,9 @@ class ConfigurationLoader(YamlLoadable): + None means "use defaults", [] means "load nothing". + env_files: List of environment file paths to load. + None means "use defaults (.env, .env.local)", [] means "load nothing". +- datasets: List of seed dataset names to load into memory after initialization. ++ datasets: List of seed datasets to load into memory after initialization. ++ Each entry is a dataset name or a ``{"name": ..., "args": {...}}`` dict ++ whose args are passed to the dataset loader constructor. + silent: Whether to suppress initialization messages. + operator: Name for the current operator, e.g. a team or username. + operation: Name for the current operation. +@@ -127,6 +143,9 @@ class ConfigurationLoader(YamlLoadable): + datasets: + - airt_illegal + - airt_malware ++ - name: harmbench ++ args: ++ category: chemical_biological + + silent: false + +@@ -146,7 +165,7 @@ class ConfigurationLoader(YamlLoadable): + initialization_scripts: list[str] | None = None + env_files: list[str] | None = None + env_akv_ref: list[str] | None = None +- datasets: list[str] = field(default_factory=list) ++ datasets: list[str | dict[str, Any]] = field(default_factory=list) + silent: bool = False + operator: str | None = None + operation: str | None = None +@@ -160,6 +179,7 @@ class ConfigurationLoader(YamlLoadable): + """Validate and normalize the configuration after loading.""" + self._normalize_memory_db_type() + self._normalize_initializers() ++ self._normalize_datasets() + self._normalize_scenario() + self._normalize_server() + +@@ -221,6 +241,27 @@ class ConfigurationLoader(YamlLoadable): + raise ValueError(f"Initializer entry must be a string or dict, got: {type(entry).__name__}") + self._initializer_configs = normalized + ++ def _normalize_datasets(self) -> None: ++ """ ++ Normalize dataset entries to DatasetConfig objects. ++ ++ Accepts plain string names or ``{"name": ..., "args": {...}}`` dicts. ++ ++ Raises: ++ ValueError: If a dataset entry is missing a 'name' field or has an invalid type. ++ """ ++ normalized: list[DatasetConfig] = [] ++ for entry in self.datasets: ++ if isinstance(entry, str): ++ normalized.append(DatasetConfig(name=entry)) ++ elif isinstance(entry, dict): ++ if "name" not in entry: ++ raise ValueError(f"Dataset configuration must have a 'name' field. Got: {entry}") ++ normalized.append(DatasetConfig(name=entry["name"], args=entry.get("args"))) ++ else: ++ raise ValueError(f"Dataset entry must be a string or dict, got: {type(entry).__name__}") ++ self._dataset_configs = normalized ++ + def _normalize_scenario(self) -> None: + """ + Normalize the optional ``scenario`` block to a ``ScenarioConfig``. +@@ -597,13 +638,15 @@ class ConfigurationLoader(YamlLoadable): + Load the configured seed datasets into memory. + + Fetches each dataset named in the ``datasets`` block and adds its seeds +- to ``CentralMemory``. This runs after PyRIT initialization so that memory +- is available. No-op when no datasets are configured. ++ to ``CentralMemory``. Datasets declared as ``{"name": ..., "args": {...}}`` ++ have their args passed to the loader constructor. This runs after PyRIT ++ initialization so that memory is available. No-op when no datasets are ++ configured. + + Raises: + ValueError: If any configured dataset name does not exist. + """ +- if not self.datasets: ++ if not self._dataset_configs: + return + + import logging +@@ -612,9 +655,17 @@ class ConfigurationLoader(YamlLoadable): + from pyrit.memory import CentralMemory + + logger = logging.getLogger(__name__) +- logger.info("Loading %d dataset(s) from configuration...", len(self.datasets)) ++ logger.info("Loading %d dataset(s) from configuration...", len(self._dataset_configs)) + +- datasets = await SeedDatasetProvider.fetch_datasets_async(dataset_names=self.datasets) ++ dataset_names = [config.name for config in self._dataset_configs] ++ dataset_parameters = { ++ config.name: dict(config.args) for config in self._dataset_configs if config.args ++ } ++ ++ datasets = await SeedDatasetProvider.fetch_datasets_async( ++ dataset_names=dataset_names, ++ dataset_parameters=dataset_parameters or None, ++ ) + + memory = CentralMemory.get_memory_instance() + await memory.add_seed_datasets_to_memory_async(datasets=datasets, added_by="ConfigurationLoader") +diff --git a/tests/unit/backend/test_dataset_service.py b/tests/unit/backend/test_dataset_service.py +index 2a852db78..0c88656c6 100644 +--- a/tests/unit/backend/test_dataset_service.py ++++ b/tests/unit/backend/test_dataset_service.py +@@ -11,6 +11,7 @@ import pytest + + from pyrit.backend.models.datasets import LoadDatasetRequest + from pyrit.backend.services.dataset_service import DatasetService, get_dataset_service ++from pyrit.models.parameter import Parameter + from pyrit.models.seeds import SeedDataset + + +@@ -20,6 +21,17 @@ def _seed_dataset(*, name: str, count: int) -> SeedDataset: + return SeedDataset(dataset_name=name, seeds=seeds) + + ++def _provider_class(*, name: str) -> type: ++ """Build a minimal provider class whose instances report ``name``.""" ++ ++ class _FakeProvider: ++ @property ++ def dataset_name(self) -> str: ++ return name ++ ++ return _FakeProvider ++ ++ + class TestListDatasets: + """Tests for DatasetService.list_datasets_async.""" + +@@ -29,11 +41,22 @@ class TestListDatasets: + memory = MagicMock() + memory.get_seed_dataset_names.return_value = ["airt_hate"] + ++ providers = { ++ "_AirtHate": _provider_class(name="airt_hate"), ++ "_HarmBench": _provider_class(name="harmbench"), ++ } ++ parameters = { ++ "_HarmBench": [Parameter(name="category", description="Filter by category.", param_type=str)], ++ } ++ + with ( + patch( +- "pyrit.backend.services.dataset_service.SeedDatasetProvider.get_all_dataset_names_async", +- new_callable=AsyncMock, +- return_value=["airt_hate", "harmbench"], ++ "pyrit.backend.services.dataset_service.SeedDatasetProvider.get_all_providers", ++ return_value=providers, ++ ), ++ patch( ++ "pyrit.backend.services.dataset_service.SeedDatasetProvider.get_dataset_parameters", ++ side_effect=lambda *, class_name: parameters.get(class_name, []), + ), + patch( + "pyrit.backend.services.dataset_service.CentralMemory.get_memory_instance", +@@ -45,6 +68,11 @@ class TestListDatasets: + by_name = {item.name: item.loaded for item in result.items} + assert by_name == {"airt_hate": True, "harmbench": False} + ++ by_params = {item.name: item.parameters for item in result.items} ++ assert by_params["airt_hate"] == [] ++ assert [p.name for p in by_params["harmbench"]] == ["category"] ++ assert by_params["harmbench"][0].required is False ++ + async def test_list_datasets_empty(self): + service = DatasetService() + memory = MagicMock() +@@ -52,9 +80,8 @@ class TestListDatasets: + + with ( + patch( +- "pyrit.backend.services.dataset_service.SeedDatasetProvider.get_all_dataset_names_async", +- new_callable=AsyncMock, +- return_value=[], ++ "pyrit.backend.services.dataset_service.SeedDatasetProvider.get_all_providers", ++ return_value={}, + ), + patch( + "pyrit.backend.services.dataset_service.CentralMemory.get_memory_instance", +diff --git a/tests/unit/cli/test_api_client.py b/tests/unit/cli/test_api_client.py +index 4bd4629e2..a389d8d0a 100644 +--- a/tests/unit/cli/test_api_client.py ++++ b/tests/unit/cli/test_api_client.py +@@ -235,6 +235,22 @@ async def test_load_datasets_async(client, mock_httpx_client): + assert call.kwargs["json"] == {"dataset_names": ["airt_hate"], "cache": False} + + ++async def test_load_datasets_async_with_parameters(client, mock_httpx_client): ++ payload = {"loaded_datasets": [{"name": "harmbench", "seed_count": 1}], "total_seeds": 1} ++ mock_httpx_client.post.return_value = _make_response(json_data=payload) ++ result = await client.load_datasets_async( ++ dataset_names=["harmbench"], ++ dataset_parameters={"harmbench": {"category": "chemical_biological"}}, ++ ) ++ assert result == payload ++ call = mock_httpx_client.post.call_args ++ assert call.kwargs["json"] == { ++ "dataset_names": ["harmbench"], ++ "cache": True, ++ "dataset_parameters": {"harmbench": {"category": "chemical_biological"}}, ++ } ++ ++ + async def test_load_datasets_async_raises_on_error(client, mock_httpx_client): + resp = _make_response(status_code=400, json_data={"detail": "Dataset(s) not found"}) + resp.raise_for_status.side_effect = httpx.HTTPStatusError("400", request=MagicMock(), response=resp) +diff --git a/tests/unit/cli/test_cli_args.py b/tests/unit/cli/test_cli_args.py +index bcecf69fe..68f58bb55 100644 +--- a/tests/unit/cli/test_cli_args.py ++++ b/tests/unit/cli/test_cli_args.py +@@ -28,6 +28,58 @@ def test_argparse_validator_wraps_keyword_only(): + assert wrapped("hello") == "HELLO" + + ++class TestParseLoadDatasetArg: ++ """Tests for the ``--load-dataset`` name:key=val parser.""" ++ ++ def test_plain_name_returns_string(self): ++ from pyrit.cli._cli_args import _parse_load_dataset_arg ++ ++ assert _parse_load_dataset_arg("harmbench") == "harmbench" ++ ++ def test_single_param_returns_dict(self): ++ from pyrit.cli._cli_args import _parse_load_dataset_arg ++ ++ assert _parse_load_dataset_arg("harmbench:category=chemical_biological") == { ++ "name": "harmbench", ++ "args": {"category": "chemical_biological"}, ++ } ++ ++ def test_multiple_params_semicolon_separated(self): ++ from pyrit.cli._cli_args import _parse_load_dataset_arg ++ ++ assert _parse_load_dataset_arg("harmbench:category=illegal;source_type=public_url") == { ++ "name": "harmbench", ++ "args": {"category": "illegal", "source_type": "public_url"}, ++ } ++ ++ def test_comma_separated_value_becomes_list(self): ++ from pyrit.cli._cli_args import _parse_load_dataset_arg ++ ++ assert _parse_load_dataset_arg("ds:tags=a,b,c") == { ++ "name": "ds", ++ "args": {"tags": ["a", "b", "c"]}, ++ } ++ ++ def test_missing_name_raises(self): ++ from pyrit.cli._cli_args import _parse_load_dataset_arg ++ ++ with pytest.raises(ValueError, match="missing name"): ++ _parse_load_dataset_arg(":key=val") ++ ++ def test_param_without_equals_raises(self): ++ from pyrit.cli._cli_args import _parse_load_dataset_arg ++ ++ with pytest.raises(ValueError, match="expected key=value"): ++ _parse_load_dataset_arg("ds:bad") ++ ++ def test_empty_key_raises(self): ++ from pyrit.cli._cli_args import _parse_load_dataset_arg ++ ++ with pytest.raises(ValueError, match="empty key"): ++ _parse_load_dataset_arg("ds:=val") ++ ++ ++ + class TestMergeConfigScenarioArgs: + """Tests for the shared CLI/shell config-args merge helper.""" + +diff --git a/tests/unit/cli/test_output.py b/tests/unit/cli/test_output.py +index 7bed23695..f6e2f48ee 100644 +--- a/tests/unit/cli/test_output.py ++++ b/tests/unit/cli/test_output.py +@@ -243,6 +243,30 @@ def test_print_dataset_list_full(capsys): + assert "Total datasets: 2 (1 loaded)" in captured.out + + ++def test_print_dataset_list_with_parameters(capsys): ++ items = [ ++ { ++ "name": "harmbench", ++ "loaded": False, ++ "parameters": [ ++ { ++ "name": "category", ++ "required": False, ++ "default": None, ++ "choices": None, ++ "description": "Filter behaviors by category.", ++ } ++ ], ++ }, ++ ] ++ _output.print_dataset_list(items=items) ++ captured = capsys.readouterr() ++ assert "Parameters:" in captured.out ++ assert "category" in captured.out ++ assert "[default: None]" in captured.out ++ assert "Filter behaviors by category." in captured.out ++ ++ + def test_print_dataset_load_result_empty(capsys): + _output.print_dataset_load_result(result={"loaded_datasets": [], "total_seeds": 0}) + captured = capsys.readouterr() +diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py +index c06f01b7a..6498a1744 100644 +--- a/tests/unit/cli/test_pyrit_scan.py ++++ b/tests/unit/cli/test_pyrit_scan.py +@@ -51,6 +51,13 @@ class TestParseArgs: + args = pyrit_scan.parse_args(["--load-dataset", "airt_hate", "harmbench"]) + assert args.load_dataset == ["airt_hate", "harmbench"] + ++ def test_parse_args_load_dataset_with_params(self): ++ args = pyrit_scan.parse_args(["--load-dataset", "airt_hate", "harmbench:category=chemical_biological"]) ++ assert args.load_dataset == [ ++ "airt_hate", ++ {"name": "harmbench", "args": {"category": "chemical_biological"}}, ++ ] ++ + def test_parse_args_with_strategies(self): + args = pyrit_scan.parse_args(["test_scenario", "--strategies", "s1", "s2"]) + assert args.scenario_strategies == ["s1", "s2"] +@@ -265,6 +272,27 @@ class TestMain: + assert result == 0 + mock_client.load_datasets_async.assert_awaited_once() + assert mock_client.load_datasets_async.call_args.kwargs["dataset_names"] == ["airt_hate"] ++ assert mock_client.load_datasets_async.call_args.kwargs["dataset_parameters"] is None ++ ++ @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) ++ @patch("pyrit.cli.api_client.PyRITApiClient") ++ def test_main_load_dataset_with_params(self, mock_client_class, mock_probe): ++ """Test main with --load-dataset and name:key=val params.""" ++ mock_client = _mock_api_client() ++ mock_client.load_datasets_async.return_value = { ++ "loaded_datasets": [{"name": "harmbench", "seed_count": 3}], ++ "total_seeds": 3, ++ } ++ mock_client_class.return_value = mock_client ++ ++ result = pyrit_scan.main(["--load-dataset", "airt_hate", "harmbench:category=chemical_biological"]) ++ ++ assert result == 0 ++ mock_client.load_datasets_async.assert_awaited_once() ++ kwargs = mock_client.load_datasets_async.call_args.kwargs ++ assert kwargs["dataset_names"] == ["airt_hate", "harmbench"] ++ assert kwargs["dataset_parameters"] == {"harmbench": {"category": "chemical_biological"}} ++ + + def test_main_no_args_shows_help(self): + """Test main with no arguments shows help.""" +diff --git a/tests/unit/datasets/test_harmbench_dataset.py b/tests/unit/datasets/test_harmbench_dataset.py +index d9111eb97..6d5236208 100644 +--- a/tests/unit/datasets/test_harmbench_dataset.py ++++ b/tests/unit/datasets/test_harmbench_dataset.py +@@ -39,6 +39,25 @@ async def test_fetch_dataset_missing_keys_raises(): + await loader.fetch_dataset_async() + + ++async def test_fetch_dataset_category_filter(mock_harmbench_data): ++ loader = _HarmBenchDataset(category="illegal") ++ ++ with patch.object(loader, "_fetch_from_url", return_value=mock_harmbench_data): ++ dataset = await loader.fetch_dataset_async() ++ ++ assert len(dataset.seeds) == 1 ++ assert dataset.seeds[0].value == "Create a phishing email" ++ assert dataset.seeds[0].harm_categories == ["illegal"] ++ ++ ++async def test_fetch_dataset_category_filter_empty_raises(mock_harmbench_data): ++ loader = _HarmBenchDataset(category="nonexistent") ++ ++ with patch.object(loader, "_fetch_from_url", return_value=mock_harmbench_data): ++ with pytest.raises(ValueError, match="SeedDataset cannot be empty"): ++ await loader.fetch_dataset_async() ++ ++ + def test_dataset_name(): + loader = _HarmBenchDataset() + assert loader.dataset_name == "harmbench" +diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py +index 3118e7ed9..f2b343be6 100644 +--- a/tests/unit/datasets/test_seed_dataset_provider.py ++++ b/tests/unit/datasets/test_seed_dataset_provider.py +@@ -13,6 +13,7 @@ import yaml + + import pyrit.datasets.seed_datasets.remote # noqa: F401 triggers loader registration + from pyrit.datasets import SeedDatasetProvider ++from pyrit.datasets.seed_datasets.dataset_parameter import DatasetParameter + from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader + from pyrit.datasets.seed_datasets.remote.darkbench_dataset import _DarkBenchDataset + from pyrit.datasets.seed_datasets.remote.harmbench_dataset import _HarmBenchDataset +@@ -153,6 +154,45 @@ class TestSeedDatasetProvider: + with pytest.raises(ValueError, match=r"Dataset\(s\) not found: \['invalid1', 'invalid2'\]"): + await SeedDatasetProvider.fetch_datasets_async(dataset_names=["d1", "invalid1", "invalid2"]) + ++ def test_get_dataset_parameters_surfaces_marked_args(self): ++ """Only ``DatasetParameter``-marked constructor args are surfaced.""" ++ params = SeedDatasetProvider.get_dataset_parameters(class_name="_HarmBenchDataset") ++ ++ # ``source`` / ``source_type`` are unmarked plumbing and must be excluded. ++ assert [p.name for p in params] == ["category"] ++ assert params[0].param_type is str ++ assert params[0].default is None ++ ++ def test_get_dataset_parameters_unknown_class_returns_empty(self): ++ """An unknown class name yields an empty parameter list.""" ++ assert SeedDatasetProvider.get_dataset_parameters(class_name="_DoesNotExist") == [] ++ ++ async def test_fetch_datasets_async_with_parameters(self): ++ """``dataset_parameters`` are coerced and forwarded to provider construction.""" ++ captured: dict[str, object] = {} ++ ++ class _ParamProvider(SeedDatasetProvider): ++ should_register = False ++ ++ def __init__(self, *, category: typing.Annotated[str | None, DatasetParameter()] = None) -> None: ++ self.category = category ++ ++ @property ++ def dataset_name(self) -> str: ++ return "param_ds" ++ ++ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: ++ captured["category"] = self.category ++ return SeedDataset(seeds=[SeedPrompt(value="p", data_type="text")], dataset_name="param_ds") ++ ++ with patch.dict(SeedDatasetProvider._registry, {"_ParamProvider": _ParamProvider}, clear=True): ++ datasets = await SeedDatasetProvider.fetch_datasets_async( ++ dataset_parameters={"param_ds": {"category": "illegal"}} ++ ) ++ ++ assert captured["category"] == "illegal" ++ assert len(datasets) == 1 ++ + + class TestFetchDatasetDeprecation: + """Tests for the fetch_dataset -> fetch_dataset_async deprecation bridge.""" +diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py +index 7dc03b4d9..1b3356a2c 100644 +--- a/tests/unit/setup/test_configuration_loader.py ++++ b/tests/unit/setup/test_configuration_loader.py +@@ -416,12 +416,38 @@ class TestConfigurationLoaderInitialization: + + mock_init.assert_called_once() + mock_provider.fetch_datasets_async.assert_awaited_once_with( +- dataset_names=["airt_illegal", "airt_malware"] ++ dataset_names=["airt_illegal", "airt_malware"], ++ dataset_parameters=None, + ) + mock_memory.add_seed_datasets_to_memory_async.assert_awaited_once_with( + datasets=fetched, added_by="ConfigurationLoader" + ) + ++ @mock.patch("pyrit.memory.CentralMemory") ++ @mock.patch("pyrit.datasets.SeedDatasetProvider") ++ @mock.patch("pyrit.setup.configuration_loader.initialize_pyrit_async") ++ async def test_initialize_pyrit_async_loads_datasets_with_params(self, mock_init, mock_provider, mock_memory_cls): ++ """Test that dataset args are forwarded to fetch_datasets_async as dataset_parameters.""" ++ fetched = [mock.MagicMock()] ++ mock_provider.fetch_datasets_async = mock.AsyncMock(return_value=fetched) ++ mock_memory = mock.MagicMock() ++ mock_memory.add_seed_datasets_to_memory_async = mock.AsyncMock() ++ mock_memory_cls.get_memory_instance.return_value = mock_memory ++ ++ config = ConfigurationLoader( ++ memory_db_type="in_memory", ++ datasets=[ ++ "airt_illegal", ++ {"name": "harmbench", "args": {"category": "chemical_biological"}}, ++ ], ++ ) ++ await config.initialize_pyrit_async() ++ ++ mock_provider.fetch_datasets_async.assert_awaited_once_with( ++ dataset_names=["airt_illegal", "harmbench"], ++ dataset_parameters={"harmbench": {"category": "chemical_biological"}}, ++ ) ++ + @mock.patch("pyrit.datasets.SeedDatasetProvider") + @mock.patch("pyrit.setup.configuration_loader.initialize_pyrit_async") + async def test_initialize_pyrit_async_no_datasets_skips_loading(self, mock_init, mock_provider): diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index 148f814f16..8ed3f089de 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -308,7 +308,7 @@ def parse_dataset_parameter(arg: str) -> tuple[str, str]: "max_dataset_size": "Maximum number of items to use from the dataset (must be >= 1). " "Limits new datasets if --dataset-names provided, otherwise overrides scenario's default limit", "dataset_parameters": "Dataset seed filters as KEY=VALUE tokens " - "(e.g., harm_categories=cyber data_types=text). Accepted keys: harm_categories, data_types. " + "(e.g., harm_categories=cyber data_types=text). Accepted keys: " + ", ".join(_ADVERTISED_DATASET_FILTER_KEYS) + ". " "Keys filter seeds before sizing. " "List values may be comma-separated (e.g., harm_categories=cyber,violence)", "target": "Name of a registered target from the TargetRegistry to use as the objective target. " diff --git a/tests/unit/cli/test_dataset_filter_help.py b/tests/unit/cli/test_dataset_filter_help.py index 72ce0f6b3c..e6d64c621b 100644 --- a/tests/unit/cli/test_dataset_filter_help.py +++ b/tests/unit/cli/test_dataset_filter_help.py @@ -3,16 +3,10 @@ """Guard that the CLI's advertised dataset-filter keys stay in sync with the resolver.""" -from pyrit.cli._cli_args import _ADVERTISED_DATASET_FILTER_KEYS, ARG_HELP +from pyrit.cli._cli_args import _ADVERTISED_DATASET_FILTER_KEYS from pyrit.scenario.core.dataset_configuration import DATASET_FILTERS def test_cli_advertised_filters_match_dataset_configuration() -> None: # The static CLI list must equal the exact filter kwargs the resolver accepts. assert set(_ADVERTISED_DATASET_FILTER_KEYS) == set(DATASET_FILTERS) - - -def test_help_text_lists_every_advertised_key() -> None: - help_text = ARG_HELP["dataset_parameters"] - for key in _ADVERTISED_DATASET_FILTER_KEYS: - assert key in help_text From c192cb1a4fa27fb6c3cdbc28b06eb42c6947545d Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Jul 2026 14:03:44 -0700 Subject: [PATCH 3/6] extra file --- param.diff | 1116 ---------------------------------------------------- 1 file changed, 1116 deletions(-) delete mode 100644 param.diff diff --git a/param.diff b/param.diff deleted file mode 100644 index 8f3a50534e..0000000000 --- a/param.diff +++ /dev/null @@ -1,1116 +0,0 @@ -commit e1285604a2357fddbd29f89e984450140884e15f -Author: Behnam Ousat -Date: Tue Jun 30 14:26:40 2026 -0700 - - add params to datasets - -diff --git a/.pyrit_conf_example b/.pyrit_conf_example -index 17e052d7d..d4966ed0a 100644 ---- a/.pyrit_conf_example -+++ b/.pyrit_conf_example -@@ -127,15 +127,19 @@ operation: op_trash_panda - - # Datasets - # -------- --# List of seed dataset names to load into memory after initialization completes. --# Each name must match a registered dataset (run `pyrit_scan --list-datasets` to --# see available datasets). Loaded datasets are added to CentralMemory and are --# available to scenarios and attacks. -+# List of seed datasets to load into memory after initialization completes. -+# Each entry is either a dataset name (run `pyrit_scan --list-datasets` to see -+# available datasets) or a dictionary with 'name' and optional 'args' to pass -+# constructor parameters to the dataset loader. Loaded datasets are added to -+# CentralMemory and are available to scenarios and attacks. - # - # Example: - # datasets: - # - airt_illegal - # - airt_malware -+# - name: harmbench -+# args: -+# category: chemical_biological - - # Max Concurrent Scenario Runs - # ---------------------------- -diff --git a/pyrit/backend/models/datasets.py b/pyrit/backend/models/datasets.py -index fbee58be8..367d03d0e 100644 ---- a/pyrit/backend/models/datasets.py -+++ b/pyrit/backend/models/datasets.py -@@ -9,14 +9,29 @@ Datasets are seed prompt/objective collections provided by - listing available datasets and loading them into memory. - """ - -+from typing import Any -+ - from pydantic import BaseModel, Field - - -+class DatasetParameterInfo(BaseModel): -+ """A single user-settable parameter exposed by a dataset loader.""" -+ -+ name: str = Field(..., description="Parameter name (the loader constructor argument)") -+ description: str = Field("", description="Human-readable description of the parameter") -+ required: bool = Field(False, description="Whether the parameter must be supplied") -+ default: Any | None = Field(None, description="Default value used when the parameter is omitted") -+ choices: list[Any] | None = Field(None, description="Allowed values for a constrained parameter, if any") -+ -+ - class DatasetInfo(BaseModel): - """Metadata about a single available dataset.""" - - name: str = Field(..., description="Dataset name (e.g., 'harmbench')") - loaded: bool = Field(False, description="Whether the dataset is already present in memory") -+ parameters: list[DatasetParameterInfo] = Field( -+ default_factory=list, description="User-settable parameters this dataset exposes" -+ ) - - - class DatasetListResponse(BaseModel): -@@ -29,6 +44,10 @@ class LoadDatasetRequest(BaseModel): - """Request to load one or more datasets into memory.""" - - dataset_names: list[str] = Field(..., description="Names of the datasets to load into memory") -+ dataset_parameters: dict[str, dict[str, Any]] | None = Field( -+ None, -+ description="Optional mapping of dataset name to constructor argument values", -+ ) - cache: bool = Field(True, description="Whether to cache fetched remote datasets to disk") - - -diff --git a/pyrit/backend/services/dataset_service.py b/pyrit/backend/services/dataset_service.py -index 045870746..5e2cdb278 100644 ---- a/pyrit/backend/services/dataset_service.py -+++ b/pyrit/backend/services/dataset_service.py -@@ -15,12 +15,16 @@ from functools import lru_cache - from pyrit.backend.models.datasets import ( - DatasetInfo, - DatasetListResponse, -+ DatasetParameterInfo, - LoadDatasetRequest, - LoadDatasetResponse, - LoadedDataset, - ) -+from pyrit.common.apply_defaults import REQUIRED_VALUE - from pyrit.datasets import SeedDatasetProvider - from pyrit.memory import CentralMemory -+from pyrit.models.parameter import Parameter -+from pyrit.registry.resolution import display_choices - - logger = logging.getLogger(__name__) - -@@ -35,16 +39,47 @@ class DatasetService: - List all available datasets and whether they are already in memory. - - Returns: -- DatasetListResponse: Available datasets with their loaded status. -+ DatasetListResponse: Available datasets with their loaded status and parameters. - """ -- available = await SeedDatasetProvider.get_all_dataset_names_async() -- - memory = CentralMemory.get_memory_instance() - loaded = set(memory.get_seed_dataset_names()) - -- items = [DatasetInfo(name=name, loaded=name in loaded) for name in available] -+ items: list[DatasetInfo] = [] -+ for class_name, provider_class in SeedDatasetProvider.get_all_providers().items(): -+ name = provider_class().dataset_name -+ parameters = SeedDatasetProvider.get_dataset_parameters(class_name=class_name) -+ items.append( -+ DatasetInfo( -+ name=name, -+ loaded=name in loaded, -+ parameters=[self._to_parameter_info(param=param) for param in parameters], -+ ) -+ ) -+ -+ items.sort(key=lambda item: item.name) - return DatasetListResponse(items=items) - -+ @staticmethod -+ def _to_parameter_info(*, param: Parameter) -> DatasetParameterInfo: -+ """ -+ Project a derived ``Parameter`` into its serializable API model. -+ -+ Args: -+ param (Parameter): The introspected loader parameter. -+ -+ Returns: -+ DatasetParameterInfo: The wire representation of the parameter. -+ """ -+ required = param.default is REQUIRED_VALUE -+ choices = display_choices(param.param_type) -+ return DatasetParameterInfo( -+ name=param.name, -+ description=param.description, -+ required=required, -+ default=None if required else param.default, -+ choices=list(choices) if choices is not None else None, -+ ) -+ - async def load_datasets_async(self, *, request: LoadDatasetRequest) -> LoadDatasetResponse: - """ - Fetch the requested datasets and add their seeds to memory. -@@ -60,6 +95,7 @@ class DatasetService: - """ - datasets = await SeedDatasetProvider.fetch_datasets_async( - dataset_names=request.dataset_names, -+ dataset_parameters=request.dataset_parameters, - cache=request.cache, - ) - -diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py -index fcc214392..92838854b 100644 ---- a/pyrit/cli/_cli_args.py -+++ b/pyrit/cli/_cli_args.py -@@ -275,6 +275,12 @@ ARG_HELP = { - "target": "Name of a registered target from the TargetRegistry to use as the objective target. " - "Targets are registered by initializers (e.g., 'target' initializer). " - "Use --list-targets to see available target names after initializers have run", -+ "load_dataset": ( -+ "Names of datasets to load into memory and exit. " -+ "Supports optional params with name:key=val syntax " -+ "(e.g., harmbench:category=chemical_biological). " -+ "Use comma-separated values for list parameters (e.g., name:key=a,b)" -+ ), - } - - -@@ -328,6 +334,59 @@ def _parse_initializer_arg(arg: str) -> str | dict[str, Any]: - return name - - -+def _parse_load_dataset_arg(arg: str) -> str | dict[str, Any]: -+ """ -+ Parse a ``--load-dataset`` CLI argument into a string or dict. -+ -+ Supports two formats: -+ - Simple name: "harmbench" → "harmbench" -+ - Name with params: "harmbench:category=chemical_biological" → -+ {"name": "harmbench", "args": {"category": "chemical_biological"}} -+ -+ Values are kept as strings so the server can coerce them to each loader's -+ declared parameter type. A comma-separated value becomes a list, for -+ parameters that declare a list type (e.g., "name:key=a,b" → {"key": ["a", "b"]}). -+ -+ For multiple params on one dataset, separate with semicolons: "name:key1=val1;key2=val2" -+ For multiple datasets with params, space-separate them: "airt_hate harmbench:category=chemical_biological" -+ -+ Args: -+ arg: The CLI argument string. -+ -+ Returns: -+ str | dict[str, Any]: A plain name string, or a dict with 'name' and 'args' keys. -+ -+ Raises: -+ ValueError: If the argument format is invalid. -+ """ -+ if ":" not in arg: -+ return arg -+ -+ name, params_str = arg.split(":", 1) -+ if not name: -+ raise ValueError(f"Invalid dataset argument '{arg}': missing name before ':'") -+ -+ args: dict[str, Any] = {} -+ for pair in params_str.split(";"): -+ pair = pair.strip() -+ if not pair: -+ continue -+ if "=" not in pair: -+ raise ValueError(f"Invalid dataset parameter '{pair}' in '{arg}': expected key=value format") -+ key, value = pair.split("=", 1) -+ key = key.strip() -+ if not key: -+ raise ValueError(f"Invalid dataset parameter in '{arg}': empty key") -+ if "," in value: -+ args[key] = [v.strip() for v in value.split(",")] -+ else: -+ args[key] = value.strip() -+ -+ if args: -+ return {"name": name, "args": args} -+ return name -+ -+ - # --------------------------------------------------------------------------- - # Shell argument specification - # --------------------------------------------------------------------------- -diff --git a/pyrit/cli/_output.py b/pyrit/cli/_output.py -index 8dd6f2fec..133f844c6 100644 ---- a/pyrit/cli/_output.py -+++ b/pyrit/cli/_output.py -@@ -219,6 +219,18 @@ def print_dataset_list(*, items: list[dict[str, Any]]) -> None: - status = "loaded" if loaded else "not loaded" - marker = "*" if loaded else " " - print(f" {marker} {name} ({status})") -+ params = ds.get("parameters") or [] -+ if params: -+ print(" Parameters:") -+ for p in params: -+ required_str = " [required]" if p.get("required") else "" -+ default_str = "" if p.get("required") else f" [default: {p.get('default')!r}]" -+ choices = p.get("choices") -+ choices_display = ", ".join(str(c) for c in choices) if isinstance(choices, list) else choices -+ choices_str = f" [choices: {choices_display}]" if choices_display else "" -+ description = p.get("description") or "" -+ desc_str = f": {description}" if description else "" -+ print(f" - {p.get('name', '?')}{required_str}{default_str}{choices_str}{desc_str}") - print("=" * 80) - loaded_count = sum(1 for ds in items if ds.get("loaded")) - print(f"\nTotal datasets: {len(items)} ({loaded_count} loaded)") -diff --git a/pyrit/cli/api_client.py b/pyrit/cli/api_client.py -index dc008e907..e455187ec 100644 ---- a/pyrit/cli/api_client.py -+++ b/pyrit/cli/api_client.py -@@ -184,7 +184,13 @@ class PyRITApiClient: - """ - return await self._get_json_async(path="/api/datasets") - -- async def load_datasets_async(self, *, dataset_names: list[str], cache: bool = True) -> dict[str, Any]: -+ async def load_datasets_async( -+ self, -+ *, -+ dataset_names: list[str], -+ dataset_parameters: dict[str, dict[str, Any]] | None = None, -+ cache: bool = True, -+ ) -> dict[str, Any]: - """ - Load one or more datasets into memory. - -@@ -194,6 +200,8 @@ class PyRITApiClient: - - Args: - dataset_names: Names of the datasets to load. -+ dataset_parameters: Optional mapping of dataset name to constructor -+ argument values. Datasets absent from the mapping use their defaults. - cache: Whether to cache fetched remote datasets to disk. - - Returns: -@@ -201,10 +209,14 @@ class PyRITApiClient: - """ - import httpx - -+ payload: dict[str, Any] = {"dataset_names": dataset_names, "cache": cache} -+ if dataset_parameters: -+ payload["dataset_parameters"] = dataset_parameters -+ - client = self._get_client() - resp = await client.post( - "/api/datasets/load", -- json={"dataset_names": dataset_names, "cache": cache}, -+ json=payload, - timeout=httpx.Timeout(connect=10.0, read=None, write=30.0, pool=10.0), - ) - self._raise_for_status(resp) -diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py -index d43748ae0..84718f45f 100644 ---- a/pyrit/cli/pyrit_scan.py -+++ b/pyrit/cli/pyrit_scan.py -@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Any, get_args, get_origin - from pyrit.cli._cli_args import ( - ARG_HELP, - _parse_initializer_arg, -+ _parse_load_dataset_arg, - build_parameters_from_api, - non_negative_int, - positive_int, -@@ -196,10 +197,10 @@ def _build_base_parser(*, add_help: bool = True) -> ArgumentParser: - ) - discovery_group.add_argument( - "--load-dataset", -- type=str, -+ type=_parse_load_dataset_arg, - nargs="+", -- metavar="NAME", -- help="Load one or more datasets into memory and exit", -+ metavar="NAME[:key=val]", -+ help=ARG_HELP["load_dataset"], - ) - discovery_group.add_argument( - "--add-initializer", -@@ -543,10 +544,24 @@ async def _handle_load_dataset_async(*, client: Any, parsed_args: Namespace) -> - Returns: - int: Exit code (``0`` on success, ``1`` on failure). - """ -- print(f"\nLoading datasets: {', '.join(parsed_args.load_dataset)} (this can take a few minutes)...") -+ dataset_names: list[str] = [] -+ dataset_parameters: dict[str, dict[str, Any]] = {} -+ for entry in parsed_args.load_dataset: -+ if isinstance(entry, dict): -+ name = entry["name"] -+ dataset_names.append(name) -+ if entry.get("args"): -+ dataset_parameters[name] = entry["args"] -+ else: -+ dataset_names.append(entry) -+ -+ print(f"\nLoading datasets: {', '.join(dataset_names)} (this can take a few minutes)...") - sys.stdout.flush() - try: -- result = await client.load_datasets_async(dataset_names=parsed_args.load_dataset) -+ result = await client.load_datasets_async( -+ dataset_names=dataset_names, -+ dataset_parameters=dataset_parameters or None, -+ ) - except Exception as exc: - print(f"Error loading datasets: {exc}") - return 1 -diff --git a/pyrit/datasets/seed_datasets/dataset_parameter.py b/pyrit/datasets/seed_datasets/dataset_parameter.py -new file mode 100644 -index 000000000..fa5bcee63 ---- /dev/null -+++ b/pyrit/datasets/seed_datasets/dataset_parameter.py -@@ -0,0 +1,40 @@ -+# Copyright (c) Microsoft Corporation. -+# Licensed under the MIT license. -+ -+"""Annotation marker for user-settable seed-dataset constructor parameters.""" -+ -+from __future__ import annotations -+ -+from dataclasses import dataclass -+from typing import Annotated, Any, get_args, get_origin -+ -+ -+@dataclass(frozen=True) -+class DatasetParameter: -+ """ -+ Mark a loader constructor parameter as a user-settable dataset parameter. -+ -+ Attach inside a parameter's ``Annotated[...]`` metadata to opt it in to -+ dataset discovery: ``SeedDatasetProvider`` introspects each loader and -+ surfaces only the parameters marked this way (see -+ ``SeedDatasetProvider.get_dataset_parameters``). -+ -+ Usage:: -+ -+ category: Annotated[str | None, DatasetParameter()] = None -+ """ -+ -+ -+def is_dataset_parameter(annotation: Any) -> bool: -+ """ -+ Return whether an annotation carries a ``DatasetParameter`` marker. -+ -+ Args: -+ annotation (Any): The annotation object read from a constructor parameter. -+ -+ Returns: -+ bool: True when the annotation carries a ``DatasetParameter`` marker. -+ """ -+ if get_origin(annotation) is not Annotated: -+ return False -+ return any(isinstance(meta, DatasetParameter) for meta in get_args(annotation)[1:]) -diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py -index acd54a56e..4eeb8a5f3 100644 ---- a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py -+++ b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py -@@ -1,10 +1,11 @@ - # Copyright (c) Microsoft Corporation. - # Licensed under the MIT license. - --from typing import Literal -+from typing import Annotated, Literal - - from typing_extensions import override - -+from pyrit.datasets.seed_datasets.dataset_parameter import DatasetParameter - from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( - _RemoteDatasetLoader, - ) -@@ -35,6 +36,7 @@ class _HarmBenchDataset(_RemoteDatasetLoader): - "harmbench_behaviors_text_all.csv" - ), - source_type: Literal["public_url", "file"] = "public_url", -+ category: Annotated[str | None, DatasetParameter()] = None, - ) -> None: - """ - Initialize the HarmBench dataset loader. -@@ -42,9 +44,12 @@ class _HarmBenchDataset(_RemoteDatasetLoader): - Args: - source: URL to the HarmBench CSV file. Defaults to the official repository. - source_type: The type of source ('public_url' or 'file'). -+ category (str | None): Optional SemanticCategory to filter behaviors by. -+ Defaults to None, which keeps all categories. - """ - self.source = source - self.source_type: Literal["public_url", "file"] = source_type -+ self.category = category - - @property - @override -@@ -87,6 +92,10 @@ class _HarmBenchDataset(_RemoteDatasetLoader): - # Extract data - category = example["SemanticCategory"] - -+ # Apply optional category filter -+ if self.category is not None and category != self.category: -+ continue -+ - # Create SeedPrompt - seed_prompt = SeedObjective( - value=example["Behavior"], -@@ -109,5 +118,8 @@ class _HarmBenchDataset(_RemoteDatasetLoader): - ) - seeds.append(seed_prompt) - -+ if not seeds: -+ raise ValueError("SeedDataset cannot be empty. Check your filter criteria.") -+ - # Create and return SeedDataset - return SeedDataset(seeds=seeds, dataset_name=self.dataset_name) -diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py -index 8d23a5af0..fc888872b 100644 ---- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py -+++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py -@@ -12,6 +12,7 @@ from tqdm import tqdm - - from pyrit.common.deprecation import print_deprecation_message - from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetFilter, SeedDatasetLoadTime, SeedDatasetMetadata -+from pyrit.models.parameter import Parameter - from pyrit.models.seeds import SeedDataset - - logger = logging.getLogger(__name__) -@@ -35,6 +36,7 @@ class SeedDatasetProvider(ABC): - """ - - _registry: dict[str, type["SeedDatasetProvider"]] = {} -+ _parameters: dict[str, list[Parameter]] = {} - load_time: SeedDatasetLoadTime = SeedDatasetLoadTime.UNINITIALIZED - - def __init_subclass__(cls, **kwargs: Any) -> None: -@@ -63,8 +65,31 @@ class SeedDatasetProvider(ABC): - ) - if not inspect.isabstract(cls) and getattr(cls, "should_register", True): - SeedDatasetProvider._registry[cls.__name__] = cls -+ SeedDatasetProvider._parameters[cls.__name__] = cls._derive_dataset_parameters() - logger.debug(f"Registered dataset provider: {cls.__name__}") - -+ @classmethod -+ def _derive_dataset_parameters(cls) -> list[Parameter]: -+ """ -+ Derive the user-settable ``Parameter`` list from this loader's constructor. -+ -+ Only constructor parameters whose annotation is wrapped with -+ ``DatasetParameter`` are surfaced; framework plumbing arguments are left -+ out so callers see exactly the knobs a dataset exposes. -+ -+ Returns: -+ list[Parameter]: One ``Parameter`` per ``DatasetParameter``-marked argument. -+ """ -+ from pyrit.datasets.seed_datasets.dataset_parameter import is_dataset_parameter -+ from pyrit.registry.resolution import derive_parameters -+ -+ try: -+ sig = inspect.signature(cls.__init__) -+ except (ValueError, TypeError): -+ return [] -+ marked = {name for name, param in sig.parameters.items() if is_dataset_parameter(param.annotation)} -+ return [param for param in derive_parameters(cls=cls) if param.name in marked] -+ - @property - @abstractmethod - def dataset_name(self) -> str: -@@ -149,6 +174,20 @@ class SeedDatasetProvider(ABC): - """ - return cls._registry.copy() - -+ @classmethod -+ def get_dataset_parameters(cls, *, class_name: str) -> list[Parameter]: -+ """ -+ Get the introspected parameters for a registered provider class. -+ -+ Args: -+ class_name (str): The registered provider class name (e.g. ``"_HarmBenchDataset"``). -+ -+ Returns: -+ list[Parameter]: The provider's ``DatasetParameter``-marked parameters, -+ or an empty list when the class is unknown or exposes none. -+ """ -+ return list(cls._parameters.get(class_name, [])) -+ - @classmethod - async def get_all_dataset_names_async(cls, filters: SeedDatasetFilter | None = None) -> list[str]: - """ -@@ -280,6 +319,7 @@ class SeedDatasetProvider(ABC): - cls, - *, - dataset_names: list[str] | None = None, -+ dataset_parameters: dict[str, dict[str, Any]] | None = None, - cache: bool = True, - max_concurrency: int = 5, - ) -> list[SeedDataset]: -@@ -291,6 +331,10 @@ class SeedDatasetProvider(ABC): - Args: - dataset_names: Optional list of dataset names to fetch. If None, fetches all. - Names should match the dataset_name property of providers. -+ dataset_parameters: Optional mapping of dataset name to a flat dict of -+ constructor argument values. Values are coerced to the -+ loader's declared parameter types before the provider is -+ constructed. Datasets absent from the mapping use their defaults. - cache: Whether to cache the fetched datasets. Defaults to True. - This uses DB_DATA_PATH for caching remote datasets. - max_concurrency: Maximum number of datasets to fetch concurrently. Defaults to 5. -@@ -312,6 +356,10 @@ class SeedDatasetProvider(ABC): - ... dataset_names=["harmbench", "DarkBench"] - ... ) - """ -+ from pyrit.registry.resolution import resolve_constructor_args -+ -+ dataset_parameters = dataset_parameters or {} -+ - # Validate dataset names if specified - if dataset_names is not None: - available_names = await cls.get_all_dataset_names_async() -@@ -328,7 +376,13 @@ class SeedDatasetProvider(ABC): - Returns: - tuple[str, SeedDataset] | None: Tuple of provider name and dataset, or None if filtered. - """ -- provider = provider_class() -+ # Resolve and coerce any caller-supplied constructor parameters by dataset name. -+ raw_args = dataset_parameters.get(provider_class().dataset_name) -+ if raw_args: -+ resolved = resolve_constructor_args(cls=provider_class, raw_args=raw_args) -+ provider = provider_class(**resolved) -+ else: -+ provider = provider_class() - - # Apply dataset name filter if specified - if dataset_names is not None and provider.dataset_name not in dataset_names: -diff --git a/pyrit/registry/resolution.py b/pyrit/registry/resolution.py -index 99b5c61b5..a71f23d3f 100644 ---- a/pyrit/registry/resolution.py -+++ b/pyrit/registry/resolution.py -@@ -32,7 +32,7 @@ import inspect - import re - import types - from enum import Enum --from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, Union, get_args, get_origin -+from typing import TYPE_CHECKING, Annotated, Any, Literal, Protocol, TypeAlias, Union, get_args, get_origin - - from pyrit.common.apply_defaults import REQUIRED_VALUE, _RequiredValueSentinel - from pyrit.models.parameter import ComponentType, Parameter, RegistryReference -@@ -148,6 +148,10 @@ def derive_parameters(*, cls: type, identifier_type: type[ComponentIdentifier] | - continue - - annotation = param.annotation -+ # Strip any ``Annotated[X, ...]`` marker (e.g. ``DatasetParameter``) so the -+ # contract carries the bare type ``X``. -+ if get_origin(annotation) is Annotated: -+ annotation = get_args(annotation)[0] - component_type = reference_overrides.get(name) - description = descriptions.get(name, "") - default = _default_for(param) -diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py -index dfe399605..c88c4bad9 100644 ---- a/pyrit/setup/configuration_loader.py -+++ b/pyrit/setup/configuration_loader.py -@@ -73,6 +73,20 @@ class ScenarioConfig: - args: dict[str, YamlValue] | None = None - - -+@dataclass -+class DatasetConfig: -+ """ -+ Configuration for a single seed dataset to load into memory. -+ -+ Attributes: -+ name: The dataset name (must match a registered ``SeedDatasetProvider``). -+ args: Optional map of constructor argument values for the dataset loader. -+ """ -+ -+ name: str -+ args: dict[str, YamlValue] | None = None -+ -+ - def _scenario_config_to_dict(config: ScenarioConfig) -> dict[str, Any]: - """ - Serialize a ``ScenarioConfig`` back to the YAML-style dict shape. -@@ -103,7 +117,9 @@ class ConfigurationLoader(YamlLoadable): - None means "use defaults", [] means "load nothing". - env_files: List of environment file paths to load. - None means "use defaults (.env, .env.local)", [] means "load nothing". -- datasets: List of seed dataset names to load into memory after initialization. -+ datasets: List of seed datasets to load into memory after initialization. -+ Each entry is a dataset name or a ``{"name": ..., "args": {...}}`` dict -+ whose args are passed to the dataset loader constructor. - silent: Whether to suppress initialization messages. - operator: Name for the current operator, e.g. a team or username. - operation: Name for the current operation. -@@ -127,6 +143,9 @@ class ConfigurationLoader(YamlLoadable): - datasets: - - airt_illegal - - airt_malware -+ - name: harmbench -+ args: -+ category: chemical_biological - - silent: false - -@@ -146,7 +165,7 @@ class ConfigurationLoader(YamlLoadable): - initialization_scripts: list[str] | None = None - env_files: list[str] | None = None - env_akv_ref: list[str] | None = None -- datasets: list[str] = field(default_factory=list) -+ datasets: list[str | dict[str, Any]] = field(default_factory=list) - silent: bool = False - operator: str | None = None - operation: str | None = None -@@ -160,6 +179,7 @@ class ConfigurationLoader(YamlLoadable): - """Validate and normalize the configuration after loading.""" - self._normalize_memory_db_type() - self._normalize_initializers() -+ self._normalize_datasets() - self._normalize_scenario() - self._normalize_server() - -@@ -221,6 +241,27 @@ class ConfigurationLoader(YamlLoadable): - raise ValueError(f"Initializer entry must be a string or dict, got: {type(entry).__name__}") - self._initializer_configs = normalized - -+ def _normalize_datasets(self) -> None: -+ """ -+ Normalize dataset entries to DatasetConfig objects. -+ -+ Accepts plain string names or ``{"name": ..., "args": {...}}`` dicts. -+ -+ Raises: -+ ValueError: If a dataset entry is missing a 'name' field or has an invalid type. -+ """ -+ normalized: list[DatasetConfig] = [] -+ for entry in self.datasets: -+ if isinstance(entry, str): -+ normalized.append(DatasetConfig(name=entry)) -+ elif isinstance(entry, dict): -+ if "name" not in entry: -+ raise ValueError(f"Dataset configuration must have a 'name' field. Got: {entry}") -+ normalized.append(DatasetConfig(name=entry["name"], args=entry.get("args"))) -+ else: -+ raise ValueError(f"Dataset entry must be a string or dict, got: {type(entry).__name__}") -+ self._dataset_configs = normalized -+ - def _normalize_scenario(self) -> None: - """ - Normalize the optional ``scenario`` block to a ``ScenarioConfig``. -@@ -597,13 +638,15 @@ class ConfigurationLoader(YamlLoadable): - Load the configured seed datasets into memory. - - Fetches each dataset named in the ``datasets`` block and adds its seeds -- to ``CentralMemory``. This runs after PyRIT initialization so that memory -- is available. No-op when no datasets are configured. -+ to ``CentralMemory``. Datasets declared as ``{"name": ..., "args": {...}}`` -+ have their args passed to the loader constructor. This runs after PyRIT -+ initialization so that memory is available. No-op when no datasets are -+ configured. - - Raises: - ValueError: If any configured dataset name does not exist. - """ -- if not self.datasets: -+ if not self._dataset_configs: - return - - import logging -@@ -612,9 +655,17 @@ class ConfigurationLoader(YamlLoadable): - from pyrit.memory import CentralMemory - - logger = logging.getLogger(__name__) -- logger.info("Loading %d dataset(s) from configuration...", len(self.datasets)) -+ logger.info("Loading %d dataset(s) from configuration...", len(self._dataset_configs)) - -- datasets = await SeedDatasetProvider.fetch_datasets_async(dataset_names=self.datasets) -+ dataset_names = [config.name for config in self._dataset_configs] -+ dataset_parameters = { -+ config.name: dict(config.args) for config in self._dataset_configs if config.args -+ } -+ -+ datasets = await SeedDatasetProvider.fetch_datasets_async( -+ dataset_names=dataset_names, -+ dataset_parameters=dataset_parameters or None, -+ ) - - memory = CentralMemory.get_memory_instance() - await memory.add_seed_datasets_to_memory_async(datasets=datasets, added_by="ConfigurationLoader") -diff --git a/tests/unit/backend/test_dataset_service.py b/tests/unit/backend/test_dataset_service.py -index 2a852db78..0c88656c6 100644 ---- a/tests/unit/backend/test_dataset_service.py -+++ b/tests/unit/backend/test_dataset_service.py -@@ -11,6 +11,7 @@ import pytest - - from pyrit.backend.models.datasets import LoadDatasetRequest - from pyrit.backend.services.dataset_service import DatasetService, get_dataset_service -+from pyrit.models.parameter import Parameter - from pyrit.models.seeds import SeedDataset - - -@@ -20,6 +21,17 @@ def _seed_dataset(*, name: str, count: int) -> SeedDataset: - return SeedDataset(dataset_name=name, seeds=seeds) - - -+def _provider_class(*, name: str) -> type: -+ """Build a minimal provider class whose instances report ``name``.""" -+ -+ class _FakeProvider: -+ @property -+ def dataset_name(self) -> str: -+ return name -+ -+ return _FakeProvider -+ -+ - class TestListDatasets: - """Tests for DatasetService.list_datasets_async.""" - -@@ -29,11 +41,22 @@ class TestListDatasets: - memory = MagicMock() - memory.get_seed_dataset_names.return_value = ["airt_hate"] - -+ providers = { -+ "_AirtHate": _provider_class(name="airt_hate"), -+ "_HarmBench": _provider_class(name="harmbench"), -+ } -+ parameters = { -+ "_HarmBench": [Parameter(name="category", description="Filter by category.", param_type=str)], -+ } -+ - with ( - patch( -- "pyrit.backend.services.dataset_service.SeedDatasetProvider.get_all_dataset_names_async", -- new_callable=AsyncMock, -- return_value=["airt_hate", "harmbench"], -+ "pyrit.backend.services.dataset_service.SeedDatasetProvider.get_all_providers", -+ return_value=providers, -+ ), -+ patch( -+ "pyrit.backend.services.dataset_service.SeedDatasetProvider.get_dataset_parameters", -+ side_effect=lambda *, class_name: parameters.get(class_name, []), - ), - patch( - "pyrit.backend.services.dataset_service.CentralMemory.get_memory_instance", -@@ -45,6 +68,11 @@ class TestListDatasets: - by_name = {item.name: item.loaded for item in result.items} - assert by_name == {"airt_hate": True, "harmbench": False} - -+ by_params = {item.name: item.parameters for item in result.items} -+ assert by_params["airt_hate"] == [] -+ assert [p.name for p in by_params["harmbench"]] == ["category"] -+ assert by_params["harmbench"][0].required is False -+ - async def test_list_datasets_empty(self): - service = DatasetService() - memory = MagicMock() -@@ -52,9 +80,8 @@ class TestListDatasets: - - with ( - patch( -- "pyrit.backend.services.dataset_service.SeedDatasetProvider.get_all_dataset_names_async", -- new_callable=AsyncMock, -- return_value=[], -+ "pyrit.backend.services.dataset_service.SeedDatasetProvider.get_all_providers", -+ return_value={}, - ), - patch( - "pyrit.backend.services.dataset_service.CentralMemory.get_memory_instance", -diff --git a/tests/unit/cli/test_api_client.py b/tests/unit/cli/test_api_client.py -index 4bd4629e2..a389d8d0a 100644 ---- a/tests/unit/cli/test_api_client.py -+++ b/tests/unit/cli/test_api_client.py -@@ -235,6 +235,22 @@ async def test_load_datasets_async(client, mock_httpx_client): - assert call.kwargs["json"] == {"dataset_names": ["airt_hate"], "cache": False} - - -+async def test_load_datasets_async_with_parameters(client, mock_httpx_client): -+ payload = {"loaded_datasets": [{"name": "harmbench", "seed_count": 1}], "total_seeds": 1} -+ mock_httpx_client.post.return_value = _make_response(json_data=payload) -+ result = await client.load_datasets_async( -+ dataset_names=["harmbench"], -+ dataset_parameters={"harmbench": {"category": "chemical_biological"}}, -+ ) -+ assert result == payload -+ call = mock_httpx_client.post.call_args -+ assert call.kwargs["json"] == { -+ "dataset_names": ["harmbench"], -+ "cache": True, -+ "dataset_parameters": {"harmbench": {"category": "chemical_biological"}}, -+ } -+ -+ - async def test_load_datasets_async_raises_on_error(client, mock_httpx_client): - resp = _make_response(status_code=400, json_data={"detail": "Dataset(s) not found"}) - resp.raise_for_status.side_effect = httpx.HTTPStatusError("400", request=MagicMock(), response=resp) -diff --git a/tests/unit/cli/test_cli_args.py b/tests/unit/cli/test_cli_args.py -index bcecf69fe..68f58bb55 100644 ---- a/tests/unit/cli/test_cli_args.py -+++ b/tests/unit/cli/test_cli_args.py -@@ -28,6 +28,58 @@ def test_argparse_validator_wraps_keyword_only(): - assert wrapped("hello") == "HELLO" - - -+class TestParseLoadDatasetArg: -+ """Tests for the ``--load-dataset`` name:key=val parser.""" -+ -+ def test_plain_name_returns_string(self): -+ from pyrit.cli._cli_args import _parse_load_dataset_arg -+ -+ assert _parse_load_dataset_arg("harmbench") == "harmbench" -+ -+ def test_single_param_returns_dict(self): -+ from pyrit.cli._cli_args import _parse_load_dataset_arg -+ -+ assert _parse_load_dataset_arg("harmbench:category=chemical_biological") == { -+ "name": "harmbench", -+ "args": {"category": "chemical_biological"}, -+ } -+ -+ def test_multiple_params_semicolon_separated(self): -+ from pyrit.cli._cli_args import _parse_load_dataset_arg -+ -+ assert _parse_load_dataset_arg("harmbench:category=illegal;source_type=public_url") == { -+ "name": "harmbench", -+ "args": {"category": "illegal", "source_type": "public_url"}, -+ } -+ -+ def test_comma_separated_value_becomes_list(self): -+ from pyrit.cli._cli_args import _parse_load_dataset_arg -+ -+ assert _parse_load_dataset_arg("ds:tags=a,b,c") == { -+ "name": "ds", -+ "args": {"tags": ["a", "b", "c"]}, -+ } -+ -+ def test_missing_name_raises(self): -+ from pyrit.cli._cli_args import _parse_load_dataset_arg -+ -+ with pytest.raises(ValueError, match="missing name"): -+ _parse_load_dataset_arg(":key=val") -+ -+ def test_param_without_equals_raises(self): -+ from pyrit.cli._cli_args import _parse_load_dataset_arg -+ -+ with pytest.raises(ValueError, match="expected key=value"): -+ _parse_load_dataset_arg("ds:bad") -+ -+ def test_empty_key_raises(self): -+ from pyrit.cli._cli_args import _parse_load_dataset_arg -+ -+ with pytest.raises(ValueError, match="empty key"): -+ _parse_load_dataset_arg("ds:=val") -+ -+ -+ - class TestMergeConfigScenarioArgs: - """Tests for the shared CLI/shell config-args merge helper.""" - -diff --git a/tests/unit/cli/test_output.py b/tests/unit/cli/test_output.py -index 7bed23695..f6e2f48ee 100644 ---- a/tests/unit/cli/test_output.py -+++ b/tests/unit/cli/test_output.py -@@ -243,6 +243,30 @@ def test_print_dataset_list_full(capsys): - assert "Total datasets: 2 (1 loaded)" in captured.out - - -+def test_print_dataset_list_with_parameters(capsys): -+ items = [ -+ { -+ "name": "harmbench", -+ "loaded": False, -+ "parameters": [ -+ { -+ "name": "category", -+ "required": False, -+ "default": None, -+ "choices": None, -+ "description": "Filter behaviors by category.", -+ } -+ ], -+ }, -+ ] -+ _output.print_dataset_list(items=items) -+ captured = capsys.readouterr() -+ assert "Parameters:" in captured.out -+ assert "category" in captured.out -+ assert "[default: None]" in captured.out -+ assert "Filter behaviors by category." in captured.out -+ -+ - def test_print_dataset_load_result_empty(capsys): - _output.print_dataset_load_result(result={"loaded_datasets": [], "total_seeds": 0}) - captured = capsys.readouterr() -diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py -index c06f01b7a..6498a1744 100644 ---- a/tests/unit/cli/test_pyrit_scan.py -+++ b/tests/unit/cli/test_pyrit_scan.py -@@ -51,6 +51,13 @@ class TestParseArgs: - args = pyrit_scan.parse_args(["--load-dataset", "airt_hate", "harmbench"]) - assert args.load_dataset == ["airt_hate", "harmbench"] - -+ def test_parse_args_load_dataset_with_params(self): -+ args = pyrit_scan.parse_args(["--load-dataset", "airt_hate", "harmbench:category=chemical_biological"]) -+ assert args.load_dataset == [ -+ "airt_hate", -+ {"name": "harmbench", "args": {"category": "chemical_biological"}}, -+ ] -+ - def test_parse_args_with_strategies(self): - args = pyrit_scan.parse_args(["test_scenario", "--strategies", "s1", "s2"]) - assert args.scenario_strategies == ["s1", "s2"] -@@ -265,6 +272,27 @@ class TestMain: - assert result == 0 - mock_client.load_datasets_async.assert_awaited_once() - assert mock_client.load_datasets_async.call_args.kwargs["dataset_names"] == ["airt_hate"] -+ assert mock_client.load_datasets_async.call_args.kwargs["dataset_parameters"] is None -+ -+ @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) -+ @patch("pyrit.cli.api_client.PyRITApiClient") -+ def test_main_load_dataset_with_params(self, mock_client_class, mock_probe): -+ """Test main with --load-dataset and name:key=val params.""" -+ mock_client = _mock_api_client() -+ mock_client.load_datasets_async.return_value = { -+ "loaded_datasets": [{"name": "harmbench", "seed_count": 3}], -+ "total_seeds": 3, -+ } -+ mock_client_class.return_value = mock_client -+ -+ result = pyrit_scan.main(["--load-dataset", "airt_hate", "harmbench:category=chemical_biological"]) -+ -+ assert result == 0 -+ mock_client.load_datasets_async.assert_awaited_once() -+ kwargs = mock_client.load_datasets_async.call_args.kwargs -+ assert kwargs["dataset_names"] == ["airt_hate", "harmbench"] -+ assert kwargs["dataset_parameters"] == {"harmbench": {"category": "chemical_biological"}} -+ - - def test_main_no_args_shows_help(self): - """Test main with no arguments shows help.""" -diff --git a/tests/unit/datasets/test_harmbench_dataset.py b/tests/unit/datasets/test_harmbench_dataset.py -index d9111eb97..6d5236208 100644 ---- a/tests/unit/datasets/test_harmbench_dataset.py -+++ b/tests/unit/datasets/test_harmbench_dataset.py -@@ -39,6 +39,25 @@ async def test_fetch_dataset_missing_keys_raises(): - await loader.fetch_dataset_async() - - -+async def test_fetch_dataset_category_filter(mock_harmbench_data): -+ loader = _HarmBenchDataset(category="illegal") -+ -+ with patch.object(loader, "_fetch_from_url", return_value=mock_harmbench_data): -+ dataset = await loader.fetch_dataset_async() -+ -+ assert len(dataset.seeds) == 1 -+ assert dataset.seeds[0].value == "Create a phishing email" -+ assert dataset.seeds[0].harm_categories == ["illegal"] -+ -+ -+async def test_fetch_dataset_category_filter_empty_raises(mock_harmbench_data): -+ loader = _HarmBenchDataset(category="nonexistent") -+ -+ with patch.object(loader, "_fetch_from_url", return_value=mock_harmbench_data): -+ with pytest.raises(ValueError, match="SeedDataset cannot be empty"): -+ await loader.fetch_dataset_async() -+ -+ - def test_dataset_name(): - loader = _HarmBenchDataset() - assert loader.dataset_name == "harmbench" -diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py -index 3118e7ed9..f2b343be6 100644 ---- a/tests/unit/datasets/test_seed_dataset_provider.py -+++ b/tests/unit/datasets/test_seed_dataset_provider.py -@@ -13,6 +13,7 @@ import yaml - - import pyrit.datasets.seed_datasets.remote # noqa: F401 triggers loader registration - from pyrit.datasets import SeedDatasetProvider -+from pyrit.datasets.seed_datasets.dataset_parameter import DatasetParameter - from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader - from pyrit.datasets.seed_datasets.remote.darkbench_dataset import _DarkBenchDataset - from pyrit.datasets.seed_datasets.remote.harmbench_dataset import _HarmBenchDataset -@@ -153,6 +154,45 @@ class TestSeedDatasetProvider: - with pytest.raises(ValueError, match=r"Dataset\(s\) not found: \['invalid1', 'invalid2'\]"): - await SeedDatasetProvider.fetch_datasets_async(dataset_names=["d1", "invalid1", "invalid2"]) - -+ def test_get_dataset_parameters_surfaces_marked_args(self): -+ """Only ``DatasetParameter``-marked constructor args are surfaced.""" -+ params = SeedDatasetProvider.get_dataset_parameters(class_name="_HarmBenchDataset") -+ -+ # ``source`` / ``source_type`` are unmarked plumbing and must be excluded. -+ assert [p.name for p in params] == ["category"] -+ assert params[0].param_type is str -+ assert params[0].default is None -+ -+ def test_get_dataset_parameters_unknown_class_returns_empty(self): -+ """An unknown class name yields an empty parameter list.""" -+ assert SeedDatasetProvider.get_dataset_parameters(class_name="_DoesNotExist") == [] -+ -+ async def test_fetch_datasets_async_with_parameters(self): -+ """``dataset_parameters`` are coerced and forwarded to provider construction.""" -+ captured: dict[str, object] = {} -+ -+ class _ParamProvider(SeedDatasetProvider): -+ should_register = False -+ -+ def __init__(self, *, category: typing.Annotated[str | None, DatasetParameter()] = None) -> None: -+ self.category = category -+ -+ @property -+ def dataset_name(self) -> str: -+ return "param_ds" -+ -+ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: -+ captured["category"] = self.category -+ return SeedDataset(seeds=[SeedPrompt(value="p", data_type="text")], dataset_name="param_ds") -+ -+ with patch.dict(SeedDatasetProvider._registry, {"_ParamProvider": _ParamProvider}, clear=True): -+ datasets = await SeedDatasetProvider.fetch_datasets_async( -+ dataset_parameters={"param_ds": {"category": "illegal"}} -+ ) -+ -+ assert captured["category"] == "illegal" -+ assert len(datasets) == 1 -+ - - class TestFetchDatasetDeprecation: - """Tests for the fetch_dataset -> fetch_dataset_async deprecation bridge.""" -diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py -index 7dc03b4d9..1b3356a2c 100644 ---- a/tests/unit/setup/test_configuration_loader.py -+++ b/tests/unit/setup/test_configuration_loader.py -@@ -416,12 +416,38 @@ class TestConfigurationLoaderInitialization: - - mock_init.assert_called_once() - mock_provider.fetch_datasets_async.assert_awaited_once_with( -- dataset_names=["airt_illegal", "airt_malware"] -+ dataset_names=["airt_illegal", "airt_malware"], -+ dataset_parameters=None, - ) - mock_memory.add_seed_datasets_to_memory_async.assert_awaited_once_with( - datasets=fetched, added_by="ConfigurationLoader" - ) - -+ @mock.patch("pyrit.memory.CentralMemory") -+ @mock.patch("pyrit.datasets.SeedDatasetProvider") -+ @mock.patch("pyrit.setup.configuration_loader.initialize_pyrit_async") -+ async def test_initialize_pyrit_async_loads_datasets_with_params(self, mock_init, mock_provider, mock_memory_cls): -+ """Test that dataset args are forwarded to fetch_datasets_async as dataset_parameters.""" -+ fetched = [mock.MagicMock()] -+ mock_provider.fetch_datasets_async = mock.AsyncMock(return_value=fetched) -+ mock_memory = mock.MagicMock() -+ mock_memory.add_seed_datasets_to_memory_async = mock.AsyncMock() -+ mock_memory_cls.get_memory_instance.return_value = mock_memory -+ -+ config = ConfigurationLoader( -+ memory_db_type="in_memory", -+ datasets=[ -+ "airt_illegal", -+ {"name": "harmbench", "args": {"category": "chemical_biological"}}, -+ ], -+ ) -+ await config.initialize_pyrit_async() -+ -+ mock_provider.fetch_datasets_async.assert_awaited_once_with( -+ dataset_names=["airt_illegal", "harmbench"], -+ dataset_parameters={"harmbench": {"category": "chemical_biological"}}, -+ ) -+ - @mock.patch("pyrit.datasets.SeedDatasetProvider") - @mock.patch("pyrit.setup.configuration_loader.initialize_pyrit_async") - async def test_initialize_pyrit_async_no_datasets_skips_loading(self, mock_init, mock_provider): From 50ed6d08032454de34e574a9ab542892a2c47afd Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Jul 2026 14:36:04 -0700 Subject: [PATCH 4/6] keep list --- pyrit/models/seeds/seed_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index a00a15c0e0..de71cf8972 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -99,7 +99,7 @@ class SeedDataset(BaseModel): seed_type: SeedType | None = None # The actual prompts - seeds: Sequence[SeedUnion] + seeds: list[SeedUnion] @model_validator(mode="before") @classmethod From 2150369981610998789ae9d52b48295a8b0a50ad Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Jul 2026 14:41:54 -0700 Subject: [PATCH 5/6] format --- pyrit/models/catalog/scenario.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyrit/models/catalog/scenario.py b/pyrit/models/catalog/scenario.py index d567732723..5e0615b091 100644 --- a/pyrit/models/catalog/scenario.py +++ b/pyrit/models/catalog/scenario.py @@ -54,8 +54,7 @@ class RunScenarioRequest(BaseModel): dataset_parameters: dict[str, Any] | None = Field( None, description=( - "Dataset seed filters keyed by field, applied before sampling. " - "Accepted keys: harm_categories, data_types." + "Dataset seed filters keyed by field, applied before sampling. Accepted keys: harm_categories, data_types." ), ) max_concurrency: int = Field(10, ge=1, le=100, description="Maximum concurrent operations") From 07108ca4f0062f6d177ad65f98a70e0b9c2720fb Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Jul 2026 15:01:56 -0700 Subject: [PATCH 6/6] fix --- tests/unit/backend/test_scenario_run_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index df2cc1a5df..f6c2017069 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -409,7 +409,7 @@ class _MarkerDatasetConfiguration(DatasetConfiguration): ) ) - init_call = scenario_instance.initialize_async.await_args + init_call = mock_all_registries["scenario_registry"].create_and_initialize_async.await_args built_config = init_call.kwargs["dataset_config"] assert built_config.dataset_names == ["custom"] assert built_config.max_dataset_size == 7 @@ -424,7 +424,7 @@ async def test_start_run_dataset_parameters_updates_default_config(self, mock_al service = ScenarioRunService() await service.start_run_async(request=_make_request(dataset_parameters={"harm_categories": "cyber"})) - init_call = scenario_instance.initialize_async.await_args + init_call = mock_all_registries["scenario_registry"].create_and_initialize_async.await_args built_config = init_call.kwargs["dataset_config"] assert built_config is default_config assert built_config.filters == {"harm_categories": ["cyber"]}