From ca1d60610c6431a7700f80d52a9bed20475c6951 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Wed, 27 May 2026 12:16:40 -0500 Subject: [PATCH 1/2] Use robust summaries in nvbench_compare classification Teach nvbench_compare to parse GPU timing summaries into structured values and prefer the robust median/IQR summaries when both compared measurements provide them. Fall back to the existing mean/stdev summaries when robust summaries are not available. Classify comparisons with the larger available relative noise estimate instead of the smaller one, keep unavailable noise distinct from encoded infinite noise, and report improvements separately from regressions. Keep the process exit code as success for completed comparisons; regression counts are reported in the summary instead of being used as the process status. Make plotting tolerate unavailable noise by leaving gaps in confidence bands, sort plotted series by the plotted axis, and avoid reusing pyplot state across plot calls. Add focused Python tests for robust-summary preference, unavailable-noise classification, non-finite timing centers, plot-along handling when the selected axis is absent, and the exit-code contract. --- python/scripts/nvbench_compare.py | 478 ++++++++++++++++++++-------- python/test/test_nvbench_compare.py | 302 ++++++++++++++++++ 2 files changed, 643 insertions(+), 137 deletions(-) create mode 100644 python/test/test_nvbench_compare.py diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index c6370332..cacb419f 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -1,10 +1,14 @@ #!/usr/bin/env python +# +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import argparse import math import os import sys -from enum import StrEnum +from dataclasses import dataclass +from enum import Enum import jsondiff import tabulate @@ -23,15 +27,47 @@ def version_tuple(v): tabulate_version = version_tuple(tabulate.__version__) -all_ref_devices = [] -all_cmp_devices = [] +all_ref_devices: list[dict] = [] +all_cmp_devices: list[dict] = [] config_count = 0 unknown_count = 0 -failure_count = 0 +improvement_count = 0 +regression_count = 0 pass_count = 0 +GPU_TIME_MIN_TAG = "nv/cold/time/gpu/min" +GPU_TIME_MAX_TAG = "nv/cold/time/gpu/max" +GPU_TIME_MEAN_TAG = "nv/cold/time/gpu/mean" +GPU_TIME_STDEV_TAG = "nv/cold/time/gpu/stdev/absolute" +GPU_TIME_STDEV_RELATIVE_TAG = "nv/cold/time/gpu/stdev/relative" +GPU_TIME_MEDIAN_TAG = "nv/cold/time/gpu/median" +GPU_TIME_IR_TAG = "nv/cold/time/gpu/ir/absolute" +GPU_TIME_IR_RELATIVE_TAG = "nv/cold/time/gpu/ir/relative" -class Emoji(StrEnum): +# These dataclasses are treated as parsed value objects. frozen=True prevents +# accidental field reassignment but does not imply deep immutability. + + +@dataclass(frozen=True) +class GpuTimeSummary: + minimum: float | None + maximum: float | None + mean: float | None + stdev: float | None + stdev_relative: float | None + median: float | None + interquartile_range: float | None + interquartile_range_relative: float | None + + +@dataclass(frozen=True) +class TimeEstimate: + center: float | None + relative_dispersion: float | None + + +# TODO(opavlyk): replace with Emoji(StrEnum) after EOL of Python 3.10 +class Emoji(str, Enum): YELLOW = "\U0001f7e1" BLUE = "\U0001f535" GREEN = "\U0001f7e2" @@ -42,13 +78,153 @@ class Emoji(StrEnum): def colorize(msg: str, fore: Fore, emoji: Emoji, no_color: bool) -> str: if no_color: prefix = "" - if emoji_s := str(emoji): + if emoji_s := emoji.value: prefix = f"{emoji_s} " return f"{prefix}{msg}" else: return f"{fore}{msg}{Fore.RESET}" +def lookup_summary(summaries, tag): + return next((summary for summary in summaries if summary["tag"] == tag), None) + + +def extract_summary_value(summary): + summary_tag = summary.get("tag", "") + for value_data in summary.get("data", []): + if value_data.get("name") != "value": + continue + + value_type = value_data.get("type") + if value_type != "float64": + raise ValueError( + f"summary {summary_tag!r} field 'value' has type " + f"{value_type!r}; expected 'float64'" + ) + if "value" not in value_data: + raise ValueError(f"summary {summary_tag!r} field 'value' is missing value") + return value_data["value"] + + raise ValueError(f"summary {summary_tag!r} is missing field 'value'") + + +def normalize_float_value(value, *, null_value=None): + if value is None: + return null_value + return float(value) + + +def extract_summary_float(summaries, tag, *, null_value=None): + summary = lookup_summary(summaries, tag) + if summary is None: + return None + return normalize_float_value(extract_summary_value(summary), null_value=null_value) + + +def extract_gpu_time_summary(summaries): + return GpuTimeSummary( + minimum=extract_summary_float(summaries, GPU_TIME_MIN_TAG), + maximum=extract_summary_float(summaries, GPU_TIME_MAX_TAG), + mean=extract_summary_float(summaries, GPU_TIME_MEAN_TAG), + stdev=extract_summary_float(summaries, GPU_TIME_STDEV_TAG, null_value=math.inf), + stdev_relative=extract_summary_float( + summaries, GPU_TIME_STDEV_RELATIVE_TAG, null_value=math.inf + ), + median=extract_summary_float(summaries, GPU_TIME_MEDIAN_TAG), + interquartile_range=extract_summary_float( + summaries, GPU_TIME_IR_TAG, null_value=math.inf + ), + interquartile_range_relative=extract_summary_float( + summaries, GPU_TIME_IR_RELATIVE_TAG, null_value=math.inf + ), + ) + + +def compute_relative_dispersion(dispersion, center): + if ( + dispersion is None + or center is None + or center <= 0 + or not math.isfinite(center) + or dispersion < 0 + or math.isnan(dispersion) + ): + return None + return dispersion / center + + +def has_robust_estimate(summary): + return summary.median is not None and ( + summary.interquartile_range_relative is not None + or summary.interquartile_range is not None + ) + + +def has_mean_estimate(summary): + return summary.mean is not None and ( + summary.stdev_relative is not None or summary.stdev is not None + ) + + +def select_relative_dispersion(relative_dispersion, absolute_dispersion, center): + if relative_dispersion is not None: + return relative_dispersion + return compute_relative_dispersion(absolute_dispersion, center) + + +def compute_common_time_estimates(ref_summary, cmp_summary): + if has_robust_estimate(ref_summary) and has_robust_estimate(cmp_summary): + return ( + TimeEstimate( + center=ref_summary.median, + relative_dispersion=select_relative_dispersion( + ref_summary.interquartile_range_relative, + ref_summary.interquartile_range, + ref_summary.median, + ), + ), + TimeEstimate( + center=cmp_summary.median, + relative_dispersion=select_relative_dispersion( + cmp_summary.interquartile_range_relative, + cmp_summary.interquartile_range, + cmp_summary.median, + ), + ), + ) + + if has_mean_estimate(ref_summary) and has_mean_estimate(cmp_summary): + return ( + TimeEstimate( + center=ref_summary.mean, + relative_dispersion=select_relative_dispersion( + ref_summary.stdev_relative, ref_summary.stdev, ref_summary.mean + ), + ), + TimeEstimate( + center=cmp_summary.mean, + relative_dispersion=select_relative_dispersion( + cmp_summary.stdev_relative, cmp_summary.stdev, cmp_summary.mean + ), + ), + ) + + return ( + TimeEstimate( + center=ref_summary.mean, + relative_dispersion=compute_relative_dispersion( + ref_summary.stdev, ref_summary.mean + ), + ), + TimeEstimate( + center=cmp_summary.mean, + relative_dispersion=compute_relative_dispersion( + cmp_summary.stdev, cmp_summary.mean + ), + ), + ) + + def find_matching_bench(needle, haystack): for hay in haystack: if hay["name"] == needle["name"]: @@ -69,8 +245,8 @@ def format_int64_axis_value(axis_name, axis_value, axes): value = int(axis_value["value"]) if axis_flags == "pow2": value = math.log2(value) - return "2^%d" % value - return "%d" % value + return f"2^{value:.0f}" + return f"{value:d}" def format_float64_axis_value(axis_name, axis_value, axes): @@ -78,11 +254,11 @@ def format_float64_axis_value(axis_name, axis_value, axes): def format_type_axis_value(axis_name, axis_value, axes): - return "%s" % axis_value["value"] + return f"{axis_value['value']}" def format_string_axis_value(axis_name, axis_value, axes): - return "%s" % axis_value["value"] + return f"{axis_value['value']}" def format_axis_value(axis_name, axis_value, axes): @@ -98,10 +274,10 @@ def format_axis_value(axis_name, axis_value, axes): return format_string_axis_value(axis_name, axis_value, axes) -def make_display(name: str, display_values: [list[str]]) -> str: +def make_display(name: str, display_values: list[str]) -> str: open_bracket, close_bracket = ("[", "]") if len(display_values) > 1 else ("", "") - display_values = ",".join(display_values) - return f"{name}={open_bracket}{display_values}{close_bracket}" + joined_values = ",".join(display_values) + return f"{name}={open_bracket}{joined_values}{close_bracket}" def parse_axis_filters(axis_args): @@ -188,16 +364,21 @@ def format_duration(seconds): else: multiplier = 1e6 units = "us" - return "%0.3f %s" % (seconds * multiplier, units) + return f"{seconds * multiplier:0.3f} {units}" def format_percentage(percentage): - # When there aren't enough samples for a meaningful noise measurement, - # the noise is recorded as infinity. Unfortunately, JSON spec doesn't - # allow for inf, so these get turned into null. if percentage is None: + return "n/a" + if math.isnan(percentage): + return "n/a" + if math.isinf(percentage): return "inf" - return "%0.2f%%" % (percentage * 100.0) + return f"{percentage * 100.0:0.2f}%" + + +def has_finite_noise(noise): + return noise is not None and math.isfinite(noise) def format_axis_values(axis_values, axes, axis_filters=None): @@ -373,108 +554,80 @@ def compare_benches( if not ref_summaries or not cmp_summaries: continue - def lookup_summary(summaries, tag): - return next(filter(lambda s: s["tag"] == tag, summaries), None) - - cmp_time_summary = lookup_summary( - cmp_summaries, "nv/cold/time/gpu/mean" - ) - ref_time_summary = lookup_summary( - ref_summaries, "nv/cold/time/gpu/mean" - ) - cmp_noise_summary = lookup_summary( - cmp_summaries, "nv/cold/time/gpu/stdev/relative" - ) - ref_noise_summary = lookup_summary( - ref_summaries, "nv/cold/time/gpu/stdev/relative" - ) - # TODO: Use other timings, too. Maybe multiple rows, with a # "Timing" column + values "CPU/GPU/Batch"? - if not all( - [ - cmp_time_summary, - ref_time_summary, - cmp_noise_summary, - ref_noise_summary, - ] - ): + cmp_gpu_time = extract_gpu_time_summary(cmp_summaries) + ref_gpu_time = extract_gpu_time_summary(ref_summaries) + ref_estimate, cmp_estimate = compute_common_time_estimates( + ref_gpu_time, cmp_gpu_time + ) + + cmp_time = cmp_estimate.center + ref_time = ref_estimate.center + + if cmp_time is None or ref_time is None: continue - def extract_value(summary): - summary_data = summary["data"] - value_data = next( - filter(lambda v: v["name"] == "value", summary_data) - ) - assert value_data["type"] == "float64" - return value_data["value"] + if not math.isfinite(cmp_time) or not math.isfinite(ref_time): + continue - cmp_time = extract_value(cmp_time_summary) - ref_time = extract_value(ref_time_summary) - cmp_noise = extract_value(cmp_noise_summary) - ref_noise = extract_value(ref_noise_summary) + if cmp_time <= 0.0 or ref_time <= 0.0: + continue - # Convert string encoding to expected numerics: - cmp_time = float(cmp_time) - ref_time = float(ref_time) + cmp_noise = cmp_estimate.relative_dispersion + ref_noise = ref_estimate.relative_dispersion diff = cmp_time - ref_time frac_diff = diff / ref_time - if ref_noise and cmp_noise: - ref_noise = float(ref_noise) - cmp_noise = float(cmp_noise) - min_noise = min(ref_noise, cmp_noise) - elif ref_noise: - ref_noise = float(ref_noise) - min_noise = ref_noise - elif cmp_noise: - cmp_noise = float(cmp_noise) - min_noise = cmp_noise + if not has_finite_noise(ref_noise) or not has_finite_noise(cmp_noise): + max_noise = None else: - min_noise = None # Noise is inf + max_noise = max(ref_noise, cmp_noise) if plot_along: axis_name = [] - axis_value = "--" + axis_value = None for av in axis_values: if av["name"] != plot_along: axis_name.append(f"""{av["name"]} = {av["value"]}""") else: axis_value = float(av["value"]) - axis_name = ", ".join(axis_name) + if axis_value is not None: + axis_name = ", ".join(axis_name) - if axis_name not in plot_data["cmp"]: - plot_data["cmp"][axis_name] = {} - plot_data["ref"][axis_name] = {} - plot_data["cmp_noise"][axis_name] = {} - plot_data["ref_noise"][axis_name] = {} + if axis_name not in plot_data["cmp"]: + plot_data["cmp"][axis_name] = {} + plot_data["ref"][axis_name] = {} + plot_data["cmp_noise"][axis_name] = {} + plot_data["ref_noise"][axis_name] = {} - plot_data["cmp"][axis_name][axis_value] = cmp_time - plot_data["ref"][axis_name][axis_value] = ref_time - plot_data["cmp_noise"][axis_name][axis_value] = cmp_noise - plot_data["ref_noise"][axis_name][axis_value] = ref_noise + plot_data["cmp"][axis_name][axis_value] = cmp_time + plot_data["ref"][axis_name][axis_value] = ref_time + plot_data["cmp_noise"][axis_name][axis_value] = cmp_noise + plot_data["ref_noise"][axis_name][axis_value] = ref_noise global config_count global unknown_count global pass_count - global failure_count + global improvement_count + global regression_count config_count += 1 - if not min_noise: + if max_noise is None: unknown_count += 1 status_label = "????" status = colorize(status_label, Fore.YELLOW, Emoji.YELLOW, no_color) - elif abs(frac_diff) <= min_noise: + elif abs(frac_diff) <= max_noise: pass_count += 1 status_label = "SAME" status = colorize(status_label, Fore.BLUE, Emoji.BLUE, no_color) elif diff < 0: - failure_count += 1 + improvement_count += 1 status_label = "FAST" status = colorize(status_label, Fore.GREEN, Emoji.GREEN, no_color) else: - failure_count += 1 + regression_count += 1 status_label = "SLOW" status = colorize(status_label, Fore.RED, Emoji.RED, no_color) @@ -510,16 +663,11 @@ def extract_value(summary): ref_device = find_device_by_id(ref_state["device"], all_ref_devices) if cmp_device == ref_device: - print("## [%d] %s\n" % (cmp_device["id"], cmp_device["name"])) + print(f"## [{cmp_device['id']}] {cmp_device['name']}\n") else: print( - "## [%d] %s vs. [%d] %s\n" - % ( - ref_device["id"], - ref_device["name"], - cmp_device["id"], - cmp_device["name"], - ) + f"## [{ref_device['id']}] {ref_device['name']} vs. " + f"[{cmp_device['id']}] {cmp_device['name']}\n" ) # colalign and github format require tabulate 0.8.3 if tabulate_version >= (0, 8, 3): @@ -534,30 +682,75 @@ def extract_value(summary): print("") if plot_along: - plt.xscale("log") - plt.yscale("log") - plt.xlabel(plot_along) - plt.ylabel("time [s]") - plt.title(cmp_device["name"]) - - def plot_line(key, shape, label): - x = [float(x) for x in plot_data[key][axis].keys()] - y = list(plot_data[key][axis].values()) - - noise = list(plot_data[key + "_noise"][axis].values()) - - top = [y[i] + y[i] * noise[i] for i in range(len(x))] - bottom = [y[i] - y[i] * noise[i] for i in range(len(x))] - - p = plt.plot(x, y, shape, marker="o", label=label) - plt.fill_between(x, bottom, top, color=p[0].get_color(), alpha=0.1) - - for axis in plot_data["cmp"].keys(): - plot_line("cmp", "-", axis) - plot_line("ref", "--", axis + " ref") - - plt.legend() - plt.show() + fig = plt.figure() + try: + plt.xscale("log") + plt.yscale("log") + plt.xlabel(plot_along) + plt.ylabel("time [s]") + plt.title(cmp_device["name"]) + + def plot_line(key, shape, label, data_axis, data=plot_data): + axis_times = data[key][data_axis] + if not axis_times: + return + axis_noise = data[key + "_noise"][data_axis] + series = sorted( + ( + ( + float(axis_value), + axis_times[axis_value], + axis_noise[axis_value], + ) + for axis_value in axis_times + ), + key=lambda item: item[0], + ) + x, y, noise = map(list, zip(*series, strict=True)) + + p = plt.plot(x, y, shape, marker="o", label=label) + + def plot_confidence_band(first, last): + if last - first < 2: + return + + band_x = x[first:last] + band_y = y[first:last] + band_noise = noise[first:last] + top = [ + band_y[i] + band_y[i] * band_noise[i] + for i in range(len(band_x)) + ] + bottom = [ + max( + band_y[i] - band_y[i] * band_noise[i], + band_y[i] * 0.001, + ) + for i in range(len(band_x)) + ] + plt.fill_between( + band_x, bottom, top, color=p[0].get_color(), alpha=0.1 + ) + + start = None + for i, noise_value in enumerate(noise): + if has_finite_noise(noise_value) and start is None: + start = i + if not has_finite_noise(noise_value) and start is not None: + plot_confidence_band(start, i) + start = None + + if start is not None: + plot_confidence_band(start, len(x)) + + for axis in plot_data["cmp"].keys(): + plot_line("cmp", "-", axis, axis) + plot_line("ref", "--", axis + " ref", axis) + + plt.legend() + plt.show() + finally: + plt.close(fig) if plot: title = "%SOL Bandwidth change" @@ -574,7 +767,14 @@ def plot_line(key, shape, label): plot_comparison_entries(comparison_entries, title=title, dark=dark) -def main(): +def main() -> int: + """ + Returns a process exit code. + - 0 means the comparison completed successfully. + - 1 signals an error has occurred. + + The number of detected regressions is reported in the summary output. + """ help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]" parser = argparse.ArgumentParser(prog="nvbench_compare", usage=help_text) parser.add_argument( @@ -628,16 +828,15 @@ def main(): ) args, files_or_dirs = parser.parse_known_args() - print(files_or_dirs) try: axis_filters = parse_axis_filters(args.axis) except ValueError as exc: print(str(exc)) - sys.exit(1) + return 1 if len(files_or_dirs) != 2: parser.print_help() - sys.exit(1) + return 1 # if provided two directories, find all the exactly named files # in both and treat them as the reference and compare @@ -679,26 +878,31 @@ def main(): ) ) if not args.ignore_devices: - sys.exit(1) - - compare_benches( - ref_root["benchmarks"], - cmp_root["benchmarks"], - args.threshold, - args.plot_along, - args.plot, - args.dark, - axis_filters, - args.benchmark, - args.no_color, - ) + return 1 + + try: + compare_benches( + ref_root["benchmarks"], + cmp_root["benchmarks"], + args.threshold, + args.plot_along, + args.plot, + args.dark, + axis_filters, + args.benchmark, + args.no_color, + ) + except ValueError as exc: + print(str(exc)) + return 1 print("# Summary\n") - print("- Total Matches: %d" % config_count) - print(" - Pass (diff <= min_noise): %d" % pass_count) - print(" - Unknown (infinite noise): %d" % unknown_count) - print(" - Failure (diff > min_noise): %d" % failure_count) - return failure_count + print(f"- Total Matches: {config_count}") + print(f" - Pass (abs(%Diff) <= max_noise): {pass_count}") + print(f" - Improvement (abs(%Diff) > max_noise, %Diff < 0): {improvement_count}") + print(f" - Regression (abs(%Diff) > max_noise, %Diff > 0): {regression_count}") + print(f" - Unknown (infinite or unavailable noise): {unknown_count}") + return 0 if __name__ == "__main__": diff --git a/python/test/test_nvbench_compare.py b/python/test/test_nvbench_compare.py new file mode 100644 index 00000000..8d82accd --- /dev/null +++ b/python/test/test_nvbench_compare.py @@ -0,0 +1,302 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import importlib.util +import sys +import types +from pathlib import Path + +import pytest + + +@pytest.fixture +def nvbench_compare(monkeypatch): + class DummyLine: + def get_color(self): + return "black" + + pyplot = types.ModuleType("matplotlib.pyplot") + pyplot.figure = lambda *args, **kwargs: None + pyplot.xscale = lambda *args, **kwargs: None + pyplot.yscale = lambda *args, **kwargs: None + pyplot.xlabel = lambda *args, **kwargs: None + pyplot.ylabel = lambda *args, **kwargs: None + pyplot.title = lambda *args, **kwargs: None + pyplot.plot = lambda *args, **kwargs: [DummyLine()] + pyplot.fill_between = lambda *args, **kwargs: None + pyplot.legend = lambda *args, **kwargs: None + pyplot.show = lambda *args, **kwargs: None + pyplot.close = lambda *args, **kwargs: None + + matplotlib = types.ModuleType("matplotlib") + matplotlib.pyplot = pyplot + monkeypatch.setitem(sys.modules, "matplotlib", matplotlib) + monkeypatch.setitem(sys.modules, "matplotlib.pyplot", pyplot) + monkeypatch.setitem( + sys.modules, + "seaborn", + types.SimpleNamespace(set_theme=lambda *args, **kwargs: None), + ) + monkeypatch.setitem( + sys.modules, "jsondiff", types.SimpleNamespace(diff=lambda *args, **kwargs: {}) + ) + monkeypatch.setitem( + sys.modules, + "tabulate", + types.SimpleNamespace( + __version__="0.8.10", tabulate=lambda *args, **kwargs: "" + ), + ) + monkeypatch.setitem( + sys.modules, + "colorama", + types.SimpleNamespace( + Fore=types.SimpleNamespace( + BLUE="", + GREEN="", + RED="", + RESET="", + YELLOW="", + ) + ), + ) + + module_path = Path(__file__).resolve().parents[1] / "scripts" / "nvbench_compare.py" + spec = importlib.util.spec_from_file_location("nvbench_compare", module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def make_state( + nvbench_compare, name, *, mean="1.0", noise="0.01", axis_value=None, device=0 +): + return { + "name": name, + "device": device, + "axis_values": [] + if axis_value is None + else [{"name": "A", "type": "int64", "value": axis_value}], + "summaries": [ + { + "tag": nvbench_compare.GPU_TIME_MEAN_TAG, + "data": [{"name": "value", "type": "float64", "value": mean}], + }, + { + "tag": nvbench_compare.GPU_TIME_STDEV_RELATIVE_TAG, + "data": [{"name": "value", "type": "float64", "value": noise}], + }, + ], + } + + +def make_summary(nvbench_compare, tag, value): + return { + "tag": getattr(nvbench_compare, tag), + "data": [{"name": "value", "type": "float64", "value": value}], + } + + +def make_benchmark(states, *, name="bench"): + devices = [] + for state in states: + if state["device"] not in devices: + devices.append(state["device"]) + + return { + "name": name, + "devices": devices, + "axes": [{"name": "A", "type": "int64", "flags": ""}] + if any(state["axis_values"] for state in states) + else [], + "states": states, + } + + +def set_test_devices(monkeypatch, nvbench_compare): + devices = [{"id": 0, "name": "Test GPU"}] + monkeypatch.setattr(nvbench_compare, "all_ref_devices", devices) + monkeypatch.setattr(nvbench_compare, "all_cmp_devices", devices) + monkeypatch.setattr(nvbench_compare, "config_count", 0) + monkeypatch.setattr(nvbench_compare, "pass_count", 0) + monkeypatch.setattr(nvbench_compare, "improvement_count", 0) + monkeypatch.setattr(nvbench_compare, "regression_count", 0) + monkeypatch.setattr(nvbench_compare, "unknown_count", 0) + + +def compare_benches(nvbench_compare, ref_benches, cmp_benches, **kwargs): + nvbench_compare.compare_benches( + ref_benches, + cmp_benches, + threshold=kwargs.get("threshold", 0.0), + plot_along=kwargs.get("plot_along"), + plot=kwargs.get("plot", False), + dark=False, + axis_filters=kwargs.get("axis_filters", []), + benchmark_filters=kwargs.get("benchmark_filters", []), + no_color=True, + ) + + +def test_compare_benches_skips_non_finite_centers(monkeypatch, nvbench_compare): + set_test_devices(monkeypatch, nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "finite", mean="1.0"), + make_state(nvbench_compare, "nan", mean="nan"), + make_state(nvbench_compare, "inf", mean="inf"), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "finite", mean="1.0"), + make_state(nvbench_compare, "nan", mean="1.0"), + make_state(nvbench_compare, "inf", mean="1.0"), + ] + ) + ] + + compare_benches(nvbench_compare, ref_benches, cmp_benches) + + assert nvbench_compare.config_count == 1 + assert nvbench_compare.pass_count == 1 + assert nvbench_compare.improvement_count == 0 + assert nvbench_compare.regression_count == 0 + assert nvbench_compare.unknown_count == 0 + + +def test_compare_benches_prefers_median_and_iqr_when_available( + monkeypatch, nvbench_compare +): + set_test_devices(monkeypatch, nvbench_compare) + + ref_state = make_state(nvbench_compare, "state", mean="1.0", noise="0.01") + ref_state["summaries"].extend( + [ + make_summary(nvbench_compare, "GPU_TIME_MEDIAN_TAG", "1.0"), + make_summary(nvbench_compare, "GPU_TIME_IR_RELATIVE_TAG", "0.01"), + ] + ) + cmp_state = make_state(nvbench_compare, "state", mean="1.0", noise="0.01") + cmp_state["summaries"].extend( + [ + make_summary(nvbench_compare, "GPU_TIME_MEDIAN_TAG", "1.2"), + make_summary(nvbench_compare, "GPU_TIME_IR_RELATIVE_TAG", "0.01"), + ] + ) + + compare_benches( + nvbench_compare, [make_benchmark([ref_state])], [make_benchmark([cmp_state])] + ) + + assert nvbench_compare.config_count == 1 + assert nvbench_compare.pass_count == 0 + assert nvbench_compare.improvement_count == 0 + assert nvbench_compare.regression_count == 1 + assert nvbench_compare.unknown_count == 0 + + +def test_compare_benches_marks_unavailable_noise_unknown(monkeypatch, nvbench_compare): + set_test_devices(monkeypatch, nvbench_compare) + + missing_noise_ref = make_state(nvbench_compare, "missing_noise") + missing_noise_ref["summaries"] = [ + make_summary(nvbench_compare, "GPU_TIME_MEAN_TAG", "1.0") + ] + missing_noise_cmp = make_state(nvbench_compare, "missing_noise") + missing_noise_cmp["summaries"] = [ + make_summary(nvbench_compare, "GPU_TIME_MEAN_TAG", "1.001") + ] + + null_noise_ref = make_state(nvbench_compare, "null_noise") + null_noise_ref["summaries"] = [ + make_summary(nvbench_compare, "GPU_TIME_MEAN_TAG", "1.0"), + make_summary(nvbench_compare, "GPU_TIME_STDEV_RELATIVE_TAG", None), + ] + null_noise_cmp = make_state(nvbench_compare, "null_noise") + null_noise_cmp["summaries"] = [ + make_summary(nvbench_compare, "GPU_TIME_MEAN_TAG", "1.001"), + make_summary(nvbench_compare, "GPU_TIME_STDEV_RELATIVE_TAG", None), + ] + + compare_benches( + nvbench_compare, + [make_benchmark([missing_noise_ref, null_noise_ref])], + [make_benchmark([missing_noise_cmp, null_noise_cmp])], + ) + + assert nvbench_compare.config_count == 2 + assert nvbench_compare.pass_count == 0 + assert nvbench_compare.improvement_count == 0 + assert nvbench_compare.regression_count == 0 + assert nvbench_compare.unknown_count == 2 + + +def test_plot_along_skips_states_without_selected_axis(monkeypatch, nvbench_compare): + set_test_devices(monkeypatch, nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "with_axis", axis_value=1), + make_state(nvbench_compare, "without_axis"), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "with_axis", axis_value=1), + make_state(nvbench_compare, "without_axis"), + ] + ) + ] + + compare_benches( + nvbench_compare, + ref_benches, + cmp_benches, + plot_along="A", + ) + + assert nvbench_compare.config_count == 2 + assert nvbench_compare.pass_count == 2 + assert nvbench_compare.improvement_count == 0 + assert nvbench_compare.regression_count == 0 + assert nvbench_compare.unknown_count == 0 + + +def test_main_returns_success_exit_code_when_regressions_are_detected( + monkeypatch, capsys, nvbench_compare +): + devices = [{"id": 0, "name": "Test GPU"}] + ref_root = { + "devices": devices, + "benchmarks": [ + make_benchmark([make_state(nvbench_compare, "state", mean="1.0")]) + ], + } + cmp_root = { + "devices": devices, + "benchmarks": [ + make_benchmark([make_state(nvbench_compare, "state", mean="1.2")]) + ], + } + + def read_file(path): + return ref_root if path == "ref.json" else cmp_root + + monkeypatch.setattr(nvbench_compare.reader, "read_file", read_file) + monkeypatch.setattr(sys, "argv", ["nvbench_compare", "ref.json", "cmp.json"]) + + assert nvbench_compare.main() == 0 + assert nvbench_compare.regression_count == 1 + assert ( + "Regression (abs(%Diff) > max_noise, %Diff > 0): 1" in capsys.readouterr().out + ) From 1d13b49996cbb1f1e1b643baf9e630e102c790f3 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Wed, 27 May 2026 12:21:40 -0500 Subject: [PATCH 2/2] Add scoped filtering and device pairing to nvbench_compare Teach nvbench_compare to keep the order of --benchmark and --axis arguments so axis filters can apply either globally or to the most recent benchmark. Build a filter plan from the ordered CLI arguments and apply the same plan to table output and plotting labels. Add explicit --reference-devices and --compare-devices filters. The filters accept all, a single device id, or a comma-separated list of ids; ordered lists and duplicates are preserved so selected reference and compare devices can be paired by position. Device-section mismatches remain fatal for unfiltered all-vs-all comparisons, but become warnings when the user explicitly selects devices and the selected device counts match. Match duplicate benchmark states by occurrence within each filtered device section instead of matching only by state name across the whole benchmark. This keeps repeated axis values and filtered duplicate states aligned between the reference and compare inputs, and reports mismatched occurrence counts instead of silently dropping extra states. Add Python tests for duplicate-state matching, axis filtering before matching, device filter parsing and validation, explicit cross-device pairing, and benchmark-scoped axis filters. Original commit messages folded into this change: Tweaks for nvbench_compare 1. When JSON files contain multiple entries with the same name and axis values, make sure that scripts compares corresponding entries. Previous logic would extract the first entry from ref data, and would compare measurements for each state in cmp against the first entry from ref. The change introduces a counter to know which nth entry we process for a particular axis value, and retrieve corresponding entry in ref. Scope occurrence matching by device. Device pairing in nvbench_compare.py is strictly index-based under --ignore-devices, reused IDs in a different order no longer pair against the wrong reference device. Require devices in ref and cmp to have the same cardinality Handle mismatch when number of duplicates in ref data is not same as in cmp data Use pytest monkeypatch fixture to pretend third-party package dependencies are available during test run for nvbench_compare without introducing test-time dependency Added the happy-path test and fixed its direct-call setup by initializing the device globals that main() normally populates. Fix to filter-before-matching. - compare_benches() now pairs devices by selected position instead of taking a device id. - For each device pair, compare_benches() now builds: - ref_device_states: matching reference device and axis filters - cmp_device_states: matching compare device and axis filters - State occurrence counts and duplicate occurrence matching now operate only on those filtered per-device lists. - Removed the later matches_axis_filters() skip inside the compare-state loop because filtering now happens before matching. Added a regression test where ref/cmp have duplicate state names in opposite order, and --axis keeps only one of them. The test verifies the kept compare state is matched against the kept reference state, not the first unfiltered occurrence. Introduce device filtering in nvbench_compare - --reference-devices all|ID|ID,ID,... - --compare-devices all|ID|ID,ID,... - Integer lists preserve order and duplicates. - Requested IDs are validated against the file-level device list. - Filtered reference/compare device counts must match before comparison. - compare_benches() pairs selected reference and compare devices by position. - Each benchmark validates that requested device IDs are present in its own devices list. Implemented benchmark-scoped --axis handling. - --axis and --benchmark now share an ordered argparse action, so their relative CLI order is preserved. - -a before any -b becomes a global axis filter. - -a after -b applies to that most recent benchmark only. - Repeated -b entries are treated as separate filter scopes and combined as alternatives for that benchmark. - Device filtering remains global and is applied independently. Allow non-matching devices for explicit device selection Now the device-section equality check remains fatal only for unfiltered all-vs-all comparisons. If either --reference-devices or --compare-devices is explicit, mismatched selected device metadata is printed as a warning, but comparison proceeds after the selected device counts have been validated. Fix for resolve_benchmark_device_ids, add comments The return value of resolve_benchmark_device_ids now always owns its list. Use monkeypatch class in set_test_devices helper Stricted device id validation Test for device id validation --- python/scripts/nvbench_compare.py | 332 +++++++++++++++++++++++++--- python/test/test_nvbench_compare.py | 323 +++++++++++++++++++++++++-- 2 files changed, 603 insertions(+), 52 deletions(-) diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index cacb419f..99d64854 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -7,6 +7,7 @@ import math import os import sys +from collections import Counter from dataclasses import dataclass from enum import Enum @@ -66,6 +67,121 @@ class TimeEstimate: relative_dispersion: float | None +@dataclass(frozen=True) +class BenchmarkFilterScope: + benchmark_name: str + axis_filters: list[dict] + + +@dataclass(frozen=True) +class BenchmarkFilterPlan: + global_axis_filters: list[dict] + benchmark_scopes: list[BenchmarkFilterScope] + + +class OrderedBenchmarkFilterAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + actions = getattr(namespace, self.dest, None) + actions = [] if actions is None else list(actions) + action_kind = "axis" if option_string in {"-a", "--axis"} else "benchmark" + actions.append((action_kind, values)) + setattr(namespace, self.dest, actions) + + +def state_match_key(state): + device_prefix = f"Device={state['device']}" + state_name = state["name"] + if state_name == device_prefix: + return "" + if state_name.startswith(f"{device_prefix} "): + return state_name[len(device_prefix) + 1 :] + return state_name + + +def group_states_by_match_key(states): + grouped = {} + for state in states: + grouped.setdefault(state_match_key(state), []).append(state) + return grouped + + +def state_group_counts(grouped_states): + return Counter( + {state_name: len(states) for state_name, states in grouped_states.items()} + ) + + +def format_device_ids(device_ids): + return ", ".join(str(device_id) for device_id in device_ids) + + +def parse_device_filter(device_arg, option_name): + device_arg = device_arg.strip() + if device_arg.lower() == "all": + return None + + values = [value.strip() for value in device_arg.split(",")] + if not all(values): + raise ValueError( + f"{option_name} must be 'all', a non-negative integer, " + "or comma-separated non-negative integers" + ) + + try: + device_ids = [int(value) for value in values] + except ValueError as exc: + raise ValueError( + f"{option_name} must be 'all', a non-negative integer, " + "or comma-separated non-negative integers" + ) from exc + if any(device_id < 0 for device_id in device_ids): + raise ValueError( + f"{option_name} must be 'all', a non-negative integer, " + "or comma-separated non-negative integers" + ) + return device_ids + + +def select_devices(all_devices, device_filter, option_name): + if device_filter is None: + return list(all_devices) + + devices_by_id = {device["id"]: device for device in all_devices} + missing_ids = [ + device_id for device_id in device_filter if device_id not in devices_by_id + ] + if missing_ids: + raise ValueError( + f"{option_name} requested device id(s) not present in input: " + f"{format_device_ids(missing_ids)}" + ) + + return [devices_by_id[device_id] for device_id in device_filter] + + +def resolve_benchmark_device_ids(bench, device_filter, option_name): + if device_filter is None: + return list(bench["devices"]) + + benchmark_device_ids = set(bench["devices"]) + missing_ids = [ + device_id + for device_id in device_filter + if device_id not in benchmark_device_ids + ] + if missing_ids: + raise ValueError( + f"benchmark {bench['name']!r} does not contain {option_name} " + f"device id(s): {format_device_ids(missing_ids)}" + ) + + return device_filter + + +def require_matching_device_sections(reference_device_filter, compare_device_filter): + return reference_device_filter is None and compare_device_filter is None + + # TODO(opavlyk): replace with Emoji(StrEnum) after EOL of Python 3.10 class Emoji(str, Enum): YELLOW = "\U0001f7e1" @@ -328,6 +444,53 @@ def parse_axis_filters(axis_args): return filters +def build_benchmark_filter_plan(filter_actions): + global_axis_args = [] + benchmark_scopes = [] + current_scope = None + + for action_kind, action_value in filter_actions or []: + if action_kind == "benchmark": + current_scope = {"benchmark_name": action_value, "axis_args": []} + benchmark_scopes.append(current_scope) + elif current_scope is None: + global_axis_args.append(action_value) + else: + current_scope["axis_args"].append(action_value) + + return BenchmarkFilterPlan( + global_axis_filters=parse_axis_filters(global_axis_args), + benchmark_scopes=[ + BenchmarkFilterScope( + benchmark_name=scope["benchmark_name"], + axis_filters=parse_axis_filters(scope["axis_args"]), + ) + for scope in benchmark_scopes + ], + ) + + +def benchmark_is_selected(benchmark_name, filter_plan): + return not filter_plan.benchmark_scopes or any( + scope.benchmark_name == benchmark_name for scope in filter_plan.benchmark_scopes + ) + + +def axis_filter_groups_for_benchmark(benchmark_name, filter_plan): + if not filter_plan.benchmark_scopes: + return [filter_plan.global_axis_filters] + + matching_scopes = [ + scope + for scope in filter_plan.benchmark_scopes + if scope.benchmark_name == benchmark_name + ] + return [ + filter_plan.global_axis_filters + scope.axis_filters + for scope in matching_scopes + ] + + def matches_axis_filters(state, axis_filters): if not axis_filters: return True @@ -351,6 +514,23 @@ def matches_axis_filters(state, axis_filters): return True +def matches_axis_filter_groups(state, axis_filter_groups): + return any( + matches_axis_filters(state, axis_filters) for axis_filters in axis_filter_groups + ) + + +def matching_axis_filters(state, axis_filter_groups): + return next( + ( + axis_filters + for axis_filters in axis_filter_groups + if matches_axis_filters(state, axis_filters) + ), + [], + ) + + def format_duration(seconds): if seconds >= 1: multiplier = 1.0 @@ -479,9 +659,10 @@ def compare_benches( plot_along, plot, dark, - axis_filters, - benchmark_filters, + filter_plan, no_color, + reference_device_filter=None, + compare_device_filter=None, ): if plot_along: import matplotlib.pyplot as plt @@ -495,12 +676,28 @@ def compare_benches( ref_bench = find_matching_bench(cmp_bench, ref_benches) if not ref_bench: continue - if benchmark_filters and cmp_bench["name"] not in benchmark_filters: + if not benchmark_is_selected(cmp_bench["name"], filter_plan): continue + axis_filter_groups = axis_filter_groups_for_benchmark( + cmp_bench["name"], filter_plan + ) + + cmp_device_ids = resolve_benchmark_device_ids( + cmp_bench, compare_device_filter, "--compare-devices" + ) + ref_device_ids = resolve_benchmark_device_ids( + ref_bench, reference_device_filter, "--reference-devices" + ) + if len(cmp_device_ids) != len(ref_device_ids): + raise ValueError( + f"benchmark {cmp_bench['name']!r} has {len(ref_device_ids)} " + f"reference device(s) but {len(cmp_device_ids)} compare device(s); " + "nvbench_compare pairs devices by position, so each compared " + "benchmark must contain the same number of devices" + ) print(f"""# {cmp_bench["name"]}\n""") - cmp_device_ids = cmp_bench["devices"] axes = cmp_bench["axes"] ref_states = ref_bench["states"] cmp_states = cmp_bench["states"] @@ -525,20 +722,43 @@ def compare_benches( headers.append("Status") colalign.append("center") - for cmp_device_id in cmp_device_ids: - rows = [] - plot_data = {"cmp": {}, "ref": {}, "cmp_noise": {}, "ref_noise": {}} - - for cmp_state in cmp_states: - cmp_state_name = cmp_state["name"] - ref_state = next( - filter(lambda st: st["name"] == cmp_state_name, ref_states), None + for cmp_device_index, cmp_device_id in enumerate(cmp_device_ids): + ref_device_id = ref_device_ids[cmp_device_index] + ref_device_states = [ + state + for state in ref_states + if state["device"] == ref_device_id + and matches_axis_filter_groups(state, axis_filter_groups) + ] + cmp_device_states = [ + state + for state in cmp_states + if state["device"] == cmp_device_id + and matches_axis_filter_groups(state, axis_filter_groups) + ] + ref_states_by_name = group_states_by_match_key(ref_device_states) + cmp_states_by_name = group_states_by_match_key(cmp_device_states) + ref_state_counts = state_group_counts(ref_states_by_name) + cmp_state_counts = state_group_counts(cmp_states_by_name) + if ref_state_counts != cmp_state_counts: + raise ValueError( + f"benchmark {cmp_bench['name']!r} device pair " + f"ref={ref_device_id} cmp={cmp_device_id} has mismatched " + f"state occurrences: ref={dict(ref_state_counts)}, " + f"cmp={dict(cmp_state_counts)}" ) - if not ref_state: - continue - if not matches_axis_filters(cmp_state, axis_filters): - continue + rows = [] + plot_data = {"cmp": {}, "ref": {}, "cmp_noise": {}, "ref_noise": {}} + counters = {} + + for cmp_state in cmp_device_states: + cmp_state_name = state_match_key(cmp_state) + occurrence = counters.get(cmp_state_name, 0) + counters[cmp_state_name] = occurrence + 1 + # Duplicate state names are matched by occurrence order within + # the filtered device section. + ref_state = ref_states_by_name[cmp_state_name][occurrence] axis_values = cmp_state["axis_values"] if not axis_values: axis_values = [] @@ -632,6 +852,7 @@ def compare_benches( status = colorize(status_label, Fore.RED, Emoji.RED, no_color) if abs(frac_diff) >= threshold: + axis_filters = matching_axis_filters(cmp_state, axis_filter_groups) row.append(format_duration(ref_time)) row.append(format_percentage(ref_noise)) row.append(format_duration(cmp_time)) @@ -660,7 +881,12 @@ def compare_benches( continue cmp_device = find_device_by_id(cmp_device_id, all_cmp_devices) - ref_device = find_device_by_id(ref_state["device"], all_ref_devices) + ref_device = find_device_by_id(ref_device_id, all_ref_devices) + if ref_device is None or cmp_device is None: + raise ValueError( + f"benchmark {cmp_bench['name']!r} references device pair " + f"ref={ref_device_id} cmp={cmp_device_id}, but device metadata is missing" + ) if cmp_device == ref_device: print(f"## [{cmp_device['id']}] {cmp_device['name']}\n") @@ -756,10 +982,10 @@ def plot_confidence_band(first, last): title = "%SOL Bandwidth change" if len(comparison_device_names) == 1: title = f"{title} - {next(iter(comparison_device_names))}" - if axis_filters: + if filter_plan.global_axis_filters: axis_label = ", ".join( axis_filter["display"] - for axis_filter in axis_filters + for axis_filter in filter_plan.global_axis_filters if len(axis_filter["values"]) == 1 ) if axis_label: @@ -812,24 +1038,44 @@ def main() -> int: action="store_true", help="Use emoji instead of ANSI color codes (useful for GitHub issues/PRs)", ) + parser.add_argument( + "--reference-devices", + default="all", + help="Reference devices to compare: all, a non-negative integer id, or comma-separated ids", + ) + parser.add_argument( + "--compare-devices", + default="all", + help="Compare devices to compare: all, a non-negative integer id, or comma-separated ids", + ) parser.add_argument( "-a", "--axis", - action="append", - default=[], - help="Filter on axis value, e.g. -a Elements{io}=2^20 (can repeat)", + dest="filter_actions", + action=OrderedBenchmarkFilterAction, + help=( + "Filter on axis value, e.g. -a Elements{io}=2^20. Applies to the " + "most recent --benchmark, or all benchmarks if specified before any " + "--benchmark arguments." + ), ) parser.add_argument( "-b", "--benchmark", - action="append", - default=[], + dest="filter_actions", + action=OrderedBenchmarkFilterAction, help="Filter by benchmark name (can repeat)", ) args, files_or_dirs = parser.parse_known_args() try: - axis_filters = parse_axis_filters(args.axis) + filter_plan = build_benchmark_filter_plan(args.filter_actions) + reference_device_filter = parse_device_filter( + args.reference_devices, "--reference-devices" + ) + compare_device_filter = parse_device_filter( + args.compare_devices, "--compare-devices" + ) except ValueError as exc: print(str(exc)) return 1 @@ -863,21 +1109,34 @@ def main() -> int: global all_ref_devices global all_cmp_devices - all_ref_devices = ref_root["devices"] - all_cmp_devices = cmp_root["devices"] + try: + all_ref_devices = select_devices( + ref_root["devices"], reference_device_filter, "--reference-devices" + ) + all_cmp_devices = select_devices( + cmp_root["devices"], compare_device_filter, "--compare-devices" + ) + except ValueError as exc: + print(str(exc)) + return 1 + + if len(all_ref_devices) != len(all_cmp_devices): + print( + f"--reference-devices selected {len(all_ref_devices)} device(s), " + f"but --compare-devices selected {len(all_cmp_devices)} device(s)" + ) + return 1 - if ref_root["devices"] != cmp_root["devices"]: + if all_ref_devices != all_cmp_devices: warn_fore = Fore.YELLOW if args.ignore_devices else Fore.RED msg_text = "Device sections do not match" print(colorize(msg_text, warn_fore, Emoji.NONE, args.no_color), end="") print(": ", end="") - print( - jsondiff.diff( - ref_root["devices"], cmp_root["devices"], syntax="symmetric" - ) - ) - if not args.ignore_devices: + print(jsondiff.diff(all_ref_devices, all_cmp_devices, syntax="symmetric")) + if not args.ignore_devices and require_matching_device_sections( + reference_device_filter, compare_device_filter + ): return 1 try: @@ -888,9 +1147,10 @@ def main() -> int: args.plot_along, args.plot, args.dark, - axis_filters, - args.benchmark, + filter_plan, args.no_color, + reference_device_filter, + compare_device_filter, ) except ValueError as exc: print(str(exc)) diff --git a/python/test/test_nvbench_compare.py b/python/test/test_nvbench_compare.py index 8d82accd..c6d8c147 100644 --- a/python/test/test_nvbench_compare.py +++ b/python/test/test_nvbench_compare.py @@ -115,10 +115,18 @@ def make_benchmark(states, *, name="bench"): } -def set_test_devices(monkeypatch, nvbench_compare): +def set_test_devices(monkeypatch, nvbench_compare, ref_devices=None, cmp_devices=None): devices = [{"id": 0, "name": "Test GPU"}] - monkeypatch.setattr(nvbench_compare, "all_ref_devices", devices) - monkeypatch.setattr(nvbench_compare, "all_cmp_devices", devices) + monkeypatch.setattr( + nvbench_compare, + "all_ref_devices", + devices if ref_devices is None else ref_devices, + ) + monkeypatch.setattr( + nvbench_compare, + "all_cmp_devices", + devices if cmp_devices is None else cmp_devices, + ) monkeypatch.setattr(nvbench_compare, "config_count", 0) monkeypatch.setattr(nvbench_compare, "pass_count", 0) monkeypatch.setattr(nvbench_compare, "improvement_count", 0) @@ -126,19 +134,132 @@ def set_test_devices(monkeypatch, nvbench_compare): monkeypatch.setattr(nvbench_compare, "unknown_count", 0) -def compare_benches(nvbench_compare, ref_benches, cmp_benches, **kwargs): +def make_filter_plan(nvbench_compare, filter_actions=None): + return nvbench_compare.build_benchmark_filter_plan(filter_actions or []) + + +def test_compare_benches_accepts_matching_duplicate_state_counts( + monkeypatch, nvbench_compare +): + set_test_devices(monkeypatch, nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state2"), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state1", mean="1.005"), + make_state(nvbench_compare, "state1", mean="1.005"), + make_state(nvbench_compare, "state2", mean="1.005"), + ] + ) + ] + + nvbench_compare.compare_benches( + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) + + assert nvbench_compare.config_count == 3 + assert nvbench_compare.pass_count == 3 + assert nvbench_compare.improvement_count == 0 + assert nvbench_compare.regression_count == 0 + assert nvbench_compare.unknown_count == 0 + + +def test_compare_benches_rejects_swapped_duplicate_state_counts( + monkeypatch, nvbench_compare +): + set_test_devices(monkeypatch, nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state2"), + make_state(nvbench_compare, "state2"), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state2"), + make_state(nvbench_compare, "state2"), + make_state(nvbench_compare, "state2"), + ] + ) + ] + + with pytest.raises(ValueError, match="mismatched state occurrences"): + nvbench_compare.compare_benches( + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) + + +def test_compare_benches_matches_duplicate_states_after_axis_filter( + monkeypatch, nvbench_compare +): + set_test_devices(monkeypatch, nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + ] + ) + ] + nvbench_compare.compare_benches( ref_benches, cmp_benches, - threshold=kwargs.get("threshold", 0.0), - plot_along=kwargs.get("plot_along"), - plot=kwargs.get("plot", False), + threshold=0.0, + plot_along=None, + plot=False, dark=False, - axis_filters=kwargs.get("axis_filters", []), - benchmark_filters=kwargs.get("benchmark_filters", []), + filter_plan=make_filter_plan(nvbench_compare, [("axis", "A=2")]), no_color=True, ) + assert nvbench_compare.config_count == 1 + assert nvbench_compare.pass_count == 1 + assert nvbench_compare.improvement_count == 0 + assert nvbench_compare.regression_count == 0 + assert nvbench_compare.unknown_count == 0 + def test_compare_benches_skips_non_finite_centers(monkeypatch, nvbench_compare): set_test_devices(monkeypatch, nvbench_compare) @@ -162,7 +283,16 @@ def test_compare_benches_skips_non_finite_centers(monkeypatch, nvbench_compare): ) ] - compare_benches(nvbench_compare, ref_benches, cmp_benches) + nvbench_compare.compare_benches( + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) assert nvbench_compare.config_count == 1 assert nvbench_compare.pass_count == 1 @@ -191,8 +321,15 @@ def test_compare_benches_prefers_median_and_iqr_when_available( ] ) - compare_benches( - nvbench_compare, [make_benchmark([ref_state])], [make_benchmark([cmp_state])] + nvbench_compare.compare_benches( + [make_benchmark([ref_state])], + [make_benchmark([cmp_state])], + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, ) assert nvbench_compare.config_count == 1 @@ -225,10 +362,15 @@ def test_compare_benches_marks_unavailable_noise_unknown(monkeypatch, nvbench_co make_summary(nvbench_compare, "GPU_TIME_STDEV_RELATIVE_TAG", None), ] - compare_benches( - nvbench_compare, + nvbench_compare.compare_benches( [make_benchmark([missing_noise_ref, null_noise_ref])], [make_benchmark([missing_noise_cmp, null_noise_cmp])], + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, ) assert nvbench_compare.config_count == 2 @@ -258,11 +400,15 @@ def test_plot_along_skips_states_without_selected_axis(monkeypatch, nvbench_comp ) ] - compare_benches( - nvbench_compare, + nvbench_compare.compare_benches( ref_benches, cmp_benches, + threshold=0.0, plot_along="A", + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, ) assert nvbench_compare.config_count == 2 @@ -272,6 +418,151 @@ def test_plot_along_skips_states_without_selected_axis(monkeypatch, nvbench_comp assert nvbench_compare.unknown_count == 0 +def test_device_filter_parser_accepts_all_and_duplicate_ids(nvbench_compare): + assert nvbench_compare.parse_device_filter(" all ", "--reference-devices") is None + assert nvbench_compare.parse_device_filter("0", "--reference-devices") == [0] + assert nvbench_compare.parse_device_filter("0, 2,0", "--reference-devices") == [ + 0, + 2, + 0, + ] + + +@pytest.mark.parametrize( + "device_arg", + [ + "", + " ", + "gpu", + "-1", + "0,gpu", + "0,-1", + "0,", + ",0", + ], +) +def test_device_filter_parser_rejects_invalid_values(nvbench_compare, device_arg): + with pytest.raises(ValueError, match="must be 'all'"): + nvbench_compare.parse_device_filter(device_arg, "--reference-devices") + + +def test_explicit_device_filters_downgrade_device_mismatch_to_warning(nvbench_compare): + assert nvbench_compare.require_matching_device_sections(None, None) + assert not nvbench_compare.require_matching_device_sections([0], None) + assert not nvbench_compare.require_matching_device_sections(None, [1]) + assert not nvbench_compare.require_matching_device_sections([0], [1]) + + +def test_compare_benches_pairs_filtered_devices_by_position( + monkeypatch, nvbench_compare +): + set_test_devices( + monkeypatch, + nvbench_compare, + ref_devices=[ + {"id": 0, "name": "Reference GPU 0"}, + {"id": 1, "name": "Reference GPU 1"}, + ], + cmp_devices=[ + {"id": 0, "name": "Compare GPU 0"}, + {"id": 1, "name": "Compare GPU 1"}, + ], + ) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "Device=0", mean="1.0", device=0), + make_state(nvbench_compare, "Device=1", mean="9.0", device=1), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "Device=0", mean="9.0", device=0), + make_state(nvbench_compare, "Device=1", mean="1.0", device=1), + ] + ) + ] + + nvbench_compare.compare_benches( + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + reference_device_filter=[0], + compare_device_filter=[1], + ) + + assert nvbench_compare.config_count == 1 + assert nvbench_compare.pass_count == 1 + assert nvbench_compare.improvement_count == 0 + assert nvbench_compare.regression_count == 0 + assert nvbench_compare.unknown_count == 0 + + +def test_axis_filter_applies_to_most_recent_benchmark(monkeypatch, nvbench_compare): + set_test_devices(monkeypatch, nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + ], + name="bench1", + ), + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="3.0", axis_value=1), + make_state(nvbench_compare, "state", mean="4.0", axis_value=2), + ], + name="bench2", + ), + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + ], + name="bench1", + ), + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="3.0", axis_value=1), + make_state(nvbench_compare, "state", mean="4.0", axis_value=2), + ], + name="bench2", + ), + ] + + nvbench_compare.compare_benches( + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan( + nvbench_compare, + [("benchmark", "bench1"), ("axis", "A=2"), ("benchmark", "bench2")], + ), + no_color=True, + ) + + assert nvbench_compare.config_count == 3 + assert nvbench_compare.pass_count == 3 + assert nvbench_compare.improvement_count == 0 + assert nvbench_compare.regression_count == 0 + assert nvbench_compare.unknown_count == 0 + + def test_main_returns_success_exit_code_when_regressions_are_detected( monkeypatch, capsys, nvbench_compare ):