Do recognition and decoding in batch to speedup
Browse files
app.py
CHANGED
@@ -15,9 +15,10 @@ def process(path, progress = gr.Progress()):
|
|
15 |
# Load the model and processor
|
16 |
processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
|
17 |
model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH)
|
18 |
-
|
19 |
# Open an image of handwritten text
|
20 |
image = Image.open(path).convert("RGB")
|
|
|
21 |
progress(0, desc="Extracting Text Lines")
|
22 |
try:
|
23 |
# Load the trained line detection model
|
@@ -25,23 +26,29 @@ def process(path, progress = gr.Progress()):
|
|
25 |
line_model = YOLO(cached_model_path)
|
26 |
except Exception as e:
|
27 |
print('Failed to load the line detection model: %s' % e)
|
28 |
-
|
29 |
results = line_model.predict(source = image)[0]
|
30 |
-
full_text = ""
|
31 |
boxes = results.boxes.xyxy
|
32 |
indices = boxes[:,1].sort().indices
|
33 |
boxes = boxes[indices]
|
34 |
-
|
|
|
35 |
#box = box + torch.tensor([-10,0, 10, 0])
|
36 |
box = [tensor.item() for tensor in box]
|
37 |
lineImg = image.crop(tuple(list(box)))
|
38 |
|
39 |
-
# Preprocess
|
40 |
pixel_values = processor(lineImg, return_tensors="pt").pixel_values
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
44 |
|
|
|
45 |
fix_spelling = pipeline("text2text-generation",model=CORRECTOR_PATH)
|
46 |
fixed_text = fix_spelling(full_text, max_new_tokens=len(full_text)+100)
|
47 |
fixed_text = fixed_text[0]['generated_text']
|
|
|
15 |
# Load the model and processor
|
16 |
processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
|
17 |
model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH)
|
18 |
+
|
19 |
# Open an image of handwritten text
|
20 |
image = Image.open(path).convert("RGB")
|
21 |
+
|
22 |
progress(0, desc="Extracting Text Lines")
|
23 |
try:
|
24 |
# Load the trained line detection model
|
|
|
26 |
line_model = YOLO(cached_model_path)
|
27 |
except Exception as e:
|
28 |
print('Failed to load the line detection model: %s' % e)
|
29 |
+
|
30 |
results = line_model.predict(source = image)[0]
|
|
|
31 |
boxes = results.boxes.xyxy
|
32 |
indices = boxes[:,1].sort().indices
|
33 |
boxes = boxes[indices]
|
34 |
+
batch = []
|
35 |
+
for box in progress.tqdm(boxes, desc="Preprocessing"):
|
36 |
#box = box + torch.tensor([-10,0, 10, 0])
|
37 |
box = [tensor.item() for tensor in box]
|
38 |
lineImg = image.crop(tuple(list(box)))
|
39 |
|
40 |
+
# Preprocess
|
41 |
pixel_values = processor(lineImg, return_tensors="pt").pixel_values
|
42 |
+
batch.append(pixel_values)
|
43 |
+
|
44 |
+
#Predict and decode the entire batch
|
45 |
+
progress(0, desc="Recognizing..")
|
46 |
+
generated_ids = model.generate(torch.cat(batch))
|
47 |
+
progress(0, desc="Decoding (token -> str)")
|
48 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
49 |
+
full_text = " ".join(generated_text)
|
50 |
|
51 |
+
progress(0, desc="Correction..")
|
52 |
fix_spelling = pipeline("text2text-generation",model=CORRECTOR_PATH)
|
53 |
fixed_text = fix_spelling(full_text, max_new_tokens=len(full_text)+100)
|
54 |
fixed_text = fixed_text[0]['generated_text']
|