@@ -239,20 +239,20 @@ def __init__(self, tpu_list: list, fname_list: list):
239239 with open (fname , "rb" ) as fd :
240240 self .fbytes_list .append (fd .read ())
241241
242- self ._init_interpreters ()
242+ with self .balance_lock :
243+ self ._init_interpreters ()
243244
244245 def _init_interpreters (self ):
245246
246247 start_boot_time = time .perf_counter_ns ()
247248
248249 # Fill TPUs with interpreters
249- with self .balance_lock :
250- for i , tpu_name in enumerate (self .tpu_list ):
251- seg_idx = i % len (self .fname_list )
250+ for i , tpu_name in enumerate (self .tpu_list ):
251+ seg_idx = i % len (self .fname_list )
252252
253- i = DynamicInterpreter (self .fname_list , tpu_name , self .queues , self .rebalancing_lock )
254- i .start (seg_idx , self .fbytes_list [seg_idx ])
255- self .interpreters [seg_idx ].append (i )
253+ i = DynamicInterpreter (self .fname_list , tpu_name , self .queues , self .rebalancing_lock )
254+ i .start (seg_idx , self .fbytes_list [seg_idx ])
255+ self .interpreters [seg_idx ].append (i )
256256
257257 self .first_name = self .interpreters [0 ][0 ].input_details [0 ]['name' ]
258258
@@ -261,8 +261,9 @@ def _init_interpreters(self):
261261
262262
263263 def enqueue (self , in_tensor , out_q : queue .Queue ):
264- if not self .first_name :
265- self ._init_interpreters ()
264+ with self .balance_lock :
265+ if not self .first_name :
266+ self ._init_interpreters ()
266267
267268 self .queues [0 ].put (({self .first_name : in_tensor }, out_q ))
268269
@@ -500,11 +501,11 @@ def __init__(self, tpu_limit: int = -1):
500501 def _watchdog (self ):
501502 self .watchdog_time = time .time ()
502503 while not self .watchdog_shutdown :
503- if self .pipe and \
504+ if self .pipe and self . pipe . first_name is None and \
504505 time .time () - self .watchdog_time > self .max_idle_secs_before_recycle :
505506 logging .warning ("No work in {} seconds, watchdog shutting down TPUs." .format (self .max_idle_secs_before_recycle ))
506507 self .runner_lock .acquire (timeout = MAX_WAIT_TIME )
507- if self .pipe . first_name :
508+ if self .pipe :
508509 self .pipe .delete ()
509510 self .runner_lock .release ()
510511 # Pipeline will reinitialize itself as needed
0 commit comments