ocr-llm-test / app.py
winamnd's picture
Update app.py
5d70ee9 verified
raw
history blame
3.96 kB
import gradio as gr
import torch
import json
import os
import cv2
import numpy as np
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch.nn.functional as F
from save_results import save_results_to_repo
# Paths
MODEL_PATH = "./distilbert_spam_model"
RESULTS_JSON = "ocr_results.json"
# Ensure model exists
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
model.save_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
tokenizer.save_pretrained(MODEL_PATH)
print(f"✅ Model saved at {MODEL_PATH}.")
else:
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
# Set the model to evaluation mode to disable dropout layers
model.eval()
# Load OCR Methods
def ocr_with_paddle(img):
ocr = PaddleOCR(lang='en', use_angle_cls=True)
result = ocr.ocr(img)
return ' '.join([item[1][0] for item in result[0]])
def ocr_with_keras(img):
pipeline = keras_ocr.pipeline.Pipeline()
images = [keras_ocr.tools.read(img)]
predictions = pipeline.recognize(images)
return ' '.join([text for text, _ in predictions[0]])
def ocr_with_easy(img):
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
reader = easyocr.Reader(['en'])
results = reader.readtext(gray_image, detail=0)
return ' '.join(results)
# OCR Function
def generate_ocr(method, img):
if img is None:
raise gr.Error("Please upload an image!")
# Convert PIL Image to OpenCV format
img = np.array(img)
# Select OCR method
if method == "PaddleOCR":
text_output = ocr_with_paddle(img)
elif method == "EasyOCR":
text_output = ocr_with_easy(img)
else: # KerasOCR
text_output = ocr_with_keras(img)
# Clean and truncate the extracted text
text_output = text_output.strip()
if len(text_output) == 0:
return "No text detected!", "Cannot classify"
# Classify Text as Spam or Not Spam
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
prediction = torch.argmax(probs, dim=1).item()
label_map = {0: "Not Spam", 1: "Spam"}
label = label_map[prediction]
# Save results using the external save function
save_results_to_repo(text_output, label)
return text_output, label
# Save results to JSON file
def save_to_json(text, label):
data = {"extracted_text": text, "classification": label}
with open(RESULTS_JSON, "w") as f:
json.dump(data, f, indent=4)
return "Results saved to JSON file!"
# Gradio Interface
image_input = gr.Image()
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
output_text = gr.Textbox(label="Extracted Text")
output_label = gr.Textbox(label="Spam Classification")
save_button = gr.Button("Save to JSON")
save_output = gr.Textbox(label="Save Status")
# Define the save button within the interface
demo = gr.Interface(
fn=generate_ocr,
inputs=[method_input, image_input],
outputs=[output_text, output_label],
title="OCR Spam Classifier",
description="Upload an image, extract text, and classify it as Spam or Not Spam.",
theme="compact",
)
# Create a separate interface for saving the results
save_interface = gr.Interface(
fn=save_to_json,
inputs=[output_text, output_label],
outputs=[save_output],
live=False
)
# Launch both interfaces together
demo.launch()
save_interface.launch()