import gradio as gr import torch import torch.nn as nn import numpy as np from PIL import Image import matplotlib.pyplot as plt import albumentations as A from albumentations.pytorch import ToTensorV2 from huggingface_hub import hf_hub_download import io import requests # Your UNET Model Definition class UNET(nn.Module): def __init__(self, dropout_rate=0.1, ch=32): super(UNET, self).__init__() self.pool = nn.MaxPool2d(kernel_size=2, stride=2) def conv_block(in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Dropout2d(p=dropout_rate), nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Dropout2d(p=dropout_rate) ) self.encoder1 = conv_block(3, ch) self.encoder2 = conv_block(ch, ch*2) self.encoder3 = conv_block(ch*2, ch*4) self.encoder4 = conv_block(ch*4, ch*8) self.bottle_neck = conv_block(ch*8, ch*16) self.upsample1 = nn.ConvTranspose2d(ch*16, ch*8, kernel_size=2, stride=2) self.decoder1 = conv_block(ch*16, ch*8) self.upsample2 = nn.ConvTranspose2d(ch*8, ch*4, kernel_size=2, stride=2) self.decoder2 = conv_block(ch*8, ch*4) self.upsample3 = nn.ConvTranspose2d(ch*4, ch*2, kernel_size=2, stride=2) self.decoder3 = conv_block(ch*4, ch*2) self.upsample4 = nn.ConvTranspose2d(ch*2, ch, kernel_size=2, stride=2) self.decoder4 = conv_block(ch*2, ch) self.final = nn.Conv2d(ch, 1, kernel_size=1) def forward(self, x): c1 = self.encoder1(x) c2 = self.encoder2(self.pool(c1)) c3 = self.encoder3(self.pool(c2)) c4 = self.encoder4(self.pool(c3)) c5 = self.bottle_neck(self.pool(c4)) u6 = self.upsample1(c5) u6 = torch.cat([c4, u6], dim=1) c6 = self.decoder1(u6) u7 = self.upsample2(c6) u7 = torch.cat([c3, u7], dim=1) c7 = self.decoder2(u7) u8 = self.upsample3(c7) u8 = torch.cat([c2, u8], dim=1) c8 = self.decoder3(u8) u9 = self.upsample4(c8) u9 = torch.cat([c1, u9], dim=1) c9 = self.decoder4(u9) return self.final(c9) # Global variables model = None device = torch.device('cpu') # HF Spaces use CPU transform = A.Compose([ A.Resize(384, 384), A.Normalize(mean=(0,0,0), std=(1,1,1), max_pixel_value=255), ToTensorV2() ]) def load_model(): """Load model from your HF repository""" global model try: print("📥 Downloading model from Hugging Face...") # Download your model from HF model_path = hf_hub_download( repo_id="ibrahim313/unet-adam-diceloss", filename="pytorch_model.bin" ) # Load model model = UNET(ch=32) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() print("✅ Model loaded successfully!") return "✅ Model loaded from ibrahim313/unet-adam-diceloss" except Exception as e: print(f"❌ Error loading model: {e}") return f"❌ Error: {e}" def predict_polyp(image, threshold=0.5): """Predict polyp in uploaded image""" if model is None: return None, "❌ Model not loaded! Please wait for model to load.", None if image is None: return None, "❌ Please upload an image first!", None try: # Convert image to numpy array if isinstance(image, Image.Image): original_image = np.array(image.convert('RGB')) else: original_image = np.array(image) # Preprocess image transformed = transform(image=original_image) input_tensor = transformed['image'].unsqueeze(0).float() # Make prediction with torch.no_grad(): prediction = model(input_tensor) prediction = torch.sigmoid(prediction) prediction = (prediction > threshold).float() # Convert to numpy pred_mask = prediction.squeeze().cpu().numpy() # Calculate metrics polyp_pixels = np.sum(pred_mask) total_pixels = pred_mask.shape[0] * pred_mask.shape[1] polyp_percentage = (polyp_pixels / total_pixels) * 100 # Create visualization fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # Original image axes[0].imshow(original_image) axes[0].set_title('🖼️ Original Image', fontsize=14) axes[0].axis('off') # Predicted mask axes[1].imshow(pred_mask, cmap='gray') axes[1].set_title('🎭 Predicted Mask', fontsize=14) axes[1].axis('off') # Overlay axes[2].imshow(original_image) axes[2].imshow(pred_mask, cmap='Reds', alpha=0.6) axes[2].set_title('🔍 Detection Overlay', fontsize=14) axes[2].axis('off') # Add main title with results if polyp_pixels > 100: main_title = f"🚨 POLYP DETECTED! Coverage: {polyp_percentage:.2f}%" title_color = 'red' else: main_title = f"✅ No Polyp Detected - Coverage: {polyp_percentage:.2f}%" title_color = 'green' fig.suptitle(main_title, fontsize=16, fontweight='bold', color=title_color) plt.tight_layout() # Save plot to image buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') buf.seek(0) result_image = Image.open(buf) plt.close() # Create detailed results text if polyp_pixels > 100: status_emoji = "🚨" status_text = "POLYP DETECTED" recommendation = "⚠️ **Recommendation:** Medical review recommended" else: status_emoji = "✅" status_text = "NO POLYP DETECTED" recommendation = "✅ **Recommendation:** Continue routine monitoring" results_text = f""" ## {status_emoji} **{status_text}** ### 📊 **Analysis Results:** - **Polyp Coverage:** {polyp_percentage:.3f}% - **Detected Pixels:** {int(polyp_pixels):,} / {total_pixels:,} - **Detection Threshold:** {threshold} ### 🏥 **Clinical Assessment:** {recommendation} ### 🔬 **Technical Details:** - **Model:** U-Net (32 channels) - **Input Size:** 384×384 pixels - **Architecture:** Encoder-Decoder with skip connections """ return result_image, results_text, pred_mask except Exception as e: error_msg = f"❌ **Error processing image:** {str(e)}" return None, error_msg, None def load_example_image(image_num): """Load example images from your HF space""" try: if image_num == 1: # Image 1: cju0qoxqj9q6s0835b43399p4.jpg image_path = hf_hub_download( repo_id="ibrahim313/unet-adam-diceloss", filename="cju0qoxqj9q6s0835b43399p4.jpg", repo_type="space" ) else: # Image 2: cju0roawvklrq0799vmjorwfv.jpg image_path = hf_hub_download( repo_id="ibrahim313/unet-adam-diceloss", filename="cju0roawvklrq0799vmjorwfv.jpg", repo_type="space" ) # Load and return the image image = Image.open(image_path) return image except Exception as e: print(f"Error loading example image {image_num}: {e}") return None # Load model when app starts print("🚀 Starting Polyp Detection App...") load_status = load_model() print(load_status) # Create Gradio Interface with gr.Blocks(theme=gr.themes.Soft(), title="🏥 Polyp Detection AI") as demo: # Header gr.HTML("""

🏥 AI Polyp Detection System

Advanced Medical Imaging with Deep Learning

Upload colonoscopy images for intelligent polyp detection

""") # Model info gr.HTML(f"""
🔬 Model: ibrahim313/unet-adam-diceloss
📏 Architecture: U-Net with 32 base channels
🎯 Dataset: Trained on Kvasir-SEG (1000 polyp images)
📸 Examples: 2 test colonoscopy images included
⚡ Status: {load_status}
""") # Main interface with gr.Row(): with gr.Column(scale=1): gr.HTML("

📤 Upload Image

") input_image = gr.Image( label="Drop colonoscopy image here", type="pil", height=300 ) threshold_slider = gr.Slider( minimum=0.1, maximum=0.9, value=0.5, step=0.1, label="🎯 Detection Sensitivity", info="Higher = more sensitive detection" ) analyze_btn = gr.Button( "🔍 Analyze for Polyps", variant="primary", size="lg" ) gr.HTML("
") # Quick examples gr.HTML("

📸 Try Sample Images:

") gr.HTML("

Click to load colonoscopy test images

") with gr.Row(): example1_btn = gr.Button("🖼️ Test Image 1", size="sm", variant="secondary") example2_btn = gr.Button("🖼️ Test Image 2", size="sm", variant="secondary") with gr.Column(scale=2): gr.HTML("

📊 Detection Results

") output_image = gr.Image( label="Analysis Results", height=400 ) results_text = gr.Markdown( value="Upload an image and click 'Analyze for Polyps' to see results.", label="Detailed Analysis" ) # Event handlers analyze_btn.click( fn=predict_polyp, inputs=[input_image, threshold_slider], outputs=[output_image, results_text, gr.State()] ) # Example button handlers example1_btn.click( fn=lambda: load_example_image(1), inputs=[], outputs=[input_image] ) example2_btn.click( fn=lambda: load_example_image(2), inputs=[], outputs=[input_image] ) # Footer gr.HTML("""

⚠️ MEDICAL DISCLAIMER

This AI system is for research and educational purposes only.
Always consult qualified medical professionals for clinical decisions.

🔬 Powered by PyTorch | 🤗 Hosted on Hugging Face | 📊 Gradio Interface

""") # Launch the app if __name__ == "__main__": demo.launch()