@@ -243,7 +243,7 @@ def __init__(self, tpu_list: list, fname_list: list):
243243 self ._init_interpreters ()
244244
245245 def _init_interpreters (self ):
246- # Set a Time To Live for balancing so we don't swap for inf in corner cases
246+ # Set a Time To Live for balancing so we don't thrash
247247 self .balance_ttl = len (self .tpu_list ) * 2
248248 start_boot_time = time .perf_counter_ns ()
249249
@@ -269,7 +269,7 @@ def enqueue(self, in_tensor, out_q: queue.Queue):
269269 self .queues [0 ].put (({self .first_name : in_tensor }, out_q ))
270270
271271
272- def _eval_timings (interpreter_counts ):
272+ def _eval_timings (self , interpreter_counts ):
273273 # How much time are we allocating for each segment
274274 time_alloc = []
275275 VALID_CNT_THRESH = 50
@@ -310,33 +310,42 @@ def _eval_timings(interpreter_counts):
310310 min_gt1_t = t
311311 min_gt1_i = i
312312
313- # Only eval swapping max segment if we have many samples
314- if VALID_CNT_THRESH > sum ([i .exec_count [max_i ] for i in self .interpreters [max_i ]]):
315- return min_gt1_i , max_i , max (time_alloc ), None
313+ # Only eval swapping max time segment if we have many samples in the current setup
314+ for i in self .interpreters [max_i ]:
315+ if i .exec_count [max_i ] < VALID_CNT_THRESH :
316+ return min_gt1_i , max_i , max (time_alloc ), None
316317
317- # See if we can do better than the current max timing with swapping
318+ # Undo avg interp count adjustment for TPU-to-TPU comparisons
319+ max_t = max ([i .timings [max_i ] / i .exec_count [max_i ] for i in self .interpreters [max_i ]])
320+
321+ # See if we can do better than the current max time by swapping segments between TPUs
318322 swap_i = None
323+ swap_t = float ('inf' )
319324 for interp_i , interpreters in enumerate (self .interpreters ):
320325 # Doesn't make sense to pull a TPU from a queue just to re-add it.
321326 if interp_i == max_i :
322327 continue
323328
324329 # Test all TPUs in this segment
325330 for i in interpreters :
331+ # If TPU hasn't yet been tried for this segment or ...
332+ if i .exec_count [max_i ] < VALID_CNT_THRESH :
333+ return min_gt1_i , max_i , max (time_alloc ), interp_i
334+
326335 # Only calc valid time after a few runs
327336 new_max_t = 0.0
328337 if i .exec_count [max_i ] > VALID_CNT_THRESH :
329338 new_max_t = i .timings [max_i ] / i .exec_count [max_i ]
330339 new_swap_t = 0.0
331340 if i .exec_count [interp_i ] > VALID_CNT_THRESH :
332341 new_swap_t = i .timings [interp_i ] / i .exec_count [interp_i ]
333-
334- # If it hasn't yet been tried for this segment or
335- # If it has already found to be faster on this segment
336- # and we aren't making the other segment the new worst .
337- if i . exec_count [ max_i ] < VALID_CNT_THRESH or ( max_t > new_max_t and max_t > new_swap_t ) :
342+
343+ # If TPU has already found to be faster on this segment
344+ # and we aren't making the other segment the new worst
345+ # and we are choosing the best available candidate .
346+ if max_t - 0.5 > new_max_t and max_t > new_swap_t and swap_t > new_max_t :
338347 swap_i = interp_i
339- break
348+ swap_t = new_max_t
340349
341350 return min_gt1_i , max_i , max (time_alloc ), swap_i
342351
@@ -367,7 +376,7 @@ def balance_queues(self):
367376 # 2nd Priority: Swap slow segments with faster ones to see if we can
368377 # run them faster. Hopefully still a good way to optimize for
369378 # heterogenous hardware.
370- logging .info (f"Re-balancing between queues { swap_i } and { max_i } " )
379+ logging .info (f"Auto-tuning between queues { swap_i } and { max_i } " )
371380
372381 # Stop them
373382 new_max = self ._rem_interpreter_from (swap_i )
0 commit comments