Skip to content

Commit 1ef657a

Browse files
committed
[refactor] YOLO 객체 탐지 코드 gpu 강제 코드 추가
1 parent f8e7231 commit 1ef657a

1 file changed

Lines changed: 32 additions & 5 deletions

File tree

everTale/app/service/yolo_service.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,35 @@
99

1010
YOLO_MODEL_PATH = os.environ["YOLO_MODEL_PATH"]
1111

12+
from ultralytics import YOLO
13+
import os, torch
14+
15+
def _resolve_yolo_path() -> str:
16+
path = os.getenv("YOLO_MODEL_PATH", "/models/my_yolo_model.pt")
17+
if not os.path.exists(path):
18+
raise FileNotFoundError(f"YOLO model not found at: {path}")
19+
return path
20+
21+
def _require_gpu_for_yolo(stage: str = "YOLO load"):
22+
if torch.cuda.is_available():
23+
return 0 # device index for CUDA
24+
# MPS는 Ultralytics 지원이 제한적이므로 필요한 경우만 허용
25+
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
26+
return "mps"
27+
raise RuntimeError(f"[ERROR] No GPU backend during {stage}. CPU is not allowed for YOLO.")
28+
1229
def load_model() -> YOLO:
30+
path = _resolve_yolo_path()
31+
device = _require_gpu_for_yolo("YOLO load")
1332
try:
14-
model = YOLO(YOLO_MODEL_PATH)
15-
print("모델이 성공적으로 로드되었습니다.")
33+
model = YOLO(path)
34+
# warm-up(선택): 작은 더미로 한 번 실행해 메모리 로딩
35+
model.predict(source=np.zeros((64,64,3), dtype=np.uint8), device=device, imgsz=64, verbose=False)
36+
print(f"[INFO] YOLO loaded on device={device} from {path}")
1637
return model
1738
except Exception as e:
18-
print(f"모델 로드 중 오류가 발생했습니다: {e}")
19-
return None
39+
raise RuntimeError(f"Failed to load YOLO model at {path}: {e}")
40+
2041

2142
def _url_to_bgr(url: str) -> np.ndarray:
2243
resp = requests.get(url, timeout=10)
@@ -35,13 +56,19 @@ def detect_object(image_paths: List[str]) -> Dict[str, Any]:
3556
탐지 후보가 전혀 없으면 {"index": None, "url": None, "detection": None}
3657
"""
3758
model = load_model()
59+
device = 0 if torch.cuda.is_available() else "mps" # 위와 일치
3860
urls = image_paths[:8]
3961
candidates: List[Dict[str, Any]] = []
4062

4163
for idx, url in enumerate(urls):
4264
try:
4365
img = _url_to_bgr(url)
44-
results = model.predict(source=img, verbose=False)
66+
results = model.predict(
67+
source=img,
68+
device=device,
69+
half=torch.cuda.is_available(),
70+
verbose=False
71+
)
4572
if not results or results[0].boxes is None or results[0].boxes.shape[0] == 0:
4673
continue
4774

0 commit comments

Comments
 (0)