@@ -212,18 +212,18 @@ def __init__(self, tpu_list: list, fname_list: list):
212212
213213 self .max_pipeline_queue_length = MAX_PIPELINE_QUEUE_LEN
214214
215- self .fname_list = fname_list
216- self .tpu_list = tpu_list
217- self .interpreters = [[] for i in range (seg_count )]
215+ self .fname_list = fname_list
216+ self .tpu_list = tpu_list
217+ self .interpreters = [[] for i in range (seg_count )]
218218
219219 # Input queues for each segment; if we go over maxsize, something went wrong
220220 self .queues = [queue .Queue (maxsize = self .max_pipeline_queue_length ) for i in range (seg_count )]
221221
222222 # Lock for internal reorganization
223- self .balance_lock = threading .Lock ()
223+ self .balance_lock = threading .Lock ()
224224
225225 # Lock for interpreter use
226- self .rebalancing_lock = threading .Lock ()
226+ self .rebalancing_lock = threading .Lock ()
227227
228228 # Read file data
229229 self .fbytes_list = []
@@ -244,6 +244,7 @@ def __init__(self, tpu_list: list, fname_list: list):
244244
245245 def _init_interpreters (self ):
246246
247+ self .balance_ttl = len (self .tpu_list ) * 2
247248 start_boot_time = time .perf_counter_ns ()
248249
249250 # Fill TPUs with interpreters
@@ -270,7 +271,7 @@ def enqueue(self, in_tensor, out_q: queue.Queue):
270271
271272 def balance_queues (self ):
272273 # Don't bother if someone else is working on balancing
273- if len (self .queues ) <= 1 or len (self .tpu_list ) < 2 or \
274+ if len (self .queues ) <= 1 or len (self .tpu_list ) < 2 or self . balance_ttl <= 0 or \
274275 not self .balance_lock .acquire (blocking = False ):
275276 return
276277
@@ -363,6 +364,7 @@ def eval_timings(interpreter_counts):
363364 realloc_interp .start (max_i , self .fbytes_list [max_i ])
364365 self .interpreters [max_i ].append (realloc_interp )
365366
367+ self .balance_ttl -= 1
366368 self .balance_lock .release ()
367369 self .print_queue_len ()
368370
0 commit comments