@@ -272,12 +272,13 @@ def enqueue(self, in_tensor, out_q: queue.Queue):
272272 def _eval_timings (interpreter_counts ):
273273 # How much time are we allocating for each segment
274274 time_alloc = []
275+ VALID_CNT_THRESH = 50
275276
276277 for seg_i in range (len (self .interpreters )):
277278 # Find average runtime for this segment
278279 avg_times = []
279280 for interpreters in self .interpreters :
280- avg_times += [i .timings [seg_i ] / i .exec_count [seg_i ] for i in interpreters if i .exec_count [seg_i ] != 0 ]
281+ avg_times += [i .timings [seg_i ] / i .exec_count [seg_i ] for i in interpreters if i .exec_count [seg_i ] > VALID_CNT_THRESH ]
281282
282283 if avg_times :
283284 avg_time = sum (avg_times ) / len (avg_times )
@@ -293,7 +294,7 @@ def _eval_timings(interpreter_counts):
293294
294295 min_gt1_t = float ('inf' )
295296 min_gt1_i = - 1
296- max_t = 0
297+ max_t = 0.0
297298 max_i = - 1
298299
299300 # Find segments that maybe should swap
@@ -309,6 +310,10 @@ def _eval_timings(interpreter_counts):
309310 min_gt1_t = t
310311 min_gt1_i = i
311312
313+ # Only eval swapping max segment if we have many samples
314+ if VALID_CNT_THRESH > sum ([i .exec_count [max_i ] for i in self .interpreters [max_i ]]):
315+ return min_gt1_i , max_i , max (time_alloc ), None
316+
312317 # See if we can do better than the current max timing with swapping
313318 swap_i = None
314319 for interp_i , interpreters in enumerate (self .interpreters ):
@@ -319,17 +324,17 @@ def _eval_timings(interpreter_counts):
319324 # Test all TPUs in this segment
320325 for i in interpreters :
321326 # Only calc valid time after a few runs
322- new_max_t = 0
323- if i .exec_count [max_i ] > 10 :
327+ new_max_t = 0.0
328+ if i .exec_count [max_i ] > VALID_CNT_THRESH :
324329 new_max_t = i .timings [max_i ] / i .exec_count [max_i ]
325- new_swap_t = 0
326- if i .exec_count [interp_i ] > 10 :
330+ new_swap_t = 0.0
331+ if i .exec_count [interp_i ] > VALID_CNT_THRESH :
327332 new_swap_t = i .timings [interp_i ] / i .exec_count [interp_i ]
328333
329334 # If it hasn't yet been tried for this segment or
330335 # If it has already found to be faster on this segment
331336 # and we aren't making the other segment the new worst.
332- if i .exec_count [max_i ] < 10 or (max_t > new_max_t and max_t > new_swap_t ):
337+ if i .exec_count [max_i ] < VALID_CNT_THRESH or (max_t > new_max_t and max_t > new_swap_t ):
333338 swap_i = interp_i
334339 break
335340
0 commit comments