File size: 1,731 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from mmcv.transforms.loading import LoadImageFromFile
from mmcv.transforms.processing import TestTimeAug

from mmdet.datasets.transforms.formatting import PackDetInputs
from mmdet.datasets.transforms.loading import LoadAnnotations
from mmdet.datasets.transforms.transforms import Pad, RandomFlip, Resize
from mmdet.models.test_time_augs.det_tta import DetTTAModel

tta_model = dict(
    type=DetTTAModel,
    tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.6), max_per_img=100))

img_scales = [(640, 640), (320, 320), (960, 960)]

tta_pipeline = [
    dict(type=LoadImageFromFile, backend_args=None),
    dict(
        type=TestTimeAug,
        transforms=[
            [dict(type=Resize, scale=s, keep_ratio=True) for s in img_scales],
            [
                # ``RandomFlip`` must be placed before ``Pad``, otherwise
                # bounding box coordinates after flipping cannot be
                # recovered correctly.
                dict(type=RandomFlip, prob=1.),
                dict(type=RandomFlip, prob=0.)
            ],
            [
                dict(
                    type=Pad,
                    size=(960, 960),
                    pad_val=dict(img=(114, 114, 114))),
            ],
            [dict(type=LoadAnnotations, with_bbox=True)],
            [
                dict(
                    type=PackDetInputs,
                    meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                               'scale_factor', 'flip', 'flip_direction'))
            ]
        ])
]