Raghu commited on
Commit
b8f0f36
·
1 Parent(s): acf3ed2

Accuracy-first optimizations: improved total extraction, OCR-first field extraction, LayoutLM validation, adaptive OCR ensemble

Browse files
Files changed (1) hide show
  1. app.py +69 -20
app.py CHANGED
@@ -664,7 +664,7 @@ class ReceiptOCR:
664
 
665
  return merged
666
 
667
- def extract_with_positions(self, image, min_confidence=0.3, use_ensemble=True):
668
  """Extract text with positions using ensemble of OCR engines."""
669
  if isinstance(image, Image.Image):
670
  img_array = np.array(image)
@@ -785,21 +785,29 @@ class ReceiptOCR:
785
  return None
786
 
787
  def _extract_total(self, text):
788
- """Extract total amount with improved patterns."""
789
- # Look for TOTAL, AMOUNT, DUE keywords
790
- patterns = [
791
- r'(?:TOTAL|AMOUNT|DUE|BALANCE)[:\s]*\$?\s*(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)',
792
- r'\$\s*(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)', # Any dollar amount
793
- ]
794
 
795
- for pattern in patterns:
796
- matches = re.findall(pattern, text, re.IGNORECASE)
797
- if matches:
798
- # Return largest amount (usually the total)
799
- amounts = [float(m.replace(',', '')) for m in matches]
800
- return f"{max(amounts):.2f}"
801
 
802
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
803
 
804
  def _extract_time(self, text):
805
  """Extract time."""
@@ -1110,7 +1118,15 @@ def process_receipt(image):
1110
  ocr_results = []
1111
  try:
1112
  if receipt_ocr:
1113
- ocr_results = receipt_ocr.extract_with_positions(image)
 
 
 
 
 
 
 
 
1114
  ocr_image = draw_ocr_boxes(image, ocr_results)
1115
 
1116
  lines = [f"{i+1}. [{r['confidence']:.0%}] {r['text']}" for i, r in enumerate(ocr_results)]
@@ -1119,14 +1135,47 @@ def process_receipt(image):
1119
  except Exception as e:
1120
  ocr_text = f"OCR error: {e}"
1121
 
1122
- # 3. Field Extraction
1123
  fields = {}
1124
  fields_html = ""
1125
  try:
1126
- if layoutlm_extractor:
1127
- fields = layoutlm_extractor.predict_fields(image, ocr_results)
1128
- elif receipt_ocr and ocr_results:
1129
- fields = receipt_ocr.postprocess_receipt(ocr_results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1130
 
1131
  fields_html = "<div style='padding: 16px; background: #111827; color: #e5e7eb; border-radius: 12px; border: 1px solid #1f2937;'><h4 style=\"color: #e5e7eb;\">Extracted Fields</h4>"
1132
  for name, value in [
 
664
 
665
  return merged
666
 
667
+ def extract_with_positions(self, image, min_confidence=0.3, use_ensemble=False):
668
  """Extract text with positions using ensemble of OCR engines."""
669
  if isinstance(image, Image.Image):
670
  img_array = np.array(image)
 
785
  return None
786
 
787
  def _extract_total(self, text):
788
+ """Extract total amount - improved to find largest amount near TOTAL keyword."""
789
+ # First, find all dollar amounts in the text
790
+ all_amounts = re.findall(r'\$(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)', text)
791
+ all_amounts = [float(a.replace(',', '')) for a in all_amounts]
 
 
792
 
793
+ if not all_amounts:
794
+ return None
 
 
 
 
795
 
796
+ # Look for "TOTAL", "AMOUNT DUE", "BALANCE" keywords and find amount near them
797
+ lines = text.split('\n')
798
+ for i, line in enumerate(lines):
799
+ line_upper = line.upper()
800
+ if any(keyword in line_upper for keyword in ['TOTAL', 'AMOUNT DUE', 'BALANCE DUE', 'DUE']):
801
+ # Check this line and next 2 lines for amount
802
+ search_text = ' '.join(lines[i:min(i+3, len(lines))])
803
+ matches = re.findall(r'\$(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)', search_text)
804
+ if matches:
805
+ amounts_near_total = [float(m.replace(',', '')) for m in matches]
806
+ # Return largest amount near TOTAL keyword
807
+ return f"{max(amounts_near_total):.2f}"
808
+
809
+ # Fallback: return largest amount overall (usually the total)
810
+ return f"{max(all_amounts):.2f}"
811
 
812
  def _extract_time(self, text):
813
  """Extract time."""
 
1118
  ocr_results = []
1119
  try:
1120
  if receipt_ocr:
1121
+ # Try fast OCR first (EasyOCR + Tesseract only)
1122
+ ocr_results = receipt_ocr.extract_with_positions(image, use_ensemble=False)
1123
+
1124
+ # If confidence is low, try full ensemble
1125
+ if ocr_results:
1126
+ avg_conf = np.mean([r['confidence'] for r in ocr_results])
1127
+ if avg_conf < 0.5 or len(ocr_results) < 5:
1128
+ # Low confidence or few results, try full ensemble
1129
+ ocr_results = receipt_ocr.extract_with_positions(image, use_ensemble=True)
1130
  ocr_image = draw_ocr_boxes(image, ocr_results)
1131
 
1132
  lines = [f"{i+1}. [{r['confidence']:.0%}] {r['text']}" for i, r in enumerate(ocr_results)]
 
1135
  except Exception as e:
1136
  ocr_text = f"OCR error: {e}"
1137
 
1138
+ # 3. Field Extraction (OCR-first, LayoutLM as fallback)
1139
  fields = {}
1140
  fields_html = ""
1141
  try:
1142
+ # Try OCR regex first (faster and often more accurate for totals)
1143
+ ocr_fields = {}
1144
+ if receipt_ocr and ocr_results:
1145
+ ocr_fields = receipt_ocr.postprocess_receipt(ocr_results)
1146
+ fields = ocr_fields.copy()
1147
+
1148
+ # Use LayoutLM only to fill in missing fields or validate
1149
+ if layoutlm_extractor and ocr_results:
1150
+ layoutlm_fields = layoutlm_extractor.predict_fields(image, ocr_results)
1151
+
1152
+ # For each field, merge intelligently
1153
+ for field_name in ['vendor', 'date', 'total', 'time']:
1154
+ ocr_val = ocr_fields.get(field_name)
1155
+ layoutlm_val = layoutlm_fields.get(field_name)
1156
+
1157
+ if not ocr_val and layoutlm_val:
1158
+ # OCR didn't find it, use LayoutLM
1159
+ fields[field_name] = layoutlm_val
1160
+ elif ocr_val and layoutlm_val and field_name == 'total':
1161
+ # For total: validate LayoutLM against OCR text
1162
+ ocr_text = ' '.join([r['text'] for r in ocr_results])
1163
+ layoutlm_clean = str(layoutlm_val).replace('$', '').replace('.', '').replace(',', '').strip()
1164
+ ocr_clean = ocr_text.replace('$', '').replace('.', '').replace(',', '')
1165
+
1166
+ # Check if LayoutLM total appears in OCR text
1167
+ if layoutlm_clean in ocr_clean:
1168
+ # LayoutLM matches OCR, use it (might be more accurate)
1169
+ fields['total'] = layoutlm_val
1170
+ else:
1171
+ # LayoutLM doesn't match OCR, trust OCR (more reliable)
1172
+ fields['total'] = ocr_val
1173
+ elif ocr_val and layoutlm_val and field_name != 'total':
1174
+ # For other fields, prefer LayoutLM if it's longer/more complete
1175
+ if len(str(layoutlm_val)) > len(str(ocr_val)):
1176
+ fields[field_name] = layoutlm_val
1177
+ else:
1178
+ fields[field_name] = ocr_val
1179
 
1180
  fields_html = "<div style='padding: 16px; background: #111827; color: #e5e7eb; border-radius: 12px; border: 1px solid #1f2937;'><h4 style=\"color: #e5e7eb;\">Extracted Fields</h4>"
1181
  for name, value in [