@@ -243,7 +243,7 @@ def __init__(self, tpu_list: list, fname_list: list):
243243 self ._init_interpreters ()
244244
245245 def _init_interpreters (self ):
246-
246+ # Set a Time To Live for balancing so we don't swap for inf in corner cases
247247 self .balance_ttl = len (self .tpu_list ) * 2
248248 start_boot_time = time .perf_counter_ns ()
249249
@@ -269,73 +269,87 @@ def enqueue(self, in_tensor, out_q: queue.Queue):
269269 self .queues [0 ].put (({self .first_name : in_tensor }, out_q ))
270270
271271
272- def balance_queues (self ):
273- # Don't bother if someone else is working on balancing
274- if len (self .queues ) <= 1 or len (self .tpu_list ) < 2 or self .balance_ttl <= 0 or \
275- not self .balance_lock .acquire (blocking = False ):
276- return
272+ def _eval_timings (interpreter_counts ):
273+ # How much time are we allocating for each segment
274+ time_alloc = []
277275
278- def eval_timings (interpreter_counts ):
279- # How much time are we allocating for each segment
280- time_alloc = []
276+ for seg_i in range (len (self .interpreters )):
277+ # Find average runtime for this segment
278+ avg_times = []
279+ for interpreters in self .interpreters :
280+ avg_times += [i .timings [seg_i ] / i .exec_count [seg_i ] for i in interpreters if i .exec_count [seg_i ] != 0 ]
281281
282- for seg_i in range (len (self .interpreters )):
283- # Find average runtime for this segment
284- avg_times = []
285- for interpreters in self .interpreters :
286- avg_times += [i .timings [seg_i ] / i .exec_count [seg_i ] for i in interpreters if i .exec_count [seg_i ] != 0 ]
282+ if avg_times :
283+ avg_time = sum (avg_times ) / len (avg_times )
284+ else :
285+ return 0 , 0 , 0.0 , None
287286
288- if avg_times :
289- avg_time = sum (avg_times ) / len (avg_times )
290- else :
291- return 0 , 0 , 0.0 , None
287+ # Adjust for number of TPUs allocated to it
288+ if interpreter_counts [seg_i ] > 0 :
289+ time_alloc .append (avg_time / interpreter_counts [seg_i ])
290+ else :
291+ # No interpreters result inf time
292+ time_alloc .append (float ('inf' ))
293+
294+ min_gt1_t = float ('inf' )
295+ min_gt1_i = - 1
296+ max_t = 0
297+ max_i = - 1
298+
299+ # Find segments that maybe should swap
300+ for i , t in enumerate (time_alloc ):
301+ # Max time needs to be shortened so add an interpreter.
302+ if t > max_t :
303+ max_t = t
304+ max_i = i
305+
306+ # Min time needs to be lengthened so rem an interpreter,
307+ # but only if it has more than one interpreter
308+ if t < min_gt1_t and len (self .interpreters [i ]) > 1 :
309+ min_gt1_t = t
310+ min_gt1_i = i
311+
312+ # See if we can do better than the current max timing with swapping
313+ swap_i = None
314+ for interp_i , interpreters in enumerate (self .interpreters ):
315+ # Doesn't make sense to pull a TPU from a queue just to re-add it.
316+ if interp_i == max_i :
317+ continue
292318
293- # Adjust for number of TPUs allocated to it
294- if interpreter_counts [seg_i ] > 0 :
295- time_alloc .append (avg_time / interpreter_counts [seg_i ])
296- else :
297- # No interpreters result inf time
298- time_alloc .append (float ('inf' ))
299-
300- min_gt1_t = float ('inf' )
301- min_gt1_i = - 1
302- max_t = 0
303- max_i = - 1
304-
305- # Find segments that maybe should swap
306- for i , t in enumerate (time_alloc ):
307- # Max time needs to be shortened so add an interpreter.
308- if t > max_t :
309- max_t = t
310- max_i = i
311-
312- # Min time needs to be lengthened so rem an interpreter,
313- # but only if it has more than one interpreter
314- if t < min_gt1_t and len (self .interpreters [i ]) > 1 :
315- min_gt1_t = t
316- min_gt1_i = i
317-
318- # See if we can do better than the current max timing
319- untried_candidates = []
320- for interp_i , interpreters in enumerate (self .interpreters ):
321- # Doesn't make sense to pull a TPU from a queue just to re-add it.
322- if interp_i == max_i :
323- continue
324- # If it hasn't yet been tried for this segment
325- # (or if it has already found to be faster on this segment)
326- if any ([True for i in interpreters if i .exec_count [max_i ] < 10 or max_t - 0.1 > i .timings [max_i ] / i .exec_count [max_i ]]):
327- untried_candidates .append (interp_i )
319+ # Test all TPUs in this segment
320+ for i in interpreters :
321+ # Only calc valid time after a few runs
322+ new_max_t = 0
323+ if i .exec_count [max_i ] > 10 :
324+ new_max_t = i .timings [max_i ] / i .exec_count [max_i ]
325+ new_swap_t = 0
326+ if i .exec_count [interp_i ] > 10 :
327+ new_swap_t = i .timings [interp_i ] / i .exec_count [interp_i ]
328+
329+ # If it hasn't yet been tried for this segment or
330+ # If it has already found to be faster on this segment
331+ # and we aren't making the other segment the new worst.
332+ if i .exec_count [max_i ] < 10 or (max_t > new_max_t and max_t > new_swap_t ):
333+ swap_i = interp_i
334+ break
328335
329- return min_gt1_i , max_i , max (time_alloc ), untried_candidates [0 ] if len (untried_candidates ) > 0 else None
336+ return min_gt1_i , max_i , max (time_alloc ), swap_i
337+
338+
339+ def balance_queues (self ):
340+ # Don't bother if someone else is working on balancing
341+ if len (self .queues ) <= 1 or len (self .tpu_list ) < 2 or self .balance_ttl <= 0 or \
342+ not self .balance_lock .acquire (blocking = False ):
343+ return
330344
331345 interpreter_counts = [len (i ) for i in self .interpreters ]
332- min_i , max_i , current_max , min_untried_i = eval_timings (interpreter_counts )
346+ min_i , max_i , current_max , swap_i = self . _eval_timings (interpreter_counts )
333347 interpreter_counts [min_i ] -= 1
334348 interpreter_counts [max_i ] += 1
335- _ , _ , new_max , _ = eval_timings (interpreter_counts )
349+ _ , _ , new_max , _ = self . _eval_timings (interpreter_counts )
336350
337351 if new_max + 1.0 < current_max :
338- # Allocate more TPUs to slow segments
352+ # 1st Priority: Allocate more TPUs to slow segments
339353 logging .info (f"Re-balancing from queue { min_i } to { max_i } (max from { current_max :.2f} to { new_max :.2f} )" )
340354
341355 realloc_interp = self ._rem_interpreter_from (min_i )
@@ -344,24 +358,22 @@ def eval_timings(interpreter_counts):
344358 realloc_interp .start (max_i , self .fbytes_list [max_i ])
345359 self .interpreters [max_i ].append (realloc_interp )
346360
347- elif min_untried_i is not None :
348- # Swap slow segments with faster ones to see if we can run them faster.
349- # It might be a good way to optimize for heterogenous hardware.
350- logging .info (f"Re-balancing between queues { min_untried_i } and { max_i } " )
361+ elif swap_i is not None :
362+ # 2nd Priority: Swap slow segments with faster ones to see if we can
363+ # run them faster. Hopefully still a good way to optimize for
364+ # heterogenous hardware.
365+ logging .info (f"Re-balancing between queues { swap_i } and { max_i } " )
351366
352367 # Stop them
353- new_max_i = self ._rem_interpreter_from (min_untried_i )
354- new_min_untried_i = self ._rem_interpreter_from (max_i )
368+ new_max = self ._rem_interpreter_from (swap_i )
369+ new_swap = self ._rem_interpreter_from (max_i )
355370
356371 # Swap them
357- new_max_i .start (max_i , self .fbytes_list [max_i ])
358- self .interpreters [max_i ].append (new_max_i )
359-
360- new_min_untried_i .start (min_untried_i , self .fbytes_list [min_untried_i ])
361- self .interpreters [min_untried_i ].append (new_min_untried_i )
372+ new_max .start (max_i , self .fbytes_list [max_i ])
373+ self .interpreters [max_i ].append (new_max )
362374
363- # FIXME: After we have TPUs evaluated and otherwise balanced, we could
364- # further optimize by ensuring the slowest segment doesn't contain any slow TPUs.
375+ new_swap . start ( swap_i , self . fbytes_list [ swap_i ])
376+ self . interpreters [ swap_i ]. append ( new_swap )
365377 else :
366378 # Return if we don't want to swap
367379 self .balance_lock .release ()
0 commit comments