-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcompile_trt_hybrid24.py
More file actions
55 lines (49 loc) · 2.25 KB
/
compile_trt_hybrid24.py
File metadata and controls
55 lines (49 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""Compile TRT models for mouseHybrid24 predict2D pipeline."""
import sys, os
from pathlib import Path
import torch, torch_tensorrt
_SCRIPT_DIR = Path(__file__).resolve().parent
_JARVIS_CANDIDATES = [
Path.home() / 'JARVIS-HybridNet',
_SCRIPT_DIR.parent / 'JARVIS-HybridNet',
Path('/home/user/src/JARVIS-HybridNet'),
]
JARVIS_DIR = next((p for p in _JARVIS_CANDIDATES if p.is_dir()), _JARVIS_CANDIDATES[0])
sys.path.insert(0, str(JARVIS_DIR))
from jarvis.config.project_manager import ProjectManager
from jarvis.efficienttrack.efficienttrack import EfficientTrack
trt_dir = str(JARVIS_DIR / 'projects' / 'mouseHybrid24' / 'trt-models' / 'predict2D')
os.makedirs(trt_dir, exist_ok=True)
p = ProjectManager(); p.load('mouseHybrid24'); cfg = p.cfg
img_size_cd = cfg.CENTERDETECT.IMAGE_SIZE
bbox_size = cfg.KEYPOINTDETECT.BOUNDING_BOX_SIZE
# ---- CenterDetect ----
print(f"Compiling CenterDetect ({img_size_cd}x{img_size_cd}) ...")
cd_model = EfficientTrack('CenterDetectInference', cfg, 'latest').model.eval().cuda()
traced = torch.jit.trace(cd_model, torch.randn(16, 3, img_size_cd, img_size_cd, device='cuda'))
trt_cd = torch_tensorrt.compile(
traced, ir='ts',
inputs=[torch_tensorrt.Input(
min_shape=(1, 3, img_size_cd, img_size_cd),
opt_shape=(16, 3, img_size_cd, img_size_cd),
max_shape=(16, 3, img_size_cd, img_size_cd),
dtype=torch.float)],
enabled_precisions={torch.float16})
torch.jit.save(trt_cd, os.path.join(trt_dir, 'centerDetect.pt'))
print(f" Saved centerDetect.pt")
del cd_model, traced, trt_cd; torch.cuda.empty_cache()
# ---- KeypointDetect ----
print(f"Compiling KeypointDetect ({bbox_size}x{bbox_size}) ...")
kd_model = EfficientTrack('KeypointDetectInference', cfg, 'latest').model.eval().cuda()
traced = torch.jit.trace(kd_model, torch.randn(8, 3, bbox_size, bbox_size, device='cuda'))
trt_kd = torch_tensorrt.compile(
traced, ir='ts',
inputs=[torch_tensorrt.Input(
min_shape=(1, 3, bbox_size, bbox_size),
opt_shape=(8, 3, bbox_size, bbox_size),
max_shape=(16, 3, bbox_size, bbox_size),
dtype=torch.float)],
enabled_precisions={torch.float16})
torch.jit.save(trt_kd, os.path.join(trt_dir, 'keypointDetect.pt'))
print(f" Saved keypointDetect.pt")
print("TRT compilation done!")