Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions pyrit/backend/services/scenario_run_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
TargetRegistry,
)
from pyrit.scenario import Scenario
from pyrit.scenario.core import DatasetAttackConfiguration
from pyrit.scenario.core import DatasetAttackConfiguration, build_dataset_filters

if TYPE_CHECKING:
from pyrit.prompt_converter import PromptConverter
Expand Down Expand Up @@ -302,15 +302,21 @@ def _build_init_kwargs(
if request.labels:
init_kwargs["memory_labels"] = request.labels

dataset_filters = build_dataset_filters(parameters=request.dataset_parameters)

# Resolve strategies and dataset config from a temporary instance of the
# scenario. The downstream _initialize_scenario_async builds its own
# instance (so scenario_result_id can be passed), so this is a cheap
# throwaway used only for introspection. Introspection is required
# whenever the caller wants to override strategies, dataset names, or
# the sample cap, because each of those needs the scenario's own
# strategy enum or dataset-config subclass to be resolved correctly.
# whenever the caller wants to override strategies, dataset names, the
# sample cap, or dataset filters, because each of those needs the
# scenario's own strategy enum or dataset-config subclass to be resolved
# correctly.
needs_introspection = (
bool(request.strategies) or bool(request.dataset_names) or request.max_dataset_size is not None
bool(request.strategies)
or bool(request.dataset_names)
or request.max_dataset_size is not None
or bool(dataset_filters)
)
if not needs_introspection:
return init_kwargs
Expand All @@ -334,7 +340,7 @@ def _build_init_kwargs(
if strategy_converters:
init_kwargs["strategy_converters"] = strategy_converters

if request.dataset_names or request.max_dataset_size is not None:
if request.dataset_names or request.max_dataset_size is not None or dataset_filters:
default_config = introspection_instance._default_dataset_config

if request.dataset_names:
Expand All @@ -345,6 +351,7 @@ def _build_init_kwargs(
init_kwargs["dataset_config"] = default_config_class(
dataset_names=request.dataset_names,
max_dataset_size=request.max_dataset_size,
filters=dataset_filters or None,
)
except TypeError as exc:
# The subclass __init__ takes extra required kwargs we cannot
Expand All @@ -354,7 +361,7 @@ def _build_init_kwargs(
# define a no-extra-required-args constructor or surface the
# incompatibility through their own initialize_async validation.
logger.warning(
"Cannot construct %s(dataset_names=..., max_dataset_size=...) (%s). "
"Cannot construct %s(dataset_names=..., max_dataset_size=..., filters=...) (%s). "
"Falling back to a generic DatasetAttackConfiguration; scenario-specific "
"dataset-config behavior may be lost.",
default_config_class.__name__,
Expand All @@ -363,12 +370,17 @@ def _build_init_kwargs(
init_kwargs["dataset_config"] = DatasetAttackConfiguration(
dataset_names=request.dataset_names,
max_dataset_size=request.max_dataset_size,
filters=dataset_filters or None,
)
elif request.max_dataset_size is not None:
else:
# Reuse the scenario's default dataset config (preserves subtype +
# the scenario's own default dataset names) and override only the
# sample cap. Safe because the introspection instance is throwaway.
default_config.max_dataset_size = request.max_dataset_size
# sample cap and/or filters. Safe because the introspection instance
# is throwaway.
if request.max_dataset_size is not None:
default_config.max_dataset_size = request.max_dataset_size
if dataset_filters:
default_config.update_filters(filters=dataset_filters)
init_kwargs["dataset_config"] = default_config

return init_kwargs
Expand Down
46 changes: 46 additions & 0 deletions pyrit/cli/_cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,44 @@ def parse_memory_labels(json_string: str) -> dict[str, str]:
return labels


def parse_dataset_parameter(arg: str) -> tuple[str, str]:
"""
Parse a single ``KEY=VALUE`` dataset-parameter token from the CLI.

Note: The ``arg`` parameter is positional (not keyword-only) so it can be used directly
as an argparse ``type=`` callable and an ``_ArgSpec`` parser. This mirrors
``_parse_initializer_arg`` and is an intentional exception to the keyword-only style rule
for argparse compatibility.

Args:
arg (str): The raw ``KEY=VALUE`` token.

Returns:
tuple[str, str]: The (key, value) pair. The value keeps its raw string form so the
server can coerce and validate it.

Raises:
ValueError: If the token is not in ``KEY=VALUE`` form or the key is empty. Argparse
converts this into a clean CLI error; the shell catches it directly.
"""
if "=" not in arg:
raise ValueError(f"Dataset parameter must be in KEY=VALUE form, got: {arg!r}")
key, _, value = arg.partition("=")
key = key.strip()
if not key:
raise ValueError(f"Dataset parameter key cannot be empty in: {arg!r}")
return key, value


# ---------------------------------------------------------------------------
# Shared argument help text
# ---------------------------------------------------------------------------

# Dataset-filter keys advertised in --dataset-parameters help. Kept as a static list here (this
# module is on the fast --help path and stays import-light); a unit test asserts it matches the
# authoritative registry in pyrit.scenario.core.dataset_configuration.
_ADVERTISED_DATASET_FILTER_KEYS: tuple[str, ...] = ("harm_categories", "data_types")

ARG_HELP = {
"config_file": CONFIG_FILE_HELP,
"initializers": (
Expand All @@ -275,6 +310,10 @@ def parse_memory_labels(json_string: str) -> dict[str, str]:
"Creates a new dataset config; fetches all items unless --max-dataset-size is also specified",
"max_dataset_size": "Maximum number of items to use from the dataset (must be >= 1). "
"Limits new datasets if --dataset-names provided, otherwise overrides scenario's default limit",
"dataset_parameters": "Dataset seed filters as KEY=VALUE tokens "

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rich agrees with this comment but it is copilot generated:

The comma-list syntax presents one uniform surface, but the two exposed filters resolve with different semantics in get_seeds:

  • harm_categories_add_list_conditions ANDs a .contains() per value, so harm_categories=cyber,violence means "tagged with both cyber AND violence," and it's a substring match.
  • data_types.in_(values), so data_types=text,image_path means "either type," exact match.

A user writing harm_categories=cyber,violence will almost certainly expect a union and instead silently get the (likely empty) intersection. At minimum the help text and the filters docstring should spell out that comma-list semantics are per-field (AND/substring for harm categories, OR/exact for data types). Worth a test that pins each field's behavior against real seeds, since this is the surface most likely to confuse people.

"(e.g., harm_categories=cyber data_types=text). Accepted keys: " + ", ".join(_ADVERTISED_DATASET_FILTER_KEYS) + ". "
"Keys filter seeds before sizing. "
"List values may be comma-separated (e.g., harm_categories=cyber,violence)",
"target": "Name of a registered target from the TargetRegistry to use as the objective target. "
"Targets are registered by initializers (e.g., 'target' initializer). "
"Use --list-targets to see available target names after initializers have run",
Expand Down Expand Up @@ -409,6 +448,12 @@ class _ArgSpec:
result_key="max_dataset_size",
parser=lambda v: validate_integer(v, name="--max-dataset-size", min_value=1),
)
_DATASET_PARAMETERS_ARG = _ArgSpec(
flags=["--dataset-parameters"],
result_key="dataset_parameters",
multi_value=True,
parser=parse_dataset_parameter,
)
_TARGET_ARG = _ArgSpec(
flags=["--target"],
result_key="target",
Expand All @@ -422,6 +467,7 @@ class _ArgSpec:
_MEMORY_LABELS_ARG,
_DATASET_NAMES_ARG,
_MAX_DATASET_SIZE_ARG,
_DATASET_PARAMETERS_ARG,
_TARGET_ARG,
]

Expand Down
10 changes: 10 additions & 0 deletions pyrit/cli/pyrit_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_parse_initializer_arg,
build_parameters_from_api,
non_negative_int,
parse_dataset_parameter,
positive_int,
validate_log_level_argparse,
)
Expand Down Expand Up @@ -267,6 +268,13 @@ def _build_base_parser(*, add_help: bool = True) -> ArgumentParser:
type=positive_int,
help=ARG_HELP["max_dataset_size"],
)
run_group.add_argument(
"--dataset-parameters",
type=parse_dataset_parameter,
nargs="+",
metavar="KEY=VALUE",
help=ARG_HELP["dataset_parameters"],
)

return parser

Expand Down Expand Up @@ -644,6 +652,8 @@ def _build_run_request(*, parsed_args: Namespace, scenario_name: str) -> RunScen
kwargs["dataset_names"] = parsed_args.dataset_names
if parsed_args.max_dataset_size is not None:
kwargs["max_dataset_size"] = parsed_args.max_dataset_size
if parsed_args.dataset_parameters:
kwargs["dataset_parameters"] = dict(parsed_args.dataset_parameters)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rich agrees with this comment but it is copilot generated:

--dataset-parameters harm_categories=cyber harm_categories=violencedict(...) keeps only the last, silently dropping cyber. Given harm_categories is AND/substring, a user who repeats the key is almost certainly trying to add constraints and gets none of the earlier ones with no warning. Prefer failing loud: detect duplicate keys when folding the token list into a dict and raise ValueError("Duplicate dataset parameter 'harm_categories'; combine values with commas: harm_categories=cyber,violence"). Cheap to do at the same spot the tuples are collapsed (same applies to the shell path in pyrit_shell.py).

if parsed_args.memory_labels:
kwargs["labels"] = parse_memory_labels(json_string=parsed_args.memory_labels)

Expand Down
2 changes: 2 additions & 0 deletions pyrit/cli/pyrit_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,8 @@ def do_run(self, line: str) -> None:
request_kwargs["dataset_names"] = args["dataset_names"]
if args.get("max_dataset_size") is not None:
request_kwargs["max_dataset_size"] = args["max_dataset_size"]
if args.get("dataset_parameters"):
request_kwargs["dataset_parameters"] = dict(args["dataset_parameters"])
if args.get("memory_labels"):
request_kwargs["labels"] = args["memory_labels"]

Expand Down
6 changes: 6 additions & 0 deletions pyrit/models/catalog/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class RunScenarioRequest(BaseModel):
strategies: list[str] | None = Field(None, description="Strategy names to use (uses scenario default if omitted)")
dataset_names: list[str] | None = Field(None, description="Dataset names to use (uses scenario default if omitted)")
max_dataset_size: int | None = Field(None, ge=1, description="Maximum items per dataset")
dataset_parameters: dict[str, Any] | None = Field(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rich agrees with this comment but it is copilot generated:

Three names for one concept makes this hard to follow: it's dataset_parameters at the CLI/request boundary, filters internally, and the comment/test call it a "registry." It isn't a registry — PyRIT registries build and store components (TargetRegistry, AttackTechniqueRegistry); this is a hardcoded frozenset of two strings. Calling it a registry implies an extension point that doesn't exist (you can't register a filter, you edit a literal). Please drop the "registry" framing (the comment text and the TestDatasetFilterRegistry name), and ideally settle on one term — I'd pick "filters" end-to-end, or if dataset_parameters is the intended public name, at least don't call the constant a registry.

None,
description=(
"Dataset seed filters keyed by field, applied before sampling. Accepted keys: harm_categories, data_types."
),
)
max_concurrency: int = Field(10, ge=1, le=100, description="Maximum concurrent operations")
max_retries: int = Field(0, ge=0, le=20, description="Maximum retry attempts on failure")
labels: dict[str, str] | None = Field(None, description="Labels to attach to memory entries")
Expand Down
2 changes: 2 additions & 0 deletions pyrit/scenario/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DatasetConstraintError,
DatasetSourceKind,
ResolvedDataset,
build_dataset_filters,
require_nonempty,
)
from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario
Expand All @@ -34,6 +35,7 @@
"INLINE_DATASET_NAME",
"Parameter",
"ResolvedDataset",
"build_dataset_filters",
"require_nonempty",
"Scenario",
"ScenarioCompositeStrategy",
Expand Down
Loading
Loading