import gradio as gr import cv2 import numpy as np import tensorflow as tf from tensorflow import keras from keras.models import load_model from keras.utils.generic_utils import CustomObjectScope # Import custom modules from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling from utils.learning.metrics import dice_coef, precision, recall from utils.io.data import normalize class WoundSegmentationApp: def __init__(self): self.input_dim_x = 224 self.input_dim_y = 224 self.model = None self.load_model() def load_model(self): """Load the trained wound segmentation model""" try: # Load the model with custom objects weight_file_name = '2025-08-07_12-30-43.hdf5' # Use the most recent model model_path = f'./training_history/{weight_file_name}' self.model = load_model(model_path, custom_objects={ 'recall': recall, 'precision': precision, 'dice_coef': dice_coef, 'relu6': relu6, 'DepthwiseConv2D': DepthwiseConv2D, 'BilinearUpsampling': BilinearUpsampling }) print(f"Model loaded successfully from {model_path}") except Exception as e: print(f"Error loading model: {e}") # Fallback to the older model if the newer one fails try: weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5' model_path = f'./training_history/{weight_file_name}' self.model = load_model(model_path, custom_objects={ 'recall': recall, 'precision': precision, 'dice_coef': dice_coef, 'relu6': relu6, 'DepthwiseConv2D': DepthwiseConv2D, 'BilinearUpsampling': BilinearUpsampling }) print(f"Model loaded successfully from {model_path}") except Exception as e2: print(f"Error loading fallback model: {e2}") self.model = None def preprocess_image(self, image): """Preprocess the uploaded image for model input""" if image is None: return None # Convert to RGB if needed if len(image.shape) == 3 and image.shape[2] == 3: # Convert BGR to RGB if needed image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Resize to model input size image = cv2.resize(image, (self.input_dim_x, self.input_dim_y)) # Normalize the image image = image.astype(np.float32) / 255.0 # Add batch dimension image = np.expand_dims(image, axis=0) return image def postprocess_prediction(self, prediction): """Postprocess the model prediction""" # Remove batch dimension prediction = prediction[0] # Apply threshold to get binary mask threshold = 0.5 binary_mask = (prediction > threshold).astype(np.uint8) * 255 # Convert to 3-channel image for visualization mask_rgb = cv2.cvtColor(binary_mask, cv2.COLOR_GRAY2RGB) return mask_rgb def segment_wound(self, input_image): """Main function to segment wound from uploaded image""" if self.model is None: return None, "Error: Model not loaded. Please check the model files." if input_image is None: return None, "Please upload an image." try: # Preprocess the image processed_image = self.preprocess_image(input_image) if processed_image is None: return None, "Error processing image." # Make prediction prediction = self.model.predict(processed_image, verbose=0) # Postprocess the prediction segmented_mask = self.postprocess_prediction(prediction) # Create overlay image (original image with segmentation overlay) original_resized = cv2.resize(input_image, (self.input_dim_x, self.input_dim_y)) if len(original_resized.shape) == 3: original_resized = cv2.cvtColor(original_resized, cv2.COLOR_RGB2BGR) # Create overlay with red segmentation overlay = original_resized.copy() mask_red = np.zeros_like(original_resized) mask_red[:, :, 2] = segmented_mask[:, :, 0] # Red channel # Blend overlay with original image alpha = 0.6 overlay = cv2.addWeighted(overlay, 1-alpha, mask_red, alpha, 0) return segmented_mask, overlay except Exception as e: return None, f"Error during segmentation: {str(e)}" def create_gradio_interface(): """Create and return the Gradio interface""" # Initialize the app app = WoundSegmentationApp() # Define the interface with gr.Blocks(title="Wound Segmentation Tool", theme=gr.themes.Soft()) as interface: gr.Markdown( """ # 🩹 Wound Segmentation Tool Upload an image of a wound to get an automated segmentation mask. The model will identify and highlight the wound area in the image. **Instructions:** 1. Upload an image of a wound 2. Click "Segment Wound" to process the image 3. View the segmentation mask and overlay results """ ) with gr.Row(): with gr.Column(): input_image = gr.Image( label="Upload Wound Image", type="numpy", height=400 ) segment_btn = gr.Button( "🔍 Segment Wound", variant="primary", size="lg" ) with gr.Column(): mask_output = gr.Image( label="Segmentation Mask", height=400 ) overlay_output = gr.Image( label="Overlay Result", height=400 ) # Status message status_msg = gr.Textbox( label="Status", interactive=False, placeholder="Ready to process images..." ) # Example images gr.Markdown("### 📸 Example Images") gr.Markdown("You can test the tool with wound images from the dataset.") # Connect the button to the segmentation function def process_image(image): mask, overlay = app.segment_wound(image) if mask is None: return None, None, overlay # overlay contains error message return mask, overlay, "Segmentation completed successfully!" segment_btn.click( fn=process_image, inputs=[input_image], outputs=[mask_output, overlay_output, status_msg] ) # Auto-process when image is uploaded input_image.change( fn=process_image, inputs=[input_image], outputs=[mask_output, overlay_output, status_msg] ) return interface if __name__ == "__main__": # Create and launch the interface interface = create_gradio_interface() interface.launch( server_name="0.0.0.0", server_port=7860, share=True, show_error=True )