Miroslav Purkrabek
add code
a249588
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple, Union
import numpy as np
from mmcv.transforms import BaseTransform
from mmpose.registry import TRANSFORMS
@TRANSFORMS.register_module()
class KeypointConverter(BaseTransform):
"""Change the order of keypoints according to the given mapping.
Required Keys:
- keypoints
- keypoints_visible
Modified Keys:
- keypoints
- keypoints_visible
Args:
num_keypoints (int): The number of keypoints in target dataset.
mapping (list): A list containing mapping indexes. Each element has
format (source_index, target_index)
Example:
>>> import numpy as np
>>> # case 1: 1-to-1 mapping
>>> # (0, 0) means target[0] = source[0]
>>> self = KeypointConverter(
>>> num_keypoints=3,
>>> mapping=[
>>> (0, 0), (1, 1), (2, 2), (3, 3)
>>> ])
>>> results = dict(
>>> keypoints=np.arange(34).reshape(2, 3, 2),
>>> keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2)
>>> results = self(results)
>>> assert np.equal(results['keypoints'],
>>> np.arange(34).reshape(2, 3, 2)).all()
>>> assert np.equal(results['keypoints_visible'],
>>> np.arange(34).reshape(2, 3, 2) % 2).all()
>>>
>>> # case 2: 2-to-1 mapping
>>> # ((1, 2), 0) means target[0] = (source[1] + source[2]) / 2
>>> self = KeypointConverter(
>>> num_keypoints=3,
>>> mapping=[
>>> ((1, 2), 0), (1, 1), (2, 2)
>>> ])
>>> results = dict(
>>> keypoints=np.arange(34).reshape(2, 3, 2),
>>> keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2)
>>> results = self(results)
"""
def __init__(self, num_keypoints: int,
mapping: Union[List[Tuple[int, int]], List[Tuple[Tuple,
int]]]):
self.num_keypoints = num_keypoints
self.mapping = mapping
if len(mapping):
source_index, target_index = zip(*mapping)
else:
source_index, target_index = [], []
src1, src2 = [], []
interpolation = False
for x in source_index:
if isinstance(x, (list, tuple)):
assert len(x) == 2, 'source_index should be a list/tuple of ' \
'length 2'
src1.append(x[0])
src2.append(x[1])
interpolation = True
else:
src1.append(x)
src2.append(x)
# When paired source_indexes are input,
# keep a self.source_index2 for interpolation
if interpolation:
self.source_index2 = src2
self.source_index = src1
self.target_index = list(target_index)
self.interpolation = interpolation
def transform(self, results: dict) -> dict:
"""Transforms the keypoint results to match the target keypoints."""
num_instances = results['keypoints'].shape[0]
if 'keypoints_visible' not in results:
results['keypoints_visible'] = np.ones(
(num_instances, results['keypoints'].shape[1]))
if len(results['keypoints_visible'].shape) > 2:
results['keypoints_visible'] = results['keypoints_visible'][:, :,
0]
# Initialize output arrays
keypoints = np.zeros((num_instances, self.num_keypoints, 3))
keypoints_visible = np.zeros((num_instances, self.num_keypoints))
key = 'keypoints_3d' if 'keypoints_3d' in results else 'keypoints'
c = results[key].shape[-1]
flip_indices = results.get('flip_indices', None)
# Create a mask to weight visibility loss
keypoints_visible_weights = keypoints_visible.copy()
keypoints_visible_weights[:, self.target_index] = 1.0
# Interpolate keypoints if pairs of source indexes provided
if self.interpolation:
keypoints[:, self.target_index, :c] = 0.5 * (
results[key][:, self.source_index] +
results[key][:, self.source_index2])
keypoints_visible[:, self.target_index] = results[
'keypoints_visible'][:, self.source_index] * results[
'keypoints_visible'][:, self.source_index2]
# Flip keypoints if flip_indices provided
if flip_indices is not None:
for i, (x1, x2) in enumerate(
zip(self.source_index, self.source_index2)):
idx = flip_indices[x1] if x1 == x2 else i
flip_indices[i] = idx if idx < self.num_keypoints else i
flip_indices = flip_indices[:len(self.source_index)]
# Otherwise just copy from the source index
else:
keypoints[:,
self.target_index, :c] = results[key][:,
self.source_index]
keypoints_visible[:, self.target_index] = results[
'keypoints_visible'][:, self.source_index]
# Update the results dict
results['keypoints'] = keypoints[..., :2]
results['keypoints_visible'] = np.stack(
[keypoints_visible, keypoints_visible_weights], axis=2)
if 'keypoints_3d' in results:
results['keypoints_3d'] = keypoints
results['lifting_target'] = keypoints[results['target_idx']]
results['lifting_target_visible'] = keypoints_visible[
results['target_idx']]
results['flip_indices'] = flip_indices
return results
def __repr__(self) -> str:
"""print the basic information of the transform.
Returns:
str: Formatted string.
"""
repr_str = self.__class__.__name__
repr_str += f'(num_keypoints={self.num_keypoints}, '\
f'mapping={self.mapping})'
return repr_str
@TRANSFORMS.register_module()
class SingleHandConverter(BaseTransform):
"""Mapping a single hand keypoints into double hands according to the given
mapping and hand type.
Required Keys:
- keypoints
- keypoints_visible
- hand_type
Modified Keys:
- keypoints
- keypoints_visible
Args:
num_keypoints (int): The number of keypoints in target dataset.
left_hand_mapping (list): A list containing mapping indexes. Each
element has format (source_index, target_index)
right_hand_mapping (list): A list containing mapping indexes. Each
element has format (source_index, target_index)
Example:
>>> import numpy as np
>>> self = SingleHandConverter(
>>> num_keypoints=42,
>>> left_hand_mapping=[
>>> (0, 0), (1, 1), (2, 2), (3, 3)
>>> ],
>>> right_hand_mapping=[
>>> (0, 21), (1, 22), (2, 23), (3, 24)
>>> ])
>>> results = dict(
>>> keypoints=np.arange(84).reshape(2, 21, 2),
>>> keypoints_visible=np.arange(84).reshape(2, 21, 2) % 2,
>>> hand_type=np.array([[0, 1], [1, 0]]))
>>> results = self(results)
"""
def __init__(self, num_keypoints: int,
left_hand_mapping: Union[List[Tuple[int, int]],
List[Tuple[Tuple, int]]],
right_hand_mapping: Union[List[Tuple[int, int]],
List[Tuple[Tuple, int]]]):
self.num_keypoints = num_keypoints
self.left_hand_converter = KeypointConverter(num_keypoints,
left_hand_mapping)
self.right_hand_converter = KeypointConverter(num_keypoints,
right_hand_mapping)
def transform(self, results: dict) -> dict:
"""Transforms the keypoint results to match the target keypoints."""
assert 'hand_type' in results, (
'hand_type should be provided in results')
hand_type = results['hand_type']
if np.sum(hand_type - [[0, 1]]) <= 1e-6:
# left hand
results = self.left_hand_converter(results)
elif np.sum(hand_type - [[1, 0]]) <= 1e-6:
results = self.right_hand_converter(results)
else:
raise ValueError('hand_type should be left or right')
return results
def __repr__(self) -> str:
"""print the basic information of the transform.
Returns:
str: Formatted string.
"""
repr_str = self.__class__.__name__
repr_str += f'(num_keypoints={self.num_keypoints}, '\
f'left_hand_converter={self.left_hand_converter}, '\
f'right_hand_converter={self.right_hand_converter})'
return repr_str