Skip to content

Commit eac8483

Browse files
committed
Improve benchmark script.
1 parent 65cee7c commit eac8483

1 file changed

Lines changed: 82 additions & 39 deletions

File tree

tools/runbenchmark.py

Lines changed: 82 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,37 @@ def humanize(t):
3434

3535

3636
class Benchmark:
37-
def __init__(self, path):
38-
self.path = os.path.abspath(path)
37+
def __init__(self, dirs, error_exit=False):
38+
self.files = []
39+
for i in dirs:
40+
self.files.extend(
41+
glob.glob(os.path.join(i, "**/benchmark_*.py"), recursive=True)
42+
)
3943

40-
self.files = glob.glob(os.path.join(path, "**/benchmark_*.py"), recursive=True)
4144
self.rv = {}
4245
self.finish = False
46+
self.error_exit = error_exit
4347

4448
def _exec_file(self, file):
4549
with open(file, "rt") as f:
4650
s = f.read()
4751
g = {"benchmark_setup": benchmark_setup}
48-
exec(s, g)
52+
try:
53+
exec(s, g)
54+
except Exception as e:
55+
sys.stderr.write("Error load file ")
56+
sys.stderr.write(file)
57+
sys.stderr.write(": ")
58+
sys.stderr.write(repr(e))
59+
sys.stderr.write("\n")
60+
sys.stderr.flush()
61+
if self.error_exit:
62+
sys.exit(1)
63+
return None
4964
return g
5065

51-
def _is_benchmark(self, name, obj):
66+
@staticmethod
67+
def _is_benchmark(name, obj):
5268
if not inspect.isfunction(obj):
5369
return False
5470
if name == "benchmark_setup":
@@ -64,19 +80,36 @@ def _filter_benchmarks(self, g):
6480
return rv
6581

6682
def start(self):
67-
print("rootdir:", self.path)
68-
print("Collect %d files" % len(self.files), flush=True)
69-
for file in self.files:
83+
print("Total %d files\n" % len(self.files), flush=True)
84+
for idx, file in enumerate(self.files):
7085
g = self._exec_file(file)
86+
if g is None:
87+
continue
88+
7189
benchmarks = self._filter_benchmarks(g)
72-
print(
73-
"\nFound {} benchmarks in file {}".format(len(benchmarks), file),
74-
flush=True,
75-
)
90+
title = "Collect {} benchmarks in {}".format(len(benchmarks), file)
91+
print(title, flush=True)
92+
print("-" * len(title), flush=True)
7693
for name in benchmarks:
77-
b_rv = self.run_str(name, g)
78-
s = "{}::{} {}".format(file, name, self._pretty_rv(*b_rv))
94+
try:
95+
b_rv = self.run_str(name, g)
96+
except Exception as e:
97+
sys.stderr.write("Error run benchmark '")
98+
sys.stderr.write(name)
99+
sys.stderr.write("' in file ")
100+
sys.stderr.write(file)
101+
sys.stderr.write(": ")
102+
sys.stderr.write(repr(e))
103+
sys.stderr.write("\n")
104+
sys.stderr.flush()
105+
if self.error_exit:
106+
sys.exit(1)
107+
else:
108+
continue
109+
110+
s = "* {}::{} {}".format(file, name, self._pretty_rv(*b_rv))
79111
print(s, flush=True)
112+
print()
80113

81114
def run_str(self, name, g):
82115
call = g[name]
@@ -88,21 +121,22 @@ def run_str(self, name, g):
88121

89122
setup = ";".join(lines)
90123

91-
repeat = 10
92-
total_cost = 8
124+
repeat = 1
125+
max_cost = 8
126+
mask = 20
93127
cost = sum(timeit.repeat(callstr, setup=setup, globals=g, number=1, repeat=1))
94-
if cost > total_cost:
95-
loop = 3
128+
if cost >= max_cost:
129+
loop = 1
130+
elif cost * mask >= max_cost:
131+
loop = int(round(max_cost / cost))
132+
loop += loop % 10
96133
else:
97-
num = int(round(total_cost / cost))
134+
num = int(round(max_cost / cost))
135+
while num > mask * repeat and repeat < 10:
136+
repeat += 1
137+
98138
loop = num // repeat
99-
if loop <= 3:
100-
loop = 3
101-
else:
102-
count = 10
103-
while count < loop:
104-
count *= 10
105-
loop = count
139+
loop += loop % 10
106140

107141
costs = timeit.repeat(
108142
callstr, setup=setup, globals=g, number=loop, repeat=repeat
@@ -113,25 +147,34 @@ def run_str(self, name, g):
113147
@staticmethod
114148
def _pretty_rv(loop, repeat, costs):
115149
mean, std = mean_std(costs)
116-
return "{} ± {}\t(each {:,} runs, {:,} loops)".format(
117-
humanize(mean), humanize(std), repeat, loop
150+
return "{} ± {}\t({:,} loops, each {:,} runs)".format(
151+
humanize(mean), humanize(std), loop, repeat
118152
)
119153

120154

121155
def main():
122156
parser = argparse.ArgumentParser()
123-
parser.add_argument("path", default=".")
124-
if len(sys.argv) != 2:
125-
path = "."
126-
else:
127-
path = sys.argv[1]
128-
if path.startswith("-"):
129-
parser.print_help()
130-
sys.exit(1)
131-
132-
benchmark = Benchmark(path)
133-
157+
parser.add_argument("path", default=["."], nargs="+")
158+
parser.add_argument("-p", "--python-path", default=["."], nargs="*")
159+
parser.add_argument(
160+
"-e",
161+
"--error-exit",
162+
default=False,
163+
action="store_true",
164+
help="Exit if error occurred",
165+
)
166+
167+
args = parser.parse_args()
168+
path = args.path
169+
py_path = args.python_path
170+
error_exit = args.error_exit
171+
if py_path:
172+
for i in py_path:
173+
sys.path.insert(0, i)
174+
175+
benchmark = Benchmark(path, error_exit)
134176
benchmark.start()
177+
135178
sys.exit(0)
136179

137180

0 commit comments

Comments
 (0)