Spaces:
Running
Running
Raghu
commited on
Commit
·
b8f0f36
1
Parent(s):
acf3ed2
Accuracy-first optimizations: improved total extraction, OCR-first field extraction, LayoutLM validation, adaptive OCR ensemble
Browse files
app.py
CHANGED
|
@@ -664,7 +664,7 @@ class ReceiptOCR:
|
|
| 664 |
|
| 665 |
return merged
|
| 666 |
|
| 667 |
-
def extract_with_positions(self, image, min_confidence=0.3, use_ensemble=
|
| 668 |
"""Extract text with positions using ensemble of OCR engines."""
|
| 669 |
if isinstance(image, Image.Image):
|
| 670 |
img_array = np.array(image)
|
|
@@ -785,21 +785,29 @@ class ReceiptOCR:
|
|
| 785 |
return None
|
| 786 |
|
| 787 |
def _extract_total(self, text):
|
| 788 |
-
"""Extract total amount
|
| 789 |
-
#
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
r'\$\s*(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)', # Any dollar amount
|
| 793 |
-
]
|
| 794 |
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
if matches:
|
| 798 |
-
# Return largest amount (usually the total)
|
| 799 |
-
amounts = [float(m.replace(',', '')) for m in matches]
|
| 800 |
-
return f"{max(amounts):.2f}"
|
| 801 |
|
| 802 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 803 |
|
| 804 |
def _extract_time(self, text):
|
| 805 |
"""Extract time."""
|
|
@@ -1110,7 +1118,15 @@ def process_receipt(image):
|
|
| 1110 |
ocr_results = []
|
| 1111 |
try:
|
| 1112 |
if receipt_ocr:
|
| 1113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1114 |
ocr_image = draw_ocr_boxes(image, ocr_results)
|
| 1115 |
|
| 1116 |
lines = [f"{i+1}. [{r['confidence']:.0%}] {r['text']}" for i, r in enumerate(ocr_results)]
|
|
@@ -1119,14 +1135,47 @@ def process_receipt(image):
|
|
| 1119 |
except Exception as e:
|
| 1120 |
ocr_text = f"OCR error: {e}"
|
| 1121 |
|
| 1122 |
-
# 3. Field Extraction
|
| 1123 |
fields = {}
|
| 1124 |
fields_html = ""
|
| 1125 |
try:
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1130 |
|
| 1131 |
fields_html = "<div style='padding: 16px; background: #111827; color: #e5e7eb; border-radius: 12px; border: 1px solid #1f2937;'><h4 style=\"color: #e5e7eb;\">Extracted Fields</h4>"
|
| 1132 |
for name, value in [
|
|
|
|
| 664 |
|
| 665 |
return merged
|
| 666 |
|
| 667 |
+
def extract_with_positions(self, image, min_confidence=0.3, use_ensemble=False):
|
| 668 |
"""Extract text with positions using ensemble of OCR engines."""
|
| 669 |
if isinstance(image, Image.Image):
|
| 670 |
img_array = np.array(image)
|
|
|
|
| 785 |
return None
|
| 786 |
|
| 787 |
def _extract_total(self, text):
|
| 788 |
+
"""Extract total amount - improved to find largest amount near TOTAL keyword."""
|
| 789 |
+
# First, find all dollar amounts in the text
|
| 790 |
+
all_amounts = re.findall(r'\$(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)', text)
|
| 791 |
+
all_amounts = [float(a.replace(',', '')) for a in all_amounts]
|
|
|
|
|
|
|
| 792 |
|
| 793 |
+
if not all_amounts:
|
| 794 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 795 |
|
| 796 |
+
# Look for "TOTAL", "AMOUNT DUE", "BALANCE" keywords and find amount near them
|
| 797 |
+
lines = text.split('\n')
|
| 798 |
+
for i, line in enumerate(lines):
|
| 799 |
+
line_upper = line.upper()
|
| 800 |
+
if any(keyword in line_upper for keyword in ['TOTAL', 'AMOUNT DUE', 'BALANCE DUE', 'DUE']):
|
| 801 |
+
# Check this line and next 2 lines for amount
|
| 802 |
+
search_text = ' '.join(lines[i:min(i+3, len(lines))])
|
| 803 |
+
matches = re.findall(r'\$(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)', search_text)
|
| 804 |
+
if matches:
|
| 805 |
+
amounts_near_total = [float(m.replace(',', '')) for m in matches]
|
| 806 |
+
# Return largest amount near TOTAL keyword
|
| 807 |
+
return f"{max(amounts_near_total):.2f}"
|
| 808 |
+
|
| 809 |
+
# Fallback: return largest amount overall (usually the total)
|
| 810 |
+
return f"{max(all_amounts):.2f}"
|
| 811 |
|
| 812 |
def _extract_time(self, text):
|
| 813 |
"""Extract time."""
|
|
|
|
| 1118 |
ocr_results = []
|
| 1119 |
try:
|
| 1120 |
if receipt_ocr:
|
| 1121 |
+
# Try fast OCR first (EasyOCR + Tesseract only)
|
| 1122 |
+
ocr_results = receipt_ocr.extract_with_positions(image, use_ensemble=False)
|
| 1123 |
+
|
| 1124 |
+
# If confidence is low, try full ensemble
|
| 1125 |
+
if ocr_results:
|
| 1126 |
+
avg_conf = np.mean([r['confidence'] for r in ocr_results])
|
| 1127 |
+
if avg_conf < 0.5 or len(ocr_results) < 5:
|
| 1128 |
+
# Low confidence or few results, try full ensemble
|
| 1129 |
+
ocr_results = receipt_ocr.extract_with_positions(image, use_ensemble=True)
|
| 1130 |
ocr_image = draw_ocr_boxes(image, ocr_results)
|
| 1131 |
|
| 1132 |
lines = [f"{i+1}. [{r['confidence']:.0%}] {r['text']}" for i, r in enumerate(ocr_results)]
|
|
|
|
| 1135 |
except Exception as e:
|
| 1136 |
ocr_text = f"OCR error: {e}"
|
| 1137 |
|
| 1138 |
+
# 3. Field Extraction (OCR-first, LayoutLM as fallback)
|
| 1139 |
fields = {}
|
| 1140 |
fields_html = ""
|
| 1141 |
try:
|
| 1142 |
+
# Try OCR regex first (faster and often more accurate for totals)
|
| 1143 |
+
ocr_fields = {}
|
| 1144 |
+
if receipt_ocr and ocr_results:
|
| 1145 |
+
ocr_fields = receipt_ocr.postprocess_receipt(ocr_results)
|
| 1146 |
+
fields = ocr_fields.copy()
|
| 1147 |
+
|
| 1148 |
+
# Use LayoutLM only to fill in missing fields or validate
|
| 1149 |
+
if layoutlm_extractor and ocr_results:
|
| 1150 |
+
layoutlm_fields = layoutlm_extractor.predict_fields(image, ocr_results)
|
| 1151 |
+
|
| 1152 |
+
# For each field, merge intelligently
|
| 1153 |
+
for field_name in ['vendor', 'date', 'total', 'time']:
|
| 1154 |
+
ocr_val = ocr_fields.get(field_name)
|
| 1155 |
+
layoutlm_val = layoutlm_fields.get(field_name)
|
| 1156 |
+
|
| 1157 |
+
if not ocr_val and layoutlm_val:
|
| 1158 |
+
# OCR didn't find it, use LayoutLM
|
| 1159 |
+
fields[field_name] = layoutlm_val
|
| 1160 |
+
elif ocr_val and layoutlm_val and field_name == 'total':
|
| 1161 |
+
# For total: validate LayoutLM against OCR text
|
| 1162 |
+
ocr_text = ' '.join([r['text'] for r in ocr_results])
|
| 1163 |
+
layoutlm_clean = str(layoutlm_val).replace('$', '').replace('.', '').replace(',', '').strip()
|
| 1164 |
+
ocr_clean = ocr_text.replace('$', '').replace('.', '').replace(',', '')
|
| 1165 |
+
|
| 1166 |
+
# Check if LayoutLM total appears in OCR text
|
| 1167 |
+
if layoutlm_clean in ocr_clean:
|
| 1168 |
+
# LayoutLM matches OCR, use it (might be more accurate)
|
| 1169 |
+
fields['total'] = layoutlm_val
|
| 1170 |
+
else:
|
| 1171 |
+
# LayoutLM doesn't match OCR, trust OCR (more reliable)
|
| 1172 |
+
fields['total'] = ocr_val
|
| 1173 |
+
elif ocr_val and layoutlm_val and field_name != 'total':
|
| 1174 |
+
# For other fields, prefer LayoutLM if it's longer/more complete
|
| 1175 |
+
if len(str(layoutlm_val)) > len(str(ocr_val)):
|
| 1176 |
+
fields[field_name] = layoutlm_val
|
| 1177 |
+
else:
|
| 1178 |
+
fields[field_name] = ocr_val
|
| 1179 |
|
| 1180 |
fields_html = "<div style='padding: 16px; background: #111827; color: #e5e7eb; border-radius: 12px; border: 1px solid #1f2937;'><h4 style=\"color: #e5e7eb;\">Extracted Fields</h4>"
|
| 1181 |
for name, value in [
|