|
import gradio as gr |
|
from pathlib import Path |
|
from PIL import Image |
|
|
|
|
|
from app.src.vit_load import VITDocumentClassifier |
|
from app.src.vgg16_load import VGGDocumentClassifier |
|
from app.src.constant import vit_model_path, vit_mlb_path, vgg_model_path, vgg_mlb_path |
|
|
|
|
|
vit_model = VITDocumentClassifier(vit_model_path, vit_mlb_path) |
|
vgg_model = VGGDocumentClassifier(vgg_model_path, vgg_mlb_path) |
|
|
|
def predict_vit(image, cut_off): |
|
if image is None: |
|
return "Please upload an image." |
|
temp_path = "temp_vit_image.png" |
|
image.save(temp_path) |
|
result = vit_model.predict(Path(temp_path), cut_off) |
|
return f"ViT Prediction: {result}" |
|
|
|
def predict_vgg(image): |
|
if image is None: |
|
return "Please upload an image." |
|
temp_path = "temp_vgg_image.png" |
|
image.save(temp_path) |
|
result = vgg_model.predict(Path(temp_path)) |
|
return f"VGG16 Prediction: {result}" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Document Classification Demo\nUpload an image and choose a model to classify it.") |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
cut_off = gr.Slider(0, 1, value=0.5, label="ViT Cutoff Threshold") |
|
with gr.Column(): |
|
result_output = gr.Textbox(label="Prediction Result", interactive=False) |
|
with gr.Row(): |
|
vit_btn = gr.Button("Predict with ViT Model") |
|
vgg_btn = gr.Button("Predict with VGG16 Model") |
|
|
|
vit_btn.click(fn=predict_vit, inputs=[image_input, cut_off], outputs=result_output) |
|
vgg_btn.click(fn=predict_vgg, inputs=image_input, outputs=result_output) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(ssr_mode=False, share=True) |