Spaces:
Sleeping
Sleeping
import os | |
import json | |
import torch | |
import timm | |
import numpy as np | |
import gradio as gr | |
from PIL import Image, ImageDraw | |
import cv2 | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
from io import BytesIO | |
import base64 | |
import torchvision | |
import torch.nn as nn | |
import torchvision.models as models | |
from torchvision import transforms | |
from torchvision.models.detection import fasterrcnn_resnet50_fpn | |
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights | |
from torchvision.transforms import functional as F | |
from torch.nn.functional import interpolate | |
import segmentation_models_pytorch as smp | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
device = torch.device("cpu") | |
# ResNet18Classifier definition - just in case I decide to use it | |
class ResNet18Classifier(nn.Module): | |
def __init__(self, num_classes=10): | |
super().__init__() | |
self.model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) | |
in_features = self.model.fc.in_features | |
self.model.fc = nn.Linear(in_features, num_classes) | |
def forward(self, x): | |
return self.model(x) | |
# For classification | |
class_names = { | |
0: 'Casual dresses', | |
1: 'Evening dresses', | |
2: 'Jersey dresses', | |
3: 'Knitted dresses', | |
4: 'Maxi dresses', | |
5: 'Midi dresses', | |
6: 'Mini dresses', | |
7: 'Occasion dresses', | |
8: 'Shirt dresses', | |
9: 'Skater dresses' | |
} | |
# Global models - will be loaded once | |
detection_model, segmentation_model, classification_model, gradcam = (None,) * 4 | |
def load_models(): | |
global detection_model, segmentation_model, classification_model, gradcam | |
try: | |
print("Loading human detection model...") | |
detection_model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1) | |
detection_model.to(device) | |
detection_model.eval() | |
print("โ Detection model loaded") | |
print("Loading segmentation model...") | |
segmentation_model = smp.Unet( | |
encoder_name="resnet34", | |
encoder_weights="imagenet", | |
in_channels=3, | |
classes=1, | |
activation=None | |
).to(device) | |
if os.path.exists('best_model.pth'): | |
segmentation_model.load_state_dict(torch.load('best_model.pth', map_location=device)) | |
print("โ Loaded custom segmentation weights") | |
else: | |
print("โ ๏ธ Using ImageNet pre-trained weights for segmentation (custom weights not found)") | |
segmentation_model.eval() | |
print("โ Segmentation model loaded") | |
# --- Classification model: ResNet18Classifier --- | |
print("Loading ResNet18 or ConvNext V2 classification model...") | |
# classification_model = ResNet18Classifier(num_classes=10) | |
# classification_model = ResNet18Classifier(num_classes=10) | |
# model_path_class_model = "class_model.pth" | |
# if os.path.exists(model_path_class_model): | |
# try: | |
# state_dict = torch.load(model_path_class_model, map_location=torch.device('cpu')) | |
# new_state_dict = {} | |
# for key, value in state_dict.items(): | |
# new_key = f"model.{key}" | |
# new_state_dict[new_key] = value | |
# classification_model.load_state_dict(new_state_dict) | |
# print("โ ResNet18 classification model weights loaded") | |
# except Exception as e: | |
# print(f"โ ๏ธ Error loading ResNet18 classification weights: {e}") | |
# else: | |
# print("โ ๏ธ ResNet18 classification model weights not found, using pretrained initialization") | |
# classification_model = classification_model.to(device) | |
# classification_model.eval() | |
# Alternatively: load the more heavier one | |
classification_model = timm.create_model('convnextv2_base', pretrained=False) | |
classification_model.reset_classifier(num_classes=10) | |
classification_model.load_state_dict(torch.load("ConvNeXt.pth", map_location=torch.device('cpu'))) | |
classification_model = classification_model.to(device) | |
classification_model.eval() | |
print("โ ResNet18 (or ConvNext V2) classification model loaded") | |
print("Setting up Grad-CAM for the classification model...") | |
target_layer = [m for m in classification_model.modules() if isinstance(m, torch.nn.Conv2d)][-1] | |
gradcam = GradCAM(model=classification_model, target_layers=[target_layer]) | |
print("โ Grad-CAM initialized") | |
print("๐ All models loaded successfully!") | |
except Exception as e: | |
print(f"โ Error loading models: {e}") | |
raise e | |
def detect_human_boxes(image): | |
"""Detect human bounding boxes using Faster R-CNN""" | |
image_tensor = F.to_tensor(image).unsqueeze(0).to(device) | |
width, height = image.size | |
with torch.no_grad(): | |
output = detection_model(image_tensor)[0] | |
person_indices = [ | |
i for i, label in enumerate(output['labels']) | |
if label == 1 and output['scores'][i] > 0.5 | |
] | |
if not person_indices: | |
return None, None | |
top_idx = max(person_indices, key=lambda i: output['scores'][i]) | |
box = output['boxes'][top_idx].cpu().tolist() | |
score = output['scores'][top_idx].item() | |
x1 = max(0, min(box[0], width)) | |
y1 = max(0, min(box[1], height)) | |
x2 = max(0, min(box[2], width)) | |
y2 = max(0, min(box[3], height)) | |
return [x1, y1, x2, y2], score | |
def preprocess_for_classification(image, mask, target_size=224): | |
"""Preprocess masked image for classification""" | |
image_np = np.array(image) / 255.0 if not isinstance(image, torch.Tensor) else image.permute(1, 2, 0).cpu().numpy() | |
mask_np = np.array(mask) if not isinstance(mask, torch.Tensor) else mask.cpu().numpy() | |
if len(mask_np.shape) == 2: | |
mask_np = mask_np[:, :, np.newaxis] | |
masked_image = image_np * mask_np | |
pil_img = Image.fromarray((masked_image * 255).astype(np.uint8)) | |
width, height = pil_img.size | |
ratio = min(target_size / width, target_size / height) | |
new_size = (int(width * ratio), int(height * ratio)) | |
resized_img = pil_img.resize(new_size, Image.BILINEAR) | |
result = Image.new("RGB", (target_size, target_size), color=(123, 117, 104)) | |
result.paste(resized_img, ((target_size - new_size[0]) // 2, (target_size - new_size[1]) // 2)) | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
return transform(result) | |
def process_image_pipeline(input_image): | |
"""Complete pipeline for processing an uploaded image""" | |
try: | |
if input_image is None: | |
return None, None, "Please upload an image", None | |
image = Image.fromarray(input_image).convert("RGB") if isinstance(input_image, np.ndarray) else input_image.convert("RGB") | |
bbox, confidence = detect_human_boxes(image) | |
if bbox is None: | |
return image, None, "No human detected in the image", None | |
image_with_box = image.copy() | |
draw = ImageDraw.Draw(image_with_box) | |
draw.rectangle(bbox, outline="red", width=8) | |
draw.text((bbox[0], bbox[1] - 20), f"Human: {confidence:.2f}", fill="red") | |
x1, y1, x2, y2 = bbox | |
cropped_image = image.crop((x1, y1, x2, y2)) | |
cropped_resized = cropped_image.resize((256, 256)) | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
input_tensor = transform(cropped_resized).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
segmentation_output = segmentation_model(input_tensor) | |
mask = (torch.sigmoid(segmentation_output) > 0.5).float() | |
mask_np = mask[0, 0].cpu().numpy() | |
mask_resized_original = cv2.resize(mask_np, (cropped_image.width, cropped_image.height)) | |
mask_image = Image.fromarray((mask_resized_original * 255).astype(np.uint8)).convert("RGB") | |
processed_tensor = preprocess_for_classification(cropped_resized, mask_np).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
classification_output = classification_model(processed_tensor) | |
predicted_class = torch.argmax(classification_output, dim=1).item() | |
confidence_scores = torch.softmax(classification_output, dim=1) | |
max_confidence = confidence_scores[0, predicted_class].item() | |
predicted_category = f"{class_names[predicted_class]} (Confidence: {max_confidence:.2f})" | |
grayscale_cam = gradcam(input_tensor=processed_tensor, targets=[ClassifierOutputTarget(predicted_class)])[0] | |
mask_for_gradcam = cv2.resize(mask_np.astype(np.float32), grayscale_cam.shape[::-1]) | |
grayscale_cam *= mask_for_gradcam | |
grayscale_cam = (grayscale_cam - grayscale_cam.min()) / (grayscale_cam.max() - grayscale_cam.min() + 1e-8) | |
cropped_np_original = np.array(cropped_image) / 255.0 | |
mask_for_original = cv2.resize(mask_np.astype(np.float32), (cropped_image.width, cropped_image.height)) | |
masked_image_np_original = cropped_np_original * mask_for_original[:, :, np.newaxis] | |
gradcam_resized_original = cv2.resize(grayscale_cam, (cropped_image.width, cropped_image.height)) | |
heatmap_on_image = show_cam_on_image( | |
masked_image_np_original, | |
gradcam_resized_original, | |
use_rgb=True, | |
image_weight=0.6 | |
) | |
heatmap_image = Image.fromarray(heatmap_on_image) | |
return image_with_box, mask_image, heatmap_image, predicted_category | |
except Exception as e: | |
return None, None, f"Error processing image: {str(e)}", None | |
def create_interface(): | |
"""Create Gradio interface""" | |
with gr.Blocks(title="Dress Analysis Pipeline (DL Assignment 3)", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
## Dress Analysis Pipeline (DL Assignment 3) | |
### Author: Roman Rakov, University of Tรผbingen | |
Upload an image to run it through the pipeline: | |
1. **Human Detection** โ Bounding box around detected person (*Faster R-CNN with ResNet-50 backbone*) | |
2. **Dress Segmentation** โ Extracted dress/clothing region (*U-Net with ResNet-34 encoder*) | |
3. **Grad-CAM Heatmap** โ Model attention visualization (*Grad-CAM for attention visualization*) | |
4. **Classification** โ Predicted clothing category (*ConvNeXt v2 Base*) | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(label="Upload Image", type="pil", height=500) | |
analyze_btn = gr.Button("Analyze Image", variant="primary", size="lg") | |
gr.Markdown("### Some (well-behaved) example images to start with..") | |
example_images = [] | |
for i in range(1, 11): | |
example_path = f"examples/example{i}.jpg" | |
if os.path.exists(example_path): | |
example_images.append([example_path]) | |
if example_images: | |
gr.Examples(examples=example_images, inputs=input_image) | |
else: | |
gr.Markdown("*No example images found.*") | |
with gr.Column(scale=2): | |
with gr.Row(): | |
output_detection = gr.Image(label="Human Detection", height=500) | |
output_segmentation = gr.Image(label="Dress Segmentation", height=500) | |
with gr.Row(): | |
output_gradcam = gr.Image(label="Grad-CAM Heatmap", height=500) | |
output_classification = gr.Textbox(label="Predicted Category", lines=2, max_lines=3, scale=1) | |
analyze_btn.click( | |
fn=process_image_pipeline, | |
inputs=[input_image], | |
outputs=[output_detection, output_segmentation, output_gradcam, output_classification] | |
) | |
input_image.change( | |
fn=process_image_pipeline, | |
inputs=[input_image], | |
outputs=[output_detection, output_segmentation, output_gradcam, output_classification] | |
) | |
return demo | |
# Load models when the app starts | |
print("Loading models...") | |
load_models() | |
# Create and launch the interface | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() | |