import gradio as gr import torch import json import os import cv2 import numpy as np import easyocr import keras_ocr from paddleocr import PaddleOCR from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import torch.nn.functional as F from PIL import Image import pytesseract import io # Import save function from save_results import save_results_to_repo # Paths MODEL_PATH = "./distilbert_spam_model" # Ensure LLM Model exists if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")): print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...") model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) model.save_pretrained(MODEL_PATH) tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") tokenizer.save_pretrained(MODEL_PATH) print(f"✅ Model saved at {MODEL_PATH}.") else: model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH) tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH) # Ensure model is in evaluation mode model.eval() # Function to process image for OCR def preprocess_image(image): """Convert PIL image to OpenCV format (NumPy array)""" return np.array(image) # OCR Functions (same as ocr-api) def ocr_with_paddle(img): ocr = PaddleOCR(lang='en', use_angle_cls=True) result = ocr.ocr(img) extracted_text, confidences = [], [] for line in result[0]: text, confidence = line[1] extracted_text.append(text) confidences.append(confidence) return extracted_text, confidences def ocr_with_keras(img): pipeline = keras_ocr.pipeline.Pipeline() images = [keras_ocr.tools.read(img)] predictions = pipeline.recognize(images) extracted_text = [text for text, confidence in predictions[0]] confidences = [confidence for text, confidence in predictions[0]] return extracted_text, confidences def ocr_with_easy(img): gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) reader = easyocr.Reader(['en']) results = reader.readtext(gray_image) extracted_text = [text for _, text, confidence in results] confidences = [confidence for _, text, confidence in results] return extracted_text, confidences def ocr_with_tesseract(img): gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) extracted_text = pytesseract.image_to_string(gray_image).split("\n") extracted_text = [line.strip() for line in extracted_text if line.strip()] confidences = [1.0] * len(extracted_text) # Tesseract doesn't return confidence scores return extracted_text, confidences # OCR & Classification Function def generate_ocr(method, image): if image is None: raise gr.Error("Please upload an image!") # Convert PIL Image to OpenCV format img_cv = preprocess_image(image) # Select OCR method if method == "PaddleOCR": extracted_text, confidences = ocr_with_paddle(img_cv) elif method == "EasyOCR": extracted_text, confidences = ocr_with_easy(img_cv) elif method == "KerasOCR": extracted_text, confidences = ocr_with_keras(img_cv) elif method == "TesseractOCR": extracted_text, confidences = ocr_with_tesseract(img_cv) else: return "Invalid OCR method", "N/A" # Join extracted text into a single string text_output = " ".join(extracted_text).strip() # If no text detected, return early if len(text_output) == 0: return "No text detected!", "Cannot classify" # Tokenize text for classification inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512) # Perform inference with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Get raw logits # Print raw logits to debug print(f"Raw logits: {logits}") # Convert logits to probabilities using softmax probs = F.softmax(logits, dim=1) # Extract probability values not_spam_prob = probs[0, 0].item() spam_prob = probs[0, 1].item() # Print probability values for debugging print(f"Not Spam Probability: {not_spam_prob}, Spam Probability: {spam_prob}") # Ensure correct label mapping predicted_class = torch.argmax(probs, dim=1).item() # Get predicted class index print(f"Predicted Class Index: {predicted_class}") # Debugging output # Check if the labels are flipped if predicted_class == 1: label = "Spam" else: label = "Not Spam" # Save results save_results_to_repo(text_output, label) return text_output, label # Gradio Interface image_input = gr.Image() method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR", "TesseractOCR"], value="PaddleOCR") output_text = gr.Textbox(label="Extracted Text") output_label = gr.Textbox(label="Spam Classification") demo = gr.Interface( generate_ocr, inputs=[method_input, image_input], outputs=[output_text, output_label], title="OCR Spam Classifier", description="Upload an image, extract text using OCR, and classify it as Spam or Not Spam.", theme="compact", ) # Launch App if __name__ == "__main__": demo.launch()