File size: 1,262 Bytes
bfae0e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import copy
from typing import Dict, Any

from transformers.image_processing_utils import ImageProcessingMixin, BatchFeature
from timm.data.transforms_factory import create_transform



class ClassificationImageProprocessor(ImageProcessingMixin):
    def __init__(self, data_config, **kwargs):
        super().__init__(**kwargs)
        self.data_config = data_config
        self._transform = create_transform(**data_config)

    def __call__(self, images, **kwargs) -> BatchFeature:
        """Preprocess an image or a batch of images."""
        return self.preprocess(images, **kwargs)

    def preprocess(self, images, return_tensors=None, **kwargs) -> BatchFeature:
        images = [self._transform(image) for image in images]
        data = {"pixel_values": images} 
        return BatchFeature(data=data, tensor_type=return_tensors)

    def to_dict(self) -> Dict[str, Any]:
        """
        Serializes this instance to a Python dictionary.

        Returns:
            `Dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance.
        """
        output = copy.deepcopy(self.__dict__)
        output.pop("_transform", None)
        output["image_processor_type"] = self.__class__.__name__

        return output