File size: 11,999 Bytes
06387ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5280a25
06387ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from huggingface_hub import hf_hub_download
import io
import requests

# Your UNET Model Definition
class UNET(nn.Module):
    def __init__(self, dropout_rate=0.1, ch=32):
        super(UNET, self).__init__()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p=dropout_rate),
                nn.Conv2d(out_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p=dropout_rate)
            )
        
        self.encoder1 = conv_block(3, ch)
        self.encoder2 = conv_block(ch, ch*2)
        self.encoder3 = conv_block(ch*2, ch*4)
        self.encoder4 = conv_block(ch*4, ch*8)
        self.bottle_neck = conv_block(ch*8, ch*16)

        self.upsample1 = nn.ConvTranspose2d(ch*16, ch*8, kernel_size=2, stride=2)
        self.decoder1 = conv_block(ch*16, ch*8)
        self.upsample2 = nn.ConvTranspose2d(ch*8, ch*4, kernel_size=2, stride=2)
        self.decoder2 = conv_block(ch*8, ch*4)
        self.upsample3 = nn.ConvTranspose2d(ch*4, ch*2, kernel_size=2, stride=2)
        self.decoder3 = conv_block(ch*4, ch*2)
        self.upsample4 = nn.ConvTranspose2d(ch*2, ch, kernel_size=2, stride=2)
        self.decoder4 = conv_block(ch*2, ch)
        self.final = nn.Conv2d(ch, 1, kernel_size=1)

    def forward(self, x):
        c1 = self.encoder1(x)
        c2 = self.encoder2(self.pool(c1))
        c3 = self.encoder3(self.pool(c2))
        c4 = self.encoder4(self.pool(c3))
        c5 = self.bottle_neck(self.pool(c4))

        u6 = self.upsample1(c5)
        u6 = torch.cat([c4, u6], dim=1)
        c6 = self.decoder1(u6)
        u7 = self.upsample2(c6)
        u7 = torch.cat([c3, u7], dim=1)
        c7 = self.decoder2(u7)
        u8 = self.upsample3(c7)
        u8 = torch.cat([c2, u8], dim=1)
        c8 = self.decoder3(u8)
        u9 = self.upsample4(c8)
        u9 = torch.cat([c1, u9], dim=1)
        c9 = self.decoder4(u9)
        return self.final(c9)

# Global variables
model = None
device = torch.device('cpu')  # HF Spaces use CPU
transform = A.Compose([
    A.Resize(384, 384),
    A.Normalize(mean=(0,0,0), std=(1,1,1), max_pixel_value=255),
    ToTensorV2()
])

def load_model():
    """Load model from your HF repository"""
    global model
    try:
        print("πŸ“₯ Downloading model from Hugging Face...")
        
        # Download your model from HF
        model_path = hf_hub_download(
            repo_id="ibrahim313/unet-adam-diceloss",
            filename="pytorch_model.bin"
        )
        
        # Load model
        model = UNET(ch=32)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        
        print("βœ… Model loaded successfully!")
        return "βœ… Model loaded from ibrahim313/unet-adam-diceloss"
        
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        return f"❌ Error: {e}"

def predict_polyp(image, threshold=0.5):
    """Predict polyp in uploaded image"""
    if model is None:
        return None, "❌ Model not loaded! Please wait for model to load.", None
    
    if image is None:
        return None, "❌ Please upload an image first!", None
    
    try:
        # Convert image to numpy array
        if isinstance(image, Image.Image):
            original_image = np.array(image.convert('RGB'))
        else:
            original_image = np.array(image)
        
        # Preprocess image
        transformed = transform(image=original_image)
        input_tensor = transformed['image'].unsqueeze(0).float()
        
        # Make prediction
        with torch.no_grad():
            prediction = model(input_tensor)
            prediction = torch.sigmoid(prediction)
            prediction = (prediction > threshold).float()
        
        # Convert to numpy
        pred_mask = prediction.squeeze().cpu().numpy()
        
        # Calculate metrics
        polyp_pixels = np.sum(pred_mask)
        total_pixels = pred_mask.shape[0] * pred_mask.shape[1]
        polyp_percentage = (polyp_pixels / total_pixels) * 100
        
        # Create visualization
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Original image
        axes[0].imshow(original_image)
        axes[0].set_title('πŸ–ΌοΈ Original Image', fontsize=14)
        axes[0].axis('off')
        
        # Predicted mask
        axes[1].imshow(pred_mask, cmap='gray')
        axes[1].set_title('🎭 Predicted Mask', fontsize=14)
        axes[1].axis('off')
        
        # Overlay
        axes[2].imshow(original_image)
        axes[2].imshow(pred_mask, cmap='Reds', alpha=0.6)
        axes[2].set_title('πŸ” Detection Overlay', fontsize=14)
        axes[2].axis('off')
        
        # Add main title with results
        if polyp_pixels > 100:
            main_title = f"🚨 POLYP DETECTED! Coverage: {polyp_percentage:.2f}%"
            title_color = 'red'
        else:
            main_title = f"βœ… No Polyp Detected - Coverage: {polyp_percentage:.2f}%"
            title_color = 'green'
        
        fig.suptitle(main_title, fontsize=16, fontweight='bold', color=title_color)
        plt.tight_layout()
        
        # Save plot to image
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
        buf.seek(0)
        result_image = Image.open(buf)
        plt.close()
        
        # Create detailed results text
        if polyp_pixels > 100:
            status_emoji = "🚨"
            status_text = "POLYP DETECTED"
            recommendation = "⚠️ **Recommendation:** Medical review recommended"
        else:
            status_emoji = "βœ…"
            status_text = "NO POLYP DETECTED"
            recommendation = "βœ… **Recommendation:** Continue routine monitoring"
        
        results_text = f"""
## {status_emoji} **{status_text}**

### πŸ“Š **Analysis Results:**
- **Polyp Coverage:** {polyp_percentage:.3f}%
- **Detected Pixels:** {int(polyp_pixels):,} / {total_pixels:,}
- **Detection Threshold:** {threshold}

### πŸ₯ **Clinical Assessment:**
{recommendation}

### πŸ”¬ **Technical Details:**
- **Model:** U-Net (32 channels)
- **Input Size:** 384Γ—384 pixels
- **Architecture:** Encoder-Decoder with skip connections
"""
        
        return result_image, results_text, pred_mask
        
    except Exception as e:
        error_msg = f"❌ **Error processing image:** {str(e)}"
        return None, error_msg, None

def load_example_image(image_num):
    """Load example images from your HF space"""
    try:
        if image_num == 1:
            # Image 1: cju0qoxqj9q6s0835b43399p4.jpg
            image_path = hf_hub_download(
                repo_id="ibrahim313/unet-adam-diceloss",
                filename="cju0qoxqj9q6s0835b43399p4.jpg",
                repo_type="space"
            )
        else:
            # Image 2: cju0roawvklrq0799vmjorwfv.jpg  
            image_path = hf_hub_download(
                repo_id="ibrahim313/unet-adam-diceloss", 
                filename="cju0roawvklrq0799vmjorwfv.jpg",
                repo_type="space"
            )
        
        # Load and return the image
        image = Image.open(image_path)
        return image
        
    except Exception as e:
        print(f"Error loading example image {image_num}: {e}")
        return None

# Load model when app starts
print("πŸš€ Starting Polyp Detection App...")
load_status = load_model()
print(load_status)

# Create Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), title="πŸ₯ Polyp Detection AI") as demo:
    
    # Header
    gr.HTML("""
    <div style="text-align: center; padding: 30px; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px;">
        <h1 style="margin: 0; font-size: 2.5em;">πŸ₯ AI Polyp Detection System</h1>
        <p style="margin: 10px 0 0 0; font-size: 1.2em;">Advanced Medical Imaging with Deep Learning</p>
        <p style="margin: 5px 0 0 0; opacity: 0.9;">Upload colonoscopy images for intelligent polyp detection</p>
    </div>
    """)
    
    # Model info
    gr.HTML(f"""
    <div style="background: black; padding: 15px; border-radius: 8px; border-left: 4px solid #0ea5e9; margin-bottom: 20px;">
        <strong>πŸ”¬ Model:</strong> ibrahim313/unet-adam-diceloss<br>
        <strong>πŸ“ Architecture:</strong> U-Net with 32 base channels<br>
        <strong>🎯 Dataset:</strong> Trained on Kvasir-SEG (1000 polyp images)<br>
        <strong>πŸ“Έ Examples:</strong> 2 test colonoscopy images included<br>
        <strong>⚑ Status:</strong> {load_status}
    </div>
    """)
    
    # Main interface
    with gr.Row():
        with gr.Column(scale=1):
            gr.HTML("<h3>πŸ“€ Upload Image</h3>")
            input_image = gr.Image(
                label="Drop colonoscopy image here",
                type="pil",
                height=300
            )
            
            threshold_slider = gr.Slider(
                minimum=0.1,
                maximum=0.9,
                value=0.5,
                step=0.1,
                label="🎯 Detection Sensitivity",
                info="Higher = more sensitive detection"
            )
            
            analyze_btn = gr.Button(
                "πŸ” Analyze for Polyps",
                variant="primary",
                size="lg"
            )
            
            gr.HTML("<br>")
            
            # Quick examples
            gr.HTML("<h4>πŸ“Έ Try Sample Images:</h4>")
            gr.HTML("<p style='font-size: 0.9em; color: #666; margin: 5px 0;'>Click to load colonoscopy test images</p>")
            with gr.Row():
                example1_btn = gr.Button("πŸ–ΌοΈ Test Image 1", size="sm", variant="secondary")
                example2_btn = gr.Button("πŸ–ΌοΈ Test Image 2", size="sm", variant="secondary")
        
        with gr.Column(scale=2):
            gr.HTML("<h3>πŸ“Š Detection Results</h3>")
            output_image = gr.Image(
                label="Analysis Results",
                height=400
            )
            
            results_text = gr.Markdown(
                value="Upload an image and click 'Analyze for Polyps' to see results.",
                label="Detailed Analysis"
            )
    
    # Event handlers
    analyze_btn.click(
        fn=predict_polyp,
        inputs=[input_image, threshold_slider],
        outputs=[output_image, results_text, gr.State()]
    )
    
    # Example button handlers
    example1_btn.click(
        fn=lambda: load_example_image(1),
        inputs=[],
        outputs=[input_image]
    )
    
    example2_btn.click(
        fn=lambda: load_example_image(2), 
        inputs=[],
        outputs=[input_image]
    )
    
    # Footer
    gr.HTML("""
    <div style="text-align: center; padding: 20px; margin-top: 40px; border-top: 2px solid #e5e7eb; background: #f9fafb;">
        <p style="margin: 0; color: #dc2626; font-weight: bold;">
            ⚠️ MEDICAL DISCLAIMER
        </p>
        <p style="margin: 5px 0; color: #4b5563;">
            This AI system is for research and educational purposes only.<br>
            Always consult qualified medical professionals for clinical decisions.
        </p>
        <p style="margin: 10px 0 0 0; color: #6b7280; font-size: 0.9em;">
            πŸ”¬ Powered by PyTorch | πŸ€— Hosted on Hugging Face | πŸ“Š Gradio Interface
        </p>
    </div>
    """)

# Launch the app
if __name__ == "__main__":
    demo.launch()