Receipt_Agent / app.py
Raghu
Fix vendor extraction: detect known vendors like Einstein, skip numbers/IDs, prefer longer names. Remove Time field from display.
4b81303
"""
Receipt Processing Pipeline - Hugging Face Spaces App
Ensemble classification, OCR, field extraction, anomaly detection, and agentic routing.
"""
import os
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
import gradio.routes as gr_routes
import easyocr
import json
import re
from PIL import Image, ImageDraw
from datetime import datetime
from torchvision import transforms, models
from transformers import (
ViTForImageClassification,
ViTImageProcessor,
LayoutLMv3ForTokenClassification,
LayoutLMv3Processor,
)
from sklearn.ensemble import IsolationForest
import warnings
warnings.filterwarnings('ignore')
# ---------------------------------------------------------------------------
# Work around Gradio json_schema traversal crash:
# - guard bool schema entries
# ---------------------------------------------------------------------------
import gradio_client.utils as grc_utils
_orig_get_type = grc_utils.get_type
_orig_json_schema_to_python_type = grc_utils.json_schema_to_python_type
def _safe_get_type(schema):
if isinstance(schema, bool):
return "any"
return _orig_get_type(schema)
def _safe_json_schema_to_python_type(schema, defs=None):
if isinstance(schema, bool):
return "any"
try:
return _orig_json_schema_to_python_type(schema, defs)
except Exception:
return "any"
grc_utils.get_type = _safe_get_type
grc_utils.json_schema_to_python_type = _safe_json_schema_to_python_type
# ---------------------------------------------------------------------------
# JSON sanitation helper (convert numpy types & PIL-friendly outputs)
# ---------------------------------------------------------------------------
def to_jsonable(obj):
if isinstance(obj, dict):
return {k: to_jsonable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [to_jsonable(v) for v in obj]
if isinstance(obj, (np.bool_, bool)):
return bool(obj)
if isinstance(obj, (np.integer,)):
return int(obj)
if isinstance(obj, (np.floating,)):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, Image.Image):
return None # avoid serializing images; skip in JSON
return obj
# ---------------------------------------------------------------------------
# Feedback persistence helper (CSV; optionally include section label)
# ---------------------------------------------------------------------------
def save_feedback(assessment, notes, results_json_str, section="overall"):
try:
parsed = json.loads(results_json_str) if results_json_str else {}
except Exception:
parsed = {"raw": results_json_str}
entry = {
"timestamp": datetime.utcnow().isoformat(),
"section": section or "",
"assessment": assessment or "",
"notes": notes or "",
"results": parsed,
}
import csv
fieldnames = ["timestamp", "section", "assessment", "notes", "results"]
file_exists = os.path.exists("feedback_logs.csv")
with open("feedback_logs.csv", "a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
if not file_exists:
writer.writeheader()
writer.writerow({
"timestamp": entry["timestamp"],
"section": entry.get("section", ""),
"assessment": entry["assessment"],
"notes": entry["notes"],
"results": json.dumps(entry["results"]),
})
return "✅ Feedback saved. (Stored in feedback_logs.csv)"
# ============================================================================
# Configuration
# ============================================================================
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODELS_DIR = 'models'
print(f"Device: {DEVICE}")
print(f"Models directory: {MODELS_DIR}")
# ============================================================================
# Model Classes
# ============================================================================
class DocumentClassifier:
"""ViT-based document classifier (receipt vs other)."""
def __init__(self, num_labels=2, model_path=None):
self.num_labels = num_labels
self.model = None
self.processor = None
self.model_path = model_path or os.path.join(MODELS_DIR, 'rvl_classifier.pt')
self.pretrained = 'WinKawaks/vit-tiny-patch16-224'
def load_model(self):
try:
self.processor = ViTImageProcessor.from_pretrained(self.pretrained)
except:
self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
self.model = ViTForImageClassification.from_pretrained(
self.pretrained,
num_labels=self.num_labels,
ignore_mismatched_sizes=True
)
self.model = self.model.to(DEVICE)
self.model.eval()
return self.model
def load_weights(self, path):
if os.path.exists(path):
checkpoint = torch.load(path, map_location=DEVICE)
if isinstance(checkpoint, dict):
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
elif 'state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['state_dict'], strict=False)
else:
self.model.load_state_dict(checkpoint, strict=False)
else:
self.model.load_state_dict(checkpoint, strict=False)
print(f" Loaded ViT weights from {path}")
def predict(self, image):
if self.model is None:
self.load_model()
self.model.eval()
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert('RGB')
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
pred = torch.argmax(probs, dim=-1).item()
conf = probs[0, pred].item()
is_receipt = pred == 1
label = "receipt" if is_receipt else "other"
return {
'is_receipt': is_receipt,
'confidence': conf,
'label': label,
'probabilities': probs[0].cpu().numpy().tolist()
}
class ResNetDocumentClassifier:
"""ResNet18-based document classifier."""
def __init__(self, num_labels=2, model_path=None):
self.num_labels = num_labels
self.model = None
self.model_path = model_path or os.path.join(MODELS_DIR, 'resnet18_rvlcdip.pt')
self.use_class_mapping = False
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def load_model(self):
self.model = models.resnet18(weights=None)
self.model = self.model.to(DEVICE)
self.model.eval()
return self.model
def load_weights(self, path):
if not os.path.exists(path):
return
checkpoint = torch.load(path, map_location=DEVICE)
if isinstance(checkpoint, dict):
state_dict = checkpoint.get('model_state_dict', checkpoint.get('state_dict', checkpoint))
id2label = checkpoint.get('id2label', None)
else:
state_dict = checkpoint
id2label = None
# Determine number of classes from checkpoint
fc_weight_key = 'fc.weight'
if fc_weight_key in state_dict:
num_classes = state_dict[fc_weight_key].shape[0]
else:
num_classes = self.num_labels
# Rebuild final layer if needed
if num_classes != self.model.fc.out_features:
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
self.model = self.model.to(DEVICE)
self.model.load_state_dict(state_dict, strict=False)
# Handle 16-class RVL-CDIP models
if num_classes == 16:
self.use_class_mapping = True
self.receipt_class_idx = 11 # Receipt class in RVL-CDIP
print(f" Loaded ResNet weights from {path} ({num_classes} classes)")
def predict(self, image):
if self.model is None:
self.load_model()
self.model.eval()
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert('RGB')
input_tensor = self.transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
outputs = self.model(input_tensor)
probs = torch.softmax(outputs, dim=-1)
if self.use_class_mapping:
receipt_prob = probs[0, self.receipt_class_idx].item()
other_prob = 1.0 - receipt_prob
is_receipt = receipt_prob > 0.5
conf = receipt_prob if is_receipt else other_prob
final_probs = [other_prob, receipt_prob]
else:
pred = torch.argmax(probs, dim=-1).item()
conf = probs[0, pred].item()
is_receipt = pred == 1
final_probs = probs[0].cpu().numpy().tolist()
return {
'is_receipt': is_receipt,
'confidence': conf,
'label': "receipt" if is_receipt else "other",
'probabilities': final_probs
}
class EnsembleDocumentClassifier:
"""Ensemble of ViT and ResNet classifiers."""
def __init__(self, model_configs=None, weights=None):
self.model_configs = model_configs or [
{'name': 'vit_base', 'path': os.path.join(MODELS_DIR, 'rvl_classifier.pt')},
{'name': 'resnet18', 'path': os.path.join(MODELS_DIR, 'resnet18_rvlcdip.pt')},
]
# Filter to existing models
self.model_configs = [cfg for cfg in self.model_configs if os.path.exists(cfg['path'])]
if not self.model_configs:
print("Warning: No model files found, will use default ViT")
self.model_configs = [{'name': 'vit_default', 'path': None}]
self.weights = weights or [1.0 / len(self.model_configs)] * len(self.model_configs)
self.classifiers = []
self.processor = None
def load_models(self):
print(f"Loading ensemble with {len(self.model_configs)} models...")
for cfg in self.model_configs:
is_resnet = 'resnet' in cfg['name'].lower() or 'resnet' in cfg.get('path', '').lower()
if is_resnet:
classifier = ResNetDocumentClassifier(num_labels=2, model_path=cfg['path'])
else:
classifier = DocumentClassifier(num_labels=2, model_path=cfg['path'])
classifier.load_model()
if cfg['path'] and os.path.exists(cfg['path']):
try:
classifier.load_weights(cfg['path'])
except Exception as e:
print(f" Warning: Could not load {cfg['name']}: {e}")
self.classifiers.append(classifier)
if self.processor is None:
if hasattr(classifier, 'processor'):
self.processor = classifier.processor
elif hasattr(classifier, 'transform'):
self.processor = classifier.transform
print(f"Ensemble ready with {len(self.classifiers)} models")
return self
def predict(self, image, return_individual=False):
if not self.classifiers:
self.load_models()
all_probs = []
individual_results = []
for i, classifier in enumerate(self.classifiers):
result = classifier.predict(image)
probs = result.get('probabilities', [0.5, 0.5])
if len(probs) < 2:
probs = [1 - result['confidence'], result['confidence']]
all_probs.append(probs)
individual_results.append({
'name': self.model_configs[i]['name'],
'prediction': result['label'],
'confidence': result['confidence'],
'probabilities': probs
})
# Weighted average
ensemble_probs = np.zeros(2)
for i, probs in enumerate(all_probs):
ensemble_probs += np.array(probs[:2]) * self.weights[i]
pred = np.argmax(ensemble_probs)
is_receipt = pred == 1
conf = ensemble_probs[pred]
result = {
'is_receipt': is_receipt,
'confidence': float(conf),
'label': "receipt" if is_receipt else "other",
'probabilities': ensemble_probs.tolist()
}
if return_individual:
result['individual_results'] = individual_results
return result
# ============================================================================
# OCR
# ============================================================================
class ReceiptOCR:
"""Enhanced OCR with EasyOCR + TrOCR + PaddleOCR + Tesseract ensemble."""
def __init__(self):
self.reader = None
self.trocr_engine = None
self.paddleocr_engine = None
self.use_tesseract = False
# Engine weights for ensemble
self.engine_weights = {
'trocr': 0.40, # Highest weight - best quality
'easyocr': 0.35,
'paddleocr': 0.30,
'tesseract': 0.20
}
# Try to initialize TrOCR
try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
self.trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
self.trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
self.trocr_model = self.trocr_model.to(DEVICE)
self.trocr_model.eval()
self.trocr_available = True
print("TrOCR initialized")
except Exception as e:
self.trocr_available = False
print(f"TrOCR not available: {e}")
# Try to initialize PaddleOCR
try:
from paddleocr import PaddleOCR
self.paddleocr_engine = PaddleOCR(use_angle_cls=True, lang='en', show_log=False)
self.paddleocr_available = True
print("PaddleOCR initialized")
except Exception as e:
self.paddleocr_available = False
print(f"PaddleOCR not available: {e}")
# Try to initialize Tesseract
try:
import pytesseract
self.use_tesseract = True
except ImportError:
self.use_tesseract = False
def load(self):
if self.reader is None:
print("Loading EasyOCR...")
self.reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
print("EasyOCR ready")
return self
def _preprocess_image(self, image, method='enhance'):
"""Apply image preprocessing to improve OCR accuracy."""
import cv2
if isinstance(image, Image.Image):
img_array = np.array(image)
else:
img_array = image.copy()
if method == 'enhance':
# Convert to grayscale if needed
if len(img_array.shape) == 3:
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
else:
gray = img_array
# Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
enhanced = clahe.apply(gray)
# Denoise
denoised = cv2.fastNlMeansDenoising(enhanced, h=10)
# Convert back to RGB for OCR engines
return cv2.cvtColor(denoised, cv2.COLOR_GRAY2RGB)
elif method == 'sharpen':
# Sharpen the image
kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
if len(img_array.shape) == 3:
sharpened = cv2.filter2D(img_array, -1, kernel)
else:
gray = img_array
sharpened = cv2.filter2D(gray, -1, kernel)
sharpened = cv2.cvtColor(sharpened, cv2.COLOR_GRAY2RGB)
return sharpened
return img_array
def _run_easyocr(self, image):
"""Run EasyOCR."""
if self.reader is None:
self.load()
results = self.reader.readtext(image)
extracted = []
for bbox, text, conf in results:
x_coords = [p[0] for p in bbox]
y_coords = [p[1] for p in bbox]
extracted.append({
'text': text.strip(),
'confidence': conf,
'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)],
'engine': 'easyocr'
})
return extracted
def _run_trocr(self, image, boxes):
"""Run TrOCR on detected text regions."""
if not self.trocr_available:
return []
if isinstance(image, np.ndarray):
pil_image = Image.fromarray(image).convert('RGB')
else:
pil_image = image.convert('RGB')
results = []
for box in boxes:
try:
if isinstance(box, list) and len(box) >= 4:
# Convert to [x1, y1, x2, y2]
if isinstance(box[0], list):
x1 = int(min(p[0] for p in box))
y1 = int(min(p[1] for p in box))
x2 = int(max(p[0] for p in box))
y2 = int(max(p[1] for p in box))
else:
x1, y1, x2, y2 = [int(b) for b in box[:4]]
# Crop and recognize
cropped = pil_image.crop((x1, y1, x2, y2))
# TrOCR recognition
pixel_values = self.trocr_processor(images=cropped, return_tensors="pt").pixel_values.to(DEVICE)
with torch.no_grad():
generated_ids = self.trocr_model.generate(
pixel_values,
max_length=128,
num_beams=4,
early_stopping=True
)
text = self.trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if text.strip():
results.append({
'text': text.strip(),
'confidence': 0.9, # TrOCR doesn't provide confidence, use high default
'bbox': [x1, y1, x2, y2],
'engine': 'trocr'
})
except Exception as e:
continue
return results
def _run_paddleocr(self, image):
"""Run PaddleOCR."""
if not self.paddleocr_available:
return []
try:
result = self.paddleocr_engine.ocr(image, cls=True)
if result is None or len(result) == 0 or result[0] is None:
return []
extracted = []
for line in result[0]:
if line is None:
continue
bbox, (text, conf) = line
x_coords = [p[0] for p in bbox]
y_coords = [p[1] for p in bbox]
extracted.append({
'text': text.strip(),
'confidence': conf,
'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)],
'engine': 'paddleocr'
})
return extracted
except Exception as e:
print(f"PaddleOCR error: {e}")
return []
def _run_tesseract(self, image):
"""Run Tesseract OCR."""
if not self.use_tesseract:
return []
try:
import pytesseract
if isinstance(image, Image.Image):
pil_image = image.convert('RGB')
else:
pil_image = Image.fromarray(image).convert('RGB')
data = pytesseract.image_to_data(pil_image, output_type=pytesseract.Output.DICT)
results = []
n_boxes = len(data['text'])
for i in range(n_boxes):
text = data['text'][i].strip()
conf = int(data['conf'][i])
if text and conf > 0:
x, y, w, h = data['left'][i], data['top'][i], data['width'][i], data['height'][i]
results.append({
'text': text,
'confidence': conf / 100.0,
'bbox': [x, y, x+w, y+h],
'engine': 'tesseract'
})
return results
except Exception as e:
print(f"Tesseract OCR error: {e}")
return []
def _compute_iou(self, box1, box2):
"""Compute Intersection over Union for bounding boxes."""
x1_1, y1_1, x2_1, y2_1 = box1
x1_2, y1_2, x2_2, y2_2 = box2
xi1 = max(x1_1, x1_2)
yi1 = max(y1_1, y1_2)
xi2 = min(x2_1, x2_2)
yi2 = min(y2_1, y2_2)
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
union_area = box1_area + box2_area - inter_area
return inter_area / union_area if union_area > 0 else 0
def _merge_results(self, all_results):
"""Merge results from multiple OCR engines using weighted voting."""
if not all_results:
return []
# Use the engine with most detections as base
base_engine = max(all_results.keys(), key=lambda k: len(all_results[k]))
base_results = all_results[base_engine]
merged = []
for base_result in base_results:
base_box = base_result['bbox']
base_text = base_result['text']
base_conf = base_result['confidence']
# Find matching results from other engines
matches = [(base_text, base_conf, self.engine_weights.get(base_engine, 0.3))]
for engine_name, results in all_results.items():
if engine_name == base_engine:
continue
for result in results:
iou = self._compute_iou(base_box, result['bbox'])
if iou > 0.3: # Same text region
weight = self.engine_weights.get(engine_name, 0.2)
matches.append((result['text'], result['confidence'], weight))
# Vote on the best text
if len(matches) == 1:
final_text = base_text
final_conf = base_conf
else:
# Weighted voting
text_scores = {}
for text, conf, weight in matches:
if text not in text_scores:
text_scores[text] = 0
text_scores[text] += conf * weight
final_text = max(text_scores.keys(), key=lambda t: text_scores[t])
total_weight = sum(w for _, _, w in matches)
final_conf = min(0.99, text_scores[final_text] / total_weight if total_weight > 0 else 0.5)
merged.append({
'text': final_text,
'confidence': final_conf,
'bbox': base_box,
'engines_used': len(matches)
})
return merged
def extract_with_positions(self, image, min_confidence=0.3, use_ensemble=False):
"""Extract text with positions using ensemble of OCR engines."""
if isinstance(image, Image.Image):
img_array = np.array(image)
else:
img_array = image.copy()
all_results = {}
# Run EasyOCR (always available)
try:
easyocr_results = self._run_easyocr(img_array)
if easyocr_results:
all_results['easyocr'] = easyocr_results
except Exception as e:
print(f"EasyOCR error: {e}")
# Run PaddleOCR if available
if self.paddleocr_available and use_ensemble:
try:
paddleocr_results = self._run_paddleocr(img_array)
if paddleocr_results:
all_results['paddleocr'] = paddleocr_results
except Exception as e:
print(f"PaddleOCR error: {e}")
# Run Tesseract if available
if self.use_tesseract and use_ensemble:
try:
tesseract_results = self._run_tesseract(img_array)
if tesseract_results:
all_results['tesseract'] = tesseract_results
except Exception as e:
print(f"Tesseract error: {e}")
# Run TrOCR on detected boxes (needs boxes from other engines)
if self.trocr_available and use_ensemble and all_results:
try:
# Get boxes from best available engine
source_engine = max(all_results.keys(), key=lambda k: len(all_results[k]))
boxes = [r['bbox'] for r in all_results[source_engine]]
trocr_results = self._run_trocr(img_array, boxes)
if trocr_results:
all_results['trocr'] = trocr_results
except Exception as e:
print(f"TrOCR error: {e}")
# Merge results if ensemble, otherwise use EasyOCR only
if use_ensemble and len(all_results) > 1:
merged = self._merge_results(all_results)
elif 'easyocr' in all_results:
merged = all_results['easyocr']
else:
merged = []
# Filter by confidence
filtered = [r for r in merged if r['confidence'] >= min_confidence]
# If results are poor, try with preprocessing
avg_confidence = np.mean([r['confidence'] for r in filtered]) if filtered else 0
if len(filtered) < 3 or avg_confidence < 0.4:
try:
preprocessed = self._preprocess_image(image, method='enhance')
retry_results = self._run_easyocr(preprocessed)
retry_filtered = [r for r in retry_results if r['confidence'] >= min_confidence]
retry_avg = np.mean([r['confidence'] for r in retry_filtered]) if retry_filtered else 0
if retry_avg > avg_confidence:
filtered = retry_filtered
except Exception:
pass
# Sort by confidence (highest first)
filtered.sort(key=lambda x: x['confidence'], reverse=True)
return filtered
def postprocess_receipt(self, ocr_results):
"""Extract structured fields from OCR results with improved patterns."""
# Fix common OCR errors (S->$ in amounts)
fixed_results = []
for r in ocr_results:
fixed_r = r.copy()
fixed_r['text'] = self._fix_ocr_text(r['text'])
fixed_results.append(fixed_r)
full_text = ' '.join([r['text'] for r in fixed_results])
fields = {
'vendor': self._extract_vendor(ocr_results),
'date': self._extract_date(full_text),
'total': self._extract_total(full_text),
'time': self._extract_time(full_text)
}
return fields
def _extract_vendor(self, ocr_results):
"""Extract vendor name - look for business name in top portion of receipt."""
if not ocr_results:
return None
# Sort by vertical position (top of receipt first)
sorted_results = sorted(ocr_results, key=lambda x: x['bbox'][1] if isinstance(x['bbox'], list) and len(x['bbox']) > 1 else 0)
# Look in top 10 results for vendor name
top_results = sorted_results[:10]
# Skip words that are clearly not vendor names
skip_words = {'TOTAL', 'DATE', 'TIME', 'RECEIPT', 'THANK', 'YOU', 'STORE', 'HOST',
'ORDER', 'TYPE', 'TOGO', 'DINE', 'IN', 'CHECK', 'CLOSED', 'AMEX',
'VISA', 'MASTERCARD', 'CASH', 'CHANGE', 'SUBTOTAL', 'TAX'}
# Known vendor patterns (common stores)
known_vendors = ['EINSTEIN', 'STARBUCKS', 'MCDONALDS', 'WALMART', 'TARGET',
'CHIPOTLE', 'PANERA', 'DUNKIN', 'SUBWAY', 'CHICK-FIL-A']
# First, check if any known vendor is in the OCR results
for result in top_results:
text = result['text'].strip().upper()
for vendor in known_vendors:
if vendor in text:
return result['text'].strip()
# Look for longest meaningful text (likely the business name)
candidates = []
for result in top_results:
text = result['text'].strip()
text_upper = text.upper()
# Skip short texts, numbers, and common skip words
if len(text) < 3:
continue
if text_upper in skip_words:
continue
if re.match(r'^[\d\s\-\/\.\$\,]+$', text): # Skip pure numbers/symbols
continue
if re.match(r'^#?\d+$', text): # Skip store numbers like #2846
continue
# Prefer texts with letters and reasonable length
if len(text) >= 4 and any(c.isalpha() for c in text):
candidates.append((text, len(text), result['confidence']))
# Return the longest candidate with good confidence
if candidates:
# Sort by length (longer = more likely to be full vendor name)
candidates.sort(key=lambda x: (x[1], x[2]), reverse=True)
return candidates[0][0]
return None
def _extract_date(self, text):
"""Extract date with improved patterns."""
patterns = [
r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', # MM/DD/YYYY or MM-DD-YYYY
r'\b\d{4}[/-]\d{2}[/-]\d{2}\b', # YYYY-MM-DD
r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{1,2},?\s+\d{4}\b', # Month DD, YYYY
]
for pattern in patterns:
matches = re.findall(pattern, text, re.IGNORECASE)
if matches:
return matches[0]
return None
def _extract_total(self, text):
"""Extract total amount - handles S/$ OCR confusion."""
# Fix S -> $ in amounts (common OCR error)
fixed_text = re.sub(r'\bS(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)\b', r'$\1', text)
# Find all dollar amounts (now with fixed $ symbols)
all_amounts = re.findall(r'[\$S](\d{1,3}(?:,\d{3})*(?:\.\d{2})?)', fixed_text)
all_amounts = [float(a.replace(',', '')) for a in all_amounts if a]
if not all_amounts:
# Try finding any decimal amounts
all_amounts = re.findall(r'(\d{1,3}(?:,\d{3})*\.\d{2})', fixed_text)
all_amounts = [float(a.replace(',', '')) for a in all_amounts if a]
if not all_amounts:
return None
# Look for "TOTAL", "AMOUNT DUE", "BALANCE" keywords and find amount near them
lines = fixed_text.split('\n')
for i, line in enumerate(lines):
line_upper = line.upper()
if any(keyword in line_upper for keyword in ['TOTAL', 'AMOUNT DUE', 'BALANCE DUE', 'DUE']):
# Check this line and next 2 lines for amount
search_text = ' '.join(lines[i:min(i+3, len(lines))])
# Match both $ and S followed by amounts
matches = re.findall(r'[\$S](\d{1,3}(?:,\d{3})*(?:\.\d{2})?)', search_text)
if matches:
amounts_near_total = [float(m.replace(',', '')) for m in matches]
return f"{max(amounts_near_total):.2f}"
# Fallback: return largest amount overall
return f"{max(all_amounts):.2f}"
def _extract_time(self, text):
"""Extract time."""
patterns = [
r'\b(\d{1,2}):(\d{2})\s*(?:AM|PM)\b',
r'\b(\d{1,2}):(\d{2})\b',
]
for pattern in patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
return match.group(0)
return None
def _fix_ocr_text(self, text):
"""Fix common OCR errors like S->$ in amounts."""
# Fix S followed by digits -> $ (e.g., S154.06 -> $154.06)
text = re.sub(r'\bS(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)\b', r'$\1', text)
# Fix Subtolal -> Subtotal (common OCR error)
text = re.sub(r'\bSubtolal\b', 'Subtotal', text, flags=re.IGNORECASE)
return text
class LayoutLMFieldExtractor:
"""LayoutLMv3-based field extractor using fine-tuned weights if available."""
def __init__(self, model_path=None):
self.model_path = model_path or os.path.join(MODELS_DIR, 'layoutlm_extractor.pt')
self.id2label = {
0: 'O',
1: 'B-VENDOR', 2: 'I-VENDOR',
3: 'B-DATE', 4: 'I-DATE',
5: 'B-TOTAL', 6: 'I-TOTAL',
7: 'B-TIME', 8: 'I-TIME'
}
self.label2id = {v: k for k, v in self.id2label.items()}
self.processor = None
self.model = None
def load(self):
print("Loading LayoutLMv3 extractor...")
self.processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
self.model = LayoutLMv3ForTokenClassification.from_pretrained(
"microsoft/layoutlmv3-base",
num_labels=len(self.id2label),
id2label=self.id2label,
label2id=self.label2id,
)
if os.path.exists(self.model_path):
checkpoint = torch.load(self.model_path, map_location=DEVICE)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
checkpoint = checkpoint['model_state_dict']
if isinstance(checkpoint, dict):
missing, unexpected = self.model.load_state_dict(checkpoint, strict=False)
print(f"Loaded LayoutLM weights; missing={len(missing)}, unexpected={len(unexpected)}")
self.model = self.model.to(DEVICE)
self.model.eval()
print("LayoutLMv3 ready")
return self
def _prepare_boxes(self, ocr_results, image_size):
"""Convert absolute pixel boxes to LayoutLM 0-1000 format."""
width, height = image_size
boxes = []
words = []
for r in ocr_results:
bbox = r.get("bbox", [0, 0, width, height])
x0, y0, x1, y1 = bbox
boxes.append([
int(1000 * x0 / width),
int(1000 * y0 / height),
int(1000 * x1 / width),
int(1000 * y1 / height),
])
words.append(r.get("text", ""))
return words, boxes
def predict_fields(self, image, ocr_results=None):
"""Predict fields with confidence scores and improved total extraction."""
if self.model is None:
self.load()
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert("RGB")
if ocr_results:
words, boxes = self._prepare_boxes(ocr_results, image.size)
encoding = self.processor(
image,
words=words,
boxes=boxes,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=512,
)
else:
encoding = self.processor(image, return_tensors="pt")
encoding = {k: v.to(DEVICE) for k, v in encoding.items()}
with torch.no_grad():
outputs = self.model(**encoding)
logits = outputs.logits[0]
# Get softmax probabilities for confidence
probs = torch.softmax(logits, dim=-1)
preds = logits.argmax(-1).cpu().tolist()
probs_np = probs.cpu().numpy()
tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu())
# Extract entities with confidence scores
entities = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []}
entity_confidences = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []}
entity_positions = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []}
current = {"label": None, "tokens": [], "start_idx": None}
for idx, (token, pred) in enumerate(zip(tokens, preds)):
label = self.id2label.get(pred, "O")
conf = float(probs_np[idx, pred])
if token in ["[PAD]", "[CLS]", "[SEP]"]:
continue
if label.startswith("B-"):
# Flush previous
if current["label"] and current["tokens"]:
entity_text = " ".join(current["tokens"]).replace("▁", " ").strip()
entities[current["label"]].append(entity_text)
entity_confidences[current["label"]].append(conf)
entity_positions[current["label"]].append(current["start_idx"])
current = {"label": label[2:], "tokens": [token], "start_idx": idx}
elif label.startswith("I-") and current["label"] == label[2:]:
current["tokens"].append(token)
else:
if current["label"] and current["tokens"]:
entity_text = " ".join(current["tokens"]).replace("▁", " ").strip()
entities[current["label"]].append(entity_text)
entity_confidences[current["label"]].append(conf)
entity_positions[current["label"]].append(current["start_idx"])
current = {"label": None, "tokens": [], "start_idx": None}
if current["label"] and current["tokens"]:
entity_text = " ".join(current["tokens"]).replace("▁", " ").strip()
entities[current["label"]].append(entity_text)
entity_confidences[current["label"]].append(conf)
entity_positions[current["label"]].append(current["start_idx"])
# Smart field selection with confidence and position awareness
result = {}
# Vendor: prefer first high-confidence result
if entities["VENDOR"]:
best_vendor_idx = max(range(len(entities["VENDOR"])),
key=lambda i: entity_confidences["VENDOR"][i])
if entity_confidences["VENDOR"][best_vendor_idx] > 0.3:
result["vendor"] = entities["VENDOR"][best_vendor_idx]
# Date: prefer first high-confidence result
if entities["DATE"]:
best_date_idx = max(range(len(entities["DATE"])),
key=lambda i: entity_confidences["DATE"][i])
if entity_confidences["DATE"][best_date_idx] > 0.3:
result["date"] = entities["DATE"][best_date_idx]
# Time: prefer first high-confidence result
if entities["TIME"]:
best_time_idx = max(range(len(entities["TIME"])),
key=lambda i: entity_confidences["TIME"][i])
if entity_confidences["TIME"][best_time_idx] > 0.3:
result["time"] = entities["TIME"][best_time_idx]
# Total: improved extraction - look for amounts near "TOTAL" keyword in OCR
if entities["TOTAL"]:
# Get all total candidates with confidence
total_candidates = [(entities["TOTAL"][i], entity_confidences["TOTAL"][i],
entity_positions["TOTAL"][i])
for i in range(len(entities["TOTAL"]))]
# If OCR results available, validate against OCR text
if ocr_results:
ocr_text = ' '.join([r['text'] for r in ocr_results]).upper()
ocr_lines = [r['text'] for r in ocr_results]
# Find amounts near "TOTAL" keyword
best_total = None
best_conf = 0
for total_val, conf, pos in total_candidates:
# Clean the total value
total_clean = str(total_val).replace('$', '').replace(',', '').replace('.', '').strip()
# Check if this total appears near "TOTAL" keyword in OCR
for i, line in enumerate(ocr_lines):
line_upper = line.upper()
if 'TOTAL' in line_upper or 'AMOUNT DUE' in line_upper:
# Check this line and next 2 lines for the amount
search_text = ' '.join(ocr_lines[i:min(i+3, len(ocr_lines))])
search_clean = search_text.replace('$', '').replace(',', '').replace('.', '')
if total_clean in search_clean:
# Found near TOTAL keyword - high confidence
if conf > best_conf:
best_total = total_val
best_conf = conf
break
if best_total:
result["total"] = best_total
else:
# Fallback: use highest confidence total
best_idx = max(range(len(total_candidates)), key=lambda i: total_candidates[i][1])
if total_candidates[best_idx][1] > 0.3:
result["total"] = total_candidates[best_idx][0]
else:
# No OCR, use highest confidence
best_idx = max(range(len(total_candidates)), key=lambda i: total_candidates[i][1])
if total_candidates[best_idx][1] > 0.3:
result["total"] = total_candidates[best_idx][0]
return result
# ============================================================================
# Anomaly Detection
# ============================================================================
class AnomalyDetector:
"""Isolation Forest-based anomaly detection."""
def __init__(self):
self.model = IsolationForest(contamination=0.1, random_state=42)
self.is_fitted = False
def extract_features(self, fields):
"""Extract features from receipt fields."""
total = 0
try:
total = float(fields.get('total', 0) or 0)
except:
pass
vendor = fields.get('vendor', '') or ''
date = fields.get('date', '') or ''
features = [
total,
np.log1p(total),
len(vendor),
1 if date else 0,
1, # num_items placeholder
12, # hour placeholder
total, # amount_per_item placeholder
0 # is_weekend placeholder
]
return np.array(features).reshape(1, -1)
def predict(self, fields):
features = self.extract_features(fields)
# Simple rule-based detection if model not fitted
reasons = []
total = float(fields.get('total', 0) or 0)
if total > 1000:
reasons.append(f"High amount: ${total:.2f}")
if not fields.get('vendor'):
reasons.append("Missing vendor")
if not fields.get('date'):
reasons.append("Missing date")
is_anomaly = len(reasons) > 0
return {
'is_anomaly': is_anomaly,
'score': -0.5 if is_anomaly else 0.5,
'prediction': 'ANOMALY' if is_anomaly else 'NORMAL',
'reasons': reasons
}
# ============================================================================
# Initialize Models
# ============================================================================
print("\n" + "="*50)
print("Initializing models...")
print("="*50)
# Check for model files
model_files = []
if os.path.exists(MODELS_DIR):
model_files = [f for f in os.listdir(MODELS_DIR) if f.endswith('.pt')]
print(f"Found model files: {model_files}")
else:
print(f"Models directory not found: {MODELS_DIR}")
os.makedirs(MODELS_DIR, exist_ok=True)
# Initialize components
try:
ensemble_classifier = EnsembleDocumentClassifier()
ensemble_classifier.load_models()
except Exception as e:
print(f"Warning: Could not load ensemble classifier: {e}")
ensemble_classifier = None
try:
receipt_ocr = ReceiptOCR()
receipt_ocr.load()
except Exception as e:
print(f"Warning: Could not load OCR: {e}")
receipt_ocr = None
try:
layoutlm_extractor = LayoutLMFieldExtractor()
layoutlm_extractor.load()
except Exception as e:
print(f"Warning: Could not load LayoutLMv3 extractor: {e}")
layoutlm_extractor = None
anomaly_detector = AnomalyDetector()
print("\n" + "="*50)
print("Initialization complete!")
print("="*50 + "\n")
# ============================================================================
# Helper Functions
# ============================================================================
def draw_ocr_boxes(image, ocr_results):
"""Draw bounding boxes on image."""
img_copy = image.copy()
draw = ImageDraw.Draw(img_copy)
for r in ocr_results:
conf = r.get('confidence', 0.5)
bbox = r.get('bbox', [])
if conf > 0.8:
color = '#28a745' # Green
elif conf > 0.5:
color = '#ffc107' # Yellow
else:
color = '#dc3545' # Red
if len(bbox) >= 4:
draw.rectangle([bbox[0], bbox[1], bbox[2], bbox[3]], outline=color, width=2)
return img_copy
def process_receipt(image):
"""Main processing function for Gradio."""
if image is None:
return (
"<div style='padding: 20px; text-align: center;'>Upload an image to begin</div>",
None, "", "", "<div style='padding: 40px; text-align: center; color: #6c757d;'>Upload an image to begin</div>"
)
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert('RGB')
results = {}
# 1. Classification
classifier_html = ""
try:
if ensemble_classifier:
class_result = ensemble_classifier.predict(image, return_individual=True)
else:
class_result = {'is_receipt': True, 'confidence': 0.5, 'label': 'unknown'}
conf = class_result['confidence']
label = class_result['label'].upper()
color = '#28a745' if class_result.get('is_receipt') else '#dc3545'
bar_color = '#28a745' if conf > 0.8 else '#ffc107' if conf > 0.6 else '#dc3545'
classifier_html = f"""
<div style="padding: 16px; background: #111827; color: #e5e7eb; border-radius: 12px; margin: 8px 0; border: 1px solid #1f2937;">
<h4 style="margin: 0 0 8px 0; color: #e5e7eb;">Classification</h4>
<div style="font-size: 20px; font-weight: bold; color: {color};">{label}</div>
<div style="margin-top: 8px; color: #e5e7eb;">
<span>Confidence: </span>
<div style="display: inline-block; width: 120px; height: 8px; background: #1f2937; border-radius: 4px;">
<div style="width: {conf*100}%; height: 100%; background: {bar_color}; border-radius: 4px;"></div>
</div>
<span style="margin-left: 8px;">{conf:.1%}</span>
</div>
</div>
"""
results['classification'] = class_result
except Exception as e:
classifier_html = f"<div style='color: red;'>Classification error: {e}</div>"
# 2. OCR
ocr_text = ""
ocr_image = None
ocr_results = []
try:
if receipt_ocr:
# Try fast OCR first (EasyOCR + Tesseract only)
ocr_results = receipt_ocr.extract_with_positions(image, use_ensemble=False)
# If confidence is low, try full ensemble
if ocr_results:
avg_conf = np.mean([r['confidence'] for r in ocr_results])
if avg_conf < 0.5 or len(ocr_results) < 5:
# Low confidence or few results, try full ensemble
ocr_results = receipt_ocr.extract_with_positions(image, use_ensemble=True)
ocr_image = draw_ocr_boxes(image, ocr_results)
lines = [f"{i+1}. [{r['confidence']:.0%}] {r['text']}" for i, r in enumerate(ocr_results)]
ocr_text = f"Detected {len(ocr_results)} text regions:\n\n" + "\n".join(lines)
results['ocr'] = ocr_results
except Exception as e:
ocr_text = f"OCR error: {e}"
# 3. Field Extraction (OCR-first, LayoutLM as fallback)
fields = {}
fields_html = ""
try:
# Try OCR regex first (faster and often more accurate for totals)
ocr_fields = {}
if receipt_ocr and ocr_results:
ocr_fields = receipt_ocr.postprocess_receipt(ocr_results)
fields = ocr_fields.copy()
# Use LayoutLM only to fill in missing fields or validate
if layoutlm_extractor and ocr_results:
layoutlm_fields = layoutlm_extractor.predict_fields(image, ocr_results)
# For each field, merge intelligently
for field_name in ['vendor', 'date', 'total', 'time']:
ocr_val = ocr_fields.get(field_name)
layoutlm_val = layoutlm_fields.get(field_name)
if not ocr_val and layoutlm_val:
# OCR didn't find it, use LayoutLM
fields[field_name] = layoutlm_val
elif ocr_val and not layoutlm_val:
# LayoutLM didn't find it, but OCR did - use OCR (especially for total)
if field_name == 'total':
fields[field_name] = ocr_val
else:
# For other fields, prefer OCR if LayoutLM missed it
fields[field_name] = ocr_val
elif ocr_val and layoutlm_val and field_name == 'total':
# For total: validate LayoutLM against OCR text
ocr_text = ' '.join([r['text'] for r in ocr_results])
layoutlm_clean = str(layoutlm_val).replace('$', '').replace('.', '').replace(',', '').strip()
ocr_clean = ocr_text.replace('$', '').replace('.', '').replace(',', '')
# Check if LayoutLM total appears in OCR text
if layoutlm_clean in ocr_clean:
# LayoutLM matches OCR, use it (might be more accurate)
fields['total'] = layoutlm_val
else:
# LayoutLM doesn't match OCR, trust OCR (more reliable)
fields['total'] = ocr_val
elif ocr_val and not layoutlm_val and field_name == 'total':
# LayoutLM didn't find total, but OCR did - use OCR
fields['total'] = ocr_val
elif ocr_val and layoutlm_val and field_name != 'total':
# For other fields, prefer LayoutLM if it's longer/more complete
if len(str(layoutlm_val)) > len(str(ocr_val)):
fields[field_name] = layoutlm_val
else:
fields[field_name] = ocr_val
fields_html = "<div style='padding: 16px; background: #111827; color: #e5e7eb; border-radius: 12px; border: 1px solid #1f2937;'><h4 style=\"color: #e5e7eb;\">Extracted Fields</h4>"
for name, value in [
('Vendor', fields.get('vendor')),
('Date', fields.get('date')),
('Total', f"${fields.get('total')}" if fields.get('total') else None),
]:
display = value or '<span style="color: #9ca3af;">Not found</span>'
fields_html += f"<div style='padding: 8px; background: #0f172a; color: #e5e7eb; border: 1px solid #1f2937; border-radius: 6px; margin: 4px 0;'><b>{name}:</b> {display}</div>"
fields_html += "</div>"
results['fields'] = fields
except Exception as e:
fields_html = f"<div style='color: red;'>Extraction error: {e}</div>"
# 4. Anomaly Detection
anomaly_html = ""
try:
anomaly_result = anomaly_detector.predict(fields)
status_color = '#dc3545' if anomaly_result['is_anomaly'] else '#28a745'
status_text = anomaly_result['prediction']
anomaly_html = f"""
<div style="padding: 16px; background: #111827; color: #e5e7eb; border-radius: 12px; margin: 8px 0; border: 1px solid #1f2937;">
<h4 style="margin: 0 0 8px 0; color: #e5e7eb;">Anomaly Detection</h4>
<div style="font-size: 18px; font-weight: bold; color: {status_color};">{status_text}</div>
"""
if anomaly_result['reasons']:
anomaly_html += "<ul style='margin: 8px 0; padding-left: 20px;'>"
for reason in anomaly_result['reasons']:
anomaly_html += f"<li>{reason}</li>"
anomaly_html += "</ul>"
anomaly_html += "</div>"
results['anomaly'] = anomaly_result
except Exception as e:
anomaly_html = f"<div style='color: red;'>Anomaly detection error: {e}</div>"
# 5. Final Decision
is_receipt = results.get('classification', {}).get('is_receipt', True)
is_anomaly = results.get('anomaly', {}).get('is_anomaly', False)
conf = results.get('classification', {}).get('confidence', 0.5)
if not is_receipt:
decision = "REJECT"
decision_color = "#dc3545"
reason = "Not a receipt"
elif is_anomaly:
decision = "REVIEW"
decision_color = "#ffc107"
reason = "Anomaly detected"
elif conf < 0.7:
decision = "REVIEW"
decision_color = "#ffc107"
reason = "Low confidence"
else:
decision = "APPROVE"
decision_color = "#28a745"
reason = "All checks passed"
summary_html = f"""
<div style="padding: 24px; background: linear-gradient(135deg, {decision_color}22, {decision_color}11);
border-left: 4px solid {decision_color}; border-radius: 12px; text-align: center;">
<div style="font-size: 32px; font-weight: bold; color: {decision_color};">{decision}</div>
<div style="color: #6c757d; margin-top: 8px;">{reason}</div>
</div>
{classifier_html}
{anomaly_html}
{fields_html}
"""
safe_results = json.dumps(to_jsonable(results), indent=2)
return summary_html, ocr_image, ocr_text, safe_results, summary_html
# ============================================================================
# Gradio Interface
# ============================================================================
CUSTOM_CSS = """
.gradio-container { max-width: 1200px !important; background: #0b0c0e; color: #e5e7eb; }
.main-header { text-align: center; padding: 20px; background: linear-gradient(135deg, #0f172a 0%, #1f2937 100%);
border-radius: 12px; color: #e5e7eb; margin-bottom: 20px; border: 1px solid #1f2937; }
.gr-button { border-radius: 10px; background: #111827; color: #e5e7eb; border: 1px solid #1f2937; }
.gr-button-primary { background: #111827; border: 1px solid #22c55e; color: #e5e7eb; }
.gr-box { border: 1px solid #1f2937; background: #111827; color: #e5e7eb; }
.gradio-accordion { border: 1px solid #1f2937 !important; background: #0f172a !important; color: #e5e7eb !important; }
.gr-markdown { color: #e5e7eb; }
.gr-textbox textarea { background: #0f172a !important; color: #e5e7eb !important; border: 1px solid #1f2937 !important; }
.gr-radio { color: #e5e7eb !important; }
"""
with gr.Blocks(title="Receipt Processing Agent", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
gr.Markdown("""
<div class="main-header">
<h1>Receipt Processing Agent</h1>
<p>Ensemble classification, OCR, field extraction, and anomaly detection</p>
<p style="margin-top: 12px; font-size: 14px; color: #9ca3af;">Built by Emily, John, Luke, Michael and Raghu</p>
<p style="margin-top: 8px; font-size: 14px;">
<a href="https://github.com/RogueTex/StreamingDataforModelTraining#readme" target="_blank" style="color: #22c55e; text-decoration: none; border-bottom: 1px solid #22c55e;">Read more here →</a>
</p>
</div>
""")
gr.Markdown("""
### How It Works
Upload a receipt image to automatically:
- **Classify** document type with ViT + ResNet ensemble
- **Extract text** with EasyOCR (with bounding boxes)
- **Extract fields** (vendor, date, total) using regex patterns
- **Detect anomalies** with rule-based checks
- **Route** to APPROVE / REVIEW / REJECT
---
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Upload Receipt")
input_image = gr.Image(type="pil", label="Receipt Image", height=350)
process_btn = gr.Button("Process Receipt", variant="primary", size="lg")
with gr.Column(scale=1):
agent_summary = gr.HTML(
label="Results",
value="<div style='padding: 40px; text-align: center; color: #6c757d;'>Upload an image to begin</div>"
)
with gr.Accordion("OCR Results", open=False):
with gr.Row():
ocr_image_output = gr.Image(label="Detected Text Regions", height=300)
ocr_text_output = gr.Textbox(label="Extracted Text", lines=12)
with gr.Accordion("Raw Results (JSON)", open=False):
results_json = gr.Textbox(label="Full Results", lines=15)
# Per-section feedback controls
with gr.Accordion("Classification Feedback", open=False):
cls_assess = gr.Radio(choices=["Correct", "Incorrect"], label="Classification correct?", value=None)
cls_notes = gr.Textbox(label="Notes (optional)", placeholder="What should be improved or fixed?", lines=2)
cls_status = gr.Markdown(value="")
cls_submit = gr.Button("Submit Classification Feedback", variant="primary")
cls_submit.click(
fn=save_feedback,
inputs=[cls_assess, cls_notes, results_json, gr.State("classification")],
outputs=cls_status
)
with gr.Accordion("OCR Feedback", open=False):
ocr_assess = gr.Radio(choices=["Correct", "Incorrect"], label="OCR correct?", value=None)
ocr_notes = gr.Textbox(label="Notes (optional)", placeholder="What should be improved or fixed?", lines=2)
ocr_status = gr.Markdown(value="")
ocr_submit = gr.Button("Submit OCR Feedback", variant="primary")
ocr_submit.click(
fn=save_feedback,
inputs=[ocr_assess, ocr_notes, results_json, gr.State("ocr")],
outputs=ocr_status
)
with gr.Accordion("Field Extraction Feedback", open=False):
fld_assess = gr.Radio(choices=["Correct", "Incorrect"], label="Fields correct?", value=None)
fld_notes = gr.Textbox(label="Notes (optional)", placeholder="What should be improved or fixed?", lines=2)
fld_status = gr.Markdown(value="")
fld_submit = gr.Button("Submit Fields Feedback", variant="primary")
fld_submit.click(
fn=save_feedback,
inputs=[fld_assess, fld_notes, results_json, gr.State("fields")],
outputs=fld_status
)
with gr.Accordion("Anomaly Feedback", open=False):
an_assess = gr.Radio(choices=["Correct", "Incorrect"], label="Anomaly correct?", value=None)
an_notes = gr.Textbox(label="Notes (optional)", placeholder="What should be improved or fixed?", lines=2)
an_status = gr.Markdown(value="")
an_submit = gr.Button("Submit Anomaly Feedback", variant="primary")
an_submit.click(
fn=save_feedback,
inputs=[an_assess, an_notes, results_json, gr.State("anomaly")],
outputs=an_status
)
with gr.Accordion("Feedback", open=True):
gr.Markdown("Review the agent output below and submit a quick assessment.")
feedback_summary = gr.HTML(label="Last Agent Response (read-only)")
with gr.Row():
feedback_assessment = gr.Radio(
choices=["Correct", "Incorrect"],
label="Is the response correct?",
value=None
)
feedback_notes = gr.Textbox(
label="Notes (optional)",
placeholder="What should be improved or fixed?",
lines=3
)
feedback_status = gr.Markdown(value="")
submit_feedback = gr.Button("Submit Feedback", variant="primary")
submit_feedback.click(
fn=save_feedback,
inputs=[feedback_assessment, feedback_notes, results_json, gr.State("overall")],
outputs=feedback_status
)
process_btn.click(
fn=process_receipt,
inputs=[input_image],
outputs=[agent_summary, ocr_image_output, ocr_text_output, results_json, feedback_summary]
)
# Launch (Spaces needs share=True when localhost is blocked)
if __name__ == "__main__":
demo.queue(max_size=8).launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True,
# keep API enabled; json_schema traversal is guarded by the gradio_client
# monkeypatch above (_safe_get_type / _safe_json_schema_to_python_type)
show_api=True,
)