Skip to content

Commit 1b16ada

Browse files
authored
Addle TPU swap logic
1 parent fc93511 commit 1b16ada

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

src/modules/ObjectDetectionCoral/tpu_runner.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,13 @@ def enqueue(self, in_tensor, out_q: queue.Queue):
272272
def _eval_timings(interpreter_counts):
273273
# How much time are we allocating for each segment
274274
time_alloc = []
275+
VALID_CNT_THRESH = 50
275276

276277
for seg_i in range(len(self.interpreters)):
277278
# Find average runtime for this segment
278279
avg_times = []
279280
for interpreters in self.interpreters:
280-
avg_times += [i.timings[seg_i] / i.exec_count[seg_i] for i in interpreters if i.exec_count[seg_i] != 0]
281+
avg_times += [i.timings[seg_i] / i.exec_count[seg_i] for i in interpreters if i.exec_count[seg_i] > VALID_CNT_THRESH]
281282

282283
if avg_times:
283284
avg_time = sum(avg_times) / len(avg_times)
@@ -293,7 +294,7 @@ def _eval_timings(interpreter_counts):
293294

294295
min_gt1_t = float('inf')
295296
min_gt1_i = -1
296-
max_t = 0
297+
max_t = 0.0
297298
max_i = -1
298299

299300
# Find segments that maybe should swap
@@ -309,6 +310,10 @@ def _eval_timings(interpreter_counts):
309310
min_gt1_t = t
310311
min_gt1_i = i
311312

313+
# Only eval swapping max segment if we have many samples
314+
if VALID_CNT_THRESH > sum([i.exec_count[max_i] for i in self.interpreters[max_i]]):
315+
return min_gt1_i, max_i, max(time_alloc), None
316+
312317
# See if we can do better than the current max timing with swapping
313318
swap_i = None
314319
for interp_i, interpreters in enumerate(self.interpreters):
@@ -319,17 +324,17 @@ def _eval_timings(interpreter_counts):
319324
# Test all TPUs in this segment
320325
for i in interpreters:
321326
# Only calc valid time after a few runs
322-
new_max_t = 0
323-
if i.exec_count[max_i] > 10:
327+
new_max_t = 0.0
328+
if i.exec_count[max_i] > VALID_CNT_THRESH:
324329
new_max_t = i.timings[max_i] / i.exec_count[max_i]
325-
new_swap_t = 0
326-
if i.exec_count[interp_i] > 10:
330+
new_swap_t = 0.0
331+
if i.exec_count[interp_i] > VALID_CNT_THRESH:
327332
new_swap_t = i.timings[interp_i] / i.exec_count[interp_i]
328333

329334
# If it hasn't yet been tried for this segment or
330335
# If it has already found to be faster on this segment
331336
# and we aren't making the other segment the new worst.
332-
if i.exec_count[max_i] < 10 or (max_t > new_max_t and max_t > new_swap_t):
337+
if i.exec_count[max_i] < VALID_CNT_THRESH or (max_t > new_max_t and max_t > new_swap_t):
333338
swap_i = interp_i
334339
break
335340

0 commit comments

Comments
 (0)