@@ -269,31 +269,34 @@ def enqueue(self, in_tensor, out_q: queue.Queue):
269269
270270 def balance_queues (self ):
271271 # Don't bother if someone else is working on balancing
272- if len (self .queues ) <= 1 or len (self .tpu_list ) <= 2 or \
273- len (self .queues ) == len (self .tpu_list ) or \
272+ if len (self .queues ) <= 1 or len (self .tpu_list ) < 2 or \
274273 not self .balance_lock .acquire (blocking = False ):
275274 return
276275
277276 def eval_timings (interpreter_counts ):
278277 # How much time are we allocating for each segment
279278 time_alloc = []
280279
281- for idx in range (len (self .interpreters )):
280+ for seg_i in range (len (self .interpreters )):
282281 # Find average runtime for this segment
283282 avg_times = []
284283 for interpreters in self .interpreters :
285- avg_times += [i .timings [idx ] / i .exec_count [idx ] for i in interpreters if i .exec_count [idx ] != 0 ]
284+ avg_times += [i .timings [seg_i ] / i .exec_count [seg_i ] for i in interpreters if i .exec_count [seg_i ] != 0 ]
286285
287286 if avg_times :
288287 avg_time = sum (avg_times ) / len (avg_times )
289288 else :
290- return 0 , 0 , 0.0
289+ return 0 , 0 , 0.0 , None
291290
292291 # Adjust for number of TPUs allocated to it
293- time_alloc .append (avg_time / interpreter_counts [idx ])
292+ if interpreter_counts [seg_i ] > 0 :
293+ time_alloc .append (avg_time / interpreter_counts [seg_i ])
294+ else :
295+ # No interpreters result inf time
296+ time_alloc .append (float ('inf' ))
294297
295- min_t = 100000000
296- min_i = - 1
298+ min_gt1_t = float ( 'inf' )
299+ min_gt1_i = - 1
297300 max_t = 0
298301 max_i = - 1
299302
@@ -306,28 +309,67 @@ def eval_timings(interpreter_counts):
306309
307310 # Min time needs to be lengthened so rem an interpreter,
308311 # but only if it has more than one interpreter
309- if t < min_t and len (self .interpreters [i ]) > 1 :
310- min_t = t
311- min_i = i
312+ if t < min_gt1_t and len (self .interpreters [i ]) > 1 :
313+ min_gt1_t = t
314+ min_gt1_i = i
315+
316+ # See if we can do better than the current max timing
317+ untried_candidates = []
318+ for interp_i , interpreters in enumerate (self .interpreters ):
319+ # Doesn't make sense to pull a TPU from a queue just to re-add it.
320+ if interp_i == max_i :
321+ continue
322+ # If it hasn't yet been tried for this segment
323+ # (or if it has already found to be faster on this segment)
324+ if any ([True for i in interpreters if i .exec_count [max_i ] == 0 or max_t - 1.0 > i .timings [max_i ] / i .exec_count [max_i ]]):
325+ untried_candidates .append (interp_i )
312326
313- return min_i , max_i , max (time_alloc )
327+ return min_gt1_i , max_i , max (time_alloc ), untried_candidates [ 0 ] if len ( untried_candidates ) > 0 else None
314328
315329 interpreter_counts = [len (i ) for i in self .interpreters ]
316- min_i , max_i , current_max = eval_timings (interpreter_counts )
330+ min_i , max_i , current_max , min_untried_i = eval_timings (interpreter_counts )
317331 interpreter_counts [min_i ] -= 1
318332 interpreter_counts [max_i ] += 1
319- _ , _ , new_max = eval_timings (interpreter_counts )
333+ _ , _ , new_max , _ = eval_timings (interpreter_counts )
320334
321- # Return if we don't want to swap (+/- 1 ms)
335+ # Return if we don't want to swap
322336 if new_max + 1.0 >= current_max :
323- self .balance_lock .release ()
324- return
337+ if min_untried_i is None :
338+ self .balance_lock .release ()
339+ return
340+
341+ # Swap slow segments with faster ones to see if we can run them faster.
342+ # It might be a good way to optimize for heterogenous hardware.
343+ logging .info (f"Re-balancing between queues { min_untried_i } and { max_i } " )
344+
345+ # Stop them
346+ new_max_i = self ._rem_interpreter_from (min_untried_i )
347+ new_min_untried_i = self ._rem_interpreter_from (max_i )
348+
349+ # Swap them
350+ new_max_i .start (max_i , self .fbytes_list [max_i ])
351+ self .interpreters [max_i ].append (new_max_i )
352+
353+ new_min_untried_i .start (min_untried_i , self .fbytes_list [min_untried_i ])
354+ self .interpreters [min_untried_i ].append (new_min_untried_i )
355+
356+ else :
357+ logging .info (f"Re-balancing from queue { min_i } to { max_i } (max from { current_max :.2f} to { new_max :.2f} )" )
325358
326- logging . info ( f"Re-balancing from queue { min_i } to { max_i } (max from { current_max :.2f } to { new_max :.2f } )" )
359+ realloc_interp = self . _rem_interpreter_from ( min_i )
327360
361+ # Add to large (too-slow) queue
362+ realloc_interp .start (max_i , self .fbytes_list [max_i ])
363+ self .interpreters [max_i ].append (realloc_interp )
364+
365+ self .balance_lock .release ()
366+ self .print_queue_len ()
367+
368+
369+ def _rem_interpreter_from (self , interp_i ):
328370 # Sending False kills the processing loop
329371 self .rebalancing_lock .acquire ()
330- self .queues [min_i ].put (False )
372+ self .queues [interp_i ].put (False )
331373
332374 # This is ugly, but I can't think of something better
333375 # Threads are blocked by queues. Queues may not have a stream
@@ -338,21 +380,15 @@ def eval_timings(interpreter_counts):
338380 # Block & wait
339381 realloc_interp = None
340382 with self .rebalancing_lock :
341- for idx , interpreter in enumerate (self .interpreters [min_i ]):
383+ for idx , interpreter in enumerate (self .interpreters [interp_i ]):
342384 if not interpreter .interpreter :
343- realloc_interp = self .interpreters [min_i ].pop (idx )
385+ realloc_interp = self .interpreters [interp_i ].pop (idx )
344386 break
387+
345388 if not realloc_interp :
346389 logging .warning ("Unable to find killed interpreter" )
347390 self .balance_lock .release ()
348- return
349-
350- # Add to large (too-slow) queue
351- realloc_interp .start (max_i , self .fbytes_list [max_i ])
352- self .interpreters [max_i ].append (realloc_interp )
353-
354- self .balance_lock .release ()
355- self .print_queue_len ()
391+ return realloc_interp
356392
357393
358394 def print_queue_len (self ):
0 commit comments