import copy from typing import Dict, Any from transformers.image_processing_utils import ImageProcessingMixin, BatchFeature from timm.data.transforms_factory import create_transform class ClassificationImageProprocessor(ImageProcessingMixin): def __init__(self, data_config, **kwargs): super().__init__(**kwargs) self.data_config = data_config self._transform = create_transform(**data_config) def __call__(self, images, **kwargs) -> BatchFeature: """Preprocess an image or a batch of images.""" return self.preprocess(images, **kwargs) def preprocess(self, images, return_tensors=None, **kwargs) -> BatchFeature: images = [self._transform(image) for image in images] data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) def to_dict(self) -> Dict[str, Any]: """ Serializes this instance to a Python dictionary. Returns: `Dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance. """ output = copy.deepcopy(self.__dict__) output.pop("_transform", None) output["image_processor_type"] = self.__class__.__name__ return output