Fed-AE-Kidney-Stone-Corruption-Detection / image_corruption_utils.py
Ivanrs's picture
Upload image_corruption_utils.py
fa3612e verified
"""
Image Corruption Utilities for Anomaly Detection Testing
Simple functions to corrupt images and test anomaly detection
"""
import torch
import numpy as np
from PIL import Image, ImageFilter, ImageEnhance
import matplotlib.pyplot as plt
import random
from torchvision import transforms
def corrupt_image(image_path, corruption_type='random', intensity=1.0, save_path=None):
"""
Apply corruption to an image
Args:
image_path: Path to the input image or PIL Image object
corruption_type: Type of corruption ('noise', 'blur', 'brightness', 'contrast', 'saturation', 'random')
intensity: Intensity of corruption (0.1 = light, 1.0 = normal, 2.0 = heavy)
save_path: Optional path to save corrupted image
Returns:
PIL Image: Corrupted image
"""
# Load image
if isinstance(image_path, str):
image = Image.open(image_path).convert('RGB')
else:
image = image_path.convert('RGB') # Assume it's already a PIL Image
if corruption_type == 'random':
corruption_type = random.choice(['noise', 'blur', 'brightness', 'contrast', 'saturation'])
print(f"Applied random corruption: {corruption_type}")
if corruption_type == 'noise':
# Add Gaussian noise
img_array = np.array(image).astype(np.float32) / 255.0
noise = np.random.normal(0, 0.1 * intensity, img_array.shape)
corrupted_array = np.clip(img_array + noise, 0, 1)
corrupted_image = Image.fromarray((corrupted_array * 255).astype(np.uint8))
elif corruption_type == 'blur':
# Apply Gaussian blur
radius = 1.0 * intensity
corrupted_image = image.filter(ImageFilter.GaussianBlur(radius=radius))
elif corruption_type == 'brightness':
# Change brightness
enhancer = ImageEnhance.Brightness(image)
factor = 0.5 + (0.5 * intensity) # Range from 0.5 to 1.0+ depending on intensity
corrupted_image = enhancer.enhance(factor)
elif corruption_type == 'contrast':
# Change contrast
enhancer = ImageEnhance.Contrast(image)
factor = 0.3 + (0.7 * intensity) # Range from 0.3 to 1.0+ depending on intensity
corrupted_image = enhancer.enhance(factor)
elif corruption_type == 'saturation':
# Change color saturation
enhancer = ImageEnhance.Color(image)
factor = 0.2 + (0.8 * intensity) # Range from 0.2 to 1.0+ depending on intensity
corrupted_image = enhancer.enhance(factor)
else:
raise ValueError(f"Unknown corruption type: {corruption_type}")
# Save if path provided
if save_path:
corrupted_image.save(save_path)
print(f"Corrupted image saved to: {save_path}")
return corrupted_image
def compare_images(original_path, corruption_types=['noise', 'blur', 'brightness'], intensity=1.0):
"""
Compare original image with different corruptions
Args:
original_path: Path to original image
corruption_types: List of corruption types to apply
intensity: Corruption intensity
"""
# Load original
original = Image.open(original_path)
# Create subplots
fig, axes = plt.subplots(1, len(corruption_types) + 1, figsize=(4 * (len(corruption_types) + 1), 4))
if len(corruption_types) == 0:
axes = [axes]
# Show original
axes[0].imshow(original)
axes[0].set_title('Original')
axes[0].axis('off')
# Show corrupted versions
for i, corruption_type in enumerate(corruption_types):
corrupted = corrupt_image(original_path, corruption_type, intensity)
axes[i + 1].imshow(corrupted)
axes[i + 1].set_title(f'{corruption_type.title()} (intensity: {intensity})')
axes[i + 1].axis('off')
plt.tight_layout()
plt.show()
def test_corruption_detection(detector, image_path, corruption_types=['noise', 'blur', 'brightness'], intensity=1.0):
"""
Test anomaly detection on original vs corrupted images
Args:
detector: SimpleAnomalyDetector instance
image_path: Path to test image
corruption_types: List of corruption types to test
intensity: Corruption intensity
"""
print(f"πŸ” Testing anomaly detection on: {image_path}")
print("=" * 60)
# Test original image
original_error = detector.calculate_reconstruction_error(image_path)
print(f"Original image error: {original_error:.6f}")
# Test corrupted versions
for corruption_type in corruption_types:
corrupted_image = corrupt_image(image_path, corruption_type, intensity)
corrupted_error = detector.calculate_reconstruction_error(corrupted_image)
multiplier = corrupted_error / original_error if original_error > 0 else float('inf')
print(f"{corruption_type.title():12} error: {corrupted_error:.6f} ({multiplier:.2f}x original)")
print("\nπŸ’‘ Higher reconstruction error = more anomalous/corrupted")
def quick_corruption_test(image_path, save_corrupted=False):
"""
Quick test of all corruption types on a single image
Args:
image_path: Path to test image
save_corrupted: Whether to save corrupted images
"""
corruption_types = ['noise', 'blur', 'brightness', 'contrast', 'saturation']
print(f"πŸ§ͺ Testing all corruption types on: {image_path}")
print("=" * 50)
for corruption_type in corruption_types:
save_path = f"corrupted_{corruption_type}.png" if save_corrupted else None
corrupted = corrupt_image(image_path, corruption_type, intensity=1.0, save_path=save_path)
print(f"βœ… {corruption_type.title()} corruption applied")
print(f"\nπŸ’‘ Use compare_images() to visualize all corruptions side by side")
if __name__ == "__main__":
print("βœ… Image corruption utilities loaded!")
print("\nAvailable functions:")
print("- corrupt_image(image_path, corruption_type, intensity)")
print("- compare_images(image_path, corruption_types, intensity)")
print("- test_corruption_detection(detector, image_path, corruption_types, intensity)")
print("- quick_corruption_test(image_path, save_corrupted)")