File size: 4,781 Bytes
1e8e8e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from typing import List, Optional, Union

from PIL import Image
import torch
from transformers.image_processing_base import BatchFeature
from transformers.image_processing_utils_fast import (BaseImageProcessorFast,
                                                      divide_to_patches)
from transformers.image_utils import (ChannelDimension, SizeDict,
                                      get_image_size, make_list_of_images,
                                      get_image_type, ImageInput, ImageType)
from transformers.utils import TensorType


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_factor = float('-inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        factor_based_on_area_n_ratio = min(
            (ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6
            )* min(
                target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
        if factor_based_on_area_n_ratio > best_factor:
            best_factor = factor_based_on_area_n_ratio
            best_ratio = ratio
    return best_ratio


class LlamaNemotronNanoVLImageProcessor(BaseImageProcessorFast):
    model_input_names = ["pixel_values"]

    def __init__(self, image_size=512, max_num_tiles=12, use_thumbnail=True, **kwargs):
        super().__init__(**kwargs)
        self.image_size = image_size
        self.max_num_tiles = max_num_tiles
        self.use_thumbnail = use_thumbnail

    # Based on https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702
    def dynamic_preprocess(self, image, image_size=448, max_num_tiles=12, use_thumbnail=False):
        orig_height, orig_width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
        aspect_ratio = orig_width / orig_height

        # calculate the existing image aspect ratio
        target_ratios = set(
            (i, j) for n in range(1, max_num_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
            i * j <= max_num_tiles and i * j >= 1)
        target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

        # find the closest aspect ratio to the target
        target_aspect_ratio = find_closest_aspect_ratio(
            aspect_ratio, target_ratios, orig_width, orig_height, image_size)

        # calculate the target width and height
        target_width = image_size * target_aspect_ratio[0]
        target_height = image_size * target_aspect_ratio[1]

        resized_img = self.resize(image, SizeDict(height=target_height, width=target_width))
        patches = divide_to_patches(resized_img, image_size)
        if use_thumbnail and len(patches) != 1:
            patches.append(self.resize(image, SizeDict(height=image_size, width=image_size)))

        return patches

    def _process_image(
        self,
        image: ImageInput,
        **kwargs,
    ) -> torch.Tensor:
        image_type = get_image_type(image)
        if image_type not in [ImageType.PIL]:
            raise ValueError(f"Unsupported input image type {image_type}. Only PIL images supported")
        image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR)
        return super()._process_image(image, **kwargs)

    def _preprocess(
        self,
        images: List[torch.Tensor],
        image_size: int = None,
        max_num_tiles: int = None,
        use_thumbnail: bool = None,
        do_rescale: bool = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ) -> List[torch.Tensor]:
        image_size = image_size if image_size is not None else self.image_size
        max_num_tiles = max_num_tiles if max_num_tiles is not None else self.max_num_tiles
        use_thumbnail = use_thumbnail if use_thumbnail is not None else self.use_thumbnail
        do_rescale = do_rescale if do_rescale is not None else self.do_rescale

        images = make_list_of_images(images)

        all_patches = []
        num_patches = []
        for image in images:
            patches = self.dynamic_preprocess(
                image, image_size, max_num_tiles, use_thumbnail
            )
            all_patches.extend(patches)
            num_patches.append(len(patches))

        pixel_values = torch.stack(all_patches, dim=0)
        pixel_values = self.rescale_and_normalize(
            pixel_values,
            do_rescale,
            self.rescale_factor,
            do_normalize=self.do_normalize,
            image_mean=self.image_mean,
            image_std=self.image_std
        )

        return BatchFeature(data={"pixel_values": pixel_values, "num_patches": num_patches}, tensor_type=return_tensors)