import torch import gradio as gr from transformers import TrOCRProcessor, VisionEncoderDecoderModel from huggingface_hub import hf_hub_download from transformers import pipeline from ultralytics import YOLO from PIL import Image def process(path, progress = gr.Progress(), device = 'cpu'): progress(0, desc="Starting") LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection" OCR_MODEL_PATH = "microsoft/trocr-large-handwritten" CORRECTOR_PATH = "oliverguhr/spelling-correction-english-base" # Load the model and processor processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH) model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH) model.to(device) # Open an image of handwritten text image = Image.open(path).convert("RGB") progress(0, desc="Extracting Text Lines") try: # Load the trained line detection model cached_model_path = hf_hub_download(repo_id = LINE_MODEL_PATH, filename="lines_20240827.pt") line_model = YOLO(cached_model_path) except Exception as e: print('Failed to load the line detection model: %s' % e) results = line_model.predict(source = image)[0] boxes = results.boxes.xyxy indices = boxes[:,1].sort().indices boxes = boxes[indices] batch = [] for box in progress.tqdm(boxes, desc="Preprocessing"): #box = box + torch.tensor([-10,0, 10, 0]) box = [tensor.item() for tensor in box] lineImg = image.crop(tuple(list(box))) # Preprocess pixel_values = processor(lineImg, return_tensors="pt").pixel_values batch.append(pixel_values) #Predict and decode the entire batch progress(0, desc="Recognizing..") batch = torch.cat(batch).to(device) print("batch.shape", batch.shape) generated_ids = model.generate(batch) progress(0, desc="Decoding (token -> str)") generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) print(generated_text) full_text = " ".join(generated_text) print(full_text) progress(0, desc="Correction..") fix_spelling = pipeline("text2text-generation",model=CORRECTOR_PATH) fixed_text = fix_spelling(full_text, max_new_tokens=len(full_text)+100) fixed_text = fixed_text[0]['generated_text'] return fixed_text if __name__ == "__main__": demo = gr.Interface(fn=process, inputs=gr.Image(type="filepath"), outputs="text") demo.launch()