|
""" |
|
Simple Gradio Application for Anomaly Detection Testing |
|
Shows embedding analysis instead of reconstructed images |
|
""" |
|
|
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from scipy import stats |
|
import io |
|
import base64 |
|
from simple_anomaly_detector import SimpleAnomalyDetector |
|
from image_corruption_utils import corrupt_image |
|
|
|
|
|
|
|
models = { |
|
"Daudon_MIX": "models/Daudon_MIX/best_autoencoder_Daudon_MIX.pth", |
|
"Daudon_SEC": "models/Daudon_SEC/best_autoencoder_Daudon_SEC.pth", |
|
"Daudon_SUR": "models/Daudon_SUR/best_autoencoder_Daudon_SUR.pth" |
|
} |
|
|
|
current_detector = None |
|
current_model_name = None |
|
|
|
|
|
def load_model(model_name): |
|
"""Load the selected model""" |
|
global current_detector, current_model_name |
|
|
|
try: |
|
if model_name != current_model_name: |
|
print(f"Loading model: {model_name}") |
|
model_path = models[model_name] |
|
current_detector = SimpleAnomalyDetector(model_path) |
|
current_model_name = model_name |
|
return f"β
Model {model_name} loaded!" |
|
return f"β
Model {model_name} already loaded" |
|
except Exception as e: |
|
return f"β Error loading {model_name}: {str(e)}" |
|
|
|
|
|
def get_embedding_and_stats(image): |
|
"""Get embedding from autoencoder and calculate statistics""" |
|
try: |
|
from torchvision import transforms |
|
import config |
|
|
|
|
|
if isinstance(config.IMAGE_SIZE, tuple): |
|
target_size = config.IMAGE_SIZE |
|
else: |
|
target_size = (config.IMAGE_SIZE, config.IMAGE_SIZE) |
|
|
|
|
|
image_pil = image.convert('RGB').resize(target_size, Image.LANCZOS) |
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
image_tensor = transform(image_pil).unsqueeze(0).to(current_detector.device) |
|
|
|
|
|
with torch.no_grad(): |
|
_, embedding = current_detector.model(image_tensor) |
|
|
|
|
|
embedding_np = embedding.squeeze(0).cpu().numpy().flatten() |
|
|
|
|
|
stats_dict = { |
|
'mean': float(np.mean(embedding_np)), |
|
'median': float(np.median(embedding_np)), |
|
'std': float(np.std(embedding_np)), |
|
'min': float(np.min(embedding_np)), |
|
'max': float(np.max(embedding_np)), |
|
'q25': float(np.percentile(embedding_np, 25)), |
|
'q75': float(np.percentile(embedding_np, 75)), |
|
'skewness': float(stats.skew(embedding_np)), |
|
'kurtosis': float(stats.kurtosis(embedding_np)), |
|
'variance': float(np.var(embedding_np)), |
|
'range': float(np.max(embedding_np) - np.min(embedding_np)), |
|
'iqr': float(np.percentile(embedding_np, 75) - np.percentile(embedding_np, 25)) |
|
} |
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(12, 10)) |
|
fig.suptitle(f'Embedding Analysis (Dimension: {len(embedding_np)})', fontsize=16) |
|
|
|
|
|
axes[0, 0].hist(embedding_np, bins=50, alpha=0.7, color='skyblue', edgecolor='black') |
|
axes[0, 0].set_title('Distribution Histogram') |
|
axes[0, 0].set_xlabel('Embedding Values') |
|
axes[0, 0].set_ylabel('Frequency') |
|
axes[0, 0].grid(True, alpha=0.3) |
|
|
|
|
|
axes[0, 1].boxplot(embedding_np, vert=True) |
|
axes[0, 1].set_title('Box Plot') |
|
axes[0, 1].set_ylabel('Embedding Values') |
|
axes[0, 1].grid(True, alpha=0.3) |
|
|
|
|
|
stats.probplot(embedding_np, dist="norm", plot=axes[1, 0]) |
|
axes[1, 0].set_title('Q-Q Plot (Normal Distribution)') |
|
axes[1, 0].grid(True, alpha=0.3) |
|
|
|
|
|
axes[1, 1].plot(embedding_np, alpha=0.7, color='red', linewidth=1) |
|
axes[1, 1].set_title('Embedding Values Sequence') |
|
axes[1, 1].set_xlabel('Dimension Index') |
|
axes[1, 1].set_ylabel('Value') |
|
axes[1, 1].grid(True, alpha=0.3) |
|
|
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') |
|
buf.seek(0) |
|
plot_image = Image.open(buf) |
|
plt.close() |
|
|
|
return embedding_np, stats_dict, plot_image |
|
|
|
except Exception as e: |
|
print(f"Error in embedding analysis: {e}") |
|
return None, {}, None |
|
|
|
|
|
def format_stats_text(stats_dict): |
|
"""Format statistics into readable text""" |
|
if not stats_dict: |
|
return "β Error calculating statistics" |
|
|
|
text = f"""π EMBEDDING STATISTICS |
|
|
|
π― Central Tendency: |
|
Mean: {stats_dict['mean']:.6f} |
|
Median: {stats_dict['median']:.6f} |
|
|
|
π Spread: |
|
Std Dev: {stats_dict['std']:.6f} |
|
Variance: {stats_dict['variance']:.6f} |
|
Range: {stats_dict['range']:.6f} |
|
IQR: {stats_dict['iqr']:.6f} |
|
|
|
π Extremes: |
|
Min: {stats_dict['min']:.6f} |
|
Max: {stats_dict['max']:.6f} |
|
Q25: {stats_dict['q25']:.6f} |
|
Q75: {stats_dict['q75']:.6f} |
|
|
|
π Shape: |
|
Skewness: {stats_dict['skewness']:.6f} |
|
Kurtosis: {stats_dict['kurtosis']:.6f} |
|
|
|
""" |
|
|
|
return text |
|
|
|
|
|
def classify_image(reconstruction_error, threshold): |
|
"""Classify image as corrupted or clean based on threshold""" |
|
is_corrupted = reconstruction_error > threshold |
|
confidence = abs(reconstruction_error - threshold) / threshold * 100 |
|
|
|
if is_corrupted: |
|
classification = "π¨ CORRUPTED/ANOMALOUS" |
|
color_indicator = "π΄" |
|
explanation = f"Reconstruction error ({reconstruction_error:.6f}) > Threshold ({threshold:.6f})" |
|
else: |
|
classification = "β
CLEAN/NORMAL" |
|
color_indicator = "π’" |
|
explanation = f"Reconstruction error ({reconstruction_error:.6f}) β€ Threshold ({threshold:.6f})" |
|
|
|
|
|
distance_pct = (reconstruction_error - threshold) / threshold * 100 |
|
|
|
classification_text = f"""π― ANOMALY CLASSIFICATION |
|
|
|
{color_indicator} Status: {classification} |
|
|
|
π Details: |
|
Reconstruction Error: {reconstruction_error:.6f} |
|
Threshold: {threshold:.6f} |
|
Distance from Threshold: {distance_pct:+.2f}% |
|
|
|
π Explanation: |
|
{explanation} |
|
|
|
π‘ Confidence Indicator: |
|
β’ Distance > 50%: High confidence |
|
β’ Distance 10-50%: Medium confidence |
|
β’ Distance < 10%: Low confidence (near threshold) |
|
|
|
ποΈ Current Distance: {abs(distance_pct):.2f}% ({'High' if abs(distance_pct) > 50 else 'Medium' if abs(distance_pct) > 10 else 'Low'} confidence)""" |
|
|
|
return classification_text, is_corrupted |
|
|
|
|
|
def process_image(model_name, image, corruption_type, intensity, threshold): |
|
"""Main processing function""" |
|
try: |
|
|
|
load_status = load_model(model_name) |
|
if "β" in load_status: |
|
return None, None, load_status, 0.0, "", "" |
|
|
|
if image is None: |
|
return None, None, "β Please upload an image", 0.0, "", "" |
|
|
|
|
|
if corruption_type == "none": |
|
corrupted_image = image.copy() |
|
corruption_info = "No corruption applied" |
|
else: |
|
corrupted_image = corrupt_image(image, corruption_type, intensity) |
|
corruption_info = f"Applied {corruption_type} corruption (intensity: {intensity})" |
|
|
|
|
|
error = current_detector.calculate_reconstruction_error(corrupted_image) |
|
|
|
|
|
embedding, stats_dict, plot_image = get_embedding_and_stats(corrupted_image) |
|
|
|
|
|
stats_text = format_stats_text(stats_dict) |
|
|
|
|
|
classification_text, is_corrupted = classify_image(error, threshold) |
|
|
|
|
|
status = f"""β
Processing complete! |
|
π Model: {model_name} |
|
π§ {corruption_info} |
|
π Reconstruction Error: {error:.6f} |
|
ποΈ Threshold: {threshold:.6f} |
|
π― Classification: {'CORRUPTED' if is_corrupted else 'CLEAN'} |
|
π§ Embedding Dimension: {len(embedding) if embedding is not None else 'N/A'} |
|
π‘ Higher error = more anomalous""" |
|
|
|
return corrupted_image, plot_image, status, error, stats_text, classification_text |
|
|
|
except Exception as e: |
|
error_msg = f"β Error: {str(e)}" |
|
return None, None, error_msg, 0.0, "", "" |
|
|
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks(title="Anomaly Detection Tester") as demo: |
|
gr.Markdown("# π Federated Autoencoder for Kidney Stone Image Corruption Detection") |
|
gr.Markdown("Upload an image, analyze its latent representation, and classify it as corrupted or clean using a threshold.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### βοΈ Model & Corruption Settings") |
|
|
|
model_dropdown = gr.Dropdown( |
|
choices=list(models.keys()), |
|
value="Daudon_MIX", |
|
label="π€ Select Model" |
|
) |
|
|
|
corruption_dropdown = gr.Dropdown( |
|
choices=["none", "noise", "blur", "brightness", "contrast", "saturation", "random"], |
|
value="none", |
|
label="π§ Corruption Type" |
|
) |
|
|
|
intensity_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=3.0, |
|
value=1.0, |
|
step=0.1, |
|
label="πͺ Corruption Intensity" |
|
) |
|
|
|
gr.Markdown("### ποΈ Classification Settings") |
|
|
|
threshold_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=3.0, |
|
value=1.0, |
|
step=0.1, |
|
label="π― Anomaly Threshold (Reconstruction Error)" |
|
) |
|
|
|
gr.Markdown("### πΈ Image Input") |
|
|
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
|
|
|
|
gr.Markdown("### π Example Images") |
|
|
|
|
|
example_images = [ |
|
["example_imgs/TypeIa_LaosNΓΒ°15_Image21-25.png", "Clean Daudon MIX-Subtype_Ia"], |
|
["example_imgs/72222-SectionIVa+WK maj_0009-60.png", "Clean Daudon MIX-Subtype_IVa"], |
|
["example_imgs/TypeIVa2_N47583_Notteb-11.png", "Clean Daudon MIX-Subtype_IVa2"], |
|
["example_imgs/typIVc_IVbsectbis-43.png", "Clean Daudon MIX-Subtype_IVc"], |
|
["example_imgs/TypeIVd_Sect_LC3373-65.png", "Clean Daudon MIX-Subtype_IVd"], |
|
["example_imgs/Section_Va_72845-3-18.png", "Clean Daudon MIX-Subtype_Va"], |
|
] |
|
|
|
examples_component = gr.Examples( |
|
examples=example_images, |
|
inputs=image_input, |
|
label="Daudon MIX Example Clean Images", |
|
examples_per_page=6, |
|
cache_examples=False |
|
) |
|
|
|
process_btn = gr.Button("π Analyze & Classify", variant="primary", size="lg") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("### π Results") |
|
|
|
status_output = gr.Textbox(label="π Status", lines=8) |
|
error_output = gr.Number(label="π Reconstruction Error", precision=6) |
|
|
|
corrupted_output = gr.Image(label="π§ Input Image (Corrupted)") |
|
|
|
|
|
with gr.Row(): |
|
embedding_plot = gr.Image(label="π§ Embedding Analysis") |
|
|
|
with gr.Row(): |
|
stats_output = gr.Textbox(label="π Embedding Statistics", lines=20) |
|
classification_output = gr.Textbox(label="π― Classification Result", lines=15) |
|
|
|
|
|
process_btn.click( |
|
fn=process_image, |
|
inputs=[model_dropdown, image_input, corruption_dropdown, intensity_slider, threshold_slider], |
|
outputs=[corrupted_output, embedding_plot, status_output, error_output, stats_output, classification_output] |
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
print("π Starting Embedding Analysis App...") |
|
|
|
demo = create_interface() |
|
demo.launch() |