jacobmp's picture
remove redundant torch.cat()
eaa916e verified
import torch
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from huggingface_hub import hf_hub_download
from transformers import pipeline
from ultralytics import YOLO
from PIL import Image
def process(path, progress = gr.Progress(), device = 'cpu'):
progress(0, desc="Starting")
LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
OCR_MODEL_PATH = "microsoft/trocr-large-handwritten"
CORRECTOR_PATH = "oliverguhr/spelling-correction-english-base"
# Load the model and processor
processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH)
model.to(device)
# Open an image of handwritten text
image = Image.open(path).convert("RGB")
progress(0, desc="Extracting Text Lines")
try:
# Load the trained line detection model
cached_model_path = hf_hub_download(repo_id = LINE_MODEL_PATH, filename="lines_20240827.pt")
line_model = YOLO(cached_model_path)
except Exception as e:
print('Failed to load the line detection model: %s' % e)
results = line_model.predict(source = image)[0]
boxes = results.boxes.xyxy
indices = boxes[:,1].sort().indices
boxes = boxes[indices]
batch = []
for box in progress.tqdm(boxes, desc="Preprocessing"):
#box = box + torch.tensor([-10,0, 10, 0])
box = [tensor.item() for tensor in box]
lineImg = image.crop(tuple(list(box)))
# Preprocess
pixel_values = processor(lineImg, return_tensors="pt").pixel_values
batch.append(pixel_values)
#Predict and decode the entire batch
progress(0, desc="Recognizing..")
batch = torch.cat(batch).to(device)
print("batch.shape", batch.shape)
generated_ids = model.generate(batch)
progress(0, desc="Decoding (token -> str)")
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_text)
full_text = " ".join(generated_text)
print(full_text)
progress(0, desc="Correction..")
fix_spelling = pipeline("text2text-generation",model=CORRECTOR_PATH)
fixed_text = fix_spelling(full_text, max_new_tokens=len(full_text)+100)
fixed_text = fixed_text[0]['generated_text']
return fixed_text
if __name__ == "__main__":
demo = gr.Interface(fn=process, inputs=gr.Image(type="filepath"), outputs="text")
demo.launch()