from __future__ import absolute_import import warnings import numpy as np import torch from torchvision.transforms import functional as F from ..core.transforms_interface import BasicTransform __all__ = ["ToTensorV2"] def img_to_tensor(im, normalize=None): tensor = torch.from_numpy(np.moveaxis(im / (255.0 if im.dtype == np.uint8 else 1), -1, 0).astype(np.float32)) if normalize is not None: return F.normalize(tensor, **normalize) return tensor def mask_to_tensor(mask, num_classes, sigmoid): if num_classes > 1: if not sigmoid: # softmax long_mask = np.zeros((mask.shape[:2]), dtype=np.int64) if len(mask.shape) == 3: for c in range(mask.shape[2]): long_mask[mask[..., c] > 0] = c else: long_mask[mask > 127] = 1 long_mask[mask == 0] = 0 mask = long_mask else: mask = np.moveaxis(mask / (255.0 if mask.dtype == np.uint8 else 1), -1, 0).astype(np.float32) else: mask = np.expand_dims(mask / (255.0 if mask.dtype == np.uint8 else 1), 0).astype(np.float32) return torch.from_numpy(mask) class ToTensor(BasicTransform): """Convert image and mask to `torch.Tensor` and divide by 255 if image or mask are `uint8` type. This transform is now removed from custom_albumentations. If you need it downgrade the library to version 0.5.2. Args: num_classes (int): only for segmentation sigmoid (bool, optional): only for segmentation, transform mask to LongTensor or not. normalize (dict, optional): dict with keys [mean, std] to pass it into torchvision.normalize """ def __init__(self, num_classes=1, sigmoid=True, normalize=None): raise RuntimeError( "`ToTensor` is obsolete and it was removed from custom_albumentations. Please use `ToTensorV2` instead - " "https://albumentations.ai/docs/api_reference/pytorch/transforms/" "#albumentations.pytorch.transforms.ToTensorV2. " "\n\nIf you need `ToTensor` downgrade Albumentations to version 0.5.2." ) class ToTensorV2(BasicTransform): """Convert image and mask to `torch.Tensor`. The numpy `HWC` image is converted to pytorch `CHW` tensor. If the image is in `HW` format (grayscale image), it will be converted to pytorch `HW` tensor. This is a simplified and improved version of the old `ToTensor` transform (`ToTensor` was deprecated, and now it is not present in Albumentations. You should use `ToTensorV2` instead). Args: transpose_mask (bool): If True and an input mask has three dimensions, this transform will transpose dimensions so the shape `[height, width, num_channels]` becomes `[num_channels, height, width]`. The latter format is a standard format for PyTorch Tensors. Default: False. always_apply (bool): Indicates whether this transformation should be always applied. Default: True. p (float): Probability of applying the transform. Default: 1.0. """ def __init__(self, transpose_mask=False, always_apply=True, p=1.0): super(ToTensorV2, self).__init__(always_apply=always_apply, p=p) self.transpose_mask = transpose_mask @property def targets(self): return {"image": self.apply, "mask": self.apply_to_mask, "masks": self.apply_to_masks} def apply(self, img, **params): # skipcq: PYL-W0613 if len(img.shape) not in [2, 3]: raise ValueError("Albumentations only supports images in HW or HWC format") if len(img.shape) == 2: img = np.expand_dims(img, 2) return torch.from_numpy(img.transpose(2, 0, 1)) def apply_to_mask(self, mask, **params): # skipcq: PYL-W0613 if self.transpose_mask and mask.ndim == 3: mask = mask.transpose(2, 0, 1) return torch.from_numpy(mask) def apply_to_masks(self, masks, **params): return [self.apply_to_mask(mask, **params) for mask in masks] def get_transform_init_args_names(self): return ("transpose_mask",) def get_params_dependent_on_targets(self, params): return {}