File size: 4,720 Bytes
a249588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine import MessageHub
from mmengine.dist import barrier, broadcast, get_dist_info
from mmengine.structures import PixelData
from torch import Tensor

from mmpose.registry import MODELS
from mmpose.structures import PoseDataSample


@MODELS.register_module()
class BatchSyncRandomResize(nn.Module):
    """Batch random resize which synchronizes the random size across ranks.

    Args:
        random_size_range (tuple): The multi-scale random range during
            multi-scale training.
        interval (int): The iter interval of change
            image size. Defaults to 10.
        size_divisor (int): Image size divisible factor.
            Defaults to 32.
    """

    def __init__(self,
                 random_size_range: Tuple[int, int],
                 interval: int = 10,
                 size_divisor: int = 32) -> None:
        super().__init__()
        self.rank, self.world_size = get_dist_info()
        self._input_size = None
        self._random_size_range = (round(random_size_range[0] / size_divisor),
                                   round(random_size_range[1] / size_divisor))
        self._interval = interval
        self._size_divisor = size_divisor

    def forward(self, inputs: Tensor, data_samples: List[PoseDataSample]
                ) -> Tuple[Tensor, List[PoseDataSample]]:
        """resize a batch of images and bboxes to shape ``self._input_size``"""
        h, w = inputs.shape[-2:]
        if self._input_size is None:
            self._input_size = (h, w)
        scale_y = self._input_size[0] / h
        scale_x = self._input_size[1] / w
        if scale_x != 1 or scale_y != 1:
            inputs = F.interpolate(
                inputs,
                size=self._input_size,
                mode='bilinear',
                align_corners=False)
            for data_sample in data_samples:
                img_shape = (int(data_sample.img_shape[0] * scale_y),
                             int(data_sample.img_shape[1] * scale_x))
                pad_shape = (int(data_sample.pad_shape[0] * scale_y),
                             int(data_sample.pad_shape[1] * scale_x))
                data_sample.set_metainfo({
                    'img_shape': img_shape,
                    'pad_shape': pad_shape,
                    'batch_input_shape': self._input_size
                })

                if 'gt_instance_labels' not in data_sample:
                    continue

                if 'bboxes' in data_sample.gt_instance_labels:
                    data_sample.gt_instance_labels.bboxes[..., 0::2] *= scale_x
                    data_sample.gt_instance_labels.bboxes[..., 1::2] *= scale_y

                if 'keypoints' in data_sample.gt_instance_labels:
                    data_sample.gt_instance_labels.keypoints[..., 0] *= scale_x
                    data_sample.gt_instance_labels.keypoints[..., 1] *= scale_y

                if 'areas' in data_sample.gt_instance_labels:
                    data_sample.gt_instance_labels.areas *= scale_x * scale_y

                if 'gt_fields' in data_sample \
                        and 'heatmap_mask' in data_sample.gt_fields:

                    mask = data_sample.gt_fields.heatmap_mask.unsqueeze(0)
                    gt_fields = PixelData()
                    gt_fields.set_field(
                        F.interpolate(
                            mask.float(),
                            size=self._input_size,
                            mode='bilinear',
                            align_corners=False).squeeze(0), 'heatmap_mask')

                    data_sample.gt_fields = gt_fields

        message_hub = MessageHub.get_current_instance()
        if (message_hub.get_info('iter') + 1) % self._interval == 0:
            self._input_size = self._get_random_size(
                aspect_ratio=float(w / h), device=inputs.device)
        return inputs, data_samples

    def _get_random_size(self, aspect_ratio: float,
                         device: torch.device) -> Tuple[int, int]:
        """Randomly generate a shape in ``_random_size_range`` and broadcast to
        all ranks."""
        tensor = torch.LongTensor(2).to(device)
        if self.rank == 0:
            size = random.randint(*self._random_size_range)
            size = (self._size_divisor * size,
                    self._size_divisor * int(aspect_ratio * size))
            tensor[0] = size[0]
            tensor[1] = size[1]
        barrier()
        broadcast(tensor, 0)
        input_size = (tensor[0].item(), tensor[1].item())
        return input_size