""" Simple Anomaly Detector using Reconstruction Error A minimal implementation for testing corruption intensity using autoencoder reconstruction error """ import torch import torch.nn as nn import numpy as np from PIL import Image import torchvision.transforms as transforms from typing import Union import random from models import Autoencoder from utils.data_utils import ImageCorruption import config def apply_corruption(image_tensor: torch.Tensor, corruption_type: str = 'random') -> torch.Tensor: """ Simple function to apply corruption to an image tensor Args: image_tensor: Input image tensor (C, H, W) corruption_type: Type of corruption ('noise', 'blur', 'brightness', 'contrast', 'random') Returns: Corrupted image tensor """ # Create corruption object with 100% probability to ensure corruption is applied corruptor = ImageCorruption(corruption_prob=1.0) if corruption_type == 'noise': return corruptor.gaussian_noise(image_tensor.clone()) elif corruption_type == 'blur': return corruptor.blur(image_tensor.clone()) elif corruption_type == 'brightness': return corruptor.brightness_change(image_tensor.clone()) elif corruption_type == 'contrast': return corruptor.contrast_change(image_tensor.clone()) elif corruption_type == 'random': return corruptor.apply_random_corruption(image_tensor.clone()) else: raise ValueError(f"Unknown corruption type: {corruption_type}") class SimpleAnomalyDetector: """Simple anomaly detector based on reconstruction error""" def __init__(self, model_path: str): """ Initialize the detector with a trained autoencoder Args: model_path: Path to the trained autoencoder (.pth file) """ self.device = torch.device(config.DEVICE) self.model = self._load_model(model_path) self.criterion = nn.MSELoss() # Image preprocessing - simplified and more robust self.transform = transforms.Compose([ transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) print(f"โœ… Anomaly detector ready! Using device: {self.device}") print(f"๐Ÿ“ Image size: {config.IMAGE_SIZE}x{config.IMAGE_SIZE}") def _load_model(self, model_path: str) -> Autoencoder: """Load the trained autoencoder model""" print(f"๐Ÿ“ฅ Loading model from {model_path}") # Load checkpoint (weights_only=False for compatibility with saved metadata) checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) # Create model with same architecture model = Autoencoder( input_channels=config.CHANNELS, latent_dim=config.LATENT_DIM ) # Load trained weights model.load_state_dict(checkpoint['model_state_dict']) model.to(self.device) model.eval() return model def calculate_reconstruction_error(self, image: Union[str, Image.Image, torch.Tensor]) -> float: """ Calculate reconstruction error for a single image Args: image: Can be: - String path to image file - PIL Image object - PyTorch tensor (C, H, W) or (1, C, H, W) Returns: Reconstruction error as a float (higher = more anomalous) """ # Get image size - handle both tuple and integer formats if isinstance(config.IMAGE_SIZE, tuple): target_size = config.IMAGE_SIZE # (256, 256) else: target_size = (config.IMAGE_SIZE, config.IMAGE_SIZE) # Convert input to tensor if isinstance(image, str): # Load from file path try: image_pil = Image.open(image).convert('RGB') # Resize the image properly image_pil = image_pil.resize(target_size, Image.LANCZOS) image_tensor = transforms.ToTensor()(image_pil) # Apply normalization normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image_tensor = normalize(image_tensor).unsqueeze(0) # Add batch dimension except Exception as e: raise ValueError(f"Error loading image from {image}: {e}") elif isinstance(image, Image.Image): # PIL Image try: image_pil = image.convert('RGB') image_pil = image_pil.resize(target_size, Image.LANCZOS) image_tensor = transforms.ToTensor()(image_pil) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image_tensor = normalize(image_tensor).unsqueeze(0) except Exception as e: raise ValueError(f"Error processing PIL Image: {e}") elif isinstance(image, torch.Tensor): # PyTorch tensor if image.dim() == 3: # (C, H, W) image_tensor = image.unsqueeze(0) # Add batch dimension elif image.dim() == 4: # (1, C, H, W) image_tensor = image else: raise ValueError(f"Unexpected tensor dimensions: {image.shape}") else: raise ValueError(f"Unsupported image type: {type(image)}") # Move to device image_tensor = image_tensor.to(self.device) # Calculate reconstruction error with torch.no_grad(): reconstructed, _ = self.model(image_tensor) error = self.criterion(reconstructed, image_tensor) return error.item() def test_detector_example(): """Example usage of the simple anomaly detector""" # You need to specify the path to your trained model model_path = "models/All_Datasets_MIX/best_autoencoder_All_Datasets_MIX.pth" # Change this! try: # Initialize detector detector = SimpleAnomalyDetector(model_path) # Test with some images from your dataset from utils.data_utils import create_global_test_loader # Get a test loader test_loader = create_global_test_loader( datasets=["Michel Daudon (w256 1k v1)", "Jonathan El-Beze (w256 1k v1)"], subversions=["MIX"] ) print("\n๐Ÿงช Testing reconstruction errors:") print("=" * 50) # Test a few images for i, (images, labels) in enumerate(test_loader): if i >= 3: # Test only first 3 batches break for j in range(min(2, images.size(0))): # Test 2 images per batch clean_image = images[j] # Test clean image clean_error = detector.calculate_reconstruction_error(clean_image) # Test corrupted versions corrupted_noise = apply_corruption(clean_image, 'noise') corrupted_blur = apply_corruption(clean_image, 'blur') noise_error = detector.calculate_reconstruction_error(corrupted_noise) blur_error = detector.calculate_reconstruction_error(corrupted_blur) print(f"\nImage {i*2 + j + 1} (Class: {labels[j]}):") print(f" Clean: {clean_error:.6f}") print(f" Noise corrupted: {noise_error:.6f} (x{noise_error/clean_error:.2f})") print(f" Blur corrupted: {blur_error:.6f} (x{blur_error/clean_error:.2f})") print(f"\n๐Ÿ’ก Usage tip: Higher reconstruction error = more anomalous/corrupted") print(f" You can set a threshold (e.g., 0.01) above which images are considered anomalous") except FileNotFoundError: print(f"โŒ Model file not found: {model_path}") print(" Please update the model_path variable with your actual model file") except Exception as e: print(f"โŒ Error: {e}") if __name__ == "__main__": test_detector_example()