import gradio as gr import numpy as np import cv2 from PIL import Image import tensorflow as tf # Load the trained model model = tf.keras.models.load_model('mnist_model.h5') def cnn_predict_digit(image): # Handle Gradio Sketchpad dictionary input if isinstance(image, dict) and 'composite' in image: image = image['composite'] # Convert to grayscale if RGB if image.ndim == 3 and image.shape[2] == 3: image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # Invert colors (white background → black background) image = 255 - image # Resize to 28x28 image = cv2.resize(image, (28, 28)) # Normalize and reshape image = image.astype('float32') / 255.0 image = image.reshape(1, 28, 28, 1) # Predict prediction = model.predict(image) pred_label = np.argmax(prediction, axis=1)[0] return str(pred_label) with gr.Blocks() as interface: gr.Markdown( """ ## ✍️ Digit Classification using Convolutional Neural Network Draw a digit in the sketchpad below (0 to 9), then click **Submit** to see the prediction. """ ) with gr.Row(): sketchpad = gr.Sketchpad(image_mode='L') output = gr.Label() gr.Button("Submit").click(cnn_predict_digit, inputs=sketchpad, outputs=output) gr.ClearButton([sketchpad, output]) interface.launch()