Skip to content

Commit 6bfc83e

Browse files
ChaoWaoclaude
andcommitted
Fix: unify --rounds flag and disable profiling when rounds > 1
Remove -n short form for --rounds in run_example.py and scene_test.py standalone entry to match pytest's conftest.py (which only has --rounds). Automatically disable --enable-profiling when --rounds > 1 with a warning, since profiling only captures the first round and multi-round mode is for benchmarking. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3eae046 commit 6bfc83e

2 files changed

Lines changed: 16 additions & 2 deletions

File tree

examples/scripts/run_example.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ def compute_golden(tensors: dict, params: dict) -> None:
171171
)
172172

173173
parser.add_argument(
174-
"-n",
175174
"--rounds",
176175
type=int,
177176
default=None,
@@ -205,6 +204,10 @@ def compute_golden(tensors: dict, params: dict) -> None:
205204

206205
configure_logging(args.log_level)
207206

207+
if args.rounds is not None and args.rounds > 1 and args.enable_profiling:
208+
logger.warning("Profiling disabled: --rounds > 1")
209+
args.enable_profiling = False
210+
208211
# Validate paths
209212
kernels_path = Path(args.kernels)
210213
golden_path = Path(args.golden)

simpler_setup/scene_test.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,11 @@ def test_run(self, st_platform, st_worker, request):
874874
skip_golden = request.config.getoption("--skip-golden", default=False)
875875
enable_profiling = request.config.getoption("--enable-profiling", default=False)
876876
enable_dump_tensor = request.config.getoption("--dump-tensor", default=False)
877+
if rounds > 1 and enable_profiling:
878+
import warnings # noqa: PLC0415
879+
880+
warnings.warn("Profiling disabled: --rounds > 1", stacklevel=1)
881+
enable_profiling = False
877882

878883
cls_name = type(self).__name__
879884
callable_obj = self.build_callable(st_platform)
@@ -946,7 +951,7 @@ def run_module(module_name):
946951
default="exclude",
947952
help="Manual case handling: exclude (default), include, only",
948953
)
949-
parser.add_argument("-n", "--rounds", type=int, default=1, help="Run each case N times (default: 1)")
954+
parser.add_argument("--rounds", type=int, default=1, help="Run each case N times (default: 1)")
950955
parser.add_argument("--skip-golden", action="store_true", help="Skip golden comparison (benchmark mode)")
951956
parser.add_argument("--enable-profiling", action="store_true", help="Enable profiling (first round only)")
952957
parser.add_argument("--dump-tensor", action="store_true", help="Dump per-task tensor I/O at runtime")
@@ -958,6 +963,12 @@ def run_module(module_name):
958963
help=f"Root logger level (default: {DEFAULT_LOG_LEVEL})",
959964
)
960965
args = parser.parse_args()
966+
if args.rounds > 1 and args.enable_profiling:
967+
import warnings # noqa: PLC0415
968+
969+
warnings.warn("Profiling disabled: --rounds > 1", stacklevel=1)
970+
args.enable_profiling = False
971+
961972
configure_logging(args.log_level)
962973

963974
module = sys.modules[module_name]

0 commit comments

Comments
 (0)