import gradio as gr import time import torch from transformers import pipeline from PIL import Image import pandas as pd import matplotlib.pyplot as plt import io # --- 1. Model Configuration & Metadata --- DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" MODEL_INFO = { "ViT (eslamxm/vit-base-food101)": { "model_id": "eslamxm/vit-base-food101", "benchmark_accuracy": 90.68, "pipeline": None }, "Swin (aspis/swin-finetuned-food101)": { "model_id": "aspis/swin-finetuned-food101", "benchmark_accuracy": 93.81, "pipeline": None } } # --- 2. Lazy Loading of Models --- def load_pipeline(model_name): """Loads a model pipeline only when it's first needed.""" if MODEL_INFO[model_name]["pipeline"] is None: print(f"Loading model: {model_name}...") model_id = MODEL_INFO[model_name]["model_id"] MODEL_INFO[model_name]["pipeline"] = pipeline(task="image-classification", model=model_id, device=DEVICE) print(f"Model '{model_name}' loaded on {DEVICE}.") return MODEL_INFO[model_name]["pipeline"] # --- 3. Function to Generate Comparison Chart --- def create_comparison_chart(selected_model_name, current_inference_time): """Generates a bar chart comparing model accuracy and inference time.""" data = {'Model': [], 'Metric': [], 'Value': []} for name, info in MODEL_INFO.items(): data['Model'].append(name) data['Metric'].append('Benchmark Accuracy (%)') data['Value'].append(info['benchmark_accuracy']) data['Model'].append(selected_model_name) data['Metric'].append('Current Inference Time (s)') data['Value'].append(current_inference_time) df = pd.DataFrame(data) fig, ax = plt.subplots(1, 2, figsize=(12, 5)) fig.suptitle('Model Performance Comparison', fontsize=16) acc_df = df[df['Metric'] == 'Benchmark Accuracy (%)'] colors_acc = ['#4c72b0' if model != selected_model_name else '#2ca02c' for model in acc_df['Model']] acc_plot = acc_df.plot(kind='bar', x='Model', y='Value', ax=ax[0], color=colors_acc, legend=None) ax[0].set_title('Benchmark Accuracy') ax[0].set_ylabel('Accuracy (%)') ax[0].set_xlabel('') ax[0].set_ylim(0, 100) ax[0].tick_params(axis='x', rotation=10) for p in acc_plot.patches: ax[0].annotate(f"{p.get_height():.2f}%", (p.get_x() + p.get_width() / 2., p.get_height()), ha='center', va='center', xytext=(0, 9), textcoords='offset points') time_df = df[df['Metric'] == 'Current Inference Time (s)'] time_plot = time_df.plot(kind='bar', x='Model', y='Value', ax=ax[1], color=['#d62728']) ax[1].set_title('Inference Time for This Image') ax[1].set_ylabel('Time (seconds)') ax[1].set_xlabel('') ax[1].tick_params(axis='x', rotation=0) for p in time_plot.patches: ax[1].annotate(f"{p.get_height():.4f}s", (p.get_x() + p.get_width() / 2., p.get_height()), ha='center', va='center', xytext=(0, 9), textcoords='offset points') plt.tight_layout(rect=[0, 0.03, 1, 0.95]) return fig # --- 4. The Core Classification Function --- def classify_image(image, model_name): """ Takes an image and model name, returns predictions, inference time, and a comparison chart. """ if image is None: return {}, "Please upload an image first.", None, "Please upload an image to see a comparison." pipe = load_pipeline(model_name) start_time = time.time() predictions = pipe(Image.fromarray(image)) end_time = time.time() inference_time = end_time - start_time top_5_preds = {p['label'].replace("_", " ").title(): p['score'] for p in predictions[:5]} comparison_fig = create_comparison_chart(model_name, inference_time) buf = io.BytesIO() comparison_fig.savefig(buf, format='png', bbox_inches='tight') buf.seek(0) comparison_img = Image.open(buf) plt.close(comparison_fig) return ( top_5_preds, f"Inference Time: {inference_time:.4f} seconds", comparison_img, f"Chart shows accuracy for all models and the inference time for the **{model_name}** model on this specific image." ) # --- 5. Gradio Interface Definition --- with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: gr.Markdown("# 🍔 Food Classifier: Accuracy vs. Speed") gr.Markdown( "Compare two different models for classifying food images from the Food101 dataset. " "Notice the trade-off: the **Swin** model is more accurate but might be slower, while the **ViT** model is faster but slightly less accurate." ) with gr.Row(variant="panel"): with gr.Column(scale=1): image_input = gr.Image(type="numpy", label="Upload a food picture") model_dropdown = gr.Dropdown( choices=list(MODEL_INFO.keys()), value=list(MODEL_INFO.keys())[0], label="Choose a Model" ) classify_button = gr.Button("Classify Image", variant="primary") gr.Examples( examples=[ ["examples/sushi.jpg", list(MODEL_INFO.keys())[1]], ["examples/pizza.jpg", list(MODEL_INFO.keys())[0]], ["examples/apple_pie.jpg", list(MODEL_INFO.keys())[1]], ], inputs=[image_input, model_dropdown], ) with gr.Column(scale=2): output_label = gr.Label(num_top_classes=5, label="Top 5 Predictions") output_time = gr.Textbox(label="Performance") output_chart = gr.Image(type="pil", label="Model Comparison Chart") chart_info = gr.Markdown() classify_button.click( fn=classify_image, inputs=[image_input, model_dropdown], outputs=[output_label, output_time, output_chart, chart_info] ) if __name__ == "__main__": demo.launch()