File size: 3,589 Bytes
53bee95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d0d58d
53bee95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import cv2
import numpy as np
from PIL import Image
import pickle
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
import easyocr

# === Load Model and Label Encoder ===
model_path = "MobileNetBest_Model.h5"
label_path = "MobileNet_Label_Encoder.pkl"

model = load_model(model_path)
print("Model loaded.")

# Load label encoder
try:
    with open(label_path, 'rb') as f:
        label_map = pickle.load(f)
        index_to_label = {v: k for k, v in label_map.items()}
    print("Label encoder loaded:", index_to_label)
except:
    index_to_label = {0: "Handwritten", 1: "Computerized"}
    print("Label encoder not found. Using default:", index_to_label)

# === Initialize EasyOCR Reader Once (with GPU) ===
reader = easyocr.Reader(['en'], gpu=True)
print("EasyOCR Reader initialized with GPU.")

# === Classify Region ===
def classify_text_region(region_img):
    try:
        region_img = cv2.resize(region_img, (224, 224))
        region_img = region_img.astype("float32") / 255.0
        region_img = img_to_array(region_img)
        region_img = np.expand_dims(region_img, axis=0)

        preds = model.predict(region_img)

        if preds.shape[-1] == 1:
            return "Computerized" if preds[0][0] > 0.5 else "Handwritten"
        else:
            class_idx = np.argmax(preds[0])
            return index_to_label.get(class_idx, "Unknown")
    except Exception as e:
        print("Classification error:", e)
        return "Unknown"

# === OCR + Annotation ===
def AnnotatedTextDetection_EasyOCR_from_array(img):
    results = reader.readtext(img)
    annotated_results = []

    for (bbox, text, conf) in results[:50]:  # Limit to top 20 boxes
        if conf < 0.3 or text.strip() == "":
            continue

        x1, y1 = map(int, bbox[0])
        x2, y2 = map(int, bbox[2])
        crop = img[y1:y2, x1:x2]
        if crop.size == 0:
            continue

        label = classify_text_region(crop)
        annotated_results.append(f"{text.strip()}{label}")

        color = (0, 255, 0) if label == "Computerized" else (255, 0, 0)
        cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
        cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 1)

    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), "\n".join(annotated_results)

# === Gradio Wrapper ===
def infer(image):
    img = np.array(image)

    # Resize if image is too large
    max_dim = 1000
    if img.shape[0] > max_dim or img.shape[1] > max_dim:
        scale = max_dim / max(img.shape[0], img.shape[1])
        img = cv2.resize(img, (int(img.shape[1]*scale), int(img.shape[0]*scale)))

    annotated_img, result_text = AnnotatedTextDetection_EasyOCR_from_array(img)
    return Image.fromarray(annotated_img), result_text

# === Custom CSS ===
custom_css = """
body {
    background-color: #e6f2ff;
}
.gradio-container {
    border-radius: 12px;
    padding: 20px;
    border: 2px solid #007acc;
}
.gr-input, .gr-output {
    border: 1px solid #007acc;
    border-radius: 10px;
}
"""

# === Launch Interface ===
demo = gr.Interface(
    fn=infer,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=[
        gr.Image(type="pil", label="Annotated Image"),
        gr.Textbox(label="Detected Text and Classification")
    ],
    title="Text Detection and Classification",
    description="This application detects text using EasyOCR and classifies each text region as Handwritten or Computerized using a MobileNet model.",
    theme="soft",
    css=custom_css
)
demo.launch()