Skip to content

Commit 83d2041

Browse files
authored
Fix: unify --rounds flag and disable profiling when rounds > 1 (#574)
Remove -n short form for --rounds in run_example.py and scene_test.py standalone entry to match pytest's conftest.py (which only defines --rounds). Automatically disable --enable-profiling when --rounds > 1 with a logger.warning, since profiling only captures the first round and multi-round mode is for benchmarking.
1 parent 8be358b commit 83d2041

2 files changed

Lines changed: 15 additions & 2 deletions

File tree

examples/scripts/run_example.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ def compute_golden(tensors: dict, params: dict) -> None:
171171
)
172172

173173
parser.add_argument(
174-
"-n",
175174
"--rounds",
176175
type=int,
177176
default=None,
@@ -205,6 +204,10 @@ def compute_golden(tensors: dict, params: dict) -> None:
205204

206205
configure_logging(args.log_level)
207206

207+
if args.rounds is not None and args.rounds > 1 and args.enable_profiling:
208+
logger.warning("Profiling disabled: --rounds > 1")
209+
args.enable_profiling = False
210+
208211
# Validate paths
209212
kernels_path = Path(args.kernels)
210213
golden_path = Path(args.golden)

simpler_setup/scene_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from __future__ import annotations
2424

2525
import inspect
26+
import logging
2627
import os
2728
import sys
2829
from contextlib import contextmanager
@@ -31,6 +32,8 @@
3132

3233
from .log_config import DEFAULT_LOG_LEVEL, LOG_LEVEL_CHOICES, configure_logging
3334

35+
logger = logging.getLogger(__name__)
36+
3437
_compile_cache: dict[tuple[str, str, str], object] = {}
3538

3639

@@ -874,6 +877,9 @@ def test_run(self, st_platform, st_worker, request):
874877
skip_golden = request.config.getoption("--skip-golden", default=False)
875878
enable_profiling = request.config.getoption("--enable-profiling", default=False)
876879
enable_dump_tensor = request.config.getoption("--dump-tensor", default=False)
880+
if rounds > 1 and enable_profiling:
881+
logger.warning("Profiling disabled: --rounds > 1")
882+
enable_profiling = False
877883

878884
cls_name = type(self).__name__
879885
callable_obj = self.build_callable(st_platform)
@@ -946,7 +952,7 @@ def run_module(module_name):
946952
default="exclude",
947953
help="Manual case handling: exclude (default), include, only",
948954
)
949-
parser.add_argument("-n", "--rounds", type=int, default=1, help="Run each case N times (default: 1)")
955+
parser.add_argument("--rounds", type=int, default=1, help="Run each case N times (default: 1)")
950956
parser.add_argument("--skip-golden", action="store_true", help="Skip golden comparison (benchmark mode)")
951957
parser.add_argument("--enable-profiling", action="store_true", help="Enable profiling (first round only)")
952958
parser.add_argument("--dump-tensor", action="store_true", help="Dump per-task tensor I/O at runtime")
@@ -960,6 +966,10 @@ def run_module(module_name):
960966
args = parser.parse_args()
961967
configure_logging(args.log_level)
962968

969+
if args.rounds > 1 and args.enable_profiling:
970+
logger.warning("Profiling disabled: --rounds > 1")
971+
args.enable_profiling = False
972+
963973
module = sys.modules[module_name]
964974
test_classes = [
965975
v

0 commit comments

Comments
 (0)