|
import torch |
|
from mmcv.ops import batched_nms |
|
|
|
from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes, |
|
multiclass_nms) |
|
from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead |
|
from ..builder import HEADS |
|
|
|
|
|
@HEADS.register_module() |
|
class TridentRoIHead(StandardRoIHead): |
|
"""Trident roi head. |
|
|
|
Args: |
|
num_branch (int): Number of branches in TridentNet. |
|
test_branch_idx (int): In inference, all 3 branches will be used |
|
if `test_branch_idx==-1`, otherwise only branch with index |
|
`test_branch_idx` will be used. |
|
""" |
|
|
|
def __init__(self, num_branch, test_branch_idx, **kwargs): |
|
self.num_branch = num_branch |
|
self.test_branch_idx = test_branch_idx |
|
super(TridentRoIHead, self).__init__(**kwargs) |
|
|
|
def merge_trident_bboxes(self, trident_det_bboxes, trident_det_labels): |
|
"""Merge bbox predictions of each branch.""" |
|
if trident_det_bboxes.numel() == 0: |
|
det_bboxes = trident_det_bboxes.new_zeros((0, 5)) |
|
det_labels = trident_det_bboxes.new_zeros((0, ), dtype=torch.long) |
|
else: |
|
nms_bboxes = trident_det_bboxes[:, :4] |
|
nms_scores = trident_det_bboxes[:, 4].contiguous() |
|
nms_inds = trident_det_labels |
|
nms_cfg = self.test_cfg['nms'] |
|
det_bboxes, keep = batched_nms(nms_bboxes, nms_scores, nms_inds, |
|
nms_cfg) |
|
det_labels = trident_det_labels[keep] |
|
if self.test_cfg['max_per_img'] > 0: |
|
det_labels = det_labels[:self.test_cfg['max_per_img']] |
|
det_bboxes = det_bboxes[:self.test_cfg['max_per_img']] |
|
|
|
return det_bboxes, det_labels |
|
|
|
def simple_test(self, |
|
x, |
|
proposal_list, |
|
img_metas, |
|
proposals=None, |
|
rescale=False): |
|
"""Test without augmentation as follows: |
|
|
|
1. Compute prediction bbox and label per branch. |
|
2. Merge predictions of each branch according to scores of |
|
bboxes, i.e., bboxes with higher score are kept to give |
|
top-k prediction. |
|
""" |
|
assert self.with_bbox, 'Bbox head must be implemented.' |
|
det_bboxes_list, det_labels_list = self.simple_test_bboxes( |
|
x, img_metas, proposal_list, self.test_cfg, rescale=rescale) |
|
num_branch = self.num_branch if self.test_branch_idx == -1 else 1 |
|
for _ in range(len(det_bboxes_list)): |
|
if det_bboxes_list[_].shape[0] == 0: |
|
det_bboxes_list[_] = det_bboxes_list[_].new_empty((0, 5)) |
|
det_bboxes, det_labels = [], [] |
|
for i in range(len(img_metas) // num_branch): |
|
det_result = self.merge_trident_bboxes( |
|
torch.cat(det_bboxes_list[i * num_branch:(i + 1) * |
|
num_branch]), |
|
torch.cat(det_labels_list[i * num_branch:(i + 1) * |
|
num_branch])) |
|
det_bboxes.append(det_result[0]) |
|
det_labels.append(det_result[1]) |
|
|
|
bbox_results = [ |
|
bbox2result(det_bboxes[i], det_labels[i], |
|
self.bbox_head.num_classes) |
|
for i in range(len(det_bboxes)) |
|
] |
|
return bbox_results |
|
|
|
def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg): |
|
"""Test det bboxes with test time augmentation.""" |
|
aug_bboxes = [] |
|
aug_scores = [] |
|
for x, img_meta in zip(feats, img_metas): |
|
|
|
img_shape = img_meta[0]['img_shape'] |
|
scale_factor = img_meta[0]['scale_factor'] |
|
flip = img_meta[0]['flip'] |
|
flip_direction = img_meta[0]['flip_direction'] |
|
|
|
trident_bboxes, trident_scores = [], [] |
|
for branch_idx in range(len(proposal_list)): |
|
proposals = bbox_mapping(proposal_list[0][:, :4], img_shape, |
|
scale_factor, flip, flip_direction) |
|
rois = bbox2roi([proposals]) |
|
bbox_results = self._bbox_forward(x, rois) |
|
bboxes, scores = self.bbox_head.get_bboxes( |
|
rois, |
|
bbox_results['cls_score'], |
|
bbox_results['bbox_pred'], |
|
img_shape, |
|
scale_factor, |
|
rescale=False, |
|
cfg=None) |
|
trident_bboxes.append(bboxes) |
|
trident_scores.append(scores) |
|
|
|
aug_bboxes.append(torch.cat(trident_bboxes, 0)) |
|
aug_scores.append(torch.cat(trident_scores, 0)) |
|
|
|
merged_bboxes, merged_scores = merge_aug_bboxes( |
|
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg) |
|
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, |
|
rcnn_test_cfg.score_thr, |
|
rcnn_test_cfg.nms, |
|
rcnn_test_cfg.max_per_img) |
|
return det_bboxes, det_labels |
|
|