KaushiGihan's picture
Update app.py
29fe14e verified
import gradio as gr
from pathlib import Path
from PIL import Image
# Import your model classes (adjust import paths as needed)
from app.src.vit_load import VITDocumentClassifier
from app.src.vgg16_load import VGGDocumentClassifier
from app.src.constant import vit_model_path, vit_mlb_path, vgg_model_path, vgg_mlb_path
# Load models once at startup
vit_model = VITDocumentClassifier(vit_model_path, vit_mlb_path)
vgg_model = VGGDocumentClassifier(vgg_model_path, vgg_mlb_path)
def predict_vit(image, cut_off):
if image is None:
return "Please upload an image."
temp_path = "temp_vit_image.png"
image.save(temp_path)
result = vit_model.predict(Path(temp_path), cut_off)
return f"ViT Prediction: {result}"
def predict_vgg(image):
if image is None:
return "Please upload an image."
temp_path = "temp_vgg_image.png"
image.save(temp_path)
result = vgg_model.predict(Path(temp_path))
return f"VGG16 Prediction: {result}"
with gr.Blocks() as demo:
gr.Markdown("# Document Classification Demo\nUpload an image and choose a model to classify it.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
cut_off = gr.Slider(0, 1, value=0.5, label="ViT Cutoff Threshold")
with gr.Column():
result_output = gr.Textbox(label="Prediction Result", interactive=False)
with gr.Row():
vit_btn = gr.Button("Predict with ViT Model")
vgg_btn = gr.Button("Predict with VGG16 Model")
vit_btn.click(fn=predict_vit, inputs=[image_input, cut_off], outputs=result_output)
vgg_btn.click(fn=predict_vgg, inputs=image_input, outputs=result_output)
if __name__ == "__main__":
demo.launch(ssr_mode=False, share=True)