# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Optional, Sequence, Union import numpy as np import torch import torch.nn as nn from mmengine.model import ImgDataPreprocessor from mmengine.utils import is_seq_of from mmpose.registry import MODELS @MODELS.register_module() class PoseDataPreprocessor(ImgDataPreprocessor): """Image pre-processor for pose estimation tasks. Comparing with the :class:`ImgDataPreprocessor`, 1. It will additionally append batch_input_shape to data_samples considering the DETR-based pose estimation tasks. 2. Support image augmentation transforms on batched data. It provides the data pre-processing as follows - Collate and move data to the target device. - Pad inputs to the maximum size of current batch with defined ``pad_value``. The padding size can be divisible by a defined ``pad_size_divisor`` - Stack inputs to batch_inputs. - Convert inputs from bgr to rgb if the shape of input is (3, H, W). - Normalize image with defined std and mean. - Apply batch augmentation transforms. Args: mean (sequence of float, optional): The pixel mean of R, G, B channels. Defaults to None. std (sequence of float, optional): The pixel standard deviation of R, G, B channels. Defaults to None. pad_size_divisor (int): The size of padded image should be divisible by ``pad_size_divisor``. Defaults to 1. pad_value (float or int): The padded pixel value. Defaults to 0. bgr_to_rgb (bool): whether to convert image from BGR to RGB. Defaults to False. rgb_to_bgr (bool): whether to convert image from RGB to BGR. Defaults to False. non_blocking (bool): Whether block current process when transferring data to device. Defaults to False. batch_augments: (list of dict, optional): Configs of augmentation transforms on batched data. Defaults to None. """ def __init__(self, mean: Sequence[float] = None, std: Sequence[float] = None, pad_size_divisor: int = 1, pad_value: Union[float, int] = 0, bgr_to_rgb: bool = False, rgb_to_bgr: bool = False, non_blocking: Optional[bool] = False, batch_augments: Optional[List[dict]] = None): super().__init__( mean=mean, std=std, pad_size_divisor=pad_size_divisor, pad_value=pad_value, bgr_to_rgb=bgr_to_rgb, rgb_to_bgr=rgb_to_bgr, non_blocking=non_blocking) if batch_augments is not None: self.batch_augments = nn.ModuleList( [MODELS.build(aug) for aug in batch_augments]) else: self.batch_augments = None def forward(self, data: dict, training: bool = False) -> dict: """Perform normalization, padding and bgr2rgb conversion based on ``BaseDataPreprocessor``. Args: data (dict): Data sampled from dataloader. training (bool): Whether to enable training time augmentation. Returns: dict: Data in the same format as the model input. """ batch_pad_shape = self._get_pad_shape(data) data = super().forward(data=data, training=training) inputs, data_samples = data['inputs'], data['data_samples'] # update metainfo since the image shape might change batch_input_shape = tuple(inputs[0].size()[-2:]) for data_sample, pad_shape in zip(data_samples, batch_pad_shape): data_sample.set_metainfo({ 'batch_input_shape': batch_input_shape, 'pad_shape': pad_shape }) # apply batch augmentations if training and self.batch_augments is not None: for batch_aug in self.batch_augments: inputs, data_samples = batch_aug(inputs, data_samples) return {'inputs': inputs, 'data_samples': data_samples} def _get_pad_shape(self, data: dict) -> List[tuple]: """Get the pad_shape of each image based on data and pad_size_divisor.""" _batch_inputs = data['inputs'] # Process data with `pseudo_collate`. if is_seq_of(_batch_inputs, torch.Tensor): batch_pad_shape = [] for ori_input in _batch_inputs: pad_h = int( np.ceil(ori_input.shape[1] / self.pad_size_divisor)) * self.pad_size_divisor pad_w = int( np.ceil(ori_input.shape[2] / self.pad_size_divisor)) * self.pad_size_divisor batch_pad_shape.append((pad_h, pad_w)) # Process data with `default_collate`. elif isinstance(_batch_inputs, torch.Tensor): assert _batch_inputs.dim() == 4, ( 'The input of `ImgDataPreprocessor` should be a NCHW tensor ' 'or a list of tensor, but got a tensor with shape: ' f'{_batch_inputs.shape}') pad_h = int( np.ceil(_batch_inputs.shape[1] / self.pad_size_divisor)) * self.pad_size_divisor pad_w = int( np.ceil(_batch_inputs.shape[2] / self.pad_size_divisor)) * self.pad_size_divisor batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0] else: raise TypeError('Output of `cast_data` should be a dict ' 'or a tuple with inputs and data_samples, but got' f'{type(data)}: {data}') return batch_pad_shape