jacobmp commited on
Commit
98361de
·
verified ·
1 Parent(s): 208e4e2

pass ocr output through LLM for spell and grammar correction

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -2,16 +2,15 @@ import gradio as gr
2
 
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
  from huggingface_hub import hf_hub_download
5
- from transformers import AutoModel
6
  from ultralytics import YOLO
7
  from PIL import Image
8
- import torch
9
 
10
  def process(path, progress = gr.Progress()):
11
  progress(0, desc="Starting")
12
  LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
13
- #OCR_MODEL_PATH = "Kansallisarkisto/multicentury-htr-model"
14
  OCR_MODEL_PATH = "microsoft/trocr-large-handwritten"
 
15
 
16
  # Load the model and processor
17
  processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
@@ -43,7 +42,11 @@ def process(path, progress = gr.Progress()):
43
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
44
  full_text += generated_text
45
 
46
- return full_text
 
 
 
 
47
 
48
  if __name__ == "__main__":
49
  demo = gr.Interface(fn=process, inputs=gr.Image(type="filepath"), outputs="text")
 
2
 
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
  from huggingface_hub import hf_hub_download
5
+ from transformers import pipeline
6
  from ultralytics import YOLO
7
  from PIL import Image
 
8
 
9
  def process(path, progress = gr.Progress()):
10
  progress(0, desc="Starting")
11
  LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
 
12
  OCR_MODEL_PATH = "microsoft/trocr-large-handwritten"
13
+ CORRECTOR_PATH = "oliverguhr/spelling-correction-english-base"
14
 
15
  # Load the model and processor
16
  processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
 
42
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
43
  full_text += generated_text
44
 
45
+ fix_spelling = pipeline("text2text-generation",model=CORRECTOR_PATH)
46
+ fixed_text = fix_spelling(full_text, max_new_tokens=len(full_text)+100)
47
+ fixed_text = fixed_text[0]['generated_text']
48
+
49
+ return fixed_text
50
 
51
  if __name__ == "__main__":
52
  demo = gr.Interface(fn=process, inputs=gr.Image(type="filepath"), outputs="text")