File size: 9,172 Bytes
a249588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple, Union

import numpy as np
from mmcv.transforms import BaseTransform

from mmpose.registry import TRANSFORMS


@TRANSFORMS.register_module()
class KeypointConverter(BaseTransform):
    """Change the order of keypoints according to the given mapping.

    Required Keys:

        - keypoints
        - keypoints_visible

    Modified Keys:

        - keypoints
        - keypoints_visible

    Args:
        num_keypoints (int): The number of keypoints in target dataset.
        mapping (list): A list containing mapping indexes. Each element has
            format (source_index, target_index)

    Example:
        >>> import numpy as np
        >>> # case 1: 1-to-1 mapping
        >>> # (0, 0) means target[0] = source[0]
        >>> self = KeypointConverter(
        >>>     num_keypoints=3,
        >>>     mapping=[
        >>>         (0, 0), (1, 1), (2, 2), (3, 3)
        >>>     ])
        >>> results = dict(
        >>>     keypoints=np.arange(34).reshape(2, 3, 2),
        >>>     keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2)
        >>> results = self(results)
        >>> assert np.equal(results['keypoints'],
        >>>                 np.arange(34).reshape(2, 3, 2)).all()
        >>> assert np.equal(results['keypoints_visible'],
        >>>                 np.arange(34).reshape(2, 3, 2) % 2).all()
        >>>
        >>> # case 2: 2-to-1 mapping
        >>> # ((1, 2), 0) means target[0] = (source[1] + source[2]) / 2
        >>> self = KeypointConverter(
        >>>     num_keypoints=3,
        >>>     mapping=[
        >>>         ((1, 2), 0), (1, 1), (2, 2)
        >>>     ])
        >>> results = dict(
        >>>     keypoints=np.arange(34).reshape(2, 3, 2),
        >>>     keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2)
        >>> results = self(results)
    """

    def __init__(self, num_keypoints: int,
                 mapping: Union[List[Tuple[int, int]], List[Tuple[Tuple,
                                                                  int]]]):
        self.num_keypoints = num_keypoints
        self.mapping = mapping
        if len(mapping):
            source_index, target_index = zip(*mapping)
        else:
            source_index, target_index = [], []

        src1, src2 = [], []
        interpolation = False
        for x in source_index:
            if isinstance(x, (list, tuple)):
                assert len(x) == 2, 'source_index should be a list/tuple of ' \
                                    'length 2'
                src1.append(x[0])
                src2.append(x[1])
                interpolation = True
            else:
                src1.append(x)
                src2.append(x)

        # When paired source_indexes are input,
        # keep a self.source_index2 for interpolation
        if interpolation:
            self.source_index2 = src2

        self.source_index = src1
        self.target_index = list(target_index)
        self.interpolation = interpolation

    def transform(self, results: dict) -> dict:
        """Transforms the keypoint results to match the target keypoints."""
        num_instances = results['keypoints'].shape[0]

        if 'keypoints_visible' not in results:
            results['keypoints_visible'] = np.ones(
                (num_instances, results['keypoints'].shape[1]))

        if len(results['keypoints_visible'].shape) > 2:
            results['keypoints_visible'] = results['keypoints_visible'][:, :,
                                                                        0]

        # Initialize output arrays
        keypoints = np.zeros((num_instances, self.num_keypoints, 3))
        keypoints_visible = np.zeros((num_instances, self.num_keypoints))
        key = 'keypoints_3d' if 'keypoints_3d' in results else 'keypoints'
        c = results[key].shape[-1]

        flip_indices = results.get('flip_indices', None)

        # Create a mask to weight visibility loss
        keypoints_visible_weights = keypoints_visible.copy()
        keypoints_visible_weights[:, self.target_index] = 1.0

        # Interpolate keypoints if pairs of source indexes provided
        if self.interpolation:
            keypoints[:, self.target_index, :c] = 0.5 * (
                results[key][:, self.source_index] +
                results[key][:, self.source_index2])
            keypoints_visible[:, self.target_index] = results[
                'keypoints_visible'][:, self.source_index] * results[
                    'keypoints_visible'][:, self.source_index2]
            # Flip keypoints if flip_indices provided
            if flip_indices is not None:
                for i, (x1, x2) in enumerate(
                        zip(self.source_index, self.source_index2)):
                    idx = flip_indices[x1] if x1 == x2 else i
                    flip_indices[i] = idx if idx < self.num_keypoints else i
                flip_indices = flip_indices[:len(self.source_index)]
        # Otherwise just copy from the source index
        else:
            keypoints[:,
                      self.target_index, :c] = results[key][:,
                                                            self.source_index]
            keypoints_visible[:, self.target_index] = results[
                'keypoints_visible'][:, self.source_index]

        # Update the results dict
        results['keypoints'] = keypoints[..., :2]
        results['keypoints_visible'] = np.stack(
            [keypoints_visible, keypoints_visible_weights], axis=2)
        if 'keypoints_3d' in results:
            results['keypoints_3d'] = keypoints
            results['lifting_target'] = keypoints[results['target_idx']]
            results['lifting_target_visible'] = keypoints_visible[
                results['target_idx']]
        results['flip_indices'] = flip_indices

        return results

    def __repr__(self) -> str:
        """print the basic information of the transform.

        Returns:
            str: Formatted string.
        """
        repr_str = self.__class__.__name__
        repr_str += f'(num_keypoints={self.num_keypoints}, '\
                    f'mapping={self.mapping})'
        return repr_str


@TRANSFORMS.register_module()
class SingleHandConverter(BaseTransform):
    """Mapping a single hand keypoints into double hands according to the given
    mapping and hand type.

    Required Keys:

        - keypoints
        - keypoints_visible
        - hand_type

    Modified Keys:

        - keypoints
        - keypoints_visible

    Args:
        num_keypoints (int): The number of keypoints in target dataset.
        left_hand_mapping (list): A list containing mapping indexes. Each
            element has format (source_index, target_index)
        right_hand_mapping (list): A list containing mapping indexes. Each
            element has format (source_index, target_index)

    Example:
        >>> import numpy as np
        >>> self = SingleHandConverter(
        >>>     num_keypoints=42,
        >>>     left_hand_mapping=[
        >>>         (0, 0), (1, 1), (2, 2), (3, 3)
        >>>     ],
        >>>     right_hand_mapping=[
        >>>         (0, 21), (1, 22), (2, 23), (3, 24)
        >>>     ])
        >>> results = dict(
        >>>     keypoints=np.arange(84).reshape(2, 21, 2),
        >>>     keypoints_visible=np.arange(84).reshape(2, 21, 2) % 2,
        >>>     hand_type=np.array([[0, 1], [1, 0]]))
        >>> results = self(results)
    """

    def __init__(self, num_keypoints: int,
                 left_hand_mapping: Union[List[Tuple[int, int]],
                                          List[Tuple[Tuple, int]]],
                 right_hand_mapping: Union[List[Tuple[int, int]],
                                           List[Tuple[Tuple, int]]]):
        self.num_keypoints = num_keypoints
        self.left_hand_converter = KeypointConverter(num_keypoints,
                                                     left_hand_mapping)
        self.right_hand_converter = KeypointConverter(num_keypoints,
                                                      right_hand_mapping)

    def transform(self, results: dict) -> dict:
        """Transforms the keypoint results to match the target keypoints."""
        assert 'hand_type' in results, (
            'hand_type should be provided in results')
        hand_type = results['hand_type']

        if np.sum(hand_type - [[0, 1]]) <= 1e-6:
            # left hand
            results = self.left_hand_converter(results)
        elif np.sum(hand_type - [[1, 0]]) <= 1e-6:
            results = self.right_hand_converter(results)
        else:
            raise ValueError('hand_type should be left or right')

        return results

    def __repr__(self) -> str:
        """print the basic information of the transform.

        Returns:
            str: Formatted string.
        """
        repr_str = self.__class__.__name__
        repr_str += f'(num_keypoints={self.num_keypoints}, '\
                    f'left_hand_converter={self.left_hand_converter}, '\
                    f'right_hand_converter={self.right_hand_converter})'
        return repr_str