File size: 1,380 Bytes
144449f
 
9f8712e
144449f
 
 
5eaa0c6
 
 
 
 
9f8712e
a8fadd0
5eaa0c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144449f
5eaa0c6
 
 
 
 
 
 
 
 
144449f
5eaa0c6
 
 
144449f
 
5eaa0c6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
import tensorflow as tf
from tensorflow import keras
import numpy as np
from huggingface_hub import hf_hub_download

# 讛讜专讚转 讜讛讟注谞转 讛诪讜讚诇
model_path = hf_hub_download(
    repo_id="GiladtheFixer/my_mnist_model",
    filename="mnist_model.keras"
)
model = keras.models.load_model(model_path)


def predict_digit(sketch_data):
    img = sketch_data["composite"]

    # 诇拽讬讞转 注专讜抓 讛讗诇驻讗 讜讛讬驻讜讱 爪讘注讬诐
    alpha_channel = img[..., 3]
    img = alpha_channel / 255.0

    # 砖讬谞讜讬 讙讜讚诇 诇-28x28
    resized = tf.image.resize(
        tf.expand_dims(img, -1),
        [28, 28],
        method='bilinear'
    )
    resized = tf.squeeze(resized)

    # 讛讻谞转 讛拽诇讟 诇诪讜讚诇
    input_data = resized.numpy().reshape(1, 28, 28)

    # 讞讬讝讜讬
    pred = model.predict(input_data, verbose=0)

    return {str(i): float(pred[0][i]) for i in range(10)}


demo = gr.Interface(
    fn=predict_digit,
    inputs=[
        gr.Sketchpad(
            label="draw some digit",
            height=400,
            width=400,
            brush=None,
            interactive=True
        )
    ],
    outputs=gr.Label(num_top_classes=3),
    title="MNIST_by Gilad",
    description="draw some digit with brush or clear your board then click submit",
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch()