Ivanrs's picture
Update app.py
4efecd6 verified
"""
Simple Gradio Application for Anomaly Detection Testing
Shows embedding analysis instead of reconstructed images
"""
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import io
import base64
from simple_anomaly_detector import SimpleAnomalyDetector
from image_corruption_utils import corrupt_image
# Global variables to store models
models = {
"Daudon_MIX": "models/Daudon_MIX/best_autoencoder_Daudon_MIX.pth",
"Daudon_SEC": "models/Daudon_SEC/best_autoencoder_Daudon_SEC.pth",
"Daudon_SUR": "models/Daudon_SUR/best_autoencoder_Daudon_SUR.pth"
}
current_detector = None
current_model_name = None
def load_model(model_name):
"""Load the selected model"""
global current_detector, current_model_name
try:
if model_name != current_model_name:
print(f"Loading model: {model_name}")
model_path = models[model_name]
current_detector = SimpleAnomalyDetector(model_path)
current_model_name = model_name
return f"βœ… Model {model_name} loaded!"
return f"βœ… Model {model_name} already loaded"
except Exception as e:
return f"❌ Error loading {model_name}: {str(e)}"
def get_embedding_and_stats(image):
"""Get embedding from autoencoder and calculate statistics"""
try:
from torchvision import transforms
import config
# Get image size
if isinstance(config.IMAGE_SIZE, tuple):
target_size = config.IMAGE_SIZE
else:
target_size = (config.IMAGE_SIZE, config.IMAGE_SIZE)
# Preprocess
image_pil = image.convert('RGB').resize(target_size, Image.LANCZOS)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image_pil).unsqueeze(0).to(current_detector.device)
# Get embedding (latent representation)
with torch.no_grad():
_, embedding = current_detector.model(image_tensor)
# Convert to numpy for analysis
embedding_np = embedding.squeeze(0).cpu().numpy().flatten()
# Calculate statistics
stats_dict = {
'mean': float(np.mean(embedding_np)),
'median': float(np.median(embedding_np)),
'std': float(np.std(embedding_np)),
'min': float(np.min(embedding_np)),
'max': float(np.max(embedding_np)),
'q25': float(np.percentile(embedding_np, 25)),
'q75': float(np.percentile(embedding_np, 75)),
'skewness': float(stats.skew(embedding_np)),
'kurtosis': float(stats.kurtosis(embedding_np)),
'variance': float(np.var(embedding_np)),
'range': float(np.max(embedding_np) - np.min(embedding_np)),
'iqr': float(np.percentile(embedding_np, 75) - np.percentile(embedding_np, 25))
}
# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle(f'Embedding Analysis (Dimension: {len(embedding_np)})', fontsize=16)
# Histogram
axes[0, 0].hist(embedding_np, bins=50, alpha=0.7, color='skyblue', edgecolor='black')
axes[0, 0].set_title('Distribution Histogram')
axes[0, 0].set_xlabel('Embedding Values')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].grid(True, alpha=0.3)
# Box plot
axes[0, 1].boxplot(embedding_np, vert=True)
axes[0, 1].set_title('Box Plot')
axes[0, 1].set_ylabel('Embedding Values')
axes[0, 1].grid(True, alpha=0.3)
# Q-Q plot (normal distribution)
stats.probplot(embedding_np, dist="norm", plot=axes[1, 0])
axes[1, 0].set_title('Q-Q Plot (Normal Distribution)')
axes[1, 0].grid(True, alpha=0.3)
# Embedding values plot
axes[1, 1].plot(embedding_np, alpha=0.7, color='red', linewidth=1)
axes[1, 1].set_title('Embedding Values Sequence')
axes[1, 1].set_xlabel('Dimension Index')
axes[1, 1].set_ylabel('Value')
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
plot_image = Image.open(buf)
plt.close()
return embedding_np, stats_dict, plot_image
except Exception as e:
print(f"Error in embedding analysis: {e}")
return None, {}, None
def format_stats_text(stats_dict):
"""Format statistics into readable text"""
if not stats_dict:
return "❌ Error calculating statistics"
text = f"""πŸ“Š EMBEDDING STATISTICS
🎯 Central Tendency:
Mean: {stats_dict['mean']:.6f}
Median: {stats_dict['median']:.6f}
πŸ“ Spread:
Std Dev: {stats_dict['std']:.6f}
Variance: {stats_dict['variance']:.6f}
Range: {stats_dict['range']:.6f}
IQR: {stats_dict['iqr']:.6f}
πŸ“ˆ Extremes:
Min: {stats_dict['min']:.6f}
Max: {stats_dict['max']:.6f}
Q25: {stats_dict['q25']:.6f}
Q75: {stats_dict['q75']:.6f}
πŸ”„ Shape:
Skewness: {stats_dict['skewness']:.6f}
Kurtosis: {stats_dict['kurtosis']:.6f}
"""
return text
def classify_image(reconstruction_error, threshold):
"""Classify image as corrupted or clean based on threshold"""
is_corrupted = reconstruction_error > threshold
confidence = abs(reconstruction_error - threshold) / threshold * 100
if is_corrupted:
classification = "🚨 CORRUPTED/ANOMALOUS"
color_indicator = "πŸ”΄"
explanation = f"Reconstruction error ({reconstruction_error:.6f}) > Threshold ({threshold:.6f})"
else:
classification = "βœ… CLEAN/NORMAL"
color_indicator = "🟒"
explanation = f"Reconstruction error ({reconstruction_error:.6f}) ≀ Threshold ({threshold:.6f})"
# Calculate how far from threshold (as percentage)
distance_pct = (reconstruction_error - threshold) / threshold * 100
classification_text = f"""🎯 ANOMALY CLASSIFICATION
{color_indicator} Status: {classification}
πŸ“Š Details:
Reconstruction Error: {reconstruction_error:.6f}
Threshold: {threshold:.6f}
Distance from Threshold: {distance_pct:+.2f}%
πŸ“ Explanation:
{explanation}
πŸ’‘ Confidence Indicator:
β€’ Distance > 50%: High confidence
β€’ Distance 10-50%: Medium confidence
β€’ Distance < 10%: Low confidence (near threshold)
🎚️ Current Distance: {abs(distance_pct):.2f}% ({'High' if abs(distance_pct) > 50 else 'Medium' if abs(distance_pct) > 10 else 'Low'} confidence)"""
return classification_text, is_corrupted
def process_image(model_name, image, corruption_type, intensity, threshold):
"""Main processing function"""
try:
# Load model
load_status = load_model(model_name)
if "❌" in load_status:
return None, None, load_status, 0.0, "", ""
if image is None:
return None, None, "❌ Please upload an image", 0.0, "", ""
# Apply corruption
if corruption_type == "none":
corrupted_image = image.copy()
corruption_info = "No corruption applied"
else:
corrupted_image = corrupt_image(image, corruption_type, intensity)
corruption_info = f"Applied {corruption_type} corruption (intensity: {intensity})"
# Calculate reconstruction error
error = current_detector.calculate_reconstruction_error(corrupted_image)
# Get embedding and statistics
embedding, stats_dict, plot_image = get_embedding_and_stats(corrupted_image)
# Format statistics text
stats_text = format_stats_text(stats_dict)
# Classify image based on threshold
classification_text, is_corrupted = classify_image(error, threshold)
# Status message
status = f"""βœ… Processing complete!
πŸ“Š Model: {model_name}
πŸ”§ {corruption_info}
πŸ“ˆ Reconstruction Error: {error:.6f}
🎚️ Threshold: {threshold:.6f}
🎯 Classification: {'CORRUPTED' if is_corrupted else 'CLEAN'}
🧠 Embedding Dimension: {len(embedding) if embedding is not None else 'N/A'}
πŸ’‘ Higher error = more anomalous"""
return corrupted_image, plot_image, status, error, stats_text, classification_text
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
return None, None, error_msg, 0.0, "", ""
# Create interface
def create_interface():
with gr.Blocks(title="Anomaly Detection Tester") as demo:
gr.Markdown("# πŸ” Federated Autoencoder for Kidney Stone Image Corruption Detection")
gr.Markdown("Upload an image, analyze its latent representation, and classify it as corrupted or clean using a threshold.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### βš™οΈ Model & Corruption Settings")
model_dropdown = gr.Dropdown(
choices=list(models.keys()),
value="Daudon_MIX",
label="πŸ€– Select Model"
)
corruption_dropdown = gr.Dropdown(
choices=["none", "noise", "blur", "brightness", "contrast", "saturation", "random"],
value="none",
label="πŸ”§ Corruption Type"
)
intensity_slider = gr.Slider(
minimum=0.1,
maximum=3.0,
value=1.0,
step=0.1,
label="πŸ’ͺ Corruption Intensity"
)
gr.Markdown("### 🎚️ Classification Settings")
threshold_slider = gr.Slider(
minimum=0.1,
maximum=3.0,
value=1.0,
step=0.1,
label="🎯 Anomaly Threshold (Reconstruction Error)"
)
gr.Markdown("### πŸ“Έ Image Input")
image_input = gr.Image(type="pil", label="Upload Image")
# Add examples section
gr.Markdown("### πŸ“ Example Images")
# You can specify your example image paths here
example_images = [
["example_imgs/TypeIa_LaosN°15_Image21-25.png", "Clean Daudon MIX-Subtype_Ia"],
["example_imgs/72222-SectionIVa+WK maj_0009-60.png", "Clean Daudon MIX-Subtype_IVa"],
["example_imgs/TypeIVa2_N47583_Notteb-11.png", "Clean Daudon MIX-Subtype_IVa2"],
["example_imgs/typIVc_IVbsectbis-43.png", "Clean Daudon MIX-Subtype_IVc"],
["example_imgs/TypeIVd_Sect_LC3373-65.png", "Clean Daudon MIX-Subtype_IVd"],
["example_imgs/Section_Va_72845-3-18.png", "Clean Daudon MIX-Subtype_Va"],
]
examples_component = gr.Examples(
examples=example_images,
inputs=image_input,
label="Daudon MIX Example Clean Images",
examples_per_page=6,
cache_examples=False
)
process_btn = gr.Button("πŸš€ Analyze & Classify", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### πŸ“Š Results")
status_output = gr.Textbox(label="πŸ“‹ Status", lines=8)
error_output = gr.Number(label="πŸ“ˆ Reconstruction Error", precision=6)
corrupted_output = gr.Image(label="πŸ”§ Input Image (Corrupted)")
with gr.Row():
embedding_plot = gr.Image(label="🧠 Embedding Analysis")
with gr.Row():
stats_output = gr.Textbox(label="πŸ“Š Embedding Statistics", lines=20)
classification_output = gr.Textbox(label="🎯 Classification Result", lines=15)
# Connect the button
process_btn.click(
fn=process_image,
inputs=[model_dropdown, image_input, corruption_dropdown, intensity_slider, threshold_slider],
outputs=[corrupted_output, embedding_plot, status_output, error_output, stats_output, classification_output]
)
return demo
if __name__ == "__main__":
print("πŸš€ Starting Embedding Analysis App...")
demo = create_interface()
demo.launch()