Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
from copy import deepcopy | |
from typing import Dict | |
import numpy as np | |
from mmcv.transforms import BaseTransform | |
from mmpose.registry import TRANSFORMS | |
from mmpose.structures.keypoint import flip_keypoints_custom_center | |
class RandomFlipAroundRoot(BaseTransform): | |
"""Data augmentation with random horizontal joint flip around a root joint. | |
Args: | |
keypoints_flip_cfg (dict): Configurations of the | |
``flip_keypoints_custom_center`` function for ``keypoints``. Please | |
refer to the docstring of the ``flip_keypoints_custom_center`` | |
function for more details. | |
target_flip_cfg (dict): Configurations of the | |
``flip_keypoints_custom_center`` function for ``lifting_target``. | |
Please refer to the docstring of the | |
``flip_keypoints_custom_center`` function for more details. | |
flip_prob (float): Probability of flip. Default: 0.5. | |
flip_camera (bool): Whether to flip horizontal distortion coefficients. | |
Default: ``False``. | |
flip_label (bool): Whether to flip labels instead of data. | |
Default: ``False``. | |
Required keys: | |
- keypoints or keypoint_labels | |
- lifting_target or lifting_target_label | |
- keypoints_visible or keypoint_labels_visible (optional) | |
- lifting_target_visible (optional) | |
- flip_indices (optional) | |
Modified keys: | |
- keypoints or keypoint_labels (optional) | |
- keypoints_visible or keypoint_labels_visible (optional) | |
- lifting_target or lifting_target_label (optional) | |
- lifting_target_visible (optional) | |
- camera_param (optional) | |
""" | |
def __init__(self, | |
keypoints_flip_cfg: dict, | |
target_flip_cfg: dict, | |
flip_prob: float = 0.5, | |
flip_camera: bool = False, | |
flip_label: bool = False): | |
self.keypoints_flip_cfg = keypoints_flip_cfg | |
self.target_flip_cfg = target_flip_cfg | |
self.flip_prob = flip_prob | |
self.flip_camera = flip_camera | |
self.flip_label = flip_label | |
def transform(self, results: Dict) -> dict: | |
"""The transform function of :class:`RandomFlipAroundRoot`. | |
See ``transform()`` method of :class:`BaseTransform` for details. | |
Args: | |
results (dict): The result dict | |
Returns: | |
dict: The result dict. | |
""" | |
if np.random.rand() <= self.flip_prob: | |
if self.flip_label: | |
assert 'keypoint_labels' in results | |
assert 'lifting_target_label' in results | |
keypoints_key = 'keypoint_labels' | |
keypoints_visible_key = 'keypoint_labels_visible' | |
target_key = 'lifting_target_label' | |
else: | |
assert 'keypoints' in results | |
assert 'lifting_target' in results | |
keypoints_key = 'keypoints' | |
keypoints_visible_key = 'keypoints_visible' | |
target_key = 'lifting_target' | |
keypoints = results[keypoints_key] | |
if keypoints_visible_key in results: | |
keypoints_visible = results[keypoints_visible_key] | |
else: | |
keypoints_visible = np.ones( | |
keypoints.shape[:-1], dtype=np.float32) | |
lifting_target = results[target_key] | |
if 'lifting_target_visible' in results: | |
lifting_target_visible = results['lifting_target_visible'] | |
else: | |
lifting_target_visible = np.ones( | |
lifting_target.shape[:-1], dtype=np.float32) | |
if 'flip_indices' not in results: | |
flip_indices = list(range(self.num_keypoints)) | |
else: | |
flip_indices = results['flip_indices'] | |
# flip joint coordinates | |
_camera_param = deepcopy(results['camera_param']) | |
keypoints, keypoints_visible = flip_keypoints_custom_center( | |
keypoints, | |
keypoints_visible, | |
flip_indices, | |
center_mode=self.keypoints_flip_cfg.get( | |
'center_mode', 'static'), | |
center_x=self.keypoints_flip_cfg.get('center_x', 0.5), | |
center_index=self.keypoints_flip_cfg.get('center_index', 0)) | |
lifting_target, lifting_target_visible = flip_keypoints_custom_center( # noqa | |
lifting_target, | |
lifting_target_visible, | |
flip_indices, | |
center_mode=self.target_flip_cfg.get('center_mode', 'static'), | |
center_x=self.target_flip_cfg.get('center_x', 0.5), | |
center_index=self.target_flip_cfg.get('center_index', 0)) | |
results[keypoints_key] = keypoints | |
results[keypoints_visible_key] = keypoints_visible | |
results[target_key] = lifting_target | |
results['lifting_target_visible'] = lifting_target_visible | |
# flip horizontal distortion coefficients | |
if self.flip_camera: | |
assert 'camera_param' in results, \ | |
'Camera parameters are missing.' | |
assert 'c' in _camera_param | |
_camera_param['c'][0] *= -1 | |
if 'p' in _camera_param: | |
_camera_param['p'][0] *= -1 | |
results['camera_param'].update(_camera_param) | |
return results | |