PraneshJs's picture
Added files to hf space
fff6782 verified
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import tensorflow as tf
# Load the trained model
model = tf.keras.models.load_model('mnist_model.h5')
def cnn_predict_digit(image):
# Handle Gradio Sketchpad dictionary input
if isinstance(image, dict) and 'composite' in image:
image = image['composite']
# Convert to grayscale if RGB
if image.ndim == 3 and image.shape[2] == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Invert colors (white background β†’ black background)
image = 255 - image
# Resize to 28x28
image = cv2.resize(image, (28, 28))
# Normalize and reshape
image = image.astype('float32') / 255.0
image = image.reshape(1, 28, 28, 1)
# Predict
prediction = model.predict(image)
pred_label = np.argmax(prediction, axis=1)[0]
return str(pred_label)
with gr.Blocks() as interface:
gr.Markdown(
"""
## ✍️ Digit Classification using Convolutional Neural Network
Draw a digit in the sketchpad below (0 to 9), then click **Submit** to see the prediction.
"""
)
with gr.Row():
sketchpad = gr.Sketchpad(image_mode='L')
output = gr.Label()
gr.Button("Submit").click(cnn_predict_digit, inputs=sketchpad, outputs=output)
gr.ClearButton([sketchpad, output])
interface.launch()