File size: 12,611 Bytes
8d96ea2
fb292b9
2a624d1
 
8d96ea2
 
fb292b9
2a624d1
 
 
 
 
 
 
 
 
 
fb292b9
 
2a624d1
 
 
8d96ea2
13225f2
066fa1b
 
 
8d96ea2
2a624d1
8c09883
5bba4ba
13225f2
 
 
5bba4ba
13225f2
 
2a624d1
13225f2
 
7f32ec3
5bba4ba
2a624d1
 
 
 
 
 
 
 
 
 
 
 
 
 
5bba4ba
 
2a624d1
 
 
 
 
 
 
 
 
 
 
 
 
13225f2
2a624d1
13225f2
 
2a624d1
 
8c09883
13225f2
2a624d1
 
13225f2
2a624d1
 
 
 
8c09883
13225f2
5bba4ba
0b2a049
052fa93
5bba4ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
052fa93
5bba4ba
 
 
 
 
 
c74a4ef
2a624d1
 
052fa93
5bba4ba
8c09883
5bba4ba
2a624d1
 
 
 
 
 
 
 
 
8c09883
2a624d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b55f52
2a624d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13225f2
 
 
2a624d1
 
13225f2
2a624d1
9f985b3
2a624d1
 
 
 
 
 
e33baee
13225f2
2a624d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4b02fe
2a624d1
 
 
 
 
 
a4b02fe
2a624d1
c88d1c8
 
5bba4ba
c88d1c8
2a624d1
7f90ead
 
f688de0
7f90ead
13225f2
2a624d1
fb292b9
 
5bba4ba
 
2a624d1
f688de0
2a624d1
7f90ead
2a624d1
 
 
 
 
5bba4ba
2a624d1
5bba4ba
2a624d1
9d382a5
5737f23
5bba4ba
 
2a624d1
5737f23
5bba4ba
a4b02fe
2a624d1
 
 
 
5bba4ba
fb292b9
2a624d1
 
 
 
5bba4ba
2a624d1
 
fb292b9
8d96ea2
5bba4ba
c5858a1
2a624d1
 
 
 
8d96ea2
2a624d1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import os
import json
import torch
import timm
import numpy as np
import gradio as gr
from PIL import Image, ImageDraw
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from io import BytesIO
import base64

import torchvision
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
from torchvision.transforms import functional as F
from torch.nn.functional import interpolate

import segmentation_models_pytorch as smp

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

device = torch.device("cpu")

# ResNet18Classifier definition - just in case I decide to use it
class ResNet18Classifier(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.model(x)

# For classification
class_names = {
    0: 'Casual dresses',
    1: 'Evening dresses',
    2: 'Jersey dresses',
    3: 'Knitted dresses',
    4: 'Maxi dresses',
    5: 'Midi dresses',
    6: 'Mini dresses',
    7: 'Occasion dresses',
    8: 'Shirt dresses',
    9: 'Skater dresses'
}

# Global models - will be loaded once
detection_model, segmentation_model, classification_model, gradcam = (None,) * 4


def load_models():
    global detection_model, segmentation_model, classification_model, gradcam

    try:
        print("Loading human detection model...")
        detection_model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1)
        detection_model.to(device)
        detection_model.eval()
        print("โœ“ Detection model loaded")

        print("Loading segmentation model...")
        segmentation_model = smp.Unet(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=3,
            classes=1,
            activation=None
        ).to(device)

        if os.path.exists('best_model.pth'):
            segmentation_model.load_state_dict(torch.load('best_model.pth', map_location=device))
            print("โœ“ Loaded custom segmentation weights")
        else:
            print("โš ๏ธ Using ImageNet pre-trained weights for segmentation (custom weights not found)")
        
        segmentation_model.eval()
        print("โœ“ Segmentation model loaded")

        # --- Classification model: ResNet18Classifier ---
        print("Loading ResNet18 or ConvNext V2 classification model...")
        # classification_model = ResNet18Classifier(num_classes=10)
        
        # classification_model = ResNet18Classifier(num_classes=10)
        # model_path_class_model = "class_model.pth"
        # if os.path.exists(model_path_class_model):
        #     try:
        #         state_dict = torch.load(model_path_class_model, map_location=torch.device('cpu'))
        #         new_state_dict = {}
        #         for key, value in state_dict.items():
        #             new_key = f"model.{key}"
        #             new_state_dict[new_key] = value
        #         classification_model.load_state_dict(new_state_dict)
        #         print("โœ… ResNet18 classification model weights loaded")
        #     except Exception as e:
        #         print(f"โš ๏ธ Error loading ResNet18 classification weights: {e}")
        # else:
        #     print("โš ๏ธ ResNet18 classification model weights not found, using pretrained initialization")
        
        # classification_model = classification_model.to(device)
        # classification_model.eval()

        # Alternatively: load the more heavier one
        classification_model = timm.create_model('convnextv2_base', pretrained=False)
        classification_model.reset_classifier(num_classes=10)
        classification_model.load_state_dict(torch.load("ConvNeXt.pth", map_location=torch.device('cpu')))
        classification_model = classification_model.to(device)
        classification_model.eval()

        print("โœ“ ResNet18 (or ConvNext V2) classification model loaded")

        print("Setting up Grad-CAM for the classification model...")
        target_layer = [m for m in classification_model.modules() if isinstance(m, torch.nn.Conv2d)][-1]
        gradcam = GradCAM(model=classification_model, target_layers=[target_layer])
        print("โœ“ Grad-CAM initialized")

        print("๐ŸŽ‰ All models loaded successfully!")

    except Exception as e:
        print(f"โŒ Error loading models: {e}")
        raise e
    
def detect_human_boxes(image):
    """Detect human bounding boxes using Faster R-CNN"""
    image_tensor = F.to_tensor(image).unsqueeze(0).to(device)
    width, height = image.size

    with torch.no_grad():
        output = detection_model(image_tensor)[0]

    person_indices = [
        i for i, label in enumerate(output['labels'])
        if label == 1 and output['scores'][i] > 0.5
    ]

    if not person_indices:
        return None, None

    top_idx = max(person_indices, key=lambda i: output['scores'][i])
    box = output['boxes'][top_idx].cpu().tolist()
    score = output['scores'][top_idx].item()

    x1 = max(0, min(box[0], width))
    y1 = max(0, min(box[1], height))
    x2 = max(0, min(box[2], width))
    y2 = max(0, min(box[3], height))

    return [x1, y1, x2, y2], score

def preprocess_for_classification(image, mask, target_size=224):
    """Preprocess masked image for classification"""
    image_np = np.array(image) / 255.0 if not isinstance(image, torch.Tensor) else image.permute(1, 2, 0).cpu().numpy()
    mask_np = np.array(mask) if not isinstance(mask, torch.Tensor) else mask.cpu().numpy()

    if len(mask_np.shape) == 2:
        mask_np = mask_np[:, :, np.newaxis]

    masked_image = image_np * mask_np
    pil_img = Image.fromarray((masked_image * 255).astype(np.uint8))

    width, height = pil_img.size
    ratio = min(target_size / width, target_size / height)
    new_size = (int(width * ratio), int(height * ratio))
    resized_img = pil_img.resize(new_size, Image.BILINEAR)

    result = Image.new("RGB", (target_size, target_size), color=(123, 117, 104))
    result.paste(resized_img, ((target_size - new_size[0]) // 2, (target_size - new_size[1]) // 2))

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    return transform(result)

def process_image_pipeline(input_image):
    """Complete pipeline for processing an uploaded image"""
    try:
        if input_image is None:
            return None, None, "Please upload an image", None

        image = Image.fromarray(input_image).convert("RGB") if isinstance(input_image, np.ndarray) else input_image.convert("RGB")

        bbox, confidence = detect_human_boxes(image)
        if bbox is None:
            return image, None, "No human detected in the image", None

        image_with_box = image.copy()
        draw = ImageDraw.Draw(image_with_box)
        draw.rectangle(bbox, outline="red", width=8)
        draw.text((bbox[0], bbox[1] - 20), f"Human: {confidence:.2f}", fill="red")

        x1, y1, x2, y2 = bbox
        cropped_image = image.crop((x1, y1, x2, y2))
        cropped_resized = cropped_image.resize((256, 256))

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        input_tensor = transform(cropped_resized).unsqueeze(0).to(device)

        with torch.no_grad():
            segmentation_output = segmentation_model(input_tensor)
            mask = (torch.sigmoid(segmentation_output) > 0.5).float()

        mask_np = mask[0, 0].cpu().numpy()
        mask_resized_original = cv2.resize(mask_np, (cropped_image.width, cropped_image.height))
        mask_image = Image.fromarray((mask_resized_original * 255).astype(np.uint8)).convert("RGB")

        processed_tensor = preprocess_for_classification(cropped_resized, mask_np).unsqueeze(0).to(device)

        with torch.no_grad():
            classification_output = classification_model(processed_tensor)
            predicted_class = torch.argmax(classification_output, dim=1).item()
            confidence_scores = torch.softmax(classification_output, dim=1)
            max_confidence = confidence_scores[0, predicted_class].item()

        predicted_category = f"{class_names[predicted_class]} (Confidence: {max_confidence:.2f})"

        grayscale_cam = gradcam(input_tensor=processed_tensor, targets=[ClassifierOutputTarget(predicted_class)])[0]

        mask_for_gradcam = cv2.resize(mask_np.astype(np.float32), grayscale_cam.shape[::-1])
        grayscale_cam *= mask_for_gradcam
        grayscale_cam = (grayscale_cam - grayscale_cam.min()) / (grayscale_cam.max() - grayscale_cam.min() + 1e-8)

        cropped_np_original = np.array(cropped_image) / 255.0
        mask_for_original = cv2.resize(mask_np.astype(np.float32), (cropped_image.width, cropped_image.height))
        masked_image_np_original = cropped_np_original * mask_for_original[:, :, np.newaxis]

        gradcam_resized_original = cv2.resize(grayscale_cam, (cropped_image.width, cropped_image.height))

        heatmap_on_image = show_cam_on_image(
            masked_image_np_original,
            gradcam_resized_original,
            use_rgb=True,
            image_weight=0.6
        )

        heatmap_image = Image.fromarray(heatmap_on_image)

        return image_with_box, mask_image, heatmap_image, predicted_category

    except Exception as e:
        return None, None, f"Error processing image: {str(e)}", None

def create_interface():
    """Create Gradio interface"""
    with gr.Blocks(title="Dress Analysis Pipeline (DL Assignment 3)", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
            ## Dress Analysis Pipeline (DL Assignment 3)
            ### Author: Roman Rakov, University of Tรผbingen

            Upload an image to run it through the pipeline:

            1. **Human Detection** โ€“ Bounding box around detected person (*Faster R-CNN with ResNet-50 backbone*)
            2. **Dress Segmentation** โ€“ Extracted dress/clothing region (*U-Net with ResNet-34 encoder*)
            3. **Grad-CAM Heatmap** โ€“ Model attention visualization (*Grad-CAM for attention visualization*)
            4. **Classification** โ€“ Predicted clothing category (*ConvNeXt v2 Base*)
        """)

        with gr.Row():
            with gr.Column(scale=1):
                input_image = gr.Image(label="Upload Image", type="pil", height=500)
                analyze_btn = gr.Button("Analyze Image", variant="primary", size="lg")

                gr.Markdown("### Some (well-behaved) example images to start with..")
                example_images = []
                for i in range(1, 11):
                    example_path = f"examples/example{i}.jpg"
                    if os.path.exists(example_path):
                        example_images.append([example_path])

                if example_images:
                    gr.Examples(examples=example_images, inputs=input_image)
                else:
                    gr.Markdown("*No example images found.*")

            with gr.Column(scale=2):
                with gr.Row():
                    output_detection = gr.Image(label="Human Detection", height=500)
                    output_segmentation = gr.Image(label="Dress Segmentation", height=500)

                with gr.Row():
                    output_gradcam = gr.Image(label="Grad-CAM Heatmap", height=500)
                    output_classification = gr.Textbox(label="Predicted Category", lines=2, max_lines=3, scale=1)

        analyze_btn.click(
            fn=process_image_pipeline,
            inputs=[input_image],
            outputs=[output_detection, output_segmentation, output_gradcam, output_classification]
        )

        input_image.change(
            fn=process_image_pipeline,
            inputs=[input_image],
            outputs=[output_detection, output_segmentation, output_gradcam, output_classification]
        )

    return demo


# Load models when the app starts
print("Loading models...")
load_models()

# Create and launch the interface
if __name__ == "__main__":
    demo = create_interface()
    demo.launch()