import os import json import torch import timm import numpy as np import gradio as gr from PIL import Image, ImageDraw import cv2 import matplotlib.pyplot as plt import matplotlib.patches as patches from io import BytesIO import base64 import torchvision import torch.nn as nn import torchvision.models as models from torchvision import transforms from torchvision.models.detection import fasterrcnn_resnet50_fpn from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights from torchvision.transforms import functional as F from torch.nn.functional import interpolate import segmentation_models_pytorch as smp from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image device = torch.device("cpu") # ResNet18Classifier definition - just in case I decide to use it class ResNet18Classifier(nn.Module): def __init__(self, num_classes=10): super().__init__() self.model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) in_features = self.model.fc.in_features self.model.fc = nn.Linear(in_features, num_classes) def forward(self, x): return self.model(x) # For classification class_names = { 0: 'Casual dresses', 1: 'Evening dresses', 2: 'Jersey dresses', 3: 'Knitted dresses', 4: 'Maxi dresses', 5: 'Midi dresses', 6: 'Mini dresses', 7: 'Occasion dresses', 8: 'Shirt dresses', 9: 'Skater dresses' } # Global models - will be loaded once detection_model, segmentation_model, classification_model, gradcam = (None,) * 4 def load_models(): global detection_model, segmentation_model, classification_model, gradcam try: print("Loading human detection model...") detection_model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1) detection_model.to(device) detection_model.eval() print("✓ Detection model loaded") print("Loading segmentation model...") segmentation_model = smp.Unet( encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=1, activation=None ).to(device) if os.path.exists('best_model.pth'): segmentation_model.load_state_dict(torch.load('best_model.pth', map_location=device)) print("✓ Loaded custom segmentation weights") else: print("⚠️ Using ImageNet pre-trained weights for segmentation (custom weights not found)") segmentation_model.eval() print("✓ Segmentation model loaded") # --- Classification model: ResNet18Classifier --- print("Loading ResNet18 or ConvNext V2 classification model...") # classification_model = ResNet18Classifier(num_classes=10) # classification_model = ResNet18Classifier(num_classes=10) # model_path_class_model = "class_model.pth" # if os.path.exists(model_path_class_model): # try: # state_dict = torch.load(model_path_class_model, map_location=torch.device('cpu')) # new_state_dict = {} # for key, value in state_dict.items(): # new_key = f"model.{key}" # new_state_dict[new_key] = value # classification_model.load_state_dict(new_state_dict) # print("✅ ResNet18 classification model weights loaded") # except Exception as e: # print(f"⚠️ Error loading ResNet18 classification weights: {e}") # else: # print("⚠️ ResNet18 classification model weights not found, using pretrained initialization") # classification_model = classification_model.to(device) # classification_model.eval() # Alternatively: load the more heavier one classification_model = timm.create_model('convnextv2_base', pretrained=False) classification_model.reset_classifier(num_classes=10) classification_model.load_state_dict(torch.load("ConvNeXt.pth", map_location=torch.device('cpu'))) classification_model = classification_model.to(device) classification_model.eval() print("✓ ResNet18 (or ConvNext V2) classification model loaded") print("Setting up Grad-CAM for the classification model...") target_layer = [m for m in classification_model.modules() if isinstance(m, torch.nn.Conv2d)][-1] gradcam = GradCAM(model=classification_model, target_layers=[target_layer]) print("✓ Grad-CAM initialized") print("🎉 All models loaded successfully!") except Exception as e: print(f"❌ Error loading models: {e}") raise e def detect_human_boxes(image): """Detect human bounding boxes using Faster R-CNN""" image_tensor = F.to_tensor(image).unsqueeze(0).to(device) width, height = image.size with torch.no_grad(): output = detection_model(image_tensor)[0] person_indices = [ i for i, label in enumerate(output['labels']) if label == 1 and output['scores'][i] > 0.5 ] if not person_indices: return None, None top_idx = max(person_indices, key=lambda i: output['scores'][i]) box = output['boxes'][top_idx].cpu().tolist() score = output['scores'][top_idx].item() x1 = max(0, min(box[0], width)) y1 = max(0, min(box[1], height)) x2 = max(0, min(box[2], width)) y2 = max(0, min(box[3], height)) return [x1, y1, x2, y2], score def preprocess_for_classification(image, mask, target_size=224): """Preprocess masked image for classification""" image_np = np.array(image) / 255.0 if not isinstance(image, torch.Tensor) else image.permute(1, 2, 0).cpu().numpy() mask_np = np.array(mask) if not isinstance(mask, torch.Tensor) else mask.cpu().numpy() if len(mask_np.shape) == 2: mask_np = mask_np[:, :, np.newaxis] masked_image = image_np * mask_np pil_img = Image.fromarray((masked_image * 255).astype(np.uint8)) width, height = pil_img.size ratio = min(target_size / width, target_size / height) new_size = (int(width * ratio), int(height * ratio)) resized_img = pil_img.resize(new_size, Image.BILINEAR) result = Image.new("RGB", (target_size, target_size), color=(123, 117, 104)) result.paste(resized_img, ((target_size - new_size[0]) // 2, (target_size - new_size[1]) // 2)) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(result) def process_image_pipeline(input_image): """Complete pipeline for processing an uploaded image""" try: if input_image is None: return None, None, "Please upload an image", None image = Image.fromarray(input_image).convert("RGB") if isinstance(input_image, np.ndarray) else input_image.convert("RGB") bbox, confidence = detect_human_boxes(image) if bbox is None: return image, None, "No human detected in the image", None image_with_box = image.copy() draw = ImageDraw.Draw(image_with_box) draw.rectangle(bbox, outline="red", width=8) draw.text((bbox[0], bbox[1] - 20), f"Human: {confidence:.2f}", fill="red") x1, y1, x2, y2 = bbox cropped_image = image.crop((x1, y1, x2, y2)) cropped_resized = cropped_image.resize((256, 256)) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) input_tensor = transform(cropped_resized).unsqueeze(0).to(device) with torch.no_grad(): segmentation_output = segmentation_model(input_tensor) mask = (torch.sigmoid(segmentation_output) > 0.5).float() mask_np = mask[0, 0].cpu().numpy() mask_resized_original = cv2.resize(mask_np, (cropped_image.width, cropped_image.height)) mask_image = Image.fromarray((mask_resized_original * 255).astype(np.uint8)).convert("RGB") processed_tensor = preprocess_for_classification(cropped_resized, mask_np).unsqueeze(0).to(device) with torch.no_grad(): classification_output = classification_model(processed_tensor) predicted_class = torch.argmax(classification_output, dim=1).item() confidence_scores = torch.softmax(classification_output, dim=1) max_confidence = confidence_scores[0, predicted_class].item() predicted_category = f"{class_names[predicted_class]} (Confidence: {max_confidence:.2f})" grayscale_cam = gradcam(input_tensor=processed_tensor, targets=[ClassifierOutputTarget(predicted_class)])[0] mask_for_gradcam = cv2.resize(mask_np.astype(np.float32), grayscale_cam.shape[::-1]) grayscale_cam *= mask_for_gradcam grayscale_cam = (grayscale_cam - grayscale_cam.min()) / (grayscale_cam.max() - grayscale_cam.min() + 1e-8) cropped_np_original = np.array(cropped_image) / 255.0 mask_for_original = cv2.resize(mask_np.astype(np.float32), (cropped_image.width, cropped_image.height)) masked_image_np_original = cropped_np_original * mask_for_original[:, :, np.newaxis] gradcam_resized_original = cv2.resize(grayscale_cam, (cropped_image.width, cropped_image.height)) heatmap_on_image = show_cam_on_image( masked_image_np_original, gradcam_resized_original, use_rgb=True, image_weight=0.6 ) heatmap_image = Image.fromarray(heatmap_on_image) return image_with_box, mask_image, heatmap_image, predicted_category except Exception as e: return None, None, f"Error processing image: {str(e)}", None def create_interface(): """Create Gradio interface""" with gr.Blocks(title="Dress Analysis Pipeline (DL Assignment 3)", theme=gr.themes.Soft()) as demo: gr.Markdown(""" ## Dress Analysis Pipeline (DL Assignment 3) ### Author: Roman Rakov, University of Tübingen Upload an image to run it through the pipeline: 1. **Human Detection** – Bounding box around detected person (*Faster R-CNN with ResNet-50 backbone*) 2. **Dress Segmentation** – Extracted dress/clothing region (*U-Net with ResNet-34 encoder*) 3. **Grad-CAM Heatmap** – Model attention visualization (*Grad-CAM for attention visualization*) 4. **Classification** – Predicted clothing category (*ConvNeXt v2 Base*) """) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(label="Upload Image", type="pil", height=500) analyze_btn = gr.Button("Analyze Image", variant="primary", size="lg") gr.Markdown("### Some (well-behaved) example images to start with..") example_images = [] for i in range(1, 11): example_path = f"examples/example{i}.jpg" if os.path.exists(example_path): example_images.append([example_path]) if example_images: gr.Examples(examples=example_images, inputs=input_image) else: gr.Markdown("*No example images found.*") with gr.Column(scale=2): with gr.Row(): output_detection = gr.Image(label="Human Detection", height=500) output_segmentation = gr.Image(label="Dress Segmentation", height=500) with gr.Row(): output_gradcam = gr.Image(label="Grad-CAM Heatmap", height=500) output_classification = gr.Textbox(label="Predicted Category", lines=2, max_lines=3, scale=1) analyze_btn.click( fn=process_image_pipeline, inputs=[input_image], outputs=[output_detection, output_segmentation, output_gradcam, output_classification] ) input_image.change( fn=process_image_pipeline, inputs=[input_image], outputs=[output_detection, output_segmentation, output_gradcam, output_classification] ) return demo # Load models when the app starts print("Loading models...") load_models() # Create and launch the interface if __name__ == "__main__": demo = create_interface() demo.launch()