Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,389 Bytes
42f2c22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
import math
import random
from typing import Union
import torch
from PIL import Image
from torchvision.transforms import functional as TVF
from torchvision.transforms.functional import InterpolationMode
class AreaResize:
def __init__(
self,
max_area: float,
downsample_only: bool = False,
interpolation: InterpolationMode = InterpolationMode.BICUBIC,
):
self.max_area = max_area
self.downsample_only = downsample_only
self.interpolation = interpolation
def __call__(self, image: Union[torch.Tensor, Image.Image]):
if isinstance(image, torch.Tensor):
height, width = image.shape[-2:]
elif isinstance(image, Image.Image):
width, height = image.size
else:
raise NotImplementedError
scale = math.sqrt(self.max_area / (height * width))
# keep original height and width for small pictures.
scale = 1 if scale >= 1 and self.downsample_only else scale
resized_height, resized_width = round(height * scale), round(width * scale)
return TVF.resize(
image,
size=(resized_height, resized_width),
interpolation=self.interpolation,
)
class AreaRandomCrop:
def __init__(
self,
max_area: float,
):
self.max_area = max_area
def get_params(self, input_size, output_size):
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
# w, h = _get_image_size(img)
h, w = input_size
th, tw = output_size
if w <= tw and h <= th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, image: Union[torch.Tensor, Image.Image]):
if isinstance(image, torch.Tensor):
height, width = image.shape[-2:]
elif isinstance(image, Image.Image):
width, height = image.size
else:
raise NotImplementedError
resized_height = math.sqrt(self.max_area / (width / height))
resized_width = (width / height) * resized_height
# print('>>>>>>>>>>>>>>>>>>>>>')
# print((height, width))
# print( (resized_height, resized_width))
resized_height, resized_width = round(resized_height), round(resized_width)
i, j, h, w = self.get_params((height, width), (resized_height, resized_width))
image = TVF.crop(image, i, j, h, w)
return image
class ScaleResize:
def __init__(
self,
scale: float,
):
self.scale = scale
def __call__(self, image: Union[torch.Tensor, Image.Image]):
if isinstance(image, torch.Tensor):
height, width = image.shape[-2:]
interpolation_mode = InterpolationMode.BILINEAR
antialias = True if image.ndim == 4 else "warn"
elif isinstance(image, Image.Image):
width, height = image.size
interpolation_mode = InterpolationMode.LANCZOS
antialias = "warn"
else:
raise NotImplementedError
scale = self.scale
# keep original height and width for small pictures
resized_height, resized_width = round(height * scale), round(width * scale)
image = TVF.resize(
image,
size=(resized_height, resized_width),
interpolation=interpolation_mode,
antialias=antialias,
)
return image
|