File size: 9,048 Bytes
a249588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union

import torch
from torch import Tensor, nn

from mmpose.models.utils.tta import flip_visibility
from mmpose.registry import MODELS
from mmpose.utils.tensor_utils import to_numpy
from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType,
                                 OptSampleList, Predictions)
from ..base_head import BaseHead


@MODELS.register_module()
class VisPredictHead(BaseHead):
    """VisPredictHead must be used together with other heads. It can predict
    keypoints coordinates of and their visibility simultaneously. In the
    current version, it only supports top-down approaches.

    Args:
        pose_cfg (Config): Config to construct keypoints prediction head
        loss (Config): Config for visibility loss. Defaults to use
            :class:`BCELoss`
        use_sigmoid (bool): Whether to use sigmoid activation function
        init_cfg (Config, optional): Config to control the initialization. See
            :attr:`default_init_cfg` for default settings
    """

    def __init__(self,
                 pose_cfg: ConfigType,
                 loss: ConfigType = dict(
                     type='BCELoss', use_target_weight=False,
                     use_sigmoid=True),
                 init_cfg: OptConfigType = None):

        if init_cfg is None:
            init_cfg = self.default_init_cfg

        super().__init__(init_cfg)

        self.in_channels = pose_cfg['in_channels']
        if pose_cfg.get('num_joints', None) is not None:
            self.out_channels = pose_cfg['num_joints']
        elif pose_cfg.get('out_channels', None) is not None:
            self.out_channels = pose_cfg['out_channels']
        else:
            raise ValueError('VisPredictHead requires \'num_joints\' or'
                             ' \'out_channels\' in the pose_cfg.')

        self.loss_module = MODELS.build(loss)

        self.pose_head = MODELS.build(pose_cfg)
        self.pose_cfg = pose_cfg

        self.use_sigmoid = loss.get('use_sigmoid', False)

        modules = [
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(self.in_channels, self.out_channels)
        ]
        if self.use_sigmoid:
            modules.append(nn.Sigmoid())

        self.vis_head = nn.Sequential(*modules)

    def vis_forward(self, feats: Tuple[Tensor]):
        """Forward the vis_head. The input is multi scale feature maps and the
        output is coordinates visibility.

        Args:
            feats (Tuple[Tensor]): Multi scale feature maps.

        Returns:
            Tensor: output coordinates visibility.
        """
        x = feats[-1]
        while len(x.shape) < 4:
            x.unsqueeze_(-1)
        x = self.vis_head(x)
        return x.reshape(-1, self.out_channels)

    def forward(self, feats: Tuple[Tensor]):
        """Forward the network. The input is multi scale feature maps and the
        output is coordinates and coordinates visibility.

        Args:
            feats (Tuple[Tensor]): Multi scale feature maps.

        Returns:
            Tuple[Tensor]: output coordinates and coordinates visibility.
        """
        x_pose = self.pose_head.forward(feats)
        x_vis = self.vis_forward(feats)

        return x_pose, x_vis

    def integrate(self, batch_vis: Tensor,
                  pose_preds: Union[Tuple, Predictions]) -> InstanceList:
        """Add keypoints visibility prediction to pose prediction.

        Overwrite the original keypoint_scores.
        """
        if isinstance(pose_preds, tuple):
            pose_pred_instances, pose_pred_fields = pose_preds
        else:
            pose_pred_instances = pose_preds
            pose_pred_fields = None

        batch_vis_np = to_numpy(batch_vis, unzip=True)

        assert len(pose_pred_instances) == len(batch_vis_np)
        for index, _ in enumerate(pose_pred_instances):
            pose_pred_instances[index].keypoints_visible = batch_vis_np[index]

        return pose_pred_instances, pose_pred_fields

    def predict(self,
                feats: Tuple[Tensor],
                batch_data_samples: OptSampleList,
                test_cfg: ConfigType = {}) -> Predictions:
        """Predict results from features.

        Args:
            feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
                features (or multiple multi-stage features in TTA)
            batch_data_samples (List[:obj:`PoseDataSample`]): The batch
                data samples
            test_cfg (dict): The runtime config for testing process. Defaults
                to {}

        Returns:
            Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
            posehead's ``test_cfg['output_heatmap']==True``, return both
            pose and heatmap prediction; otherwise only return the pose
            prediction.

            The pose prediction is a list of ``InstanceData``, each contains
            the following fields:

                - keypoints (np.ndarray): predicted keypoint coordinates in
                    shape (num_instances, K, D) where K is the keypoint number
                    and D is the keypoint dimension
                - keypoint_scores (np.ndarray): predicted keypoint scores in
                    shape (num_instances, K)
                - keypoint_visibility (np.ndarray): predicted keypoints
                    visibility in shape (num_instances, K)

            The heatmap prediction is a list of ``PixelData``, each contains
            the following fields:

                - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
        """
        if test_cfg.get('flip_test', False):
            # TTA: flip test -> feats = [orig, flipped]
            assert isinstance(feats, list) and len(feats) == 2
            flip_indices = batch_data_samples[0].metainfo['flip_indices']
            _feats, _feats_flip = feats

            _batch_vis = self.vis_forward(_feats)
            _batch_vis_flip = flip_visibility(
                self.vis_forward(_feats_flip), flip_indices=flip_indices)
            batch_vis = (_batch_vis + _batch_vis_flip) * 0.5
        else:
            batch_vis = self.vis_forward(feats)  # (B, K, D)

        batch_vis.unsqueeze_(dim=1)  # (B, N, K, D)

        if not self.use_sigmoid:
            batch_vis = torch.sigmoid(batch_vis)

        batch_pose = self.pose_head.predict(feats, batch_data_samples,
                                            test_cfg)

        return self.integrate(batch_vis, batch_pose)

    @torch.no_grad()
    def vis_accuracy(self, vis_pred_outputs, vis_labels, vis_weights=None):
        """Calculate visibility prediction accuracy."""
        if not self.use_sigmoid:
            vis_pred_outputs = torch.sigmoid(vis_pred_outputs)
        threshold = 0.5
        predictions = (vis_pred_outputs >= threshold).float()
        correct = (predictions == vis_labels).float()
        if vis_weights is not None:
            accuracy = (correct * vis_weights).sum(dim=1) / (
                vis_weights.sum(dim=1) + 1e-6)
        else:
            accuracy = correct.mean(dim=1)
        return accuracy.mean()

    def loss(self,
             feats: Tuple[Tensor],
             batch_data_samples: OptSampleList,
             train_cfg: OptConfigType = {}) -> dict:
        """Calculate losses from a batch of inputs and data samples.

        Args:
            feats (Tuple[Tensor]): The multi-stage features
            batch_data_samples (List[:obj:`PoseDataSample`]): The batch
                data samples
            train_cfg (dict): The runtime config for training process.
                Defaults to {}

        Returns:
            dict: A dictionary of losses.
        """
        vis_pred_outputs = self.vis_forward(feats)
        vis_labels = []
        vis_weights = [] if self.loss_module.use_target_weight else None
        for d in batch_data_samples:
            vis_label = d.gt_instance_labels.keypoint_weights.float()
            vis_labels.append(vis_label)
            if vis_weights is not None:
                vis_weights.append(
                    getattr(d.gt_instance_labels, 'keypoints_visible_weights',
                            vis_label.new_ones(vis_label.shape)))
        vis_labels = torch.cat(vis_labels)
        vis_weights = torch.cat(vis_weights) if vis_weights else None

        # calculate vis losses
        losses = dict()
        loss_vis = self.loss_module(vis_pred_outputs, vis_labels, vis_weights)

        losses.update(loss_vis=loss_vis)

        # calculate vis accuracy
        acc_vis = self.vis_accuracy(vis_pred_outputs, vis_labels, vis_weights)
        losses.update(acc_vis=acc_vis)

        # calculate keypoints losses
        loss_kpt = self.pose_head.loss(feats, batch_data_samples)
        losses.update(loss_kpt)

        return losses

    @property
    def default_init_cfg(self):
        init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)]
        return init_cfg