Spaces:
Running
Running
import gradio as gr | |
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow import keras | |
from keras.models import load_model | |
from keras.utils.generic_utils import CustomObjectScope | |
# Import custom modules | |
from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling | |
from utils.learning.metrics import dice_coef, precision, recall | |
from utils.io.data import normalize | |
class WoundSegmentationApp: | |
def __init__(self): | |
self.input_dim_x = 224 | |
self.input_dim_y = 224 | |
self.model = None | |
self.load_model() | |
def load_model(self): | |
"""Load the trained wound segmentation model""" | |
try: | |
# Load the model with custom objects | |
weight_file_name = '2025-08-07_12-30-43.hdf5' # Use the most recent model | |
model_path = f'./training_history/{weight_file_name}' | |
self.model = load_model(model_path, | |
custom_objects={ | |
'recall': recall, | |
'precision': precision, | |
'dice_coef': dice_coef, | |
'relu6': relu6, | |
'DepthwiseConv2D': DepthwiseConv2D, | |
'BilinearUpsampling': BilinearUpsampling | |
}) | |
print(f"Model loaded successfully from {model_path}") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# Fallback to the older model if the newer one fails | |
try: | |
weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5' | |
model_path = f'./training_history/{weight_file_name}' | |
self.model = load_model(model_path, | |
custom_objects={ | |
'recall': recall, | |
'precision': precision, | |
'dice_coef': dice_coef, | |
'relu6': relu6, | |
'DepthwiseConv2D': DepthwiseConv2D, | |
'BilinearUpsampling': BilinearUpsampling | |
}) | |
print(f"Model loaded successfully from {model_path}") | |
except Exception as e2: | |
print(f"Error loading fallback model: {e2}") | |
self.model = None | |
def preprocess_image(self, image): | |
"""Preprocess the uploaded image for model input""" | |
if image is None: | |
return None | |
# Convert to RGB if needed | |
if len(image.shape) == 3 and image.shape[2] == 3: | |
# Convert BGR to RGB if needed | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# Resize to model input size | |
image = cv2.resize(image, (self.input_dim_x, self.input_dim_y)) | |
# Normalize the image | |
image = image.astype(np.float32) / 255.0 | |
# Add batch dimension | |
image = np.expand_dims(image, axis=0) | |
return image | |
def postprocess_prediction(self, prediction): | |
"""Postprocess the model prediction""" | |
# Remove batch dimension | |
prediction = prediction[0] | |
# Apply threshold to get binary mask | |
threshold = 0.5 | |
binary_mask = (prediction > threshold).astype(np.uint8) * 255 | |
# Convert to 3-channel image for visualization | |
mask_rgb = cv2.cvtColor(binary_mask, cv2.COLOR_GRAY2RGB) | |
return mask_rgb | |
def segment_wound(self, input_image): | |
"""Main function to segment wound from uploaded image""" | |
if self.model is None: | |
return None, "Error: Model not loaded. Please check the model files." | |
if input_image is None: | |
return None, "Please upload an image." | |
try: | |
# Preprocess the image | |
processed_image = self.preprocess_image(input_image) | |
if processed_image is None: | |
return None, "Error processing image." | |
# Make prediction | |
prediction = self.model.predict(processed_image, verbose=0) | |
# Postprocess the prediction | |
segmented_mask = self.postprocess_prediction(prediction) | |
# Create overlay image (original image with segmentation overlay) | |
original_resized = cv2.resize(input_image, (self.input_dim_x, self.input_dim_y)) | |
if len(original_resized.shape) == 3: | |
original_resized = cv2.cvtColor(original_resized, cv2.COLOR_RGB2BGR) | |
# Create overlay with red segmentation | |
overlay = original_resized.copy() | |
mask_red = np.zeros_like(original_resized) | |
mask_red[:, :, 2] = segmented_mask[:, :, 0] # Red channel | |
# Blend overlay with original image | |
alpha = 0.6 | |
overlay = cv2.addWeighted(overlay, 1-alpha, mask_red, alpha, 0) | |
return segmented_mask, overlay | |
except Exception as e: | |
return None, f"Error during segmentation: {str(e)}" | |
def create_gradio_interface(): | |
"""Create and return the Gradio interface""" | |
# Initialize the app | |
app = WoundSegmentationApp() | |
# Define the interface | |
with gr.Blocks(title="Wound Segmentation Tool", theme=gr.themes.Soft()) as interface: | |
gr.Markdown( | |
""" | |
# 🩹 Wound Segmentation Tool | |
Upload an image of a wound to get an automated segmentation mask. | |
The model will identify and highlight the wound area in the image. | |
**Instructions:** | |
1. Upload an image of a wound | |
2. Click "Segment Wound" to process the image | |
3. View the segmentation mask and overlay results | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Upload Wound Image", | |
type="numpy", | |
height=400 | |
) | |
segment_btn = gr.Button( | |
"🔍 Segment Wound", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(): | |
mask_output = gr.Image( | |
label="Segmentation Mask", | |
height=400 | |
) | |
overlay_output = gr.Image( | |
label="Overlay Result", | |
height=400 | |
) | |
# Status message | |
status_msg = gr.Textbox( | |
label="Status", | |
interactive=False, | |
placeholder="Ready to process images..." | |
) | |
# Example images | |
gr.Markdown("### 📸 Example Images") | |
gr.Markdown("You can test the tool with wound images from the dataset.") | |
# Connect the button to the segmentation function | |
def process_image(image): | |
mask, overlay = app.segment_wound(image) | |
if mask is None: | |
return None, None, overlay # overlay contains error message | |
return mask, overlay, "Segmentation completed successfully!" | |
segment_btn.click( | |
fn=process_image, | |
inputs=[input_image], | |
outputs=[mask_output, overlay_output, status_msg] | |
) | |
# Auto-process when image is uploaded | |
input_image.change( | |
fn=process_image, | |
inputs=[input_image], | |
outputs=[mask_output, overlay_output, status_msg] | |
) | |
return interface | |
if __name__ == "__main__": | |
# Create and launch the interface | |
interface = create_gradio_interface() | |
interface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
show_error=True | |
) |