Skip to content

Commit fc93511

Browse files
authored
Tweaks to TPU balancer
1 parent 753310f commit fc93511

1 file changed

Lines changed: 82 additions & 70 deletions

File tree

src/modules/ObjectDetectionCoral/tpu_runner.py

Lines changed: 82 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def __init__(self, tpu_list: list, fname_list: list):
243243
self._init_interpreters()
244244

245245
def _init_interpreters(self):
246-
246+
# Set a Time To Live for balancing so we don't swap for inf in corner cases
247247
self.balance_ttl = len(self.tpu_list) * 2
248248
start_boot_time = time.perf_counter_ns()
249249

@@ -269,73 +269,87 @@ def enqueue(self, in_tensor, out_q: queue.Queue):
269269
self.queues[0].put(({self.first_name: in_tensor}, out_q))
270270

271271

272-
def balance_queues(self):
273-
# Don't bother if someone else is working on balancing
274-
if len(self.queues) <= 1 or len(self.tpu_list) < 2 or self.balance_ttl <= 0 or \
275-
not self.balance_lock.acquire(blocking=False):
276-
return
272+
def _eval_timings(interpreter_counts):
273+
# How much time are we allocating for each segment
274+
time_alloc = []
277275

278-
def eval_timings(interpreter_counts):
279-
# How much time are we allocating for each segment
280-
time_alloc = []
276+
for seg_i in range(len(self.interpreters)):
277+
# Find average runtime for this segment
278+
avg_times = []
279+
for interpreters in self.interpreters:
280+
avg_times += [i.timings[seg_i] / i.exec_count[seg_i] for i in interpreters if i.exec_count[seg_i] != 0]
281281

282-
for seg_i in range(len(self.interpreters)):
283-
# Find average runtime for this segment
284-
avg_times = []
285-
for interpreters in self.interpreters:
286-
avg_times += [i.timings[seg_i] / i.exec_count[seg_i] for i in interpreters if i.exec_count[seg_i] != 0]
282+
if avg_times:
283+
avg_time = sum(avg_times) / len(avg_times)
284+
else:
285+
return 0, 0, 0.0, None
287286

288-
if avg_times:
289-
avg_time = sum(avg_times) / len(avg_times)
290-
else:
291-
return 0, 0, 0.0, None
287+
# Adjust for number of TPUs allocated to it
288+
if interpreter_counts[seg_i] > 0:
289+
time_alloc.append(avg_time / interpreter_counts[seg_i])
290+
else:
291+
# No interpreters result inf time
292+
time_alloc.append(float('inf'))
293+
294+
min_gt1_t = float('inf')
295+
min_gt1_i = -1
296+
max_t = 0
297+
max_i = -1
298+
299+
# Find segments that maybe should swap
300+
for i, t in enumerate(time_alloc):
301+
# Max time needs to be shortened so add an interpreter.
302+
if t > max_t:
303+
max_t = t
304+
max_i = i
305+
306+
# Min time needs to be lengthened so rem an interpreter,
307+
# but only if it has more than one interpreter
308+
if t < min_gt1_t and len(self.interpreters[i]) > 1:
309+
min_gt1_t = t
310+
min_gt1_i = i
311+
312+
# See if we can do better than the current max timing with swapping
313+
swap_i = None
314+
for interp_i, interpreters in enumerate(self.interpreters):
315+
# Doesn't make sense to pull a TPU from a queue just to re-add it.
316+
if interp_i == max_i:
317+
continue
292318

293-
# Adjust for number of TPUs allocated to it
294-
if interpreter_counts[seg_i] > 0:
295-
time_alloc.append(avg_time / interpreter_counts[seg_i])
296-
else:
297-
# No interpreters result inf time
298-
time_alloc.append(float('inf'))
299-
300-
min_gt1_t = float('inf')
301-
min_gt1_i = -1
302-
max_t = 0
303-
max_i = -1
304-
305-
# Find segments that maybe should swap
306-
for i, t in enumerate(time_alloc):
307-
# Max time needs to be shortened so add an interpreter.
308-
if t > max_t:
309-
max_t = t
310-
max_i = i
311-
312-
# Min time needs to be lengthened so rem an interpreter,
313-
# but only if it has more than one interpreter
314-
if t < min_gt1_t and len(self.interpreters[i]) > 1:
315-
min_gt1_t = t
316-
min_gt1_i = i
317-
318-
# See if we can do better than the current max timing
319-
untried_candidates = []
320-
for interp_i, interpreters in enumerate(self.interpreters):
321-
# Doesn't make sense to pull a TPU from a queue just to re-add it.
322-
if interp_i == max_i:
323-
continue
324-
# If it hasn't yet been tried for this segment
325-
# (or if it has already found to be faster on this segment)
326-
if any([True for i in interpreters if i.exec_count[max_i] < 10 or max_t-0.1 > i.timings[max_i] / i.exec_count[max_i]]):
327-
untried_candidates.append(interp_i)
319+
# Test all TPUs in this segment
320+
for i in interpreters:
321+
# Only calc valid time after a few runs
322+
new_max_t = 0
323+
if i.exec_count[max_i] > 10:
324+
new_max_t = i.timings[max_i] / i.exec_count[max_i]
325+
new_swap_t = 0
326+
if i.exec_count[interp_i] > 10:
327+
new_swap_t = i.timings[interp_i] / i.exec_count[interp_i]
328+
329+
# If it hasn't yet been tried for this segment or
330+
# If it has already found to be faster on this segment
331+
# and we aren't making the other segment the new worst.
332+
if i.exec_count[max_i] < 10 or (max_t > new_max_t and max_t > new_swap_t):
333+
swap_i = interp_i
334+
break
328335

329-
return min_gt1_i, max_i, max(time_alloc), untried_candidates[0] if len(untried_candidates) > 0 else None
336+
return min_gt1_i, max_i, max(time_alloc), swap_i
337+
338+
339+
def balance_queues(self):
340+
# Don't bother if someone else is working on balancing
341+
if len(self.queues) <= 1 or len(self.tpu_list) < 2 or self.balance_ttl <= 0 or \
342+
not self.balance_lock.acquire(blocking=False):
343+
return
330344

331345
interpreter_counts = [len(i) for i in self.interpreters]
332-
min_i, max_i, current_max, min_untried_i = eval_timings(interpreter_counts)
346+
min_i, max_i, current_max, swap_i = self._eval_timings(interpreter_counts)
333347
interpreter_counts[min_i] -= 1
334348
interpreter_counts[max_i] += 1
335-
_, _, new_max, _ = eval_timings(interpreter_counts)
349+
_, _, new_max, _ = self._eval_timings(interpreter_counts)
336350

337351
if new_max+1.0 < current_max:
338-
# Allocate more TPUs to slow segments
352+
# 1st Priority: Allocate more TPUs to slow segments
339353
logging.info(f"Re-balancing from queue {min_i} to {max_i} (max from {current_max:.2f} to {new_max:.2f})")
340354

341355
realloc_interp = self._rem_interpreter_from(min_i)
@@ -344,24 +358,22 @@ def eval_timings(interpreter_counts):
344358
realloc_interp.start(max_i, self.fbytes_list[max_i])
345359
self.interpreters[max_i].append(realloc_interp)
346360

347-
elif min_untried_i is not None:
348-
# Swap slow segments with faster ones to see if we can run them faster.
349-
# It might be a good way to optimize for heterogenous hardware.
350-
logging.info(f"Re-balancing between queues {min_untried_i} and {max_i}")
361+
elif swap_i is not None:
362+
# 2nd Priority: Swap slow segments with faster ones to see if we can
363+
# run them faster. Hopefully still a good way to optimize for
364+
# heterogenous hardware.
365+
logging.info(f"Re-balancing between queues {swap_i} and {max_i}")
351366

352367
# Stop them
353-
new_max_i = self._rem_interpreter_from(min_untried_i)
354-
new_min_untried_i = self._rem_interpreter_from(max_i)
368+
new_max = self._rem_interpreter_from(swap_i)
369+
new_swap = self._rem_interpreter_from(max_i)
355370

356371
# Swap them
357-
new_max_i.start(max_i, self.fbytes_list[max_i])
358-
self.interpreters[max_i].append(new_max_i)
359-
360-
new_min_untried_i.start(min_untried_i, self.fbytes_list[min_untried_i])
361-
self.interpreters[min_untried_i].append(new_min_untried_i)
372+
new_max.start(max_i, self.fbytes_list[max_i])
373+
self.interpreters[max_i].append(new_max)
362374

363-
# FIXME: After we have TPUs evaluated and otherwise balanced, we could
364-
# further optimize by ensuring the slowest segment doesn't contain any slow TPUs.
375+
new_swap.start(swap_i, self.fbytes_list[swap_i])
376+
self.interpreters[swap_i].append(new_swap)
365377
else:
366378
# Return if we don't want to swap
367379
self.balance_lock.release()

0 commit comments

Comments
 (0)