import os import re import torch import requests import numpy as np import PIL.Image import PIL.ImageOps from PIL import Image from typing import Union from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: """ Loads `image` to a PIL Image. Args: image (`str` or `PIL.Image.Image`): The image to convert to the PIL Image format. Returns: `PIL.Image.Image`: A PIL Image. """ if isinstance(image, str): if image.startswith("http://") or image.startswith("https://"): image = PIL.Image.open(requests.get(image, stream=True).raw) elif os.path.isfile(image): image = PIL.Image.open(image) else: raise ValueError( f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" ) elif isinstance(image, PIL.Image.Image): image = image else: raise ValueError( "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." ) image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") return image def aspect_ratio_preserving_resize_and_crop(image, target_width, target_height): width, height = image.size width_ratio = width / target_width height_ratio = height / target_height if width > target_width and height > target_height: if width_ratio > height_ratio: new_width = target_width new_height = int(new_width / (width / height)) else: new_height = target_height new_width = int(new_height * (width / height)) elif width > target_width: new_width = target_width new_height = int(new_width / (width / height)) elif height > target_height: new_height = target_height new_width = int(new_height * (width / height)) else: new_width, new_height = width, height resized_image = image.resize((new_width, new_height), Image.LANCZOS) padded_image = Image.new("RGB", (target_width, target_height), (255, 255, 255)) offset_x = (target_width - new_width) // 2 offset_y = (target_height - new_height) // 2 padded_image.paste(resized_image, (offset_x, offset_y)) return padded_image class Image2Text: def __init__(self, model_path, hf_token, device, max_length=1024): self.device = device self.hf_token = hf_token self.model_path = model_path self.max_length = max_length self.model, self.processor = self.load_model(self.model_path) self.decoder_input_ids = torch.tensor([[self.model.config.decoder_start_token_id]]).to(self.device) def load_model(self, model_path): config = VisionEncoderDecoderConfig.from_pretrained(model_path, token=self.hf_token) processor = DonutProcessor.from_pretrained(model_path, token=self.hf_token) model = VisionEncoderDecoderModel.from_pretrained(model_path, config=config, token=self.hf_token).to(self.device) model.eval() return model, processor def load_img(self, inputs, width=480, height=480): images = [load_image(input_) for input_ in inputs] images = [aspect_ratio_preserving_resize_and_crop(image, target_width=width, target_height=height) for image in images] imgs = self.processor([image.convert("RGB") for image in images], return_tensors="pt", size=(width, height)).pixel_values pixel_values = imgs.to(self.device) return pixel_values def generate(self, pixel_values, num_beams): outputs = self.model.generate( pixel_values, decoder_input_ids=self.decoder_input_ids.repeat(pixel_values.shape[0], 1), max_length=self.max_length, early_stopping=True, pad_token_id=self.processor.tokenizer.pad_token_id, eos_token_id=self.processor.tokenizer.eos_token_id, use_cache=True, num_beams=num_beams, bad_words_ids=[[self.processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) return outputs def postprocessing(self, outputs): seqs = self.processor.batch_decode(outputs.sequences) seqs = [seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "") for seq in seqs] seqs = [re.sub(r"<.*?>", "", seq, count=1).strip() for seq in seqs] seqs = [self.processor.token2json(seq) for seq in seqs] contents = [] for seq in seqs: try: content = seq['content'] except: content = seq['text_sequence'] contents.append('\n'.join(content.split('[newline]'))) return contents def get_text(self, img_path, num_beams=4): pixel_values = self.load_img(img_path) outputs = self.generate(pixel_values, num_beams) contents = self.postprocessing(outputs) return contents