Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,531 Bytes
a249588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Dict
import numpy as np
from mmcv.transforms import BaseTransform
from mmpose.registry import TRANSFORMS
from mmpose.structures.keypoint import flip_keypoints_custom_center
@TRANSFORMS.register_module()
class RandomFlipAroundRoot(BaseTransform):
"""Data augmentation with random horizontal joint flip around a root joint.
Args:
keypoints_flip_cfg (dict): Configurations of the
``flip_keypoints_custom_center`` function for ``keypoints``. Please
refer to the docstring of the ``flip_keypoints_custom_center``
function for more details.
target_flip_cfg (dict): Configurations of the
``flip_keypoints_custom_center`` function for ``lifting_target``.
Please refer to the docstring of the
``flip_keypoints_custom_center`` function for more details.
flip_prob (float): Probability of flip. Default: 0.5.
flip_camera (bool): Whether to flip horizontal distortion coefficients.
Default: ``False``.
flip_label (bool): Whether to flip labels instead of data.
Default: ``False``.
Required keys:
- keypoints or keypoint_labels
- lifting_target or lifting_target_label
- keypoints_visible or keypoint_labels_visible (optional)
- lifting_target_visible (optional)
- flip_indices (optional)
Modified keys:
- keypoints or keypoint_labels (optional)
- keypoints_visible or keypoint_labels_visible (optional)
- lifting_target or lifting_target_label (optional)
- lifting_target_visible (optional)
- camera_param (optional)
"""
def __init__(self,
keypoints_flip_cfg: dict,
target_flip_cfg: dict,
flip_prob: float = 0.5,
flip_camera: bool = False,
flip_label: bool = False):
self.keypoints_flip_cfg = keypoints_flip_cfg
self.target_flip_cfg = target_flip_cfg
self.flip_prob = flip_prob
self.flip_camera = flip_camera
self.flip_label = flip_label
def transform(self, results: Dict) -> dict:
"""The transform function of :class:`RandomFlipAroundRoot`.
See ``transform()`` method of :class:`BaseTransform` for details.
Args:
results (dict): The result dict
Returns:
dict: The result dict.
"""
if np.random.rand() <= self.flip_prob:
if self.flip_label:
assert 'keypoint_labels' in results
assert 'lifting_target_label' in results
keypoints_key = 'keypoint_labels'
keypoints_visible_key = 'keypoint_labels_visible'
target_key = 'lifting_target_label'
else:
assert 'keypoints' in results
assert 'lifting_target' in results
keypoints_key = 'keypoints'
keypoints_visible_key = 'keypoints_visible'
target_key = 'lifting_target'
keypoints = results[keypoints_key]
if keypoints_visible_key in results:
keypoints_visible = results[keypoints_visible_key]
else:
keypoints_visible = np.ones(
keypoints.shape[:-1], dtype=np.float32)
lifting_target = results[target_key]
if 'lifting_target_visible' in results:
lifting_target_visible = results['lifting_target_visible']
else:
lifting_target_visible = np.ones(
lifting_target.shape[:-1], dtype=np.float32)
if 'flip_indices' not in results:
flip_indices = list(range(self.num_keypoints))
else:
flip_indices = results['flip_indices']
# flip joint coordinates
_camera_param = deepcopy(results['camera_param'])
keypoints, keypoints_visible = flip_keypoints_custom_center(
keypoints,
keypoints_visible,
flip_indices,
center_mode=self.keypoints_flip_cfg.get(
'center_mode', 'static'),
center_x=self.keypoints_flip_cfg.get('center_x', 0.5),
center_index=self.keypoints_flip_cfg.get('center_index', 0))
lifting_target, lifting_target_visible = flip_keypoints_custom_center( # noqa
lifting_target,
lifting_target_visible,
flip_indices,
center_mode=self.target_flip_cfg.get('center_mode', 'static'),
center_x=self.target_flip_cfg.get('center_x', 0.5),
center_index=self.target_flip_cfg.get('center_index', 0))
results[keypoints_key] = keypoints
results[keypoints_visible_key] = keypoints_visible
results[target_key] = lifting_target
results['lifting_target_visible'] = lifting_target_visible
# flip horizontal distortion coefficients
if self.flip_camera:
assert 'camera_param' in results, \
'Camera parameters are missing.'
assert 'c' in _camera_param
_camera_param['c'][0] *= -1
if 'p' in _camera_param:
_camera_param['p'][0] *= -1
results['camera_param'].update(_camera_param)
return results
|