Skip to content

Commit 654151c

Browse files
committed
refactor(compare): clarify baseline vs test semantics in compare scripts
Print baseline/test paths at the start of output and update argument help text. In compare_tps, flip signed_change to (test-baseline)/baseline so positive means test is faster and negative means regression.
1 parent d4aeb70 commit 654151c

2 files changed

Lines changed: 25 additions & 17 deletions

File tree

scripts/compare_loss.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#!/usr/bin/env python3
22
# Usage:
33
# python tools/compare_loss.py \
4-
# /data/shared/InfiniTrain-dev/logs/202511_a800/20260105/feature/add_1F1B_f2a383a/logs \
5-
# /data/shared/InfiniTrain-dev/logs/202511_a800/20251223/feature/tp-pp-split-stream/logs \
4+
# /path/to/baseline/logs \
5+
# /path/to/test/logs \
66
# --threshold-fp32 1e-5 --threshold-bf16 1e-2
77

88
import re
@@ -50,8 +50,8 @@ def compare_files(file1, file2, threshold):
5050

5151
def main():
5252
parser = ArgumentParser(description='Compare training loss between two log directories')
53-
parser.add_argument('dir1', type=Path, help='First log directory')
54-
parser.add_argument('dir2', type=Path, help='Second log directory')
53+
parser.add_argument('dir1', type=Path, help='Baseline log directory')
54+
parser.add_argument('dir2', type=Path, help='Test log directory')
5555
parser.add_argument('--threshold', type=float, help='Loss difference threshold (deprecated, use --threshold-fp32 and --threshold-bf16)')
5656
parser.add_argument('--threshold-fp32', type=float, default=1e-5, help='Loss difference threshold for fp32 (default: 1e-5)')
5757
parser.add_argument('--threshold-bf16', type=float, default=1e-2, help='Loss difference threshold for bfloat16 (default: 1e-2)')
@@ -63,6 +63,10 @@ def main():
6363
args.threshold_fp32 = args.threshold
6464
args.threshold_bf16 = args.threshold
6565

66+
print(f"Baseline: {args.dir1.resolve()}")
67+
print(f"Test: {args.dir2.resolve()}")
68+
print()
69+
6670
files1, duplicates1 = collect_log_files(args.dir1)
6771
files2, duplicates2 = collect_log_files(args.dir2)
6872
exit_if_duplicate_logs(args.dir1, duplicates1)
@@ -73,9 +77,9 @@ def main():
7377
common = set(files1.keys()) & set(files2.keys())
7478

7579
if only_in_1:
76-
print(f"Files only in {args.dir1.resolve()}: {', '.join(sorted(only_in_1))}")
80+
print(f"Files only in baseline: {', '.join(sorted(only_in_1))}")
7781
if only_in_2:
78-
print(f"Files only in {args.dir2.resolve()}: {', '.join(sorted(only_in_2))}")
82+
print(f"Files only in test: {', '.join(sorted(only_in_2))}")
7983
if only_in_1 or only_in_2:
8084
print()
8185

scripts/compare_tps.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#!/usr/bin/env python3
22
# Usage:
33
# python tools/compare_tps.py \
4-
# /path/to/logs/dir1 \
5-
# /path/to/logs/dir2 \
4+
# /path/to/baseline/logs \
5+
# /path/to/test/logs \
66
# --threshold 0.20
77

88
import re
@@ -38,33 +38,37 @@ def compare_files(file1, file2, threshold):
3838
avg1 = sum(tps1.values()) / len(tps1)
3939
avg2 = sum(tps2.values()) / len(tps2)
4040

41-
# Calculate signed relative change: positive means dir1 faster, negative means dir1 slower
42-
signed_change = (avg1 - avg2) / avg2 if avg2 > 0 else 0
41+
# Calculate signed relative change of test vs baseline: positive means test faster, negative means test slower
42+
signed_change = (avg2 - avg1) / avg1 if avg1 > 0 else 0
4343

4444
messages = []
4545
failed = False
4646
if abs(signed_change) > threshold:
4747
sign = "+" if signed_change >= 0 else ""
4848
if signed_change < 0:
49-
# dir1 slower than dir2 -> failure
49+
# test slower than baseline -> failure
5050
label = "✗ SLOWER"
5151
failed = True
5252
else:
53-
# dir1 faster than dir2 -> pass but notify
53+
# test faster than baseline -> pass but notify
5454
label = "↑ FASTER"
55-
messages.append(f" Average tok/s: {avg1:.2f} vs {avg2:.2f} {label} ({sign}{signed_change*100:.1f}%, threshold: ±{threshold*100:.0f}%)")
55+
messages.append(f" Average tok/s: {avg1:.2f} (baseline) vs {avg2:.2f} (test) {label} ({sign}{signed_change*100:.1f}%, threshold: ±{threshold*100:.0f}%)")
5656
messages.append(f" Steps compared: {len(tps1)} vs {len(tps2)} (excluding step 1)")
5757

5858
return 1, failed, messages, avg1, avg2, signed_change, len(tps1), len(tps2)
5959

6060
def main():
6161
parser = ArgumentParser(description='Compare tok/s between two log directories')
62-
parser.add_argument('dir1', type=Path, help='First log directory')
63-
parser.add_argument('dir2', type=Path, help='Second log directory')
62+
parser.add_argument('dir1', type=Path, help='Baseline log directory')
63+
parser.add_argument('dir2', type=Path, help='Test log directory')
6464
parser.add_argument('--threshold', type=float, default=0.20, help='Relative error threshold (default: 0.20 = 20%%)')
6565
parser.add_argument('--verbose', action='store_true', help='Print detailed output for all files, including passed ones')
6666
args = parser.parse_args()
6767

68+
print(f"Baseline: {args.dir1.resolve()}")
69+
print(f"Test: {args.dir2.resolve()}")
70+
print()
71+
6872
files1, duplicates1 = collect_log_files(args.dir1)
6973
files2, duplicates2 = collect_log_files(args.dir2)
7074
exit_if_duplicate_logs(args.dir1, duplicates1)
@@ -75,9 +79,9 @@ def main():
7579
common = set(files1.keys()) & set(files2.keys())
7680

7781
if only_in_1:
78-
print(f"Files only in {args.dir1.resolve()}: {', '.join(sorted(only_in_1))}")
82+
print(f"Files only in baseline: {', '.join(sorted(only_in_1))}")
7983
if only_in_2:
80-
print(f"Files only in {args.dir2.resolve()}: {', '.join(sorted(only_in_2))}")
84+
print(f"Files only in test: {', '.join(sorted(only_in_2))}")
8185
if only_in_1 or only_in_2:
8286
print()
8387

0 commit comments

Comments
 (0)