Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Tuple, Union | |
import numpy as np | |
from mmcv.transforms import BaseTransform | |
from mmpose.registry import TRANSFORMS | |
class KeypointConverter(BaseTransform): | |
"""Change the order of keypoints according to the given mapping. | |
Required Keys: | |
- keypoints | |
- keypoints_visible | |
Modified Keys: | |
- keypoints | |
- keypoints_visible | |
Args: | |
num_keypoints (int): The number of keypoints in target dataset. | |
mapping (list): A list containing mapping indexes. Each element has | |
format (source_index, target_index) | |
Example: | |
>>> import numpy as np | |
>>> # case 1: 1-to-1 mapping | |
>>> # (0, 0) means target[0] = source[0] | |
>>> self = KeypointConverter( | |
>>> num_keypoints=3, | |
>>> mapping=[ | |
>>> (0, 0), (1, 1), (2, 2), (3, 3) | |
>>> ]) | |
>>> results = dict( | |
>>> keypoints=np.arange(34).reshape(2, 3, 2), | |
>>> keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2) | |
>>> results = self(results) | |
>>> assert np.equal(results['keypoints'], | |
>>> np.arange(34).reshape(2, 3, 2)).all() | |
>>> assert np.equal(results['keypoints_visible'], | |
>>> np.arange(34).reshape(2, 3, 2) % 2).all() | |
>>> | |
>>> # case 2: 2-to-1 mapping | |
>>> # ((1, 2), 0) means target[0] = (source[1] + source[2]) / 2 | |
>>> self = KeypointConverter( | |
>>> num_keypoints=3, | |
>>> mapping=[ | |
>>> ((1, 2), 0), (1, 1), (2, 2) | |
>>> ]) | |
>>> results = dict( | |
>>> keypoints=np.arange(34).reshape(2, 3, 2), | |
>>> keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2) | |
>>> results = self(results) | |
""" | |
def __init__(self, num_keypoints: int, | |
mapping: Union[List[Tuple[int, int]], List[Tuple[Tuple, | |
int]]]): | |
self.num_keypoints = num_keypoints | |
self.mapping = mapping | |
if len(mapping): | |
source_index, target_index = zip(*mapping) | |
else: | |
source_index, target_index = [], [] | |
src1, src2 = [], [] | |
interpolation = False | |
for x in source_index: | |
if isinstance(x, (list, tuple)): | |
assert len(x) == 2, 'source_index should be a list/tuple of ' \ | |
'length 2' | |
src1.append(x[0]) | |
src2.append(x[1]) | |
interpolation = True | |
else: | |
src1.append(x) | |
src2.append(x) | |
# When paired source_indexes are input, | |
# keep a self.source_index2 for interpolation | |
if interpolation: | |
self.source_index2 = src2 | |
self.source_index = src1 | |
self.target_index = list(target_index) | |
self.interpolation = interpolation | |
def transform(self, results: dict) -> dict: | |
"""Transforms the keypoint results to match the target keypoints.""" | |
num_instances = results['keypoints'].shape[0] | |
if 'keypoints_visible' not in results: | |
results['keypoints_visible'] = np.ones( | |
(num_instances, results['keypoints'].shape[1])) | |
if len(results['keypoints_visible'].shape) > 2: | |
results['keypoints_visible'] = results['keypoints_visible'][:, :, | |
0] | |
# Initialize output arrays | |
keypoints = np.zeros((num_instances, self.num_keypoints, 3)) | |
keypoints_visible = np.zeros((num_instances, self.num_keypoints)) | |
key = 'keypoints_3d' if 'keypoints_3d' in results else 'keypoints' | |
c = results[key].shape[-1] | |
flip_indices = results.get('flip_indices', None) | |
# Create a mask to weight visibility loss | |
keypoints_visible_weights = keypoints_visible.copy() | |
keypoints_visible_weights[:, self.target_index] = 1.0 | |
# Interpolate keypoints if pairs of source indexes provided | |
if self.interpolation: | |
keypoints[:, self.target_index, :c] = 0.5 * ( | |
results[key][:, self.source_index] + | |
results[key][:, self.source_index2]) | |
keypoints_visible[:, self.target_index] = results[ | |
'keypoints_visible'][:, self.source_index] * results[ | |
'keypoints_visible'][:, self.source_index2] | |
# Flip keypoints if flip_indices provided | |
if flip_indices is not None: | |
for i, (x1, x2) in enumerate( | |
zip(self.source_index, self.source_index2)): | |
idx = flip_indices[x1] if x1 == x2 else i | |
flip_indices[i] = idx if idx < self.num_keypoints else i | |
flip_indices = flip_indices[:len(self.source_index)] | |
# Otherwise just copy from the source index | |
else: | |
keypoints[:, | |
self.target_index, :c] = results[key][:, | |
self.source_index] | |
keypoints_visible[:, self.target_index] = results[ | |
'keypoints_visible'][:, self.source_index] | |
# Update the results dict | |
results['keypoints'] = keypoints[..., :2] | |
results['keypoints_visible'] = np.stack( | |
[keypoints_visible, keypoints_visible_weights], axis=2) | |
if 'keypoints_3d' in results: | |
results['keypoints_3d'] = keypoints | |
results['lifting_target'] = keypoints[results['target_idx']] | |
results['lifting_target_visible'] = keypoints_visible[ | |
results['target_idx']] | |
results['flip_indices'] = flip_indices | |
return results | |
def __repr__(self) -> str: | |
"""print the basic information of the transform. | |
Returns: | |
str: Formatted string. | |
""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(num_keypoints={self.num_keypoints}, '\ | |
f'mapping={self.mapping})' | |
return repr_str | |
class SingleHandConverter(BaseTransform): | |
"""Mapping a single hand keypoints into double hands according to the given | |
mapping and hand type. | |
Required Keys: | |
- keypoints | |
- keypoints_visible | |
- hand_type | |
Modified Keys: | |
- keypoints | |
- keypoints_visible | |
Args: | |
num_keypoints (int): The number of keypoints in target dataset. | |
left_hand_mapping (list): A list containing mapping indexes. Each | |
element has format (source_index, target_index) | |
right_hand_mapping (list): A list containing mapping indexes. Each | |
element has format (source_index, target_index) | |
Example: | |
>>> import numpy as np | |
>>> self = SingleHandConverter( | |
>>> num_keypoints=42, | |
>>> left_hand_mapping=[ | |
>>> (0, 0), (1, 1), (2, 2), (3, 3) | |
>>> ], | |
>>> right_hand_mapping=[ | |
>>> (0, 21), (1, 22), (2, 23), (3, 24) | |
>>> ]) | |
>>> results = dict( | |
>>> keypoints=np.arange(84).reshape(2, 21, 2), | |
>>> keypoints_visible=np.arange(84).reshape(2, 21, 2) % 2, | |
>>> hand_type=np.array([[0, 1], [1, 0]])) | |
>>> results = self(results) | |
""" | |
def __init__(self, num_keypoints: int, | |
left_hand_mapping: Union[List[Tuple[int, int]], | |
List[Tuple[Tuple, int]]], | |
right_hand_mapping: Union[List[Tuple[int, int]], | |
List[Tuple[Tuple, int]]]): | |
self.num_keypoints = num_keypoints | |
self.left_hand_converter = KeypointConverter(num_keypoints, | |
left_hand_mapping) | |
self.right_hand_converter = KeypointConverter(num_keypoints, | |
right_hand_mapping) | |
def transform(self, results: dict) -> dict: | |
"""Transforms the keypoint results to match the target keypoints.""" | |
assert 'hand_type' in results, ( | |
'hand_type should be provided in results') | |
hand_type = results['hand_type'] | |
if np.sum(hand_type - [[0, 1]]) <= 1e-6: | |
# left hand | |
results = self.left_hand_converter(results) | |
elif np.sum(hand_type - [[1, 0]]) <= 1e-6: | |
results = self.right_hand_converter(results) | |
else: | |
raise ValueError('hand_type should be left or right') | |
return results | |
def __repr__(self) -> str: | |
"""print the basic information of the transform. | |
Returns: | |
str: Formatted string. | |
""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(num_keypoints={self.num_keypoints}, '\ | |
f'left_hand_converter={self.left_hand_converter}, '\ | |
f'right_hand_converter={self.right_hand_converter})' | |
return repr_str | |