File size: 1,262 Bytes
bfae0e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import copy
from typing import Dict, Any
from transformers.image_processing_utils import ImageProcessingMixin, BatchFeature
from timm.data.transforms_factory import create_transform
class ClassificationImageProprocessor(ImageProcessingMixin):
def __init__(self, data_config, **kwargs):
super().__init__(**kwargs)
self.data_config = data_config
self._transform = create_transform(**data_config)
def __call__(self, images, **kwargs) -> BatchFeature:
"""Preprocess an image or a batch of images."""
return self.preprocess(images, **kwargs)
def preprocess(self, images, return_tensors=None, **kwargs) -> BatchFeature:
images = [self._transform(image) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance.
"""
output = copy.deepcopy(self.__dict__)
output.pop("_transform", None)
output["image_processor_type"] = self.__class__.__name__
return output |