jacobmp commited on
Commit
adee107
·
verified ·
1 Parent(s): 98361de

Do recognition and decoding in batch to speedup

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -15,9 +15,10 @@ def process(path, progress = gr.Progress()):
15
  # Load the model and processor
16
  processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
17
  model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH)
18
-
19
  # Open an image of handwritten text
20
  image = Image.open(path).convert("RGB")
 
21
  progress(0, desc="Extracting Text Lines")
22
  try:
23
  # Load the trained line detection model
@@ -25,23 +26,29 @@ def process(path, progress = gr.Progress()):
25
  line_model = YOLO(cached_model_path)
26
  except Exception as e:
27
  print('Failed to load the line detection model: %s' % e)
28
-
29
  results = line_model.predict(source = image)[0]
30
- full_text = ""
31
  boxes = results.boxes.xyxy
32
  indices = boxes[:,1].sort().indices
33
  boxes = boxes[indices]
34
- for box in progress.tqdm(boxes, desc="Text Recognition"):
 
35
  #box = box + torch.tensor([-10,0, 10, 0])
36
  box = [tensor.item() for tensor in box]
37
  lineImg = image.crop(tuple(list(box)))
38
 
39
- # Preprocess and predict
40
  pixel_values = processor(lineImg, return_tensors="pt").pixel_values
41
- generated_ids = model.generate(pixel_values)
42
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
43
- full_text += generated_text
 
 
 
 
 
44
 
 
45
  fix_spelling = pipeline("text2text-generation",model=CORRECTOR_PATH)
46
  fixed_text = fix_spelling(full_text, max_new_tokens=len(full_text)+100)
47
  fixed_text = fixed_text[0]['generated_text']
 
15
  # Load the model and processor
16
  processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
17
  model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH)
18
+
19
  # Open an image of handwritten text
20
  image = Image.open(path).convert("RGB")
21
+
22
  progress(0, desc="Extracting Text Lines")
23
  try:
24
  # Load the trained line detection model
 
26
  line_model = YOLO(cached_model_path)
27
  except Exception as e:
28
  print('Failed to load the line detection model: %s' % e)
29
+
30
  results = line_model.predict(source = image)[0]
 
31
  boxes = results.boxes.xyxy
32
  indices = boxes[:,1].sort().indices
33
  boxes = boxes[indices]
34
+ batch = []
35
+ for box in progress.tqdm(boxes, desc="Preprocessing"):
36
  #box = box + torch.tensor([-10,0, 10, 0])
37
  box = [tensor.item() for tensor in box]
38
  lineImg = image.crop(tuple(list(box)))
39
 
40
+ # Preprocess
41
  pixel_values = processor(lineImg, return_tensors="pt").pixel_values
42
+ batch.append(pixel_values)
43
+
44
+ #Predict and decode the entire batch
45
+ progress(0, desc="Recognizing..")
46
+ generated_ids = model.generate(torch.cat(batch))
47
+ progress(0, desc="Decoding (token -> str)")
48
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
49
+ full_text = " ".join(generated_text)
50
 
51
+ progress(0, desc="Correction..")
52
  fix_spelling = pipeline("text2text-generation",model=CORRECTOR_PATH)
53
  fixed_text = fix_spelling(full_text, max_new_tokens=len(full_text)+100)
54
  fixed_text = fixed_text[0]['generated_text']