-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathcondinst_detector.py
More file actions
60 lines (56 loc) · 2.57 KB
/
condinst_detector.py
File metadata and controls
60 lines (56 loc) · 2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from mmdet.models.builder import DETECTORS
from mmdet.models.detectors.single_stage import SingleStageDetector
import torch
@DETECTORS.register_module()
class CondInst(SingleStageDetector):
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(CondInst, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
super(SingleStageDetector, self).forward_train(img, img_metas)
batch_input_shape = img_metas[0]['batch_input_shape']
gt_masks = [
gt_mask.pad(batch_input_shape).to_tensor(dtype=torch.bool, device=img.device)
for gt_mask in gt_masks
]
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_masks, gt_bboxes_ignore)
return losses
def simple_test(self, img, img_meta, rescale=False):
x = self.extract_feat(img)
outs = self.bbox_head(x)
bbox_inputs = outs + (img_meta, self.test_cfg, rescale)
bbox_results, segm_results = self.bbox_head.get_bboxes(*bbox_inputs)
return list(zip(bbox_results, segm_results))
def aug_test(self, imgs, img_metas, rescale=False):
raise NotImplementedError