# Copyright (c) OpenMMLab. All rights reserved. import random from typing import List, Tuple import torch import torch.nn as nn import torch.nn.functional as F from mmengine import MessageHub from mmengine.dist import barrier, broadcast, get_dist_info from mmengine.structures import PixelData from torch import Tensor from mmpose.registry import MODELS from mmpose.structures import PoseDataSample @MODELS.register_module() class BatchSyncRandomResize(nn.Module): """Batch random resize which synchronizes the random size across ranks. Args: random_size_range (tuple): The multi-scale random range during multi-scale training. interval (int): The iter interval of change image size. Defaults to 10. size_divisor (int): Image size divisible factor. Defaults to 32. """ def __init__(self, random_size_range: Tuple[int, int], interval: int = 10, size_divisor: int = 32) -> None: super().__init__() self.rank, self.world_size = get_dist_info() self._input_size = None self._random_size_range = (round(random_size_range[0] / size_divisor), round(random_size_range[1] / size_divisor)) self._interval = interval self._size_divisor = size_divisor def forward(self, inputs: Tensor, data_samples: List[PoseDataSample] ) -> Tuple[Tensor, List[PoseDataSample]]: """resize a batch of images and bboxes to shape ``self._input_size``""" h, w = inputs.shape[-2:] if self._input_size is None: self._input_size = (h, w) scale_y = self._input_size[0] / h scale_x = self._input_size[1] / w if scale_x != 1 or scale_y != 1: inputs = F.interpolate( inputs, size=self._input_size, mode='bilinear', align_corners=False) for data_sample in data_samples: img_shape = (int(data_sample.img_shape[0] * scale_y), int(data_sample.img_shape[1] * scale_x)) pad_shape = (int(data_sample.pad_shape[0] * scale_y), int(data_sample.pad_shape[1] * scale_x)) data_sample.set_metainfo({ 'img_shape': img_shape, 'pad_shape': pad_shape, 'batch_input_shape': self._input_size }) if 'gt_instance_labels' not in data_sample: continue if 'bboxes' in data_sample.gt_instance_labels: data_sample.gt_instance_labels.bboxes[..., 0::2] *= scale_x data_sample.gt_instance_labels.bboxes[..., 1::2] *= scale_y if 'keypoints' in data_sample.gt_instance_labels: data_sample.gt_instance_labels.keypoints[..., 0] *= scale_x data_sample.gt_instance_labels.keypoints[..., 1] *= scale_y if 'areas' in data_sample.gt_instance_labels: data_sample.gt_instance_labels.areas *= scale_x * scale_y if 'gt_fields' in data_sample \ and 'heatmap_mask' in data_sample.gt_fields: mask = data_sample.gt_fields.heatmap_mask.unsqueeze(0) gt_fields = PixelData() gt_fields.set_field( F.interpolate( mask.float(), size=self._input_size, mode='bilinear', align_corners=False).squeeze(0), 'heatmap_mask') data_sample.gt_fields = gt_fields message_hub = MessageHub.get_current_instance() if (message_hub.get_info('iter') + 1) % self._interval == 0: self._input_size = self._get_random_size( aspect_ratio=float(w / h), device=inputs.device) return inputs, data_samples def _get_random_size(self, aspect_ratio: float, device: torch.device) -> Tuple[int, int]: """Randomly generate a shape in ``_random_size_range`` and broadcast to all ranks.""" tensor = torch.LongTensor(2).to(device) if self.rank == 0: size = random.randint(*self._random_size_range) size = (self._size_divisor * size, self._size_divisor * int(aspect_ratio * size)) tensor[0] = size[0] tensor[1] = size[1] barrier() broadcast(tensor, 0) input_size = (tensor[0].item(), tensor[1].item()) return input_size