@@ -323,7 +323,7 @@ def eval_timings(interpreter_counts):
323323 continue
324324 # If it hasn't yet been tried for this segment
325325 # (or if it has already found to be faster on this segment)
326- if any ([True for i in interpreters if i .exec_count [max_i ] == 0 or max_t - 1.0 > i .timings [max_i ] / i .exec_count [max_i ]]):
326+ if any ([True for i in interpreters if i .exec_count [max_i ] < 10 or max_t - 0.1 > i .timings [max_i ] / i .exec_count [max_i ]]):
327327 untried_candidates .append (interp_i )
328328
329329 return min_gt1_i , max_i , max (time_alloc ), untried_candidates [0 ] if len (untried_candidates ) > 0 else None
@@ -334,12 +334,17 @@ def eval_timings(interpreter_counts):
334334 interpreter_counts [max_i ] += 1
335335 _ , _ , new_max , _ = eval_timings (interpreter_counts )
336336
337- # Return if we don't want to swap
338- if new_max + 1.0 >= current_max :
339- if min_untried_i is None :
340- self .balance_lock .release ()
341- return
337+ if new_max + 1.0 < current_max :
338+ # Allocate more TPUs to slow segments
339+ logging .info (f"Re-balancing from queue { min_i } to { max_i } (max from { current_max :.2f} to { new_max :.2f} )" )
342340
341+ realloc_interp = self ._rem_interpreter_from (min_i )
342+
343+ # Add to large (too-slow) queue
344+ realloc_interp .start (max_i , self .fbytes_list [max_i ])
345+ self .interpreters [max_i ].append (realloc_interp )
346+
347+ elif min_untried_i is not None :
343348 # Swap slow segments with faster ones to see if we can run them faster.
344349 # It might be a good way to optimize for heterogenous hardware.
345350 logging .info (f"Re-balancing between queues { min_untried_i } and { max_i } " )
@@ -355,14 +360,12 @@ def eval_timings(interpreter_counts):
355360 new_min_untried_i .start (min_untried_i , self .fbytes_list [min_untried_i ])
356361 self .interpreters [min_untried_i ].append (new_min_untried_i )
357362
363+ # FIXME: After we have TPUs evaluated and otherwise balanced, we could
364+ # further optimize by ensuring the slowest segment doesn't contain any slow TPUs.
358365 else :
359- logging .info (f"Re-balancing from queue { min_i } to { max_i } (max from { current_max :.2f} to { new_max :.2f} )" )
360-
361- realloc_interp = self ._rem_interpreter_from (min_i )
362-
363- # Add to large (too-slow) queue
364- realloc_interp .start (max_i , self .fbytes_list [max_i ])
365- self .interpreters [max_i ].append (realloc_interp )
366+ # Return if we don't want to swap
367+ self .balance_lock .release ()
368+ return
366369
367370 self .balance_ttl -= 1
368371 self .balance_lock .release ()
0 commit comments