Miroslav Purkrabek
add code
a249588
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Dict
import numpy as np
from mmcv.transforms import BaseTransform
from mmpose.registry import TRANSFORMS
from mmpose.structures.keypoint import flip_keypoints_custom_center
@TRANSFORMS.register_module()
class RandomFlipAroundRoot(BaseTransform):
"""Data augmentation with random horizontal joint flip around a root joint.
Args:
keypoints_flip_cfg (dict): Configurations of the
``flip_keypoints_custom_center`` function for ``keypoints``. Please
refer to the docstring of the ``flip_keypoints_custom_center``
function for more details.
target_flip_cfg (dict): Configurations of the
``flip_keypoints_custom_center`` function for ``lifting_target``.
Please refer to the docstring of the
``flip_keypoints_custom_center`` function for more details.
flip_prob (float): Probability of flip. Default: 0.5.
flip_camera (bool): Whether to flip horizontal distortion coefficients.
Default: ``False``.
flip_label (bool): Whether to flip labels instead of data.
Default: ``False``.
Required keys:
- keypoints or keypoint_labels
- lifting_target or lifting_target_label
- keypoints_visible or keypoint_labels_visible (optional)
- lifting_target_visible (optional)
- flip_indices (optional)
Modified keys:
- keypoints or keypoint_labels (optional)
- keypoints_visible or keypoint_labels_visible (optional)
- lifting_target or lifting_target_label (optional)
- lifting_target_visible (optional)
- camera_param (optional)
"""
def __init__(self,
keypoints_flip_cfg: dict,
target_flip_cfg: dict,
flip_prob: float = 0.5,
flip_camera: bool = False,
flip_label: bool = False):
self.keypoints_flip_cfg = keypoints_flip_cfg
self.target_flip_cfg = target_flip_cfg
self.flip_prob = flip_prob
self.flip_camera = flip_camera
self.flip_label = flip_label
def transform(self, results: Dict) -> dict:
"""The transform function of :class:`RandomFlipAroundRoot`.
See ``transform()`` method of :class:`BaseTransform` for details.
Args:
results (dict): The result dict
Returns:
dict: The result dict.
"""
if np.random.rand() <= self.flip_prob:
if self.flip_label:
assert 'keypoint_labels' in results
assert 'lifting_target_label' in results
keypoints_key = 'keypoint_labels'
keypoints_visible_key = 'keypoint_labels_visible'
target_key = 'lifting_target_label'
else:
assert 'keypoints' in results
assert 'lifting_target' in results
keypoints_key = 'keypoints'
keypoints_visible_key = 'keypoints_visible'
target_key = 'lifting_target'
keypoints = results[keypoints_key]
if keypoints_visible_key in results:
keypoints_visible = results[keypoints_visible_key]
else:
keypoints_visible = np.ones(
keypoints.shape[:-1], dtype=np.float32)
lifting_target = results[target_key]
if 'lifting_target_visible' in results:
lifting_target_visible = results['lifting_target_visible']
else:
lifting_target_visible = np.ones(
lifting_target.shape[:-1], dtype=np.float32)
if 'flip_indices' not in results:
flip_indices = list(range(self.num_keypoints))
else:
flip_indices = results['flip_indices']
# flip joint coordinates
_camera_param = deepcopy(results['camera_param'])
keypoints, keypoints_visible = flip_keypoints_custom_center(
keypoints,
keypoints_visible,
flip_indices,
center_mode=self.keypoints_flip_cfg.get(
'center_mode', 'static'),
center_x=self.keypoints_flip_cfg.get('center_x', 0.5),
center_index=self.keypoints_flip_cfg.get('center_index', 0))
lifting_target, lifting_target_visible = flip_keypoints_custom_center( # noqa
lifting_target,
lifting_target_visible,
flip_indices,
center_mode=self.target_flip_cfg.get('center_mode', 'static'),
center_x=self.target_flip_cfg.get('center_x', 0.5),
center_index=self.target_flip_cfg.get('center_index', 0))
results[keypoints_key] = keypoints
results[keypoints_visible_key] = keypoints_visible
results[target_key] = lifting_target
results['lifting_target_visible'] = lifting_target_visible
# flip horizontal distortion coefficients
if self.flip_camera:
assert 'camera_param' in results, \
'Camera parameters are missing.'
assert 'c' in _camera_param
_camera_param['c'][0] *= -1
if 'p' in _camera_param:
_camera_param['p'][0] *= -1
results['camera_param'].update(_camera_param)
return results