import gradio as gr import torch from PIL import Image from transformers import AutoModelForImageClassification, AutoImageProcessor # Load model and processor with custom code enabled model = AutoModelForImageClassification.from_pretrained("shravvvv/SAG-ViT", trust_remote_code=True) processor = AutoImageProcessor.from_pretrained("shravvvv/SAG-ViT", trust_remote_code=True) # Define CIFAR-10 class labels class_labels = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ] # Define prediction function def predict(image): inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() return class_labels[predicted_class_idx] # Create Gradio interface iface = gr.Interface( fn=predict, inputs=gr.inputs.Image(type="pil"), outputs=gr.outputs.Label(), title="SAG-ViT Image Classifier", description="Upload an image to classify it using the SAG-ViT model." ) if __name__ == "__main__": iface.launch()