Skip to content

Commit 2a6f862

Browse files
committed
Better tpu balancing
1 parent 1b16ada commit 2a6f862

2 files changed

Lines changed: 23 additions & 14 deletions

File tree

src/modules/ObjectDetectionCoral/objectdetection_coral_multitpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def main():
244244
tot_infr_time += infr_time
245245

246246
# Start a timer for the last ~half of the run for more accurate benchmark
247-
if chunk_i > (args.count-1) / 3.0:
247+
if chunk_i > (args.count-1) / 2.0:
248248
half_infr_count += 1
249249
if half_wall_start is None:
250250
half_wall_start = time.perf_counter()

src/modules/ObjectDetectionCoral/tpu_runner.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def __init__(self, tpu_list: list, fname_list: list):
243243
self._init_interpreters()
244244

245245
def _init_interpreters(self):
246-
# Set a Time To Live for balancing so we don't swap for inf in corner cases
246+
# Set a Time To Live for balancing so we don't thrash
247247
self.balance_ttl = len(self.tpu_list) * 2
248248
start_boot_time = time.perf_counter_ns()
249249

@@ -269,7 +269,7 @@ def enqueue(self, in_tensor, out_q: queue.Queue):
269269
self.queues[0].put(({self.first_name: in_tensor}, out_q))
270270

271271

272-
def _eval_timings(interpreter_counts):
272+
def _eval_timings(self, interpreter_counts):
273273
# How much time are we allocating for each segment
274274
time_alloc = []
275275
VALID_CNT_THRESH = 50
@@ -310,33 +310,42 @@ def _eval_timings(interpreter_counts):
310310
min_gt1_t = t
311311
min_gt1_i = i
312312

313-
# Only eval swapping max segment if we have many samples
314-
if VALID_CNT_THRESH > sum([i.exec_count[max_i] for i in self.interpreters[max_i]]):
315-
return min_gt1_i, max_i, max(time_alloc), None
313+
# Only eval swapping max time segment if we have many samples in the current setup
314+
for i in self.interpreters[max_i]:
315+
if i.exec_count[max_i] < VALID_CNT_THRESH:
316+
return min_gt1_i, max_i, max(time_alloc), None
316317

317-
# See if we can do better than the current max timing with swapping
318+
# Undo avg interp count adjustment for TPU-to-TPU comparisons
319+
max_t = max([i.timings[max_i] / i.exec_count[max_i] for i in self.interpreters[max_i]])
320+
321+
# See if we can do better than the current max time by swapping segments between TPUs
318322
swap_i = None
323+
swap_t = float('inf')
319324
for interp_i, interpreters in enumerate(self.interpreters):
320325
# Doesn't make sense to pull a TPU from a queue just to re-add it.
321326
if interp_i == max_i:
322327
continue
323328

324329
# Test all TPUs in this segment
325330
for i in interpreters:
331+
# If TPU hasn't yet been tried for this segment or ...
332+
if i.exec_count[max_i] < VALID_CNT_THRESH:
333+
return min_gt1_i, max_i, max(time_alloc), interp_i
334+
326335
# Only calc valid time after a few runs
327336
new_max_t = 0.0
328337
if i.exec_count[max_i] > VALID_CNT_THRESH:
329338
new_max_t = i.timings[max_i] / i.exec_count[max_i]
330339
new_swap_t = 0.0
331340
if i.exec_count[interp_i] > VALID_CNT_THRESH:
332341
new_swap_t = i.timings[interp_i] / i.exec_count[interp_i]
333-
334-
# If it hasn't yet been tried for this segment or
335-
# If it has already found to be faster on this segment
336-
# and we aren't making the other segment the new worst.
337-
if i.exec_count[max_i] < VALID_CNT_THRESH or (max_t > new_max_t and max_t > new_swap_t):
342+
343+
# If TPU has already found to be faster on this segment
344+
# and we aren't making the other segment the new worst
345+
# and we are choosing the best available candidate.
346+
if max_t-0.5 > new_max_t and max_t > new_swap_t and swap_t > new_max_t:
338347
swap_i = interp_i
339-
break
348+
swap_t = new_max_t
340349

341350
return min_gt1_i, max_i, max(time_alloc), swap_i
342351

@@ -367,7 +376,7 @@ def balance_queues(self):
367376
# 2nd Priority: Swap slow segments with faster ones to see if we can
368377
# run them faster. Hopefully still a good way to optimize for
369378
# heterogenous hardware.
370-
logging.info(f"Re-balancing between queues {swap_i} and {max_i}")
379+
logging.info(f"Auto-tuning between queues {swap_i} and {max_i}")
371380

372381
# Stop them
373382
new_max = self._rem_interpreter_from(swap_i)

0 commit comments

Comments
 (0)