Spaces:
Runtime error
Runtime error
import gradio as gr | |
import tensorflow as tf | |
from tensorflow import keras | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
# ืืืจืืช ืืืืขื ืช ืืืืื | |
model_path = hf_hub_download( | |
repo_id="GiladtheFixer/my_mnist_model", | |
filename="mnist_model.keras" | |
) | |
model = keras.models.load_model(model_path) | |
def predict_digit(sketch_data): | |
img = sketch_data["composite"] | |
# ืืงืืืช ืขืจืืฅ ืืืืคื ืืืืคืื ืฆืืขืื | |
alpha_channel = img[..., 3] | |
img = alpha_channel / 255.0 | |
# ืฉืื ืื ืืืื ื-28x28 | |
resized = tf.image.resize( | |
tf.expand_dims(img, -1), | |
[28, 28], | |
method='bilinear' | |
) | |
resized = tf.squeeze(resized) | |
# ืืื ืช ืืงืื ืืืืื | |
input_data = resized.numpy().reshape(1, 28, 28) | |
# ืืืืื | |
pred = model.predict(input_data, verbose=0) | |
return {str(i): float(pred[0][i]) for i in range(10)} | |
demo = gr.Interface( | |
fn=predict_digit, | |
inputs=[ | |
gr.Sketchpad( | |
label="draw some digit", | |
height=400, | |
width=400, | |
brush=None, | |
interactive=True | |
) | |
], | |
outputs=gr.Label(num_top_classes=3), | |
title="MNIST_by Gilad", | |
description="draw some digit with brush or clear your board then click submit", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch() |