# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, Optional, Tuple import cv2 import numpy as np from mmcv.transforms import BaseTransform from mmengine import is_seq_of from mmpose.registry import TRANSFORMS from mmpose.structures.bbox import get_udp_warp_matrix, get_warp_matrix, bbox_cs2xyxy, bbox_xyxy2cs @TRANSFORMS.register_module() class TopdownAffine(BaseTransform): """Get the bbox image as the model input by affine transform. Required Keys: - img - bbox_center - bbox_scale - bbox_rotation (optional) - keypoints (optional) Modified Keys: - img - bbox_scale Added Keys: - input_size - transformed_keypoints - bbox_mask Args: input_size (Tuple[int, int]): The input image size of the model in [w, h]. The bbox region will be cropped and resize to `input_size` use_udp (bool): Whether use unbiased data processing. See `UDP (CVPR 2020)`_ for details. Defaults to ``False`` .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524 """ def __init__(self, input_size: Tuple[int, int], input_padding: float = 1.25, use_udp: bool = False) -> None: super().__init__() assert is_seq_of(input_size, int) and len(input_size) == 2, ( f'Invalid input_size {input_size}') self.input_size = input_size self.use_udp = use_udp self.input_padding = input_padding @staticmethod def _fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float): """Reshape the bbox to a fixed aspect ratio. Args: bbox_scale (np.ndarray): The bbox scales (w, h) in shape (n, 2) aspect_ratio (float): The ratio of ``w/h`` Returns: np.darray: The reshaped bbox scales in (n, 2) """ w, h = np.hsplit(bbox_scale, [1]) bbox_scale = np.where(w > h * aspect_ratio, np.hstack([w, w / aspect_ratio]), np.hstack([h * aspect_ratio, h])) return bbox_scale def transform(self, results: Dict) -> Optional[dict]: """The transform function of :class:`TopdownAffine`. See ``transform()`` method of :class:`BaseTransform` for details. Args: results (dict): The result dict Returns: dict: The result dict. """ w, h = self.input_size warp_size = (int(w), int(h)) img_h, img_w = results['img'].shape[:2] bbox_xyxy = results['bbox_xyxy_wrt_input'].flatten() bbox_xyxy[:2] = np.maximum(bbox_xyxy[:2], 0) bbox_xyxy[2:4] = np.minimum(bbox_xyxy[2:4], [img_w, img_h]) x0, y0, x1, y1 = bbox_xyxy[:4].astype(int) bbox_mask = np.zeros((img_h, img_w), dtype=np.uint8) bbox_mask[y0:y1, x0:x1] = 1 # Take the bbox wrt the input bbox_xyxy_wrt_input = results.get('bbox_xyxy_wrt_input', None) if bbox_xyxy_wrt_input is not None: _c, _s = bbox_xyxy2cs(bbox_xyxy_wrt_input, padding=self.input_padding) results['bbox_center'] = _c.reshape(1, 2) results['bbox_scale'] = _s.reshape(1, 2) # reshape bbox to fixed aspect ratio results['bbox_scale'] = self._fix_aspect_ratio( results['bbox_scale'], aspect_ratio=w / h) # TODO: support multi-instance assert results['bbox_center'].shape[0] == 1, ( 'Top-down heatmap only supports single instance. Got invalid ' f'shape of bbox_center {results["bbox_center"].shape}.') center = results['bbox_center'][0] scale = results['bbox_scale'][0] if 'bbox_rotation' in results: rot = results['bbox_rotation'][0] else: rot = 0. if self.use_udp: warp_mat = get_udp_warp_matrix( center, scale, rot, output_size=(w, h)) else: warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) if isinstance(results['img'], list): results['img'] = [ cv2.warpAffine( img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) for img in results['img'] ] else: results['img'] = cv2.warpAffine( results['img'], warp_mat, warp_size, flags=cv2.INTER_LINEAR) bbox_mask = cv2.warpAffine( bbox_mask, warp_mat, warp_size, flags=cv2.INTER_LINEAR) bbox_mask = bbox_mask.reshape(1, h, w) results['bbox_mask'] = bbox_mask if results.get('keypoints', None) is not None: if results.get('transformed_keypoints', None) is not None: transformed_keypoints = results['transformed_keypoints'].copy() else: transformed_keypoints = results['keypoints'].copy() # Only transform (x, y) coordinates transformed_keypoints[..., :2] = cv2.transform( transformed_keypoints[..., :2], warp_mat) results['transformed_keypoints'] = transformed_keypoints if results.get('bbox_xyxy_wrt_input', None) is not None: bbox_xyxy_wrt_input = results['bbox_xyxy_wrt_input'].copy() bbox_xyxy_wrt_input = bbox_xyxy_wrt_input.reshape(1, 2, 2) bbox_xyxy_wrt_input = cv2.transform( bbox_xyxy_wrt_input, warp_mat) results['bbox_xyxy_wrt_input'] = bbox_xyxy_wrt_input.reshape(1, 4) results['input_size'] = (w, h) results['input_center'] = center results['input_scale'] = scale return results def __repr__(self) -> str: """print the basic information of the transform. Returns: str: Formatted string. """ repr_str = self.__class__.__name__ repr_str += f'(input_size={self.input_size}, ' repr_str += f'use_udp={self.use_udp})' return repr_str