File size: 5,531 Bytes
a249588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Dict

import numpy as np
from mmcv.transforms import BaseTransform

from mmpose.registry import TRANSFORMS
from mmpose.structures.keypoint import flip_keypoints_custom_center


@TRANSFORMS.register_module()
class RandomFlipAroundRoot(BaseTransform):
    """Data augmentation with random horizontal joint flip around a root joint.

    Args:
        keypoints_flip_cfg (dict): Configurations of the
            ``flip_keypoints_custom_center`` function for ``keypoints``. Please
            refer to the docstring of the ``flip_keypoints_custom_center``
            function for more details.
        target_flip_cfg (dict): Configurations of the
            ``flip_keypoints_custom_center`` function for ``lifting_target``.
            Please refer to the docstring of the
            ``flip_keypoints_custom_center`` function for more details.
        flip_prob (float): Probability of flip. Default: 0.5.
        flip_camera (bool): Whether to flip horizontal distortion coefficients.
            Default: ``False``.
        flip_label (bool): Whether to flip labels instead of data.
            Default: ``False``.

    Required keys:
        - keypoints or keypoint_labels
        - lifting_target or lifting_target_label
        - keypoints_visible or keypoint_labels_visible (optional)
        - lifting_target_visible (optional)
        - flip_indices (optional)

    Modified keys:
        - keypoints or keypoint_labels (optional)
        - keypoints_visible or keypoint_labels_visible (optional)
        - lifting_target or lifting_target_label (optional)
        - lifting_target_visible (optional)
        - camera_param (optional)
    """

    def __init__(self,
                 keypoints_flip_cfg: dict,
                 target_flip_cfg: dict,
                 flip_prob: float = 0.5,
                 flip_camera: bool = False,
                 flip_label: bool = False):
        self.keypoints_flip_cfg = keypoints_flip_cfg
        self.target_flip_cfg = target_flip_cfg
        self.flip_prob = flip_prob
        self.flip_camera = flip_camera
        self.flip_label = flip_label

    def transform(self, results: Dict) -> dict:
        """The transform function of :class:`RandomFlipAroundRoot`.

        See ``transform()`` method of :class:`BaseTransform` for details.

        Args:
            results (dict): The result dict

        Returns:
            dict: The result dict.
        """

        if np.random.rand() <= self.flip_prob:
            if self.flip_label:
                assert 'keypoint_labels' in results
                assert 'lifting_target_label' in results
                keypoints_key = 'keypoint_labels'
                keypoints_visible_key = 'keypoint_labels_visible'
                target_key = 'lifting_target_label'
            else:
                assert 'keypoints' in results
                assert 'lifting_target' in results
                keypoints_key = 'keypoints'
                keypoints_visible_key = 'keypoints_visible'
                target_key = 'lifting_target'

            keypoints = results[keypoints_key]
            if keypoints_visible_key in results:
                keypoints_visible = results[keypoints_visible_key]
            else:
                keypoints_visible = np.ones(
                    keypoints.shape[:-1], dtype=np.float32)

            lifting_target = results[target_key]
            if 'lifting_target_visible' in results:
                lifting_target_visible = results['lifting_target_visible']
            else:
                lifting_target_visible = np.ones(
                    lifting_target.shape[:-1], dtype=np.float32)

            if 'flip_indices' not in results:
                flip_indices = list(range(self.num_keypoints))
            else:
                flip_indices = results['flip_indices']

            # flip joint coordinates
            _camera_param = deepcopy(results['camera_param'])

            keypoints, keypoints_visible = flip_keypoints_custom_center(
                keypoints,
                keypoints_visible,
                flip_indices,
                center_mode=self.keypoints_flip_cfg.get(
                    'center_mode', 'static'),
                center_x=self.keypoints_flip_cfg.get('center_x', 0.5),
                center_index=self.keypoints_flip_cfg.get('center_index', 0))
            lifting_target, lifting_target_visible = flip_keypoints_custom_center(  # noqa
                lifting_target,
                lifting_target_visible,
                flip_indices,
                center_mode=self.target_flip_cfg.get('center_mode', 'static'),
                center_x=self.target_flip_cfg.get('center_x', 0.5),
                center_index=self.target_flip_cfg.get('center_index', 0))

            results[keypoints_key] = keypoints
            results[keypoints_visible_key] = keypoints_visible
            results[target_key] = lifting_target
            results['lifting_target_visible'] = lifting_target_visible

            # flip horizontal distortion coefficients
            if self.flip_camera:
                assert 'camera_param' in results, \
                    'Camera parameters are missing.'

                assert 'c' in _camera_param
                _camera_param['c'][0] *= -1

                if 'p' in _camera_param:
                    _camera_param['p'][0] *= -1

                results['camera_param'].update(_camera_param)

        return results