Iris_class / app.py
Rausda6's picture
Update app.py
5b7678f verified
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import numpy as np
import spaces
import logging
# Set up verbose logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Load model and processor from the Hugging Face Hub
MODEL_REPO = "Rausda6/autotrain-uo2t1-gvgzu" # Replace with your actual model repo name
logger.debug(f"Loading model from: {MODEL_REPO}")
model = AutoModelForImageClassification.from_pretrained(MODEL_REPO)
processor = AutoImageProcessor.from_pretrained(MODEL_REPO)
labels = model.config.id2label
@spaces.GPU
def classify_image(img: Image.Image):
logger.debug("Received image for classification.")
try:
inputs = processor(images=img, return_tensors="pt")
logger.debug(f"Processed inputs: {inputs}")
with torch.no_grad():
outputs = model(**inputs)
logger.debug(f"Model outputs: {outputs}")
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
logger.debug(f"Probabilities: {probs}")
# Build result dictionary with confidence values
probs_dict = {labels[i]: float(probs[i]) for i in range(len(probs))}
# Sort and format nicely
sorted_probs = sorted(probs_dict.items(), key=lambda x: x[1], reverse=True)
top_label, top_score = sorted_probs[0]
logger.debug(f"Top prediction: {top_label} with confidence {top_score:.2%}")
return top_label, dict(sorted_probs)
except Exception as e:
logger.exception("Error during classification")
raise e
# Gradio interface
demo = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(label="Top Prediction"), gr.Label(num_top_classes=6, label="Class Probabilities")],
title="Image Classification with AutoTrain Model",
description="Upload a JPG image to classify it using the fine-tuned model."
)
demo.launch()