Skip to content

Commit 9d167d3

Browse files
committed
Better queue balancer
1 parent 43ff9de commit 9d167d3

2 files changed

Lines changed: 66 additions & 30 deletions

File tree

src/modules/ObjectDetectionCoral/objectdetection_coral_multitpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def main():
275275
(wall_time * 1000 / args.count, args.count,
276276
(time.perf_counter() - start_one) * 1000))
277277

278-
logging.info('%.2fms avg time blocked across %d threads; %.2fms ea for final %d inferences' %
278+
logging.info('%.2fms avg time blocked across %d threads; %.3fms ea for final %d inferences' %
279279
(tot_infr_time / args.count, thread_cnt,
280280
half_wall_time * 1000 / half_infr_count, half_infr_count))
281281

src/modules/ObjectDetectionCoral/tpu_runner.py

Lines changed: 65 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -269,31 +269,34 @@ def enqueue(self, in_tensor, out_q: queue.Queue):
269269

270270
def balance_queues(self):
271271
# Don't bother if someone else is working on balancing
272-
if len(self.queues) <= 1 or len(self.tpu_list) <= 2 or \
273-
len(self.queues) == len(self.tpu_list) or \
272+
if len(self.queues) <= 1 or len(self.tpu_list) < 2 or \
274273
not self.balance_lock.acquire(blocking=False):
275274
return
276275

277276
def eval_timings(interpreter_counts):
278277
# How much time are we allocating for each segment
279278
time_alloc = []
280279

281-
for idx in range(len(self.interpreters)):
280+
for seg_i in range(len(self.interpreters)):
282281
# Find average runtime for this segment
283282
avg_times = []
284283
for interpreters in self.interpreters:
285-
avg_times += [i.timings[idx] / i.exec_count[idx] for i in interpreters if i.exec_count[idx] != 0]
284+
avg_times += [i.timings[seg_i] / i.exec_count[seg_i] for i in interpreters if i.exec_count[seg_i] != 0]
286285

287286
if avg_times:
288287
avg_time = sum(avg_times) / len(avg_times)
289288
else:
290-
return 0, 0, 0.0
289+
return 0, 0, 0.0, None
291290

292291
# Adjust for number of TPUs allocated to it
293-
time_alloc.append(avg_time / interpreter_counts[idx])
292+
if interpreter_counts[seg_i] > 0:
293+
time_alloc.append(avg_time / interpreter_counts[seg_i])
294+
else:
295+
# No interpreters result inf time
296+
time_alloc.append(float('inf'))
294297

295-
min_t = 100000000
296-
min_i = -1
298+
min_gt1_t = float('inf')
299+
min_gt1_i = -1
297300
max_t = 0
298301
max_i = -1
299302

@@ -306,28 +309,67 @@ def eval_timings(interpreter_counts):
306309

307310
# Min time needs to be lengthened so rem an interpreter,
308311
# but only if it has more than one interpreter
309-
if t < min_t and len(self.interpreters[i]) > 1:
310-
min_t = t
311-
min_i = i
312+
if t < min_gt1_t and len(self.interpreters[i]) > 1:
313+
min_gt1_t = t
314+
min_gt1_i = i
315+
316+
# See if we can do better than the current max timing
317+
untried_candidates = []
318+
for interp_i, interpreters in enumerate(self.interpreters):
319+
# Doesn't make sense to pull a TPU from a queue just to re-add it.
320+
if interp_i == max_i:
321+
continue
322+
# If it hasn't yet been tried for this segment
323+
# (or if it has already found to be faster on this segment)
324+
if any([True for i in interpreters if i.exec_count[max_i] == 0 or max_t-1.0 > i.timings[max_i] / i.exec_count[max_i]]):
325+
untried_candidates.append(interp_i)
312326

313-
return min_i, max_i, max(time_alloc)
327+
return min_gt1_i, max_i, max(time_alloc), untried_candidates[0] if len(untried_candidates) > 0 else None
314328

315329
interpreter_counts = [len(i) for i in self.interpreters]
316-
min_i, max_i, current_max = eval_timings(interpreter_counts)
330+
min_i, max_i, current_max, min_untried_i = eval_timings(interpreter_counts)
317331
interpreter_counts[min_i] -= 1
318332
interpreter_counts[max_i] += 1
319-
_, _, new_max = eval_timings(interpreter_counts)
333+
_, _, new_max, _ = eval_timings(interpreter_counts)
320334

321-
# Return if we don't want to swap (+/- 1 ms)
335+
# Return if we don't want to swap
322336
if new_max+1.0 >= current_max:
323-
self.balance_lock.release()
324-
return
337+
if min_untried_i is None:
338+
self.balance_lock.release()
339+
return
340+
341+
# Swap slow segments with faster ones to see if we can run them faster.
342+
# It might be a good way to optimize for heterogenous hardware.
343+
logging.info(f"Re-balancing between queues {min_untried_i} and {max_i}")
344+
345+
# Stop them
346+
new_max_i = self._rem_interpreter_from(min_untried_i)
347+
new_min_untried_i = self._rem_interpreter_from(max_i)
348+
349+
# Swap them
350+
new_max_i.start(max_i, self.fbytes_list[max_i])
351+
self.interpreters[max_i].append(new_max_i)
352+
353+
new_min_untried_i.start(min_untried_i, self.fbytes_list[min_untried_i])
354+
self.interpreters[min_untried_i].append(new_min_untried_i)
355+
356+
else:
357+
logging.info(f"Re-balancing from queue {min_i} to {max_i} (max from {current_max:.2f} to {new_max:.2f})")
325358

326-
logging.info(f"Re-balancing from queue {min_i} to {max_i} (max from {current_max:.2f} to {new_max:.2f})")
359+
realloc_interp = self._rem_interpreter_from(min_i)
327360

361+
# Add to large (too-slow) queue
362+
realloc_interp.start(max_i, self.fbytes_list[max_i])
363+
self.interpreters[max_i].append(realloc_interp)
364+
365+
self.balance_lock.release()
366+
self.print_queue_len()
367+
368+
369+
def _rem_interpreter_from(self, interp_i):
328370
# Sending False kills the processing loop
329371
self.rebalancing_lock.acquire()
330-
self.queues[min_i].put(False)
372+
self.queues[interp_i].put(False)
331373

332374
# This is ugly, but I can't think of something better
333375
# Threads are blocked by queues. Queues may not have a stream
@@ -338,21 +380,15 @@ def eval_timings(interpreter_counts):
338380
# Block & wait
339381
realloc_interp = None
340382
with self.rebalancing_lock:
341-
for idx, interpreter in enumerate(self.interpreters[min_i]):
383+
for idx, interpreter in enumerate(self.interpreters[interp_i]):
342384
if not interpreter.interpreter:
343-
realloc_interp = self.interpreters[min_i].pop(idx)
385+
realloc_interp = self.interpreters[interp_i].pop(idx)
344386
break
387+
345388
if not realloc_interp:
346389
logging.warning("Unable to find killed interpreter")
347390
self.balance_lock.release()
348-
return
349-
350-
# Add to large (too-slow) queue
351-
realloc_interp.start(max_i, self.fbytes_list[max_i])
352-
self.interpreters[max_i].append(realloc_interp)
353-
354-
self.balance_lock.release()
355-
self.print_queue_len()
391+
return realloc_interp
356392

357393

358394
def print_queue_len(self):

0 commit comments

Comments
 (0)