Wound-Analysis-LE / temp_files /segmentation_app.py
Rakhi-2025's picture
Upload 95 files
911c613 verified
import gradio as gr
import cv2
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras.models import load_model
from keras.utils.generic_utils import CustomObjectScope
# Import custom modules
from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling
from utils.learning.metrics import dice_coef, precision, recall
from utils.io.data import normalize
class WoundSegmentationApp:
def __init__(self):
self.input_dim_x = 224
self.input_dim_y = 224
self.model = None
self.load_model()
def load_model(self):
"""Load the trained wound segmentation model"""
try:
# Load the model with custom objects
weight_file_name = '2025-08-07_12-30-43.hdf5' # Use the most recent model
model_path = f'./training_history/{weight_file_name}'
self.model = load_model(model_path,
custom_objects={
'recall': recall,
'precision': precision,
'dice_coef': dice_coef,
'relu6': relu6,
'DepthwiseConv2D': DepthwiseConv2D,
'BilinearUpsampling': BilinearUpsampling
})
print(f"Model loaded successfully from {model_path}")
except Exception as e:
print(f"Error loading model: {e}")
# Fallback to the older model if the newer one fails
try:
weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5'
model_path = f'./training_history/{weight_file_name}'
self.model = load_model(model_path,
custom_objects={
'recall': recall,
'precision': precision,
'dice_coef': dice_coef,
'relu6': relu6,
'DepthwiseConv2D': DepthwiseConv2D,
'BilinearUpsampling': BilinearUpsampling
})
print(f"Model loaded successfully from {model_path}")
except Exception as e2:
print(f"Error loading fallback model: {e2}")
self.model = None
def preprocess_image(self, image):
"""Preprocess the uploaded image for model input"""
if image is None:
return None
# Convert to RGB if needed
if len(image.shape) == 3 and image.shape[2] == 3:
# Convert BGR to RGB if needed
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Resize to model input size
image = cv2.resize(image, (self.input_dim_x, self.input_dim_y))
# Normalize the image
image = image.astype(np.float32) / 255.0
# Add batch dimension
image = np.expand_dims(image, axis=0)
return image
def postprocess_prediction(self, prediction):
"""Postprocess the model prediction"""
# Remove batch dimension
prediction = prediction[0]
# Apply threshold to get binary mask
threshold = 0.5
binary_mask = (prediction > threshold).astype(np.uint8) * 255
# Convert to 3-channel image for visualization
mask_rgb = cv2.cvtColor(binary_mask, cv2.COLOR_GRAY2RGB)
return mask_rgb
def segment_wound(self, input_image):
"""Main function to segment wound from uploaded image"""
if self.model is None:
return None, "Error: Model not loaded. Please check the model files."
if input_image is None:
return None, "Please upload an image."
try:
# Preprocess the image
processed_image = self.preprocess_image(input_image)
if processed_image is None:
return None, "Error processing image."
# Make prediction
prediction = self.model.predict(processed_image, verbose=0)
# Postprocess the prediction
segmented_mask = self.postprocess_prediction(prediction)
# Create overlay image (original image with segmentation overlay)
original_resized = cv2.resize(input_image, (self.input_dim_x, self.input_dim_y))
if len(original_resized.shape) == 3:
original_resized = cv2.cvtColor(original_resized, cv2.COLOR_RGB2BGR)
# Create overlay with red segmentation
overlay = original_resized.copy()
mask_red = np.zeros_like(original_resized)
mask_red[:, :, 2] = segmented_mask[:, :, 0] # Red channel
# Blend overlay with original image
alpha = 0.6
overlay = cv2.addWeighted(overlay, 1-alpha, mask_red, alpha, 0)
return segmented_mask, overlay
except Exception as e:
return None, f"Error during segmentation: {str(e)}"
def create_gradio_interface():
"""Create and return the Gradio interface"""
# Initialize the app
app = WoundSegmentationApp()
# Define the interface
with gr.Blocks(title="Wound Segmentation Tool", theme=gr.themes.Soft()) as interface:
gr.Markdown(
"""
# 🩹 Wound Segmentation Tool
Upload an image of a wound to get an automated segmentation mask.
The model will identify and highlight the wound area in the image.
**Instructions:**
1. Upload an image of a wound
2. Click "Segment Wound" to process the image
3. View the segmentation mask and overlay results
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload Wound Image",
type="numpy",
height=400
)
segment_btn = gr.Button(
"🔍 Segment Wound",
variant="primary",
size="lg"
)
with gr.Column():
mask_output = gr.Image(
label="Segmentation Mask",
height=400
)
overlay_output = gr.Image(
label="Overlay Result",
height=400
)
# Status message
status_msg = gr.Textbox(
label="Status",
interactive=False,
placeholder="Ready to process images..."
)
# Example images
gr.Markdown("### 📸 Example Images")
gr.Markdown("You can test the tool with wound images from the dataset.")
# Connect the button to the segmentation function
def process_image(image):
mask, overlay = app.segment_wound(image)
if mask is None:
return None, None, overlay # overlay contains error message
return mask, overlay, "Segmentation completed successfully!"
segment_btn.click(
fn=process_image,
inputs=[input_image],
outputs=[mask_output, overlay_output, status_msg]
)
# Auto-process when image is uploaded
input_image.change(
fn=process_image,
inputs=[input_image],
outputs=[mask_output, overlay_output, status_msg]
)
return interface
if __name__ == "__main__":
# Create and launch the interface
interface = create_gradio_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)