Raghu commited on
Commit
acf3ed2
·
1 Parent(s): 23980e2

Add TrOCR and PaddleOCR to OCR ensemble for improved accuracy

Browse files
Files changed (2) hide show
  1. app.py +255 -116
  2. requirements.txt +1 -0
app.py CHANGED
@@ -369,16 +369,51 @@ class EnsembleDocumentClassifier:
369
  # ============================================================================
370
 
371
  class ReceiptOCR:
372
- """Enhanced OCR with EasyOCR + Tesseract fallback, better preprocessing, and retry logic."""
373
 
374
  def __init__(self):
375
  self.reader = None
 
 
376
  self.use_tesseract = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  try:
378
  import pytesseract
379
  self.use_tesseract = True
380
  except ImportError:
381
- pass
382
 
383
  def load(self):
384
  if self.reader is None:
@@ -410,7 +445,7 @@ class ReceiptOCR:
410
  # Denoise
411
  denoised = cv2.fastNlMeansDenoising(enhanced, h=10)
412
 
413
- # Convert back to RGB for EasyOCR
414
  return cv2.cvtColor(denoised, cv2.COLOR_GRAY2RGB)
415
 
416
  elif method == 'sharpen':
@@ -424,20 +459,106 @@ class ReceiptOCR:
424
  sharpened = cv2.cvtColor(sharpened, cv2.COLOR_GRAY2RGB)
425
  return sharpened
426
 
427
- elif method == 'binarize':
428
- # Adaptive thresholding
429
- if len(img_array.shape) == 3:
430
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
431
- else:
432
- gray = img_array
433
- binary = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
434
- cv2.THRESH_BINARY, 11, 2)
435
- return cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB)
436
-
437
  return img_array
438
 
439
- def _extract_with_tesseract(self, image):
440
- """Fallback OCR using Tesseract."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  if not self.use_tesseract:
442
  return []
443
 
@@ -449,7 +570,6 @@ class ReceiptOCR:
449
  else:
450
  pil_image = Image.fromarray(image).convert('RGB')
451
 
452
- # Get detailed output with bounding boxes
453
  data = pytesseract.image_to_data(pil_image, output_type=pytesseract.Output.DICT)
454
 
455
  results = []
@@ -473,44 +593,6 @@ class ReceiptOCR:
473
  print(f"Tesseract OCR error: {e}")
474
  return []
475
 
476
- def _merge_ocr_results(self, easyocr_results, tesseract_results):
477
- """Merge results from multiple OCR engines, preferring higher confidence."""
478
- if not tesseract_results:
479
- return easyocr_results
480
-
481
- # Create a map of EasyOCR results by approximate position
482
- merged = []
483
- used_tesseract = set()
484
-
485
- for easy_result in easyocr_results:
486
- best_match = None
487
- best_iou = 0
488
-
489
- # Find best matching Tesseract result
490
- for i, tess_result in enumerate(tesseract_results):
491
- if i in used_tesseract:
492
- continue
493
-
494
- # Simple IoU calculation
495
- iou = self._compute_iou(easy_result['bbox'], tess_result['bbox'])
496
- if iou > best_iou and iou > 0.3: # 30% overlap threshold
497
- best_iou = iou
498
- best_match = (i, tess_result)
499
-
500
- if best_match and best_match[1]['confidence'] > easy_result['confidence']:
501
- # Use Tesseract result if it's more confident
502
- merged.append(best_match[1])
503
- used_tesseract.add(best_match[0])
504
- else:
505
- merged.append(easy_result)
506
-
507
- # Add unused Tesseract results
508
- for i, tess_result in enumerate(tesseract_results):
509
- if i not in used_tesseract:
510
- merged.append(tess_result)
511
-
512
- return merged
513
-
514
  def _compute_iou(self, box1, box2):
515
  """Compute Intersection over Union for bounding boxes."""
516
  x1_1, y1_1, x2_1, y2_1 = box1
@@ -528,78 +610,135 @@ class ReceiptOCR:
528
 
529
  return inter_area / union_area if union_area > 0 else 0
530
 
531
- def extract_with_positions(self, image, min_confidence=0.3, use_fallback=True):
532
- """Extract text with positions using EasyOCR + optional Tesseract fallback."""
533
- if self.reader is None:
534
- self.load()
 
 
 
 
 
 
535
 
536
- original_image = image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
  if isinstance(image, Image.Image):
538
- image = np.array(image)
 
 
 
 
539
 
540
- # Try EasyOCR first
541
  try:
542
- results = self.reader.readtext(image)
 
 
543
  except Exception as e:
544
  print(f"EasyOCR error: {e}")
545
- results = []
546
 
547
- extracted = []
548
- for bbox, text, conf in results:
549
- if conf >= min_confidence:
550
- x_coords = [p[0] for p in bbox]
551
- y_coords = [p[1] for p in bbox]
552
- extracted.append({
553
- 'text': text.strip(),
554
- 'confidence': conf,
555
- 'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)],
556
- 'engine': 'easyocr'
557
- })
558
-
559
- # Check if we need fallback (low confidence or few results)
560
- avg_confidence = np.mean([r['confidence'] for r in extracted]) if extracted else 0
561
- needs_fallback = use_fallback and (len(extracted) < 3 or avg_confidence < 0.5)
562
-
563
- if needs_fallback and self.use_tesseract:
564
- # Try preprocessing + Tesseract
565
- preprocessed = self._preprocess_image(original_image, method='enhance')
566
- tesseract_results = self._extract_with_tesseract(preprocessed)
567
-
568
- if tesseract_results:
569
- # Merge results
570
- extracted = self._merge_ocr_results(extracted, tesseract_results)
571
-
572
- # If still poor results, try with preprocessing
573
- if len(extracted) < 3 or avg_confidence < 0.4:
574
- for method in ['enhance', 'sharpen']:
575
- try:
576
- preprocessed = self._preprocess_image(original_image, method=method)
577
- retry_results = self.reader.readtext(preprocessed)
578
-
579
- retry_extracted = []
580
- for bbox, text, conf in retry_results:
581
- if conf >= min_confidence:
582
- x_coords = [p[0] for p in bbox]
583
- y_coords = [p[1] for p in bbox]
584
- retry_extracted.append({
585
- 'text': text.strip(),
586
- 'confidence': conf,
587
- 'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)],
588
- 'engine': 'easyocr'
589
- })
590
-
591
- # Use retry if it's better
592
- retry_avg = np.mean([r['confidence'] for r in retry_extracted]) if retry_extracted else 0
593
- if retry_avg > avg_confidence:
594
- extracted = retry_extracted
595
- break
596
- except Exception as e:
597
- continue
 
 
598
 
599
  # Sort by confidence (highest first)
600
- extracted.sort(key=lambda x: x['confidence'], reverse=True)
601
 
602
- return extracted
603
 
604
  def postprocess_receipt(self, ocr_results):
605
  """Extract structured fields from OCR results with improved patterns."""
 
369
  # ============================================================================
370
 
371
  class ReceiptOCR:
372
+ """Enhanced OCR with EasyOCR + TrOCR + PaddleOCR + Tesseract ensemble."""
373
 
374
  def __init__(self):
375
  self.reader = None
376
+ self.trocr_engine = None
377
+ self.paddleocr_engine = None
378
  self.use_tesseract = False
379
+
380
+ # Engine weights for ensemble
381
+ self.engine_weights = {
382
+ 'trocr': 0.40, # Highest weight - best quality
383
+ 'easyocr': 0.35,
384
+ 'paddleocr': 0.30,
385
+ 'tesseract': 0.20
386
+ }
387
+
388
+ # Try to initialize TrOCR
389
+ try:
390
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
391
+ self.trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
392
+ self.trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
393
+ self.trocr_model = self.trocr_model.to(DEVICE)
394
+ self.trocr_model.eval()
395
+ self.trocr_available = True
396
+ print("TrOCR initialized")
397
+ except Exception as e:
398
+ self.trocr_available = False
399
+ print(f"TrOCR not available: {e}")
400
+
401
+ # Try to initialize PaddleOCR
402
+ try:
403
+ from paddleocr import PaddleOCR
404
+ self.paddleocr_engine = PaddleOCR(use_angle_cls=True, lang='en', show_log=False)
405
+ self.paddleocr_available = True
406
+ print("PaddleOCR initialized")
407
+ except Exception as e:
408
+ self.paddleocr_available = False
409
+ print(f"PaddleOCR not available: {e}")
410
+
411
+ # Try to initialize Tesseract
412
  try:
413
  import pytesseract
414
  self.use_tesseract = True
415
  except ImportError:
416
+ self.use_tesseract = False
417
 
418
  def load(self):
419
  if self.reader is None:
 
445
  # Denoise
446
  denoised = cv2.fastNlMeansDenoising(enhanced, h=10)
447
 
448
+ # Convert back to RGB for OCR engines
449
  return cv2.cvtColor(denoised, cv2.COLOR_GRAY2RGB)
450
 
451
  elif method == 'sharpen':
 
459
  sharpened = cv2.cvtColor(sharpened, cv2.COLOR_GRAY2RGB)
460
  return sharpened
461
 
 
 
 
 
 
 
 
 
 
 
462
  return img_array
463
 
464
+ def _run_easyocr(self, image):
465
+ """Run EasyOCR."""
466
+ if self.reader is None:
467
+ self.load()
468
+
469
+ results = self.reader.readtext(image)
470
+ extracted = []
471
+ for bbox, text, conf in results:
472
+ x_coords = [p[0] for p in bbox]
473
+ y_coords = [p[1] for p in bbox]
474
+ extracted.append({
475
+ 'text': text.strip(),
476
+ 'confidence': conf,
477
+ 'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)],
478
+ 'engine': 'easyocr'
479
+ })
480
+ return extracted
481
+
482
+ def _run_trocr(self, image, boxes):
483
+ """Run TrOCR on detected text regions."""
484
+ if not self.trocr_available:
485
+ return []
486
+
487
+ if isinstance(image, np.ndarray):
488
+ pil_image = Image.fromarray(image).convert('RGB')
489
+ else:
490
+ pil_image = image.convert('RGB')
491
+
492
+ results = []
493
+ for box in boxes:
494
+ try:
495
+ if isinstance(box, list) and len(box) >= 4:
496
+ # Convert to [x1, y1, x2, y2]
497
+ if isinstance(box[0], list):
498
+ x1 = int(min(p[0] for p in box))
499
+ y1 = int(min(p[1] for p in box))
500
+ x2 = int(max(p[0] for p in box))
501
+ y2 = int(max(p[1] for p in box))
502
+ else:
503
+ x1, y1, x2, y2 = [int(b) for b in box[:4]]
504
+
505
+ # Crop and recognize
506
+ cropped = pil_image.crop((x1, y1, x2, y2))
507
+
508
+ # TrOCR recognition
509
+ pixel_values = self.trocr_processor(images=cropped, return_tensors="pt").pixel_values.to(DEVICE)
510
+ with torch.no_grad():
511
+ generated_ids = self.trocr_model.generate(
512
+ pixel_values,
513
+ max_length=128,
514
+ num_beams=4,
515
+ early_stopping=True
516
+ )
517
+ text = self.trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
518
+
519
+ if text.strip():
520
+ results.append({
521
+ 'text': text.strip(),
522
+ 'confidence': 0.9, # TrOCR doesn't provide confidence, use high default
523
+ 'bbox': [x1, y1, x2, y2],
524
+ 'engine': 'trocr'
525
+ })
526
+ except Exception as e:
527
+ continue
528
+
529
+ return results
530
+
531
+ def _run_paddleocr(self, image):
532
+ """Run PaddleOCR."""
533
+ if not self.paddleocr_available:
534
+ return []
535
+
536
+ try:
537
+ result = self.paddleocr_engine.ocr(image, cls=True)
538
+
539
+ if result is None or len(result) == 0 or result[0] is None:
540
+ return []
541
+
542
+ extracted = []
543
+ for line in result[0]:
544
+ if line is None:
545
+ continue
546
+ bbox, (text, conf) = line
547
+ x_coords = [p[0] for p in bbox]
548
+ y_coords = [p[1] for p in bbox]
549
+ extracted.append({
550
+ 'text': text.strip(),
551
+ 'confidence': conf,
552
+ 'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)],
553
+ 'engine': 'paddleocr'
554
+ })
555
+ return extracted
556
+ except Exception as e:
557
+ print(f"PaddleOCR error: {e}")
558
+ return []
559
+
560
+ def _run_tesseract(self, image):
561
+ """Run Tesseract OCR."""
562
  if not self.use_tesseract:
563
  return []
564
 
 
570
  else:
571
  pil_image = Image.fromarray(image).convert('RGB')
572
 
 
573
  data = pytesseract.image_to_data(pil_image, output_type=pytesseract.Output.DICT)
574
 
575
  results = []
 
593
  print(f"Tesseract OCR error: {e}")
594
  return []
595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
  def _compute_iou(self, box1, box2):
597
  """Compute Intersection over Union for bounding boxes."""
598
  x1_1, y1_1, x2_1, y2_1 = box1
 
610
 
611
  return inter_area / union_area if union_area > 0 else 0
612
 
613
+ def _merge_results(self, all_results):
614
+ """Merge results from multiple OCR engines using weighted voting."""
615
+ if not all_results:
616
+ return []
617
+
618
+ # Use the engine with most detections as base
619
+ base_engine = max(all_results.keys(), key=lambda k: len(all_results[k]))
620
+ base_results = all_results[base_engine]
621
+
622
+ merged = []
623
 
624
+ for base_result in base_results:
625
+ base_box = base_result['bbox']
626
+ base_text = base_result['text']
627
+ base_conf = base_result['confidence']
628
+
629
+ # Find matching results from other engines
630
+ matches = [(base_text, base_conf, self.engine_weights.get(base_engine, 0.3))]
631
+
632
+ for engine_name, results in all_results.items():
633
+ if engine_name == base_engine:
634
+ continue
635
+
636
+ for result in results:
637
+ iou = self._compute_iou(base_box, result['bbox'])
638
+ if iou > 0.3: # Same text region
639
+ weight = self.engine_weights.get(engine_name, 0.2)
640
+ matches.append((result['text'], result['confidence'], weight))
641
+
642
+ # Vote on the best text
643
+ if len(matches) == 1:
644
+ final_text = base_text
645
+ final_conf = base_conf
646
+ else:
647
+ # Weighted voting
648
+ text_scores = {}
649
+ for text, conf, weight in matches:
650
+ if text not in text_scores:
651
+ text_scores[text] = 0
652
+ text_scores[text] += conf * weight
653
+
654
+ final_text = max(text_scores.keys(), key=lambda t: text_scores[t])
655
+ total_weight = sum(w for _, _, w in matches)
656
+ final_conf = min(0.99, text_scores[final_text] / total_weight if total_weight > 0 else 0.5)
657
+
658
+ merged.append({
659
+ 'text': final_text,
660
+ 'confidence': final_conf,
661
+ 'bbox': base_box,
662
+ 'engines_used': len(matches)
663
+ })
664
+
665
+ return merged
666
+
667
+ def extract_with_positions(self, image, min_confidence=0.3, use_ensemble=True):
668
+ """Extract text with positions using ensemble of OCR engines."""
669
  if isinstance(image, Image.Image):
670
+ img_array = np.array(image)
671
+ else:
672
+ img_array = image.copy()
673
+
674
+ all_results = {}
675
 
676
+ # Run EasyOCR (always available)
677
  try:
678
+ easyocr_results = self._run_easyocr(img_array)
679
+ if easyocr_results:
680
+ all_results['easyocr'] = easyocr_results
681
  except Exception as e:
682
  print(f"EasyOCR error: {e}")
 
683
 
684
+ # Run PaddleOCR if available
685
+ if self.paddleocr_available and use_ensemble:
686
+ try:
687
+ paddleocr_results = self._run_paddleocr(img_array)
688
+ if paddleocr_results:
689
+ all_results['paddleocr'] = paddleocr_results
690
+ except Exception as e:
691
+ print(f"PaddleOCR error: {e}")
692
+
693
+ # Run Tesseract if available
694
+ if self.use_tesseract and use_ensemble:
695
+ try:
696
+ tesseract_results = self._run_tesseract(img_array)
697
+ if tesseract_results:
698
+ all_results['tesseract'] = tesseract_results
699
+ except Exception as e:
700
+ print(f"Tesseract error: {e}")
701
+
702
+ # Run TrOCR on detected boxes (needs boxes from other engines)
703
+ if self.trocr_available and use_ensemble and all_results:
704
+ try:
705
+ # Get boxes from best available engine
706
+ source_engine = max(all_results.keys(), key=lambda k: len(all_results[k]))
707
+ boxes = [r['bbox'] for r in all_results[source_engine]]
708
+ trocr_results = self._run_trocr(img_array, boxes)
709
+ if trocr_results:
710
+ all_results['trocr'] = trocr_results
711
+ except Exception as e:
712
+ print(f"TrOCR error: {e}")
713
+
714
+ # Merge results if ensemble, otherwise use EasyOCR only
715
+ if use_ensemble and len(all_results) > 1:
716
+ merged = self._merge_results(all_results)
717
+ elif 'easyocr' in all_results:
718
+ merged = all_results['easyocr']
719
+ else:
720
+ merged = []
721
+
722
+ # Filter by confidence
723
+ filtered = [r for r in merged if r['confidence'] >= min_confidence]
724
+
725
+ # If results are poor, try with preprocessing
726
+ avg_confidence = np.mean([r['confidence'] for r in filtered]) if filtered else 0
727
+ if len(filtered) < 3 or avg_confidence < 0.4:
728
+ try:
729
+ preprocessed = self._preprocess_image(image, method='enhance')
730
+ retry_results = self._run_easyocr(preprocessed)
731
+ retry_filtered = [r for r in retry_results if r['confidence'] >= min_confidence]
732
+ retry_avg = np.mean([r['confidence'] for r in retry_filtered]) if retry_filtered else 0
733
+ if retry_avg > avg_confidence:
734
+ filtered = retry_filtered
735
+ except Exception:
736
+ pass
737
 
738
  # Sort by confidence (highest first)
739
+ filtered.sort(key=lambda x: x['confidence'], reverse=True)
740
 
741
+ return filtered
742
 
743
  def postprocess_receipt(self, ocr_results):
744
  """Extract structured fields from OCR results with improved patterns."""
requirements.txt CHANGED
@@ -2,6 +2,7 @@ torch>=2.0.0
2
  torchvision>=0.15.0
3
  transformers>=4.30.0
4
  easyocr>=1.7.0
 
5
  # Pin Gradio/gradio_client to a stable pair to avoid json_schema issues on Spaces
6
  gradio==3.41.2
7
  gradio_client==0.5.0
 
2
  torchvision>=0.15.0
3
  transformers>=4.30.0
4
  easyocr>=1.7.0
5
+ paddleocr>=2.7.0
6
  # Pin Gradio/gradio_client to a stable pair to avoid json_schema issues on Spaces
7
  gradio==3.41.2
8
  gradio_client==0.5.0