|
import copy |
|
from typing import Dict, Any |
|
|
|
from transformers.image_processing_utils import ImageProcessingMixin, BatchFeature |
|
from timm.data.transforms_factory import create_transform |
|
|
|
|
|
|
|
class ClassificationImageProprocessor(ImageProcessingMixin): |
|
def __init__(self, data_config, **kwargs): |
|
super().__init__(**kwargs) |
|
self.data_config = data_config |
|
self._transform = create_transform(**data_config) |
|
|
|
def __call__(self, images, **kwargs) -> BatchFeature: |
|
"""Preprocess an image or a batch of images.""" |
|
return self.preprocess(images, **kwargs) |
|
|
|
def preprocess(self, images, return_tensors=None, **kwargs) -> BatchFeature: |
|
images = [self._transform(image) for image in images] |
|
data = {"pixel_values": images} |
|
return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
""" |
|
Serializes this instance to a Python dictionary. |
|
|
|
Returns: |
|
`Dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance. |
|
""" |
|
output = copy.deepcopy(self.__dict__) |
|
output.pop("_transform", None) |
|
output["image_processor_type"] = self.__class__.__name__ |
|
|
|
return output |