testMnist / app.py
GiladtheFixer's picture
Update app.py
5eaa0c6 verified
import gradio as gr
import tensorflow as tf
from tensorflow import keras
import numpy as np
from huggingface_hub import hf_hub_download
# ื”ื•ืจื“ืช ื•ื”ื˜ืขื ืช ื”ืžื•ื“ืœ
model_path = hf_hub_download(
repo_id="GiladtheFixer/my_mnist_model",
filename="mnist_model.keras"
)
model = keras.models.load_model(model_path)
def predict_digit(sketch_data):
img = sketch_data["composite"]
# ืœืงื™ื—ืช ืขืจื•ืฅ ื”ืืœืคื ื•ื”ื™ืคื•ืš ืฆื‘ืขื™ื
alpha_channel = img[..., 3]
img = alpha_channel / 255.0
# ืฉื™ื ื•ื™ ื’ื•ื“ืœ ืœ-28x28
resized = tf.image.resize(
tf.expand_dims(img, -1),
[28, 28],
method='bilinear'
)
resized = tf.squeeze(resized)
# ื”ื›ื ืช ื”ืงืœื˜ ืœืžื•ื“ืœ
input_data = resized.numpy().reshape(1, 28, 28)
# ื—ื™ื–ื•ื™
pred = model.predict(input_data, verbose=0)
return {str(i): float(pred[0][i]) for i in range(10)}
demo = gr.Interface(
fn=predict_digit,
inputs=[
gr.Sketchpad(
label="draw some digit",
height=400,
width=400,
brush=None,
interactive=True
)
],
outputs=gr.Label(num_top_classes=3),
title="MNIST_by Gilad",
description="draw some digit with brush or clear your board then click submit",
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()