File size: 2,466 Bytes
939af4a
e76d06f
 
81788df
 
98361de
81788df
 
 
4130caf
208e4e2
81788df
 
98361de
81788df
 
 
 
4130caf
adee107
81788df
f4f66f6
adee107
208e4e2
81788df
 
 
 
 
 
adee107
81788df
 
 
 
adee107
 
81788df
 
 
 
adee107
81788df
adee107
 
 
 
4130caf
 
eaa916e
adee107
4130caf
 
c9bae28
adee107
c9bae28
81788df
adee107
98361de
 
 
 
 
e76d06f
9ff8035
93bb55b
f4f66f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import gradio as gr

from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from huggingface_hub import hf_hub_download
from transformers import pipeline
from ultralytics import YOLO
from PIL import Image

def process(path, progress = gr.Progress(), device = 'cpu'):
    progress(0, desc="Starting")
    LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
    OCR_MODEL_PATH = "microsoft/trocr-large-handwritten"
    CORRECTOR_PATH = "oliverguhr/spelling-correction-english-base"

    # Load the model and processor
    processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
    model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH)
    model.to(device)
    
    # Open an image of handwritten text
    image = Image.open(path).convert("RGB")
    
    progress(0, desc="Extracting Text Lines")
    try:
        # Load the trained line detection model
        cached_model_path = hf_hub_download(repo_id = LINE_MODEL_PATH, filename="lines_20240827.pt")
        line_model = YOLO(cached_model_path)
    except Exception as e:
        print('Failed to load the line detection model: %s' % e)
    
    results = line_model.predict(source = image)[0]
    boxes = results.boxes.xyxy
    indices = boxes[:,1].sort().indices
    boxes = boxes[indices]
    batch = []
    for box in progress.tqdm(boxes, desc="Preprocessing"):
        #box = box + torch.tensor([-10,0, 10, 0])
        box = [tensor.item() for tensor in box]
        lineImg = image.crop(tuple(list(box)))

        # Preprocess
        pixel_values = processor(lineImg, return_tensors="pt").pixel_values
        batch.append(pixel_values)

    #Predict and decode the entire batch
    progress(0, desc="Recognizing..")
    batch = torch.cat(batch).to(device)
    print("batch.shape", batch.shape)
    generated_ids = model.generate(batch)
    progress(0, desc="Decoding (token -> str)")
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

    print(generated_text)
    full_text = " ".join(generated_text)
    print(full_text)

    progress(0, desc="Correction..")
    fix_spelling = pipeline("text2text-generation",model=CORRECTOR_PATH)
    fixed_text = fix_spelling(full_text, max_new_tokens=len(full_text)+100)
    fixed_text = fixed_text[0]['generated_text']

    return fixed_text

if __name__ == "__main__":
    demo = gr.Interface(fn=process, inputs=gr.Image(type="filepath"), outputs="text")
    demo.launch()