File size: 13,871 Bytes
a249588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
# Copyright (c) OpenMMLab. All rights reserved.
from itertools import zip_longest
from typing import Tuple, Union

import torch
from torch import Tensor

from mmpose.models.utils import check_and_update_config
from mmpose.models.utils.tta import flip_coordinates
from mmpose.registry import MODELS
from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType,
                                 Optional, OptMultiConfig, OptSampleList,
                                 PixelDataList, SampleList)
from .base import BasePoseEstimator


@MODELS.register_module()
class PoseLifter(BasePoseEstimator):
    """Base class for pose lifter.

    Args:
        backbone (dict): The backbone config
        neck (dict, optional): The neck config. Defaults to ``None``
        head (dict, optional): The head config. Defaults to ``None``
        traj_backbone (dict, optional): The backbone config for trajectory
            model. Defaults to ``None``
        traj_neck (dict, optional): The neck config for trajectory model.
            Defaults to ``None``
        traj_head (dict, optional): The head config for trajectory model.
            Defaults to ``None``
        semi_loss (dict, optional): The semi-supervised loss config.
            Defaults to ``None``
        train_cfg (dict, optional): The runtime config for training process.
            Defaults to ``None``
        test_cfg (dict, optional): The runtime config for testing process.
            Defaults to ``None``
        data_preprocessor (dict, optional): The data preprocessing config to
            build the instance of :class:`BaseDataPreprocessor`. Defaults to
            ``None``
        init_cfg (dict, optional): The config to control the initialization.
            Defaults to ``None``
        metainfo (dict): Meta information for dataset, such as keypoints
            definition and properties. If set, the metainfo of the input data
            batch will be overridden. For more details, please refer to
            https://mmpose.readthedocs.io/en/latest/user_guides/
            prepare_datasets.html#create-a-custom-dataset-info-
            config-file-for-the-dataset. Defaults to ``None``
    """

    def __init__(self,
                 backbone: ConfigType,
                 neck: OptConfigType = None,
                 head: OptConfigType = None,
                 traj_backbone: OptConfigType = None,
                 traj_neck: OptConfigType = None,
                 traj_head: OptConfigType = None,
                 semi_loss: OptConfigType = None,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptMultiConfig = None,
                 metainfo: Optional[dict] = None):
        super().__init__(
            backbone=backbone,
            neck=neck,
            head=head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            data_preprocessor=data_preprocessor,
            init_cfg=init_cfg,
            metainfo=metainfo)

        # trajectory model
        self.share_backbone = False
        if traj_head is not None:
            if traj_backbone is not None:
                self.traj_backbone = MODELS.build(traj_backbone)
            else:
                self.share_backbone = True

            # the PR #2108 and #2126 modified the interface of neck and head.
            # The following function automatically detects outdated
            # configurations and updates them accordingly, while also providing
            # clear and concise information on the changes made.
            traj_neck, traj_head = check_and_update_config(
                traj_neck, traj_head)

            if traj_neck is not None:
                self.traj_neck = MODELS.build(traj_neck)

            self.traj_head = MODELS.build(traj_head)

        # semi-supervised loss
        self.semi_supervised = semi_loss is not None
        if self.semi_supervised:
            assert any([head, traj_head])
            self.semi_loss = MODELS.build(semi_loss)

    @property
    def with_traj_backbone(self):
        """bool: Whether the pose lifter has trajectory backbone."""
        return hasattr(self, 'traj_backbone') and \
            self.traj_backbone is not None

    @property
    def with_traj_neck(self):
        """bool: Whether the pose lifter has trajectory neck."""
        return hasattr(self, 'traj_neck') and self.traj_neck is not None

    @property
    def with_traj(self):
        """bool: Whether the pose lifter has trajectory head."""
        return hasattr(self, 'traj_head')

    @property
    def causal(self):
        """bool: Whether the pose lifter is causal."""
        if hasattr(self.backbone, 'causal'):
            return self.backbone.causal
        else:
            raise AttributeError('A PoseLifter\'s backbone should have '
                                 'the bool attribute "causal" to indicate if'
                                 'it performs causal inference.')

    def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]:
        """Extract features.

        Args:
            inputs (Tensor): Image tensor with shape (N, K, C, T).

        Returns:
            tuple[Tensor]: Multi-level features that may have various
            resolutions.
        """
        # supervised learning
        # pose model
        feats = self.backbone(inputs)
        if self.with_neck:
            feats = self.neck(feats)

        # trajectory model
        if self.with_traj:
            if self.share_backbone:
                traj_x = feats
            else:
                traj_x = self.traj_backbone(inputs)

            if self.with_traj_neck:
                traj_x = self.traj_neck(traj_x)
            return feats, traj_x
        else:
            return feats

    def _forward(self,
                 inputs: Tensor,
                 data_samples: OptSampleList = None
                 ) -> Union[Tensor, Tuple[Tensor]]:
        """Network forward process. Usually includes backbone, neck and head
        forward without any post-processing.

        Args:
            inputs (Tensor): Inputs with shape (N, K, C, T).

        Returns:
            Union[Tensor | Tuple[Tensor]]: forward output of the network.
        """
        feats = self.extract_feat(inputs)

        if self.with_traj:
            # forward with trajectory model
            x, traj_x = feats
            if self.with_head:
                x = self.head.forward(x)

            traj_x = self.traj_head.forward(traj_x)
            return x, traj_x
        else:
            # forward without trajectory model
            x = feats
            if self.with_head:
                x = self.head.forward(x)
            return x

    def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
        """Calculate losses from a batch of inputs and data samples.

        Args:
            inputs (Tensor): Inputs with shape (N, K, C, T).
            data_samples (List[:obj:`PoseDataSample`]): The batch
                data samples.

        Returns:
            dict: A dictionary of losses.
        """
        feats = self.extract_feat(inputs)

        losses = {}

        if self.with_traj:
            x, traj_x = feats
            # loss of trajectory model
            losses.update(
                self.traj_head.loss(
                    traj_x, data_samples, train_cfg=self.train_cfg))
        else:
            x = feats

        if self.with_head:
            # loss of pose model
            losses.update(
                self.head.loss(x, data_samples, train_cfg=self.train_cfg))

        # TODO: support semi-supervised learning
        if self.semi_supervised:
            losses.update(semi_loss=self.semi_loss(inputs, data_samples))

        return losses

    def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.

        Note:
            - batch_size: B
            - num_input_keypoints: K
            - input_keypoint_dim: C
            - input_sequence_len: T

        Args:
            inputs (Tensor): Inputs with shape like (B, K, C, T).
            data_samples (List[:obj:`PoseDataSample`]): The batch
                data samples

        Returns:
            list[:obj:`PoseDataSample`]: The pose estimation results of the
            input images. The return value is `PoseDataSample` instances with
            ``pred_instances`` and ``pred_fields``(optional) field , and
            ``pred_instances`` usually contains the following keys:

                - keypoints (Tensor): predicted keypoint coordinates in shape
                    (num_instances, K, D) where K is the keypoint number and D
                    is the keypoint dimension
                - keypoint_scores (Tensor): predicted keypoint scores in shape
                    (num_instances, K)
        """
        assert self.with_head, (
            'The model must have head to perform prediction.')

        if self.test_cfg.get('flip_test', False):
            flip_indices = data_samples[0].metainfo['flip_indices']
            _feats = self.extract_feat(inputs)
            _feats_flip = self.extract_feat(
                torch.stack([
                    flip_coordinates(
                        _input,
                        flip_indices=flip_indices,
                        shift_coords=self.test_cfg.get('shift_coords', True),
                        input_size=(1, 1)) for _input in inputs
                ],
                            dim=0))

            feats = [_feats, _feats_flip]
        else:
            feats = self.extract_feat(inputs)

        pose_preds, batch_pred_instances, batch_pred_fields = None, None, None
        traj_preds, batch_traj_instances, batch_traj_fields = None, None, None
        if self.with_traj:
            x, traj_x = feats
            traj_preds = self.traj_head.predict(
                traj_x, data_samples, test_cfg=self.test_cfg)
        else:
            x = feats

        if self.with_head:
            pose_preds = self.head.predict(
                x, data_samples, test_cfg=self.test_cfg)

        if isinstance(pose_preds, tuple):
            batch_pred_instances, batch_pred_fields = pose_preds
        else:
            batch_pred_instances = pose_preds

        if isinstance(traj_preds, tuple):
            batch_traj_instances, batch_traj_fields = traj_preds
        else:
            batch_traj_instances = traj_preds

        results = self.add_pred_to_datasample(batch_pred_instances,
                                              batch_pred_fields,
                                              batch_traj_instances,
                                              batch_traj_fields, data_samples)

        return results

    def add_pred_to_datasample(
        self,
        batch_pred_instances: InstanceList,
        batch_pred_fields: Optional[PixelDataList],
        batch_traj_instances: InstanceList,
        batch_traj_fields: Optional[PixelDataList],
        batch_data_samples: SampleList,
    ) -> SampleList:
        """Add predictions into data samples.

        Args:
            batch_pred_instances (List[InstanceData]): The predicted instances
                of the input data batch
            batch_pred_fields (List[PixelData], optional): The predicted
                fields (e.g. heatmaps) of the input batch
            batch_traj_instances (List[InstanceData]): The predicted instances
                of the input data batch
            batch_traj_fields (List[PixelData], optional): The predicted
                fields (e.g. heatmaps) of the input batch
            batch_data_samples (List[PoseDataSample]): The input data batch

        Returns:
            List[PoseDataSample]: A list of data samples where the predictions
            are stored in the ``pred_instances`` field of each data sample.
        """
        assert len(batch_pred_instances) == len(batch_data_samples)
        if batch_pred_fields is None:
            batch_pred_fields, batch_traj_fields = [], []
        if batch_traj_instances is None:
            batch_traj_instances = []
        output_keypoint_indices = self.test_cfg.get('output_keypoint_indices',
                                                    None)

        for (pred_instances, pred_fields, traj_instances, traj_fields,
             data_sample) in zip_longest(batch_pred_instances,
                                         batch_pred_fields,
                                         batch_traj_instances,
                                         batch_traj_fields,
                                         batch_data_samples):

            if output_keypoint_indices is not None:
                # select output keypoints with given indices
                num_keypoints = pred_instances.keypoints.shape[1]
                for key, value in pred_instances.all_items():
                    if key.startswith('keypoint'):
                        pred_instances.set_field(
                            value[:, output_keypoint_indices], key)

            data_sample.pred_instances = pred_instances

            if pred_fields is not None:
                if output_keypoint_indices is not None:
                    # select output heatmap channels with keypoint indices
                    # when the number of heatmap channel matches num_keypoints
                    for key, value in pred_fields.all_items():
                        if value.shape[0] != num_keypoints:
                            continue
                        pred_fields.set_field(value[output_keypoint_indices],
                                              key)
                data_sample.pred_fields = pred_fields

        return batch_data_samples