|
import torch |
|
import gradio as gr |
|
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
from huggingface_hub import hf_hub_download |
|
from transformers import pipeline |
|
from ultralytics import YOLO |
|
from PIL import Image |
|
|
|
def process(path, progress = gr.Progress(), device = 'cpu'): |
|
progress(0, desc="Starting") |
|
LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection" |
|
OCR_MODEL_PATH = "microsoft/trocr-large-handwritten" |
|
CORRECTOR_PATH = "oliverguhr/spelling-correction-english-base" |
|
|
|
|
|
processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH) |
|
model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH) |
|
model.to(device) |
|
|
|
|
|
image = Image.open(path).convert("RGB") |
|
|
|
progress(0, desc="Extracting Text Lines") |
|
try: |
|
|
|
cached_model_path = hf_hub_download(repo_id = LINE_MODEL_PATH, filename="lines_20240827.pt") |
|
line_model = YOLO(cached_model_path) |
|
except Exception as e: |
|
print('Failed to load the line detection model: %s' % e) |
|
|
|
results = line_model.predict(source = image)[0] |
|
boxes = results.boxes.xyxy |
|
indices = boxes[:,1].sort().indices |
|
boxes = boxes[indices] |
|
batch = [] |
|
for box in progress.tqdm(boxes, desc="Preprocessing"): |
|
|
|
box = [tensor.item() for tensor in box] |
|
lineImg = image.crop(tuple(list(box))) |
|
|
|
|
|
pixel_values = processor(lineImg, return_tensors="pt").pixel_values |
|
batch.append(pixel_values) |
|
|
|
|
|
progress(0, desc="Recognizing..") |
|
batch = torch.cat(batch).to(device) |
|
print("batch.shape", batch.shape) |
|
generated_ids = model.generate(batch) |
|
progress(0, desc="Decoding (token -> str)") |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) |
|
|
|
print(generated_text) |
|
full_text = " ".join(generated_text) |
|
print(full_text) |
|
|
|
progress(0, desc="Correction..") |
|
fix_spelling = pipeline("text2text-generation",model=CORRECTOR_PATH) |
|
fixed_text = fix_spelling(full_text, max_new_tokens=len(full_text)+100) |
|
fixed_text = fixed_text[0]['generated_text'] |
|
|
|
return fixed_text |
|
|
|
if __name__ == "__main__": |
|
demo = gr.Interface(fn=process, inputs=gr.Image(type="filepath"), outputs="text") |
|
demo.launch() |