Spaces:
Running
Running
Raghu
Fix vendor extraction: detect known vendors like Einstein, skip numbers/IDs, prefer longer names. Remove Time field from display.
4b81303
| """ | |
| Receipt Processing Pipeline - Hugging Face Spaces App | |
| Ensemble classification, OCR, field extraction, anomaly detection, and agentic routing. | |
| """ | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import gradio as gr | |
| import gradio.routes as gr_routes | |
| import easyocr | |
| import json | |
| import re | |
| from PIL import Image, ImageDraw | |
| from datetime import datetime | |
| from torchvision import transforms, models | |
| from transformers import ( | |
| ViTForImageClassification, | |
| ViTImageProcessor, | |
| LayoutLMv3ForTokenClassification, | |
| LayoutLMv3Processor, | |
| ) | |
| from sklearn.ensemble import IsolationForest | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # --------------------------------------------------------------------------- | |
| # Work around Gradio json_schema traversal crash: | |
| # - guard bool schema entries | |
| # --------------------------------------------------------------------------- | |
| import gradio_client.utils as grc_utils | |
| _orig_get_type = grc_utils.get_type | |
| _orig_json_schema_to_python_type = grc_utils.json_schema_to_python_type | |
| def _safe_get_type(schema): | |
| if isinstance(schema, bool): | |
| return "any" | |
| return _orig_get_type(schema) | |
| def _safe_json_schema_to_python_type(schema, defs=None): | |
| if isinstance(schema, bool): | |
| return "any" | |
| try: | |
| return _orig_json_schema_to_python_type(schema, defs) | |
| except Exception: | |
| return "any" | |
| grc_utils.get_type = _safe_get_type | |
| grc_utils.json_schema_to_python_type = _safe_json_schema_to_python_type | |
| # --------------------------------------------------------------------------- | |
| # JSON sanitation helper (convert numpy types & PIL-friendly outputs) | |
| # --------------------------------------------------------------------------- | |
| def to_jsonable(obj): | |
| if isinstance(obj, dict): | |
| return {k: to_jsonable(v) for k, v in obj.items()} | |
| if isinstance(obj, (list, tuple)): | |
| return [to_jsonable(v) for v in obj] | |
| if isinstance(obj, (np.bool_, bool)): | |
| return bool(obj) | |
| if isinstance(obj, (np.integer,)): | |
| return int(obj) | |
| if isinstance(obj, (np.floating,)): | |
| return float(obj) | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| if isinstance(obj, Image.Image): | |
| return None # avoid serializing images; skip in JSON | |
| return obj | |
| # --------------------------------------------------------------------------- | |
| # Feedback persistence helper (CSV; optionally include section label) | |
| # --------------------------------------------------------------------------- | |
| def save_feedback(assessment, notes, results_json_str, section="overall"): | |
| try: | |
| parsed = json.loads(results_json_str) if results_json_str else {} | |
| except Exception: | |
| parsed = {"raw": results_json_str} | |
| entry = { | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "section": section or "", | |
| "assessment": assessment or "", | |
| "notes": notes or "", | |
| "results": parsed, | |
| } | |
| import csv | |
| fieldnames = ["timestamp", "section", "assessment", "notes", "results"] | |
| file_exists = os.path.exists("feedback_logs.csv") | |
| with open("feedback_logs.csv", "a", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=fieldnames) | |
| if not file_exists: | |
| writer.writeheader() | |
| writer.writerow({ | |
| "timestamp": entry["timestamp"], | |
| "section": entry.get("section", ""), | |
| "assessment": entry["assessment"], | |
| "notes": entry["notes"], | |
| "results": json.dumps(entry["results"]), | |
| }) | |
| return "✅ Feedback saved. (Stored in feedback_logs.csv)" | |
| # ============================================================================ | |
| # Configuration | |
| # ============================================================================ | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| MODELS_DIR = 'models' | |
| print(f"Device: {DEVICE}") | |
| print(f"Models directory: {MODELS_DIR}") | |
| # ============================================================================ | |
| # Model Classes | |
| # ============================================================================ | |
| class DocumentClassifier: | |
| """ViT-based document classifier (receipt vs other).""" | |
| def __init__(self, num_labels=2, model_path=None): | |
| self.num_labels = num_labels | |
| self.model = None | |
| self.processor = None | |
| self.model_path = model_path or os.path.join(MODELS_DIR, 'rvl_classifier.pt') | |
| self.pretrained = 'WinKawaks/vit-tiny-patch16-224' | |
| def load_model(self): | |
| try: | |
| self.processor = ViTImageProcessor.from_pretrained(self.pretrained) | |
| except: | |
| self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') | |
| self.model = ViTForImageClassification.from_pretrained( | |
| self.pretrained, | |
| num_labels=self.num_labels, | |
| ignore_mismatched_sizes=True | |
| ) | |
| self.model = self.model.to(DEVICE) | |
| self.model.eval() | |
| return self.model | |
| def load_weights(self, path): | |
| if os.path.exists(path): | |
| checkpoint = torch.load(path, map_location=DEVICE) | |
| if isinstance(checkpoint, dict): | |
| if 'model_state_dict' in checkpoint: | |
| self.model.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
| elif 'state_dict' in checkpoint: | |
| self.model.load_state_dict(checkpoint['state_dict'], strict=False) | |
| else: | |
| self.model.load_state_dict(checkpoint, strict=False) | |
| else: | |
| self.model.load_state_dict(checkpoint, strict=False) | |
| print(f" Loaded ViT weights from {path}") | |
| def predict(self, image): | |
| if self.model is None: | |
| self.load_model() | |
| self.model.eval() | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| image = image.convert('RGB') | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1) | |
| pred = torch.argmax(probs, dim=-1).item() | |
| conf = probs[0, pred].item() | |
| is_receipt = pred == 1 | |
| label = "receipt" if is_receipt else "other" | |
| return { | |
| 'is_receipt': is_receipt, | |
| 'confidence': conf, | |
| 'label': label, | |
| 'probabilities': probs[0].cpu().numpy().tolist() | |
| } | |
| class ResNetDocumentClassifier: | |
| """ResNet18-based document classifier.""" | |
| def __init__(self, num_labels=2, model_path=None): | |
| self.num_labels = num_labels | |
| self.model = None | |
| self.model_path = model_path or os.path.join(MODELS_DIR, 'resnet18_rvlcdip.pt') | |
| self.use_class_mapping = False | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def load_model(self): | |
| self.model = models.resnet18(weights=None) | |
| self.model = self.model.to(DEVICE) | |
| self.model.eval() | |
| return self.model | |
| def load_weights(self, path): | |
| if not os.path.exists(path): | |
| return | |
| checkpoint = torch.load(path, map_location=DEVICE) | |
| if isinstance(checkpoint, dict): | |
| state_dict = checkpoint.get('model_state_dict', checkpoint.get('state_dict', checkpoint)) | |
| id2label = checkpoint.get('id2label', None) | |
| else: | |
| state_dict = checkpoint | |
| id2label = None | |
| # Determine number of classes from checkpoint | |
| fc_weight_key = 'fc.weight' | |
| if fc_weight_key in state_dict: | |
| num_classes = state_dict[fc_weight_key].shape[0] | |
| else: | |
| num_classes = self.num_labels | |
| # Rebuild final layer if needed | |
| if num_classes != self.model.fc.out_features: | |
| self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) | |
| self.model = self.model.to(DEVICE) | |
| self.model.load_state_dict(state_dict, strict=False) | |
| # Handle 16-class RVL-CDIP models | |
| if num_classes == 16: | |
| self.use_class_mapping = True | |
| self.receipt_class_idx = 11 # Receipt class in RVL-CDIP | |
| print(f" Loaded ResNet weights from {path} ({num_classes} classes)") | |
| def predict(self, image): | |
| if self.model is None: | |
| self.load_model() | |
| self.model.eval() | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| image = image.convert('RGB') | |
| input_tensor = self.transform(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = self.model(input_tensor) | |
| probs = torch.softmax(outputs, dim=-1) | |
| if self.use_class_mapping: | |
| receipt_prob = probs[0, self.receipt_class_idx].item() | |
| other_prob = 1.0 - receipt_prob | |
| is_receipt = receipt_prob > 0.5 | |
| conf = receipt_prob if is_receipt else other_prob | |
| final_probs = [other_prob, receipt_prob] | |
| else: | |
| pred = torch.argmax(probs, dim=-1).item() | |
| conf = probs[0, pred].item() | |
| is_receipt = pred == 1 | |
| final_probs = probs[0].cpu().numpy().tolist() | |
| return { | |
| 'is_receipt': is_receipt, | |
| 'confidence': conf, | |
| 'label': "receipt" if is_receipt else "other", | |
| 'probabilities': final_probs | |
| } | |
| class EnsembleDocumentClassifier: | |
| """Ensemble of ViT and ResNet classifiers.""" | |
| def __init__(self, model_configs=None, weights=None): | |
| self.model_configs = model_configs or [ | |
| {'name': 'vit_base', 'path': os.path.join(MODELS_DIR, 'rvl_classifier.pt')}, | |
| {'name': 'resnet18', 'path': os.path.join(MODELS_DIR, 'resnet18_rvlcdip.pt')}, | |
| ] | |
| # Filter to existing models | |
| self.model_configs = [cfg for cfg in self.model_configs if os.path.exists(cfg['path'])] | |
| if not self.model_configs: | |
| print("Warning: No model files found, will use default ViT") | |
| self.model_configs = [{'name': 'vit_default', 'path': None}] | |
| self.weights = weights or [1.0 / len(self.model_configs)] * len(self.model_configs) | |
| self.classifiers = [] | |
| self.processor = None | |
| def load_models(self): | |
| print(f"Loading ensemble with {len(self.model_configs)} models...") | |
| for cfg in self.model_configs: | |
| is_resnet = 'resnet' in cfg['name'].lower() or 'resnet' in cfg.get('path', '').lower() | |
| if is_resnet: | |
| classifier = ResNetDocumentClassifier(num_labels=2, model_path=cfg['path']) | |
| else: | |
| classifier = DocumentClassifier(num_labels=2, model_path=cfg['path']) | |
| classifier.load_model() | |
| if cfg['path'] and os.path.exists(cfg['path']): | |
| try: | |
| classifier.load_weights(cfg['path']) | |
| except Exception as e: | |
| print(f" Warning: Could not load {cfg['name']}: {e}") | |
| self.classifiers.append(classifier) | |
| if self.processor is None: | |
| if hasattr(classifier, 'processor'): | |
| self.processor = classifier.processor | |
| elif hasattr(classifier, 'transform'): | |
| self.processor = classifier.transform | |
| print(f"Ensemble ready with {len(self.classifiers)} models") | |
| return self | |
| def predict(self, image, return_individual=False): | |
| if not self.classifiers: | |
| self.load_models() | |
| all_probs = [] | |
| individual_results = [] | |
| for i, classifier in enumerate(self.classifiers): | |
| result = classifier.predict(image) | |
| probs = result.get('probabilities', [0.5, 0.5]) | |
| if len(probs) < 2: | |
| probs = [1 - result['confidence'], result['confidence']] | |
| all_probs.append(probs) | |
| individual_results.append({ | |
| 'name': self.model_configs[i]['name'], | |
| 'prediction': result['label'], | |
| 'confidence': result['confidence'], | |
| 'probabilities': probs | |
| }) | |
| # Weighted average | |
| ensemble_probs = np.zeros(2) | |
| for i, probs in enumerate(all_probs): | |
| ensemble_probs += np.array(probs[:2]) * self.weights[i] | |
| pred = np.argmax(ensemble_probs) | |
| is_receipt = pred == 1 | |
| conf = ensemble_probs[pred] | |
| result = { | |
| 'is_receipt': is_receipt, | |
| 'confidence': float(conf), | |
| 'label': "receipt" if is_receipt else "other", | |
| 'probabilities': ensemble_probs.tolist() | |
| } | |
| if return_individual: | |
| result['individual_results'] = individual_results | |
| return result | |
| # ============================================================================ | |
| # OCR | |
| # ============================================================================ | |
| class ReceiptOCR: | |
| """Enhanced OCR with EasyOCR + TrOCR + PaddleOCR + Tesseract ensemble.""" | |
| def __init__(self): | |
| self.reader = None | |
| self.trocr_engine = None | |
| self.paddleocr_engine = None | |
| self.use_tesseract = False | |
| # Engine weights for ensemble | |
| self.engine_weights = { | |
| 'trocr': 0.40, # Highest weight - best quality | |
| 'easyocr': 0.35, | |
| 'paddleocr': 0.30, | |
| 'tesseract': 0.20 | |
| } | |
| # Try to initialize TrOCR | |
| try: | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| self.trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") | |
| self.trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed") | |
| self.trocr_model = self.trocr_model.to(DEVICE) | |
| self.trocr_model.eval() | |
| self.trocr_available = True | |
| print("TrOCR initialized") | |
| except Exception as e: | |
| self.trocr_available = False | |
| print(f"TrOCR not available: {e}") | |
| # Try to initialize PaddleOCR | |
| try: | |
| from paddleocr import PaddleOCR | |
| self.paddleocr_engine = PaddleOCR(use_angle_cls=True, lang='en', show_log=False) | |
| self.paddleocr_available = True | |
| print("PaddleOCR initialized") | |
| except Exception as e: | |
| self.paddleocr_available = False | |
| print(f"PaddleOCR not available: {e}") | |
| # Try to initialize Tesseract | |
| try: | |
| import pytesseract | |
| self.use_tesseract = True | |
| except ImportError: | |
| self.use_tesseract = False | |
| def load(self): | |
| if self.reader is None: | |
| print("Loading EasyOCR...") | |
| self.reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available()) | |
| print("EasyOCR ready") | |
| return self | |
| def _preprocess_image(self, image, method='enhance'): | |
| """Apply image preprocessing to improve OCR accuracy.""" | |
| import cv2 | |
| if isinstance(image, Image.Image): | |
| img_array = np.array(image) | |
| else: | |
| img_array = image.copy() | |
| if method == 'enhance': | |
| # Convert to grayscale if needed | |
| if len(img_array.shape) == 3: | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| else: | |
| gray = img_array | |
| # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) | |
| enhanced = clahe.apply(gray) | |
| # Denoise | |
| denoised = cv2.fastNlMeansDenoising(enhanced, h=10) | |
| # Convert back to RGB for OCR engines | |
| return cv2.cvtColor(denoised, cv2.COLOR_GRAY2RGB) | |
| elif method == 'sharpen': | |
| # Sharpen the image | |
| kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) | |
| if len(img_array.shape) == 3: | |
| sharpened = cv2.filter2D(img_array, -1, kernel) | |
| else: | |
| gray = img_array | |
| sharpened = cv2.filter2D(gray, -1, kernel) | |
| sharpened = cv2.cvtColor(sharpened, cv2.COLOR_GRAY2RGB) | |
| return sharpened | |
| return img_array | |
| def _run_easyocr(self, image): | |
| """Run EasyOCR.""" | |
| if self.reader is None: | |
| self.load() | |
| results = self.reader.readtext(image) | |
| extracted = [] | |
| for bbox, text, conf in results: | |
| x_coords = [p[0] for p in bbox] | |
| y_coords = [p[1] for p in bbox] | |
| extracted.append({ | |
| 'text': text.strip(), | |
| 'confidence': conf, | |
| 'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)], | |
| 'engine': 'easyocr' | |
| }) | |
| return extracted | |
| def _run_trocr(self, image, boxes): | |
| """Run TrOCR on detected text regions.""" | |
| if not self.trocr_available: | |
| return [] | |
| if isinstance(image, np.ndarray): | |
| pil_image = Image.fromarray(image).convert('RGB') | |
| else: | |
| pil_image = image.convert('RGB') | |
| results = [] | |
| for box in boxes: | |
| try: | |
| if isinstance(box, list) and len(box) >= 4: | |
| # Convert to [x1, y1, x2, y2] | |
| if isinstance(box[0], list): | |
| x1 = int(min(p[0] for p in box)) | |
| y1 = int(min(p[1] for p in box)) | |
| x2 = int(max(p[0] for p in box)) | |
| y2 = int(max(p[1] for p in box)) | |
| else: | |
| x1, y1, x2, y2 = [int(b) for b in box[:4]] | |
| # Crop and recognize | |
| cropped = pil_image.crop((x1, y1, x2, y2)) | |
| # TrOCR recognition | |
| pixel_values = self.trocr_processor(images=cropped, return_tensors="pt").pixel_values.to(DEVICE) | |
| with torch.no_grad(): | |
| generated_ids = self.trocr_model.generate( | |
| pixel_values, | |
| max_length=128, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| text = self.trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| if text.strip(): | |
| results.append({ | |
| 'text': text.strip(), | |
| 'confidence': 0.9, # TrOCR doesn't provide confidence, use high default | |
| 'bbox': [x1, y1, x2, y2], | |
| 'engine': 'trocr' | |
| }) | |
| except Exception as e: | |
| continue | |
| return results | |
| def _run_paddleocr(self, image): | |
| """Run PaddleOCR.""" | |
| if not self.paddleocr_available: | |
| return [] | |
| try: | |
| result = self.paddleocr_engine.ocr(image, cls=True) | |
| if result is None or len(result) == 0 or result[0] is None: | |
| return [] | |
| extracted = [] | |
| for line in result[0]: | |
| if line is None: | |
| continue | |
| bbox, (text, conf) = line | |
| x_coords = [p[0] for p in bbox] | |
| y_coords = [p[1] for p in bbox] | |
| extracted.append({ | |
| 'text': text.strip(), | |
| 'confidence': conf, | |
| 'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)], | |
| 'engine': 'paddleocr' | |
| }) | |
| return extracted | |
| except Exception as e: | |
| print(f"PaddleOCR error: {e}") | |
| return [] | |
| def _run_tesseract(self, image): | |
| """Run Tesseract OCR.""" | |
| if not self.use_tesseract: | |
| return [] | |
| try: | |
| import pytesseract | |
| if isinstance(image, Image.Image): | |
| pil_image = image.convert('RGB') | |
| else: | |
| pil_image = Image.fromarray(image).convert('RGB') | |
| data = pytesseract.image_to_data(pil_image, output_type=pytesseract.Output.DICT) | |
| results = [] | |
| n_boxes = len(data['text']) | |
| for i in range(n_boxes): | |
| text = data['text'][i].strip() | |
| conf = int(data['conf'][i]) | |
| if text and conf > 0: | |
| x, y, w, h = data['left'][i], data['top'][i], data['width'][i], data['height'][i] | |
| results.append({ | |
| 'text': text, | |
| 'confidence': conf / 100.0, | |
| 'bbox': [x, y, x+w, y+h], | |
| 'engine': 'tesseract' | |
| }) | |
| return results | |
| except Exception as e: | |
| print(f"Tesseract OCR error: {e}") | |
| return [] | |
| def _compute_iou(self, box1, box2): | |
| """Compute Intersection over Union for bounding boxes.""" | |
| x1_1, y1_1, x2_1, y2_1 = box1 | |
| x1_2, y1_2, x2_2, y2_2 = box2 | |
| xi1 = max(x1_1, x1_2) | |
| yi1 = max(y1_1, y1_2) | |
| xi2 = min(x2_1, x2_2) | |
| yi2 = min(y2_1, y2_2) | |
| inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1) | |
| box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) | |
| box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) | |
| union_area = box1_area + box2_area - inter_area | |
| return inter_area / union_area if union_area > 0 else 0 | |
| def _merge_results(self, all_results): | |
| """Merge results from multiple OCR engines using weighted voting.""" | |
| if not all_results: | |
| return [] | |
| # Use the engine with most detections as base | |
| base_engine = max(all_results.keys(), key=lambda k: len(all_results[k])) | |
| base_results = all_results[base_engine] | |
| merged = [] | |
| for base_result in base_results: | |
| base_box = base_result['bbox'] | |
| base_text = base_result['text'] | |
| base_conf = base_result['confidence'] | |
| # Find matching results from other engines | |
| matches = [(base_text, base_conf, self.engine_weights.get(base_engine, 0.3))] | |
| for engine_name, results in all_results.items(): | |
| if engine_name == base_engine: | |
| continue | |
| for result in results: | |
| iou = self._compute_iou(base_box, result['bbox']) | |
| if iou > 0.3: # Same text region | |
| weight = self.engine_weights.get(engine_name, 0.2) | |
| matches.append((result['text'], result['confidence'], weight)) | |
| # Vote on the best text | |
| if len(matches) == 1: | |
| final_text = base_text | |
| final_conf = base_conf | |
| else: | |
| # Weighted voting | |
| text_scores = {} | |
| for text, conf, weight in matches: | |
| if text not in text_scores: | |
| text_scores[text] = 0 | |
| text_scores[text] += conf * weight | |
| final_text = max(text_scores.keys(), key=lambda t: text_scores[t]) | |
| total_weight = sum(w for _, _, w in matches) | |
| final_conf = min(0.99, text_scores[final_text] / total_weight if total_weight > 0 else 0.5) | |
| merged.append({ | |
| 'text': final_text, | |
| 'confidence': final_conf, | |
| 'bbox': base_box, | |
| 'engines_used': len(matches) | |
| }) | |
| return merged | |
| def extract_with_positions(self, image, min_confidence=0.3, use_ensemble=False): | |
| """Extract text with positions using ensemble of OCR engines.""" | |
| if isinstance(image, Image.Image): | |
| img_array = np.array(image) | |
| else: | |
| img_array = image.copy() | |
| all_results = {} | |
| # Run EasyOCR (always available) | |
| try: | |
| easyocr_results = self._run_easyocr(img_array) | |
| if easyocr_results: | |
| all_results['easyocr'] = easyocr_results | |
| except Exception as e: | |
| print(f"EasyOCR error: {e}") | |
| # Run PaddleOCR if available | |
| if self.paddleocr_available and use_ensemble: | |
| try: | |
| paddleocr_results = self._run_paddleocr(img_array) | |
| if paddleocr_results: | |
| all_results['paddleocr'] = paddleocr_results | |
| except Exception as e: | |
| print(f"PaddleOCR error: {e}") | |
| # Run Tesseract if available | |
| if self.use_tesseract and use_ensemble: | |
| try: | |
| tesseract_results = self._run_tesseract(img_array) | |
| if tesseract_results: | |
| all_results['tesseract'] = tesseract_results | |
| except Exception as e: | |
| print(f"Tesseract error: {e}") | |
| # Run TrOCR on detected boxes (needs boxes from other engines) | |
| if self.trocr_available and use_ensemble and all_results: | |
| try: | |
| # Get boxes from best available engine | |
| source_engine = max(all_results.keys(), key=lambda k: len(all_results[k])) | |
| boxes = [r['bbox'] for r in all_results[source_engine]] | |
| trocr_results = self._run_trocr(img_array, boxes) | |
| if trocr_results: | |
| all_results['trocr'] = trocr_results | |
| except Exception as e: | |
| print(f"TrOCR error: {e}") | |
| # Merge results if ensemble, otherwise use EasyOCR only | |
| if use_ensemble and len(all_results) > 1: | |
| merged = self._merge_results(all_results) | |
| elif 'easyocr' in all_results: | |
| merged = all_results['easyocr'] | |
| else: | |
| merged = [] | |
| # Filter by confidence | |
| filtered = [r for r in merged if r['confidence'] >= min_confidence] | |
| # If results are poor, try with preprocessing | |
| avg_confidence = np.mean([r['confidence'] for r in filtered]) if filtered else 0 | |
| if len(filtered) < 3 or avg_confidence < 0.4: | |
| try: | |
| preprocessed = self._preprocess_image(image, method='enhance') | |
| retry_results = self._run_easyocr(preprocessed) | |
| retry_filtered = [r for r in retry_results if r['confidence'] >= min_confidence] | |
| retry_avg = np.mean([r['confidence'] for r in retry_filtered]) if retry_filtered else 0 | |
| if retry_avg > avg_confidence: | |
| filtered = retry_filtered | |
| except Exception: | |
| pass | |
| # Sort by confidence (highest first) | |
| filtered.sort(key=lambda x: x['confidence'], reverse=True) | |
| return filtered | |
| def postprocess_receipt(self, ocr_results): | |
| """Extract structured fields from OCR results with improved patterns.""" | |
| # Fix common OCR errors (S->$ in amounts) | |
| fixed_results = [] | |
| for r in ocr_results: | |
| fixed_r = r.copy() | |
| fixed_r['text'] = self._fix_ocr_text(r['text']) | |
| fixed_results.append(fixed_r) | |
| full_text = ' '.join([r['text'] for r in fixed_results]) | |
| fields = { | |
| 'vendor': self._extract_vendor(ocr_results), | |
| 'date': self._extract_date(full_text), | |
| 'total': self._extract_total(full_text), | |
| 'time': self._extract_time(full_text) | |
| } | |
| return fields | |
| def _extract_vendor(self, ocr_results): | |
| """Extract vendor name - look for business name in top portion of receipt.""" | |
| if not ocr_results: | |
| return None | |
| # Sort by vertical position (top of receipt first) | |
| sorted_results = sorted(ocr_results, key=lambda x: x['bbox'][1] if isinstance(x['bbox'], list) and len(x['bbox']) > 1 else 0) | |
| # Look in top 10 results for vendor name | |
| top_results = sorted_results[:10] | |
| # Skip words that are clearly not vendor names | |
| skip_words = {'TOTAL', 'DATE', 'TIME', 'RECEIPT', 'THANK', 'YOU', 'STORE', 'HOST', | |
| 'ORDER', 'TYPE', 'TOGO', 'DINE', 'IN', 'CHECK', 'CLOSED', 'AMEX', | |
| 'VISA', 'MASTERCARD', 'CASH', 'CHANGE', 'SUBTOTAL', 'TAX'} | |
| # Known vendor patterns (common stores) | |
| known_vendors = ['EINSTEIN', 'STARBUCKS', 'MCDONALDS', 'WALMART', 'TARGET', | |
| 'CHIPOTLE', 'PANERA', 'DUNKIN', 'SUBWAY', 'CHICK-FIL-A'] | |
| # First, check if any known vendor is in the OCR results | |
| for result in top_results: | |
| text = result['text'].strip().upper() | |
| for vendor in known_vendors: | |
| if vendor in text: | |
| return result['text'].strip() | |
| # Look for longest meaningful text (likely the business name) | |
| candidates = [] | |
| for result in top_results: | |
| text = result['text'].strip() | |
| text_upper = text.upper() | |
| # Skip short texts, numbers, and common skip words | |
| if len(text) < 3: | |
| continue | |
| if text_upper in skip_words: | |
| continue | |
| if re.match(r'^[\d\s\-\/\.\$\,]+$', text): # Skip pure numbers/symbols | |
| continue | |
| if re.match(r'^#?\d+$', text): # Skip store numbers like #2846 | |
| continue | |
| # Prefer texts with letters and reasonable length | |
| if len(text) >= 4 and any(c.isalpha() for c in text): | |
| candidates.append((text, len(text), result['confidence'])) | |
| # Return the longest candidate with good confidence | |
| if candidates: | |
| # Sort by length (longer = more likely to be full vendor name) | |
| candidates.sort(key=lambda x: (x[1], x[2]), reverse=True) | |
| return candidates[0][0] | |
| return None | |
| def _extract_date(self, text): | |
| """Extract date with improved patterns.""" | |
| patterns = [ | |
| r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', # MM/DD/YYYY or MM-DD-YYYY | |
| r'\b\d{4}[/-]\d{2}[/-]\d{2}\b', # YYYY-MM-DD | |
| r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{1,2},?\s+\d{4}\b', # Month DD, YYYY | |
| ] | |
| for pattern in patterns: | |
| matches = re.findall(pattern, text, re.IGNORECASE) | |
| if matches: | |
| return matches[0] | |
| return None | |
| def _extract_total(self, text): | |
| """Extract total amount - handles S/$ OCR confusion.""" | |
| # Fix S -> $ in amounts (common OCR error) | |
| fixed_text = re.sub(r'\bS(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)\b', r'$\1', text) | |
| # Find all dollar amounts (now with fixed $ symbols) | |
| all_amounts = re.findall(r'[\$S](\d{1,3}(?:,\d{3})*(?:\.\d{2})?)', fixed_text) | |
| all_amounts = [float(a.replace(',', '')) for a in all_amounts if a] | |
| if not all_amounts: | |
| # Try finding any decimal amounts | |
| all_amounts = re.findall(r'(\d{1,3}(?:,\d{3})*\.\d{2})', fixed_text) | |
| all_amounts = [float(a.replace(',', '')) for a in all_amounts if a] | |
| if not all_amounts: | |
| return None | |
| # Look for "TOTAL", "AMOUNT DUE", "BALANCE" keywords and find amount near them | |
| lines = fixed_text.split('\n') | |
| for i, line in enumerate(lines): | |
| line_upper = line.upper() | |
| if any(keyword in line_upper for keyword in ['TOTAL', 'AMOUNT DUE', 'BALANCE DUE', 'DUE']): | |
| # Check this line and next 2 lines for amount | |
| search_text = ' '.join(lines[i:min(i+3, len(lines))]) | |
| # Match both $ and S followed by amounts | |
| matches = re.findall(r'[\$S](\d{1,3}(?:,\d{3})*(?:\.\d{2})?)', search_text) | |
| if matches: | |
| amounts_near_total = [float(m.replace(',', '')) for m in matches] | |
| return f"{max(amounts_near_total):.2f}" | |
| # Fallback: return largest amount overall | |
| return f"{max(all_amounts):.2f}" | |
| def _extract_time(self, text): | |
| """Extract time.""" | |
| patterns = [ | |
| r'\b(\d{1,2}):(\d{2})\s*(?:AM|PM)\b', | |
| r'\b(\d{1,2}):(\d{2})\b', | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, text, re.IGNORECASE) | |
| if match: | |
| return match.group(0) | |
| return None | |
| def _fix_ocr_text(self, text): | |
| """Fix common OCR errors like S->$ in amounts.""" | |
| # Fix S followed by digits -> $ (e.g., S154.06 -> $154.06) | |
| text = re.sub(r'\bS(\d{1,3}(?:,\d{3})*(?:\.\d{2})?)\b', r'$\1', text) | |
| # Fix Subtolal -> Subtotal (common OCR error) | |
| text = re.sub(r'\bSubtolal\b', 'Subtotal', text, flags=re.IGNORECASE) | |
| return text | |
| class LayoutLMFieldExtractor: | |
| """LayoutLMv3-based field extractor using fine-tuned weights if available.""" | |
| def __init__(self, model_path=None): | |
| self.model_path = model_path or os.path.join(MODELS_DIR, 'layoutlm_extractor.pt') | |
| self.id2label = { | |
| 0: 'O', | |
| 1: 'B-VENDOR', 2: 'I-VENDOR', | |
| 3: 'B-DATE', 4: 'I-DATE', | |
| 5: 'B-TOTAL', 6: 'I-TOTAL', | |
| 7: 'B-TIME', 8: 'I-TIME' | |
| } | |
| self.label2id = {v: k for k, v in self.id2label.items()} | |
| self.processor = None | |
| self.model = None | |
| def load(self): | |
| print("Loading LayoutLMv3 extractor...") | |
| self.processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base") | |
| self.model = LayoutLMv3ForTokenClassification.from_pretrained( | |
| "microsoft/layoutlmv3-base", | |
| num_labels=len(self.id2label), | |
| id2label=self.id2label, | |
| label2id=self.label2id, | |
| ) | |
| if os.path.exists(self.model_path): | |
| checkpoint = torch.load(self.model_path, map_location=DEVICE) | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| checkpoint = checkpoint['model_state_dict'] | |
| if isinstance(checkpoint, dict): | |
| missing, unexpected = self.model.load_state_dict(checkpoint, strict=False) | |
| print(f"Loaded LayoutLM weights; missing={len(missing)}, unexpected={len(unexpected)}") | |
| self.model = self.model.to(DEVICE) | |
| self.model.eval() | |
| print("LayoutLMv3 ready") | |
| return self | |
| def _prepare_boxes(self, ocr_results, image_size): | |
| """Convert absolute pixel boxes to LayoutLM 0-1000 format.""" | |
| width, height = image_size | |
| boxes = [] | |
| words = [] | |
| for r in ocr_results: | |
| bbox = r.get("bbox", [0, 0, width, height]) | |
| x0, y0, x1, y1 = bbox | |
| boxes.append([ | |
| int(1000 * x0 / width), | |
| int(1000 * y0 / height), | |
| int(1000 * x1 / width), | |
| int(1000 * y1 / height), | |
| ]) | |
| words.append(r.get("text", "")) | |
| return words, boxes | |
| def predict_fields(self, image, ocr_results=None): | |
| """Predict fields with confidence scores and improved total extraction.""" | |
| if self.model is None: | |
| self.load() | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| image = image.convert("RGB") | |
| if ocr_results: | |
| words, boxes = self._prepare_boxes(ocr_results, image.size) | |
| encoding = self.processor( | |
| image, | |
| words=words, | |
| boxes=boxes, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=512, | |
| ) | |
| else: | |
| encoding = self.processor(image, return_tensors="pt") | |
| encoding = {k: v.to(DEVICE) for k, v in encoding.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**encoding) | |
| logits = outputs.logits[0] | |
| # Get softmax probabilities for confidence | |
| probs = torch.softmax(logits, dim=-1) | |
| preds = logits.argmax(-1).cpu().tolist() | |
| probs_np = probs.cpu().numpy() | |
| tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0].cpu()) | |
| # Extract entities with confidence scores | |
| entities = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []} | |
| entity_confidences = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []} | |
| entity_positions = {"VENDOR": [], "DATE": [], "TOTAL": [], "TIME": []} | |
| current = {"label": None, "tokens": [], "start_idx": None} | |
| for idx, (token, pred) in enumerate(zip(tokens, preds)): | |
| label = self.id2label.get(pred, "O") | |
| conf = float(probs_np[idx, pred]) | |
| if token in ["[PAD]", "[CLS]", "[SEP]"]: | |
| continue | |
| if label.startswith("B-"): | |
| # Flush previous | |
| if current["label"] and current["tokens"]: | |
| entity_text = " ".join(current["tokens"]).replace("▁", " ").strip() | |
| entities[current["label"]].append(entity_text) | |
| entity_confidences[current["label"]].append(conf) | |
| entity_positions[current["label"]].append(current["start_idx"]) | |
| current = {"label": label[2:], "tokens": [token], "start_idx": idx} | |
| elif label.startswith("I-") and current["label"] == label[2:]: | |
| current["tokens"].append(token) | |
| else: | |
| if current["label"] and current["tokens"]: | |
| entity_text = " ".join(current["tokens"]).replace("▁", " ").strip() | |
| entities[current["label"]].append(entity_text) | |
| entity_confidences[current["label"]].append(conf) | |
| entity_positions[current["label"]].append(current["start_idx"]) | |
| current = {"label": None, "tokens": [], "start_idx": None} | |
| if current["label"] and current["tokens"]: | |
| entity_text = " ".join(current["tokens"]).replace("▁", " ").strip() | |
| entities[current["label"]].append(entity_text) | |
| entity_confidences[current["label"]].append(conf) | |
| entity_positions[current["label"]].append(current["start_idx"]) | |
| # Smart field selection with confidence and position awareness | |
| result = {} | |
| # Vendor: prefer first high-confidence result | |
| if entities["VENDOR"]: | |
| best_vendor_idx = max(range(len(entities["VENDOR"])), | |
| key=lambda i: entity_confidences["VENDOR"][i]) | |
| if entity_confidences["VENDOR"][best_vendor_idx] > 0.3: | |
| result["vendor"] = entities["VENDOR"][best_vendor_idx] | |
| # Date: prefer first high-confidence result | |
| if entities["DATE"]: | |
| best_date_idx = max(range(len(entities["DATE"])), | |
| key=lambda i: entity_confidences["DATE"][i]) | |
| if entity_confidences["DATE"][best_date_idx] > 0.3: | |
| result["date"] = entities["DATE"][best_date_idx] | |
| # Time: prefer first high-confidence result | |
| if entities["TIME"]: | |
| best_time_idx = max(range(len(entities["TIME"])), | |
| key=lambda i: entity_confidences["TIME"][i]) | |
| if entity_confidences["TIME"][best_time_idx] > 0.3: | |
| result["time"] = entities["TIME"][best_time_idx] | |
| # Total: improved extraction - look for amounts near "TOTAL" keyword in OCR | |
| if entities["TOTAL"]: | |
| # Get all total candidates with confidence | |
| total_candidates = [(entities["TOTAL"][i], entity_confidences["TOTAL"][i], | |
| entity_positions["TOTAL"][i]) | |
| for i in range(len(entities["TOTAL"]))] | |
| # If OCR results available, validate against OCR text | |
| if ocr_results: | |
| ocr_text = ' '.join([r['text'] for r in ocr_results]).upper() | |
| ocr_lines = [r['text'] for r in ocr_results] | |
| # Find amounts near "TOTAL" keyword | |
| best_total = None | |
| best_conf = 0 | |
| for total_val, conf, pos in total_candidates: | |
| # Clean the total value | |
| total_clean = str(total_val).replace('$', '').replace(',', '').replace('.', '').strip() | |
| # Check if this total appears near "TOTAL" keyword in OCR | |
| for i, line in enumerate(ocr_lines): | |
| line_upper = line.upper() | |
| if 'TOTAL' in line_upper or 'AMOUNT DUE' in line_upper: | |
| # Check this line and next 2 lines for the amount | |
| search_text = ' '.join(ocr_lines[i:min(i+3, len(ocr_lines))]) | |
| search_clean = search_text.replace('$', '').replace(',', '').replace('.', '') | |
| if total_clean in search_clean: | |
| # Found near TOTAL keyword - high confidence | |
| if conf > best_conf: | |
| best_total = total_val | |
| best_conf = conf | |
| break | |
| if best_total: | |
| result["total"] = best_total | |
| else: | |
| # Fallback: use highest confidence total | |
| best_idx = max(range(len(total_candidates)), key=lambda i: total_candidates[i][1]) | |
| if total_candidates[best_idx][1] > 0.3: | |
| result["total"] = total_candidates[best_idx][0] | |
| else: | |
| # No OCR, use highest confidence | |
| best_idx = max(range(len(total_candidates)), key=lambda i: total_candidates[i][1]) | |
| if total_candidates[best_idx][1] > 0.3: | |
| result["total"] = total_candidates[best_idx][0] | |
| return result | |
| # ============================================================================ | |
| # Anomaly Detection | |
| # ============================================================================ | |
| class AnomalyDetector: | |
| """Isolation Forest-based anomaly detection.""" | |
| def __init__(self): | |
| self.model = IsolationForest(contamination=0.1, random_state=42) | |
| self.is_fitted = False | |
| def extract_features(self, fields): | |
| """Extract features from receipt fields.""" | |
| total = 0 | |
| try: | |
| total = float(fields.get('total', 0) or 0) | |
| except: | |
| pass | |
| vendor = fields.get('vendor', '') or '' | |
| date = fields.get('date', '') or '' | |
| features = [ | |
| total, | |
| np.log1p(total), | |
| len(vendor), | |
| 1 if date else 0, | |
| 1, # num_items placeholder | |
| 12, # hour placeholder | |
| total, # amount_per_item placeholder | |
| 0 # is_weekend placeholder | |
| ] | |
| return np.array(features).reshape(1, -1) | |
| def predict(self, fields): | |
| features = self.extract_features(fields) | |
| # Simple rule-based detection if model not fitted | |
| reasons = [] | |
| total = float(fields.get('total', 0) or 0) | |
| if total > 1000: | |
| reasons.append(f"High amount: ${total:.2f}") | |
| if not fields.get('vendor'): | |
| reasons.append("Missing vendor") | |
| if not fields.get('date'): | |
| reasons.append("Missing date") | |
| is_anomaly = len(reasons) > 0 | |
| return { | |
| 'is_anomaly': is_anomaly, | |
| 'score': -0.5 if is_anomaly else 0.5, | |
| 'prediction': 'ANOMALY' if is_anomaly else 'NORMAL', | |
| 'reasons': reasons | |
| } | |
| # ============================================================================ | |
| # Initialize Models | |
| # ============================================================================ | |
| print("\n" + "="*50) | |
| print("Initializing models...") | |
| print("="*50) | |
| # Check for model files | |
| model_files = [] | |
| if os.path.exists(MODELS_DIR): | |
| model_files = [f for f in os.listdir(MODELS_DIR) if f.endswith('.pt')] | |
| print(f"Found model files: {model_files}") | |
| else: | |
| print(f"Models directory not found: {MODELS_DIR}") | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| # Initialize components | |
| try: | |
| ensemble_classifier = EnsembleDocumentClassifier() | |
| ensemble_classifier.load_models() | |
| except Exception as e: | |
| print(f"Warning: Could not load ensemble classifier: {e}") | |
| ensemble_classifier = None | |
| try: | |
| receipt_ocr = ReceiptOCR() | |
| receipt_ocr.load() | |
| except Exception as e: | |
| print(f"Warning: Could not load OCR: {e}") | |
| receipt_ocr = None | |
| try: | |
| layoutlm_extractor = LayoutLMFieldExtractor() | |
| layoutlm_extractor.load() | |
| except Exception as e: | |
| print(f"Warning: Could not load LayoutLMv3 extractor: {e}") | |
| layoutlm_extractor = None | |
| anomaly_detector = AnomalyDetector() | |
| print("\n" + "="*50) | |
| print("Initialization complete!") | |
| print("="*50 + "\n") | |
| # ============================================================================ | |
| # Helper Functions | |
| # ============================================================================ | |
| def draw_ocr_boxes(image, ocr_results): | |
| """Draw bounding boxes on image.""" | |
| img_copy = image.copy() | |
| draw = ImageDraw.Draw(img_copy) | |
| for r in ocr_results: | |
| conf = r.get('confidence', 0.5) | |
| bbox = r.get('bbox', []) | |
| if conf > 0.8: | |
| color = '#28a745' # Green | |
| elif conf > 0.5: | |
| color = '#ffc107' # Yellow | |
| else: | |
| color = '#dc3545' # Red | |
| if len(bbox) >= 4: | |
| draw.rectangle([bbox[0], bbox[1], bbox[2], bbox[3]], outline=color, width=2) | |
| return img_copy | |
| def process_receipt(image): | |
| """Main processing function for Gradio.""" | |
| if image is None: | |
| return ( | |
| "<div style='padding: 20px; text-align: center;'>Upload an image to begin</div>", | |
| None, "", "", "<div style='padding: 40px; text-align: center; color: #6c757d;'>Upload an image to begin</div>" | |
| ) | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| image = image.convert('RGB') | |
| results = {} | |
| # 1. Classification | |
| classifier_html = "" | |
| try: | |
| if ensemble_classifier: | |
| class_result = ensemble_classifier.predict(image, return_individual=True) | |
| else: | |
| class_result = {'is_receipt': True, 'confidence': 0.5, 'label': 'unknown'} | |
| conf = class_result['confidence'] | |
| label = class_result['label'].upper() | |
| color = '#28a745' if class_result.get('is_receipt') else '#dc3545' | |
| bar_color = '#28a745' if conf > 0.8 else '#ffc107' if conf > 0.6 else '#dc3545' | |
| classifier_html = f""" | |
| <div style="padding: 16px; background: #111827; color: #e5e7eb; border-radius: 12px; margin: 8px 0; border: 1px solid #1f2937;"> | |
| <h4 style="margin: 0 0 8px 0; color: #e5e7eb;">Classification</h4> | |
| <div style="font-size: 20px; font-weight: bold; color: {color};">{label}</div> | |
| <div style="margin-top: 8px; color: #e5e7eb;"> | |
| <span>Confidence: </span> | |
| <div style="display: inline-block; width: 120px; height: 8px; background: #1f2937; border-radius: 4px;"> | |
| <div style="width: {conf*100}%; height: 100%; background: {bar_color}; border-radius: 4px;"></div> | |
| </div> | |
| <span style="margin-left: 8px;">{conf:.1%}</span> | |
| </div> | |
| </div> | |
| """ | |
| results['classification'] = class_result | |
| except Exception as e: | |
| classifier_html = f"<div style='color: red;'>Classification error: {e}</div>" | |
| # 2. OCR | |
| ocr_text = "" | |
| ocr_image = None | |
| ocr_results = [] | |
| try: | |
| if receipt_ocr: | |
| # Try fast OCR first (EasyOCR + Tesseract only) | |
| ocr_results = receipt_ocr.extract_with_positions(image, use_ensemble=False) | |
| # If confidence is low, try full ensemble | |
| if ocr_results: | |
| avg_conf = np.mean([r['confidence'] for r in ocr_results]) | |
| if avg_conf < 0.5 or len(ocr_results) < 5: | |
| # Low confidence or few results, try full ensemble | |
| ocr_results = receipt_ocr.extract_with_positions(image, use_ensemble=True) | |
| ocr_image = draw_ocr_boxes(image, ocr_results) | |
| lines = [f"{i+1}. [{r['confidence']:.0%}] {r['text']}" for i, r in enumerate(ocr_results)] | |
| ocr_text = f"Detected {len(ocr_results)} text regions:\n\n" + "\n".join(lines) | |
| results['ocr'] = ocr_results | |
| except Exception as e: | |
| ocr_text = f"OCR error: {e}" | |
| # 3. Field Extraction (OCR-first, LayoutLM as fallback) | |
| fields = {} | |
| fields_html = "" | |
| try: | |
| # Try OCR regex first (faster and often more accurate for totals) | |
| ocr_fields = {} | |
| if receipt_ocr and ocr_results: | |
| ocr_fields = receipt_ocr.postprocess_receipt(ocr_results) | |
| fields = ocr_fields.copy() | |
| # Use LayoutLM only to fill in missing fields or validate | |
| if layoutlm_extractor and ocr_results: | |
| layoutlm_fields = layoutlm_extractor.predict_fields(image, ocr_results) | |
| # For each field, merge intelligently | |
| for field_name in ['vendor', 'date', 'total', 'time']: | |
| ocr_val = ocr_fields.get(field_name) | |
| layoutlm_val = layoutlm_fields.get(field_name) | |
| if not ocr_val and layoutlm_val: | |
| # OCR didn't find it, use LayoutLM | |
| fields[field_name] = layoutlm_val | |
| elif ocr_val and not layoutlm_val: | |
| # LayoutLM didn't find it, but OCR did - use OCR (especially for total) | |
| if field_name == 'total': | |
| fields[field_name] = ocr_val | |
| else: | |
| # For other fields, prefer OCR if LayoutLM missed it | |
| fields[field_name] = ocr_val | |
| elif ocr_val and layoutlm_val and field_name == 'total': | |
| # For total: validate LayoutLM against OCR text | |
| ocr_text = ' '.join([r['text'] for r in ocr_results]) | |
| layoutlm_clean = str(layoutlm_val).replace('$', '').replace('.', '').replace(',', '').strip() | |
| ocr_clean = ocr_text.replace('$', '').replace('.', '').replace(',', '') | |
| # Check if LayoutLM total appears in OCR text | |
| if layoutlm_clean in ocr_clean: | |
| # LayoutLM matches OCR, use it (might be more accurate) | |
| fields['total'] = layoutlm_val | |
| else: | |
| # LayoutLM doesn't match OCR, trust OCR (more reliable) | |
| fields['total'] = ocr_val | |
| elif ocr_val and not layoutlm_val and field_name == 'total': | |
| # LayoutLM didn't find total, but OCR did - use OCR | |
| fields['total'] = ocr_val | |
| elif ocr_val and layoutlm_val and field_name != 'total': | |
| # For other fields, prefer LayoutLM if it's longer/more complete | |
| if len(str(layoutlm_val)) > len(str(ocr_val)): | |
| fields[field_name] = layoutlm_val | |
| else: | |
| fields[field_name] = ocr_val | |
| fields_html = "<div style='padding: 16px; background: #111827; color: #e5e7eb; border-radius: 12px; border: 1px solid #1f2937;'><h4 style=\"color: #e5e7eb;\">Extracted Fields</h4>" | |
| for name, value in [ | |
| ('Vendor', fields.get('vendor')), | |
| ('Date', fields.get('date')), | |
| ('Total', f"${fields.get('total')}" if fields.get('total') else None), | |
| ]: | |
| display = value or '<span style="color: #9ca3af;">Not found</span>' | |
| fields_html += f"<div style='padding: 8px; background: #0f172a; color: #e5e7eb; border: 1px solid #1f2937; border-radius: 6px; margin: 4px 0;'><b>{name}:</b> {display}</div>" | |
| fields_html += "</div>" | |
| results['fields'] = fields | |
| except Exception as e: | |
| fields_html = f"<div style='color: red;'>Extraction error: {e}</div>" | |
| # 4. Anomaly Detection | |
| anomaly_html = "" | |
| try: | |
| anomaly_result = anomaly_detector.predict(fields) | |
| status_color = '#dc3545' if anomaly_result['is_anomaly'] else '#28a745' | |
| status_text = anomaly_result['prediction'] | |
| anomaly_html = f""" | |
| <div style="padding: 16px; background: #111827; color: #e5e7eb; border-radius: 12px; margin: 8px 0; border: 1px solid #1f2937;"> | |
| <h4 style="margin: 0 0 8px 0; color: #e5e7eb;">Anomaly Detection</h4> | |
| <div style="font-size: 18px; font-weight: bold; color: {status_color};">{status_text}</div> | |
| """ | |
| if anomaly_result['reasons']: | |
| anomaly_html += "<ul style='margin: 8px 0; padding-left: 20px;'>" | |
| for reason in anomaly_result['reasons']: | |
| anomaly_html += f"<li>{reason}</li>" | |
| anomaly_html += "</ul>" | |
| anomaly_html += "</div>" | |
| results['anomaly'] = anomaly_result | |
| except Exception as e: | |
| anomaly_html = f"<div style='color: red;'>Anomaly detection error: {e}</div>" | |
| # 5. Final Decision | |
| is_receipt = results.get('classification', {}).get('is_receipt', True) | |
| is_anomaly = results.get('anomaly', {}).get('is_anomaly', False) | |
| conf = results.get('classification', {}).get('confidence', 0.5) | |
| if not is_receipt: | |
| decision = "REJECT" | |
| decision_color = "#dc3545" | |
| reason = "Not a receipt" | |
| elif is_anomaly: | |
| decision = "REVIEW" | |
| decision_color = "#ffc107" | |
| reason = "Anomaly detected" | |
| elif conf < 0.7: | |
| decision = "REVIEW" | |
| decision_color = "#ffc107" | |
| reason = "Low confidence" | |
| else: | |
| decision = "APPROVE" | |
| decision_color = "#28a745" | |
| reason = "All checks passed" | |
| summary_html = f""" | |
| <div style="padding: 24px; background: linear-gradient(135deg, {decision_color}22, {decision_color}11); | |
| border-left: 4px solid {decision_color}; border-radius: 12px; text-align: center;"> | |
| <div style="font-size: 32px; font-weight: bold; color: {decision_color};">{decision}</div> | |
| <div style="color: #6c757d; margin-top: 8px;">{reason}</div> | |
| </div> | |
| {classifier_html} | |
| {anomaly_html} | |
| {fields_html} | |
| """ | |
| safe_results = json.dumps(to_jsonable(results), indent=2) | |
| return summary_html, ocr_image, ocr_text, safe_results, summary_html | |
| # ============================================================================ | |
| # Gradio Interface | |
| # ============================================================================ | |
| CUSTOM_CSS = """ | |
| .gradio-container { max-width: 1200px !important; background: #0b0c0e; color: #e5e7eb; } | |
| .main-header { text-align: center; padding: 20px; background: linear-gradient(135deg, #0f172a 0%, #1f2937 100%); | |
| border-radius: 12px; color: #e5e7eb; margin-bottom: 20px; border: 1px solid #1f2937; } | |
| .gr-button { border-radius: 10px; background: #111827; color: #e5e7eb; border: 1px solid #1f2937; } | |
| .gr-button-primary { background: #111827; border: 1px solid #22c55e; color: #e5e7eb; } | |
| .gr-box { border: 1px solid #1f2937; background: #111827; color: #e5e7eb; } | |
| .gradio-accordion { border: 1px solid #1f2937 !important; background: #0f172a !important; color: #e5e7eb !important; } | |
| .gr-markdown { color: #e5e7eb; } | |
| .gr-textbox textarea { background: #0f172a !important; color: #e5e7eb !important; border: 1px solid #1f2937 !important; } | |
| .gr-radio { color: #e5e7eb !important; } | |
| """ | |
| with gr.Blocks(title="Receipt Processing Agent", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: | |
| gr.Markdown(""" | |
| <div class="main-header"> | |
| <h1>Receipt Processing Agent</h1> | |
| <p>Ensemble classification, OCR, field extraction, and anomaly detection</p> | |
| <p style="margin-top: 12px; font-size: 14px; color: #9ca3af;">Built by Emily, John, Luke, Michael and Raghu</p> | |
| <p style="margin-top: 8px; font-size: 14px;"> | |
| <a href="https://github.com/RogueTex/StreamingDataforModelTraining#readme" target="_blank" style="color: #22c55e; text-decoration: none; border-bottom: 1px solid #22c55e;">Read more here →</a> | |
| </p> | |
| </div> | |
| """) | |
| gr.Markdown(""" | |
| ### How It Works | |
| Upload a receipt image to automatically: | |
| - **Classify** document type with ViT + ResNet ensemble | |
| - **Extract text** with EasyOCR (with bounding boxes) | |
| - **Extract fields** (vendor, date, total) using regex patterns | |
| - **Detect anomalies** with rule-based checks | |
| - **Route** to APPROVE / REVIEW / REJECT | |
| --- | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Upload Receipt") | |
| input_image = gr.Image(type="pil", label="Receipt Image", height=350) | |
| process_btn = gr.Button("Process Receipt", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| agent_summary = gr.HTML( | |
| label="Results", | |
| value="<div style='padding: 40px; text-align: center; color: #6c757d;'>Upload an image to begin</div>" | |
| ) | |
| with gr.Accordion("OCR Results", open=False): | |
| with gr.Row(): | |
| ocr_image_output = gr.Image(label="Detected Text Regions", height=300) | |
| ocr_text_output = gr.Textbox(label="Extracted Text", lines=12) | |
| with gr.Accordion("Raw Results (JSON)", open=False): | |
| results_json = gr.Textbox(label="Full Results", lines=15) | |
| # Per-section feedback controls | |
| with gr.Accordion("Classification Feedback", open=False): | |
| cls_assess = gr.Radio(choices=["Correct", "Incorrect"], label="Classification correct?", value=None) | |
| cls_notes = gr.Textbox(label="Notes (optional)", placeholder="What should be improved or fixed?", lines=2) | |
| cls_status = gr.Markdown(value="") | |
| cls_submit = gr.Button("Submit Classification Feedback", variant="primary") | |
| cls_submit.click( | |
| fn=save_feedback, | |
| inputs=[cls_assess, cls_notes, results_json, gr.State("classification")], | |
| outputs=cls_status | |
| ) | |
| with gr.Accordion("OCR Feedback", open=False): | |
| ocr_assess = gr.Radio(choices=["Correct", "Incorrect"], label="OCR correct?", value=None) | |
| ocr_notes = gr.Textbox(label="Notes (optional)", placeholder="What should be improved or fixed?", lines=2) | |
| ocr_status = gr.Markdown(value="") | |
| ocr_submit = gr.Button("Submit OCR Feedback", variant="primary") | |
| ocr_submit.click( | |
| fn=save_feedback, | |
| inputs=[ocr_assess, ocr_notes, results_json, gr.State("ocr")], | |
| outputs=ocr_status | |
| ) | |
| with gr.Accordion("Field Extraction Feedback", open=False): | |
| fld_assess = gr.Radio(choices=["Correct", "Incorrect"], label="Fields correct?", value=None) | |
| fld_notes = gr.Textbox(label="Notes (optional)", placeholder="What should be improved or fixed?", lines=2) | |
| fld_status = gr.Markdown(value="") | |
| fld_submit = gr.Button("Submit Fields Feedback", variant="primary") | |
| fld_submit.click( | |
| fn=save_feedback, | |
| inputs=[fld_assess, fld_notes, results_json, gr.State("fields")], | |
| outputs=fld_status | |
| ) | |
| with gr.Accordion("Anomaly Feedback", open=False): | |
| an_assess = gr.Radio(choices=["Correct", "Incorrect"], label="Anomaly correct?", value=None) | |
| an_notes = gr.Textbox(label="Notes (optional)", placeholder="What should be improved or fixed?", lines=2) | |
| an_status = gr.Markdown(value="") | |
| an_submit = gr.Button("Submit Anomaly Feedback", variant="primary") | |
| an_submit.click( | |
| fn=save_feedback, | |
| inputs=[an_assess, an_notes, results_json, gr.State("anomaly")], | |
| outputs=an_status | |
| ) | |
| with gr.Accordion("Feedback", open=True): | |
| gr.Markdown("Review the agent output below and submit a quick assessment.") | |
| feedback_summary = gr.HTML(label="Last Agent Response (read-only)") | |
| with gr.Row(): | |
| feedback_assessment = gr.Radio( | |
| choices=["Correct", "Incorrect"], | |
| label="Is the response correct?", | |
| value=None | |
| ) | |
| feedback_notes = gr.Textbox( | |
| label="Notes (optional)", | |
| placeholder="What should be improved or fixed?", | |
| lines=3 | |
| ) | |
| feedback_status = gr.Markdown(value="") | |
| submit_feedback = gr.Button("Submit Feedback", variant="primary") | |
| submit_feedback.click( | |
| fn=save_feedback, | |
| inputs=[feedback_assessment, feedback_notes, results_json, gr.State("overall")], | |
| outputs=feedback_status | |
| ) | |
| process_btn.click( | |
| fn=process_receipt, | |
| inputs=[input_image], | |
| outputs=[agent_summary, ocr_image_output, ocr_text_output, results_json, feedback_summary] | |
| ) | |
| # Launch (Spaces needs share=True when localhost is blocked) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=8).launch( | |
| share=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| # keep API enabled; json_schema traversal is guarded by the gradio_client | |
| # monkeypatch above (_safe_get_type / _safe_json_schema_to_python_type) | |
| show_api=True, | |
| ) | |