rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from mmcv.transforms.loading import LoadImageFromFile
from mmcv.transforms.processing import TestTimeAug
from mmdet.datasets.transforms.formatting import PackDetInputs
from mmdet.datasets.transforms.loading import LoadAnnotations
from mmdet.datasets.transforms.transforms import Pad, RandomFlip, Resize
from mmdet.models.test_time_augs.det_tta import DetTTAModel
tta_model = dict(
type=DetTTAModel,
tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.6), max_per_img=100))
img_scales = [(640, 640), (320, 320), (960, 960)]
tta_pipeline = [
dict(type=LoadImageFromFile, backend_args=None),
dict(
type=TestTimeAug,
transforms=[
[dict(type=Resize, scale=s, keep_ratio=True) for s in img_scales],
[
# ``RandomFlip`` must be placed before ``Pad``, otherwise
# bounding box coordinates after flipping cannot be
# recovered correctly.
dict(type=RandomFlip, prob=1.),
dict(type=RandomFlip, prob=0.)
],
[
dict(
type=Pad,
size=(960, 960),
pad_val=dict(img=(114, 114, 114))),
],
[dict(type=LoadAnnotations, with_bbox=True)],
[
dict(
type=PackDetInputs,
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction'))
]
])
]