File size: 2,466 Bytes
939af4a e76d06f 81788df 98361de 81788df 4130caf 208e4e2 81788df 98361de 81788df 4130caf adee107 81788df f4f66f6 adee107 208e4e2 81788df adee107 81788df adee107 81788df adee107 81788df adee107 4130caf eaa916e adee107 4130caf c9bae28 adee107 c9bae28 81788df adee107 98361de e76d06f 9ff8035 93bb55b f4f66f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from huggingface_hub import hf_hub_download
from transformers import pipeline
from ultralytics import YOLO
from PIL import Image
def process(path, progress = gr.Progress(), device = 'cpu'):
progress(0, desc="Starting")
LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
OCR_MODEL_PATH = "microsoft/trocr-large-handwritten"
CORRECTOR_PATH = "oliverguhr/spelling-correction-english-base"
# Load the model and processor
processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH)
model.to(device)
# Open an image of handwritten text
image = Image.open(path).convert("RGB")
progress(0, desc="Extracting Text Lines")
try:
# Load the trained line detection model
cached_model_path = hf_hub_download(repo_id = LINE_MODEL_PATH, filename="lines_20240827.pt")
line_model = YOLO(cached_model_path)
except Exception as e:
print('Failed to load the line detection model: %s' % e)
results = line_model.predict(source = image)[0]
boxes = results.boxes.xyxy
indices = boxes[:,1].sort().indices
boxes = boxes[indices]
batch = []
for box in progress.tqdm(boxes, desc="Preprocessing"):
#box = box + torch.tensor([-10,0, 10, 0])
box = [tensor.item() for tensor in box]
lineImg = image.crop(tuple(list(box)))
# Preprocess
pixel_values = processor(lineImg, return_tensors="pt").pixel_values
batch.append(pixel_values)
#Predict and decode the entire batch
progress(0, desc="Recognizing..")
batch = torch.cat(batch).to(device)
print("batch.shape", batch.shape)
generated_ids = model.generate(batch)
progress(0, desc="Decoding (token -> str)")
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_text)
full_text = " ".join(generated_text)
print(full_text)
progress(0, desc="Correction..")
fix_spelling = pipeline("text2text-generation",model=CORRECTOR_PATH)
fixed_text = fix_spelling(full_text, max_new_tokens=len(full_text)+100)
fixed_text = fixed_text[0]['generated_text']
return fixed_text
if __name__ == "__main__":
demo = gr.Interface(fn=process, inputs=gr.Image(type="filepath"), outputs="text")
demo.launch() |