""" Simple Gradio Application for Anomaly Detection Testing Shows embedding analysis instead of reconstructed images """ import gradio as gr import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt import seaborn as sns from scipy import stats import io import base64 from simple_anomaly_detector import SimpleAnomalyDetector from image_corruption_utils import corrupt_image # Global variables to store models models = { "Daudon_MIX": "models/Daudon_MIX/best_autoencoder_Daudon_MIX.pth", "Daudon_SEC": "models/Daudon_SEC/best_autoencoder_Daudon_SEC.pth", "Daudon_SUR": "models/Daudon_SUR/best_autoencoder_Daudon_SUR.pth" } current_detector = None current_model_name = None def load_model(model_name): """Load the selected model""" global current_detector, current_model_name try: if model_name != current_model_name: print(f"Loading model: {model_name}") model_path = models[model_name] current_detector = SimpleAnomalyDetector(model_path) current_model_name = model_name return f"βœ… Model {model_name} loaded!" return f"βœ… Model {model_name} already loaded" except Exception as e: return f"❌ Error loading {model_name}: {str(e)}" def get_embedding_and_stats(image): """Get embedding from autoencoder and calculate statistics""" try: from torchvision import transforms import config # Get image size if isinstance(config.IMAGE_SIZE, tuple): target_size = config.IMAGE_SIZE else: target_size = (config.IMAGE_SIZE, config.IMAGE_SIZE) # Preprocess image_pil = image.convert('RGB').resize(target_size, Image.LANCZOS) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image_tensor = transform(image_pil).unsqueeze(0).to(current_detector.device) # Get embedding (latent representation) with torch.no_grad(): _, embedding = current_detector.model(image_tensor) # Convert to numpy for analysis embedding_np = embedding.squeeze(0).cpu().numpy().flatten() # Calculate statistics stats_dict = { 'mean': float(np.mean(embedding_np)), 'median': float(np.median(embedding_np)), 'std': float(np.std(embedding_np)), 'min': float(np.min(embedding_np)), 'max': float(np.max(embedding_np)), 'q25': float(np.percentile(embedding_np, 25)), 'q75': float(np.percentile(embedding_np, 75)), 'skewness': float(stats.skew(embedding_np)), 'kurtosis': float(stats.kurtosis(embedding_np)), 'variance': float(np.var(embedding_np)), 'range': float(np.max(embedding_np) - np.min(embedding_np)), 'iqr': float(np.percentile(embedding_np, 75) - np.percentile(embedding_np, 25)) } # Create visualization fig, axes = plt.subplots(2, 2, figsize=(12, 10)) fig.suptitle(f'Embedding Analysis (Dimension: {len(embedding_np)})', fontsize=16) # Histogram axes[0, 0].hist(embedding_np, bins=50, alpha=0.7, color='skyblue', edgecolor='black') axes[0, 0].set_title('Distribution Histogram') axes[0, 0].set_xlabel('Embedding Values') axes[0, 0].set_ylabel('Frequency') axes[0, 0].grid(True, alpha=0.3) # Box plot axes[0, 1].boxplot(embedding_np, vert=True) axes[0, 1].set_title('Box Plot') axes[0, 1].set_ylabel('Embedding Values') axes[0, 1].grid(True, alpha=0.3) # Q-Q plot (normal distribution) stats.probplot(embedding_np, dist="norm", plot=axes[1, 0]) axes[1, 0].set_title('Q-Q Plot (Normal Distribution)') axes[1, 0].grid(True, alpha=0.3) # Embedding values plot axes[1, 1].plot(embedding_np, alpha=0.7, color='red', linewidth=1) axes[1, 1].set_title('Embedding Values Sequence') axes[1, 1].set_xlabel('Dimension Index') axes[1, 1].set_ylabel('Value') axes[1, 1].grid(True, alpha=0.3) plt.tight_layout() # Convert plot to image buf = io.BytesIO() plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') buf.seek(0) plot_image = Image.open(buf) plt.close() return embedding_np, stats_dict, plot_image except Exception as e: print(f"Error in embedding analysis: {e}") return None, {}, None def format_stats_text(stats_dict): """Format statistics into readable text""" if not stats_dict: return "❌ Error calculating statistics" text = f"""πŸ“Š EMBEDDING STATISTICS 🎯 Central Tendency: Mean: {stats_dict['mean']:.6f} Median: {stats_dict['median']:.6f} πŸ“ Spread: Std Dev: {stats_dict['std']:.6f} Variance: {stats_dict['variance']:.6f} Range: {stats_dict['range']:.6f} IQR: {stats_dict['iqr']:.6f} πŸ“ˆ Extremes: Min: {stats_dict['min']:.6f} Max: {stats_dict['max']:.6f} Q25: {stats_dict['q25']:.6f} Q75: {stats_dict['q75']:.6f} πŸ”„ Shape: Skewness: {stats_dict['skewness']:.6f} Kurtosis: {stats_dict['kurtosis']:.6f} """ return text def classify_image(reconstruction_error, threshold): """Classify image as corrupted or clean based on threshold""" is_corrupted = reconstruction_error > threshold confidence = abs(reconstruction_error - threshold) / threshold * 100 if is_corrupted: classification = "🚨 CORRUPTED/ANOMALOUS" color_indicator = "πŸ”΄" explanation = f"Reconstruction error ({reconstruction_error:.6f}) > Threshold ({threshold:.6f})" else: classification = "βœ… CLEAN/NORMAL" color_indicator = "🟒" explanation = f"Reconstruction error ({reconstruction_error:.6f}) ≀ Threshold ({threshold:.6f})" # Calculate how far from threshold (as percentage) distance_pct = (reconstruction_error - threshold) / threshold * 100 classification_text = f"""🎯 ANOMALY CLASSIFICATION {color_indicator} Status: {classification} πŸ“Š Details: Reconstruction Error: {reconstruction_error:.6f} Threshold: {threshold:.6f} Distance from Threshold: {distance_pct:+.2f}% πŸ“ Explanation: {explanation} πŸ’‘ Confidence Indicator: β€’ Distance > 50%: High confidence β€’ Distance 10-50%: Medium confidence β€’ Distance < 10%: Low confidence (near threshold) 🎚️ Current Distance: {abs(distance_pct):.2f}% ({'High' if abs(distance_pct) > 50 else 'Medium' if abs(distance_pct) > 10 else 'Low'} confidence)""" return classification_text, is_corrupted def process_image(model_name, image, corruption_type, intensity, threshold): """Main processing function""" try: # Load model load_status = load_model(model_name) if "❌" in load_status: return None, None, load_status, 0.0, "", "" if image is None: return None, None, "❌ Please upload an image", 0.0, "", "" # Apply corruption if corruption_type == "none": corrupted_image = image.copy() corruption_info = "No corruption applied" else: corrupted_image = corrupt_image(image, corruption_type, intensity) corruption_info = f"Applied {corruption_type} corruption (intensity: {intensity})" # Calculate reconstruction error error = current_detector.calculate_reconstruction_error(corrupted_image) # Get embedding and statistics embedding, stats_dict, plot_image = get_embedding_and_stats(corrupted_image) # Format statistics text stats_text = format_stats_text(stats_dict) # Classify image based on threshold classification_text, is_corrupted = classify_image(error, threshold) # Status message status = f"""βœ… Processing complete! πŸ“Š Model: {model_name} πŸ”§ {corruption_info} πŸ“ˆ Reconstruction Error: {error:.6f} 🎚️ Threshold: {threshold:.6f} 🎯 Classification: {'CORRUPTED' if is_corrupted else 'CLEAN'} 🧠 Embedding Dimension: {len(embedding) if embedding is not None else 'N/A'} πŸ’‘ Higher error = more anomalous""" return corrupted_image, plot_image, status, error, stats_text, classification_text except Exception as e: error_msg = f"❌ Error: {str(e)}" return None, None, error_msg, 0.0, "", "" # Create interface def create_interface(): with gr.Blocks(title="Anomaly Detection Tester") as demo: gr.Markdown("# πŸ” Federated Autoencoder for Kidney Stone Image Corruption Detection") gr.Markdown("Upload an image, analyze its latent representation, and classify it as corrupted or clean using a threshold.") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### βš™οΈ Model & Corruption Settings") model_dropdown = gr.Dropdown( choices=list(models.keys()), value="Daudon_MIX", label="πŸ€– Select Model" ) corruption_dropdown = gr.Dropdown( choices=["none", "noise", "blur", "brightness", "contrast", "saturation", "random"], value="none", label="πŸ”§ Corruption Type" ) intensity_slider = gr.Slider( minimum=0.1, maximum=3.0, value=1.0, step=0.1, label="πŸ’ͺ Corruption Intensity" ) gr.Markdown("### 🎚️ Classification Settings") threshold_slider = gr.Slider( minimum=0.1, maximum=3.0, value=1.0, step=0.1, label="🎯 Anomaly Threshold (Reconstruction Error)" ) gr.Markdown("### πŸ“Έ Image Input") image_input = gr.Image(type="pil", label="Upload Image") # Add examples section gr.Markdown("### πŸ“ Example Images") # You can specify your example image paths here example_images = [ ["example_imgs/TypeIa_LaosN°15_Image21-25.png", "Clean Daudon MIX-Subtype_Ia"], ["example_imgs/72222-SectionIVa+WK maj_0009-60.png", "Clean Daudon MIX-Subtype_IVa"], ["example_imgs/TypeIVa2_N47583_Notteb-11.png", "Clean Daudon MIX-Subtype_IVa2"], ["example_imgs/typIVc_IVbsectbis-43.png", "Clean Daudon MIX-Subtype_IVc"], ["example_imgs/TypeIVd_Sect_LC3373-65.png", "Clean Daudon MIX-Subtype_IVd"], ["example_imgs/Section_Va_72845-3-18.png", "Clean Daudon MIX-Subtype_Va"], ] examples_component = gr.Examples( examples=example_images, inputs=image_input, label="Daudon MIX Example Clean Images", examples_per_page=6, cache_examples=False ) process_btn = gr.Button("πŸš€ Analyze & Classify", variant="primary", size="lg") with gr.Column(scale=1): gr.Markdown("### πŸ“Š Results") status_output = gr.Textbox(label="πŸ“‹ Status", lines=8) error_output = gr.Number(label="πŸ“ˆ Reconstruction Error", precision=6) corrupted_output = gr.Image(label="πŸ”§ Input Image (Corrupted)") with gr.Row(): embedding_plot = gr.Image(label="🧠 Embedding Analysis") with gr.Row(): stats_output = gr.Textbox(label="πŸ“Š Embedding Statistics", lines=20) classification_output = gr.Textbox(label="🎯 Classification Result", lines=15) # Connect the button process_btn.click( fn=process_image, inputs=[model_dropdown, image_input, corruption_dropdown, intensity_slider, threshold_slider], outputs=[corrupted_output, embedding_plot, status_output, error_output, stats_output, classification_output] ) return demo if __name__ == "__main__": print("πŸš€ Starting Embedding Analysis App...") demo = create_interface() demo.launch()