import glob import gradio as gr import matplotlib import numpy as np from PIL import Image import torch import tempfile from gradio_imageslider import ImageSlider import plotly.graph_objects as go import plotly.express as px import open3d as o3d from depth_anything_v2.dpt import DepthAnythingV2 import os import tensorflow as tf from tensorflow.keras.models import load_model # Classification imports from transformers import AutoImageProcessor, AutoModelForImageClassification import google.generativeai as genai import gdown import spaces import cv2 # Import actual segmentation model components from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling from utils.learning.metrics import dice_coef, precision, recall from utils.io.data import normalize # --- Classification Model Setup --- # Load classification model and processor classification_processor = AutoImageProcessor.from_pretrained("Hemg/Wound-classification") classification_model = AutoModelForImageClassification.from_pretrained("Hemg/Wound-classification") # Configure Gemini AI try: # Try to get API key from Hugging Face secrets gemini_api_key = os.getenv("GOOGLE_API_KEY") if not gemini_api_key: raise ValueError("GEMINI_API_KEY not found in environment variables") genai.configure(api_key=gemini_api_key) gemini_model = genai.GenerativeModel("gemini-2.5-pro") print("✅ Gemini AI configured successfully with API key from secrets") except Exception as e: print(f"❌ Error configuring Gemini AI: {e}") print("Please make sure GEMINI_API_KEY is set in your Hugging Face Space secrets") gemini_model = None # --- Classification Functions --- def analyze_wound_with_gemini(image, predicted_label): """ Analyze wound image using Gemini AI with classification context Args: image: PIL Image predicted_label: The predicted wound type from classification model Returns: str: Gemini AI analysis """ if image is None: return "No image provided for analysis." if gemini_model is None: return "Gemini AI is not available. Please check that GEMINI_API_KEY is properly configured in your Hugging Face Space secrets." try: # Ensure image is in RGB format if image.mode != 'RGB': image = image.convert('RGB') # Create prompt that includes the classification result prompt = f"""You are assisting in a medical education and research task. Based on the wound classification model, this image has been identified as: {predicted_label} Please provide an educational analysis of this wound image focusing on: 1. Visible characteristics of the wound (size, color, texture, edges, surrounding tissue) 2. Educational explanation about this type of wound based on the classification: {predicted_label} 3. General wound healing stages if applicable 4. Key features that are typically associated with this wound type Important guidelines: - This is for educational and research purposes only - Do not provide medical advice or diagnosis - Keep the analysis objective and educational - Focus on visible features and general wound characteristics - Do not recommend treatments or medical interventions Please provide a comprehensive educational analysis.""" response = gemini_model.generate_content([prompt, image]) return response.text except Exception as e: return f"Error analyzing image with Gemini: {str(e)}" def analyze_wound_depth_with_gemini(image, depth_map, depth_stats): """ Analyze wound depth and severity using Gemini AI with depth analysis context Args: image: Original wound image (PIL Image or numpy array) depth_map: Depth map (numpy array) depth_stats: Dictionary containing depth analysis statistics Returns: str: Gemini AI medical assessment based on depth analysis """ if image is None or depth_map is None: return "No image or depth map provided for analysis." if gemini_model is None: return "Gemini AI is not available. Please check that GEMINI_API_KEY is properly configured in your Hugging Face Space secrets." try: # Convert numpy array to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Ensure image is in RGB format if image.mode != 'RGB': image = image.convert('RGB') # Convert depth map to PIL Image for Gemini if isinstance(depth_map, np.ndarray): # Normalize depth map for visualization norm_depth = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0 depth_image = Image.fromarray(norm_depth.astype(np.uint8)) else: depth_image = depth_map # Create detailed prompt with depth statistics prompt = f"""You are a medical AI assistant specializing in wound assessment. Analyze this wound using both the original image and depth map data. DEPTH ANALYSIS DATA PROVIDED: - Total Wound Area: {depth_stats['total_area_cm2']:.2f} cm² - Mean Depth: {depth_stats['mean_depth_mm']:.1f} mm - Maximum Depth: {depth_stats['max_depth_mm']:.1f} mm - Depth Standard Deviation: {depth_stats['depth_std_mm']:.1f} mm - Wound Volume: {depth_stats['wound_volume_cm3']:.2f} cm³ - Deep Tissue Involvement: {depth_stats['deep_ratio']*100:.1f}% - Analysis Quality: {depth_stats['analysis_quality']} - Depth Consistency: {depth_stats['depth_consistency']} TISSUE DEPTH DISTRIBUTION: - Superficial Areas (0-2mm): {depth_stats['superficial_area_cm2']:.2f} cm² - Partial Thickness (2-4mm): {depth_stats['partial_thickness_area_cm2']:.2f} cm² - Full Thickness (4-6mm): {depth_stats['full_thickness_area_cm2']:.2f} cm² - Deep Areas (>6mm): {depth_stats['deep_area_cm2']:.2f} cm² STATISTICAL DEPTH ANALYSIS: - 25th Percentile Depth: {depth_stats['depth_percentiles']['25']:.1f} mm - Median Depth: {depth_stats['depth_percentiles']['50']:.1f} mm - 75th Percentile Depth: {depth_stats['depth_percentiles']['75']:.1f} mm Please provide a comprehensive medical assessment focusing on: 1. **WOUND CHARACTERISTICS ANALYSIS** - Visible wound features from the original image - Correlation between visual appearance and depth measurements - Tissue quality assessment based on color, texture, and depth data 2. **DEPTH-BASED SEVERITY ASSESSMENT** - Clinical significance of the measured depths - Tissue layer involvement based on depth measurements - Risk assessment based on deep tissue involvement percentage 3. **HEALING PROGNOSIS** - Expected healing timeline based on depth and area measurements - Factors that may affect healing based on depth distribution - Complexity assessment based on wound volume and depth variation 4. **CLINICAL CONSIDERATIONS** - Significance of depth consistency/inconsistency - Areas of particular concern based on depth analysis - Educational insights about this type of wound presentation 5. **MEASUREMENT INTERPRETATION** - Clinical relevance of the statistical depth measurements - What the depth distribution tells us about wound progression - Comparison to typical wound depth classifications IMPORTANT GUIDELINES: - This is for educational and research purposes only - Do not provide specific medical advice or treatment recommendations - Focus on objective analysis of the provided measurements - Correlate visual findings with quantitative depth data - Maintain educational and clinical terminology - Emphasize the relationship between depth measurements and clinical significance Provide a detailed, structured medical assessment that integrates both visual and quantitative depth analysis.""" # Send both images to Gemini for analysis response = gemini_model.generate_content([prompt, image, depth_image]) return response.text except Exception as e: return f"Error analyzing wound with Gemini AI: {str(e)}" def classify_wound(image): """ Classify wound type from uploaded image Args: image: PIL Image or numpy array Returns: dict: Classification results with confidence scores """ if image is None: return "Please upload an image" # Convert to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Ensure image is in RGB format if image.mode != 'RGB': image = image.convert('RGB') try: # Process the image inputs = classification_processor(images=image, return_tensors="pt") # Get model predictions with torch.no_grad(): outputs = classification_model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits[0], dim=-1) # Get the predicted class labels and confidence scores confidence_scores = predictions.numpy() # Create results dictionary results = {} for i, score in enumerate(confidence_scores): # Get class name from model config class_name = classification_model.config.id2label[i] if hasattr(classification_model.config, 'id2label') else f"Class {i}" results[class_name] = float(score) return results except Exception as e: return f"Error processing image: {str(e)}" def classify_and_analyze_wound(image): """ Combined function to classify wound and get Gemini analysis Args: image: PIL Image or numpy array Returns: tuple: (classification_results, gemini_analysis) """ if image is None: return "Please upload an image", "Please upload an image for analysis" # Get classification results classification_results = classify_wound(image) # Get the top predicted label for Gemini analysis if isinstance(classification_results, dict) and classification_results: # Get the label with highest confidence top_label = max(classification_results.items(), key=lambda x: x[1])[0] # Get Gemini analysis gemini_analysis = analyze_wound_with_gemini(image, top_label) else: top_label = "Unknown" gemini_analysis = "Unable to analyze due to classification error" return classification_results, gemini_analysis def format_gemini_analysis(analysis): """Format Gemini analysis as properly structured HTML""" if not analysis or "Error" in analysis: return f"""

Analysis Error

{analysis}

""" # Parse the markdown-style response and convert to HTML formatted_analysis = parse_markdown_to_html(analysis) return f"""

Initial Wound Analysis

{formatted_analysis}
""" def format_gemini_depth_analysis(analysis): """Format Gemini depth analysis as properly structured HTML for medical assessment""" if not analysis or "Error" in analysis: return f"""
❌ AI Analysis Error
{analysis}
""" # Parse the markdown-style response and convert to HTML formatted_analysis = parse_markdown_to_html(analysis) return f"""
🤖 AI-Powered Medical Assessment
{formatted_analysis}
""" def parse_markdown_to_html(text): """Convert markdown-style text to HTML""" import re # Replace markdown headers text = re.sub(r'^### \*\*(.*?)\*\*$', r'

\1

', text, flags=re.MULTILINE) text = re.sub(r'^#### \*\*(.*?)\*\*$', r'
\1
', text, flags=re.MULTILINE) text = re.sub(r'^### (.*?)$', r'

\1

', text, flags=re.MULTILINE) text = re.sub(r'^#### (.*?)$', r'
\1
', text, flags=re.MULTILINE) # Replace bold text text = re.sub(r'\*\*(.*?)\*\*', r'\1', text) # Replace italic text text = re.sub(r'\*(.*?)\*', r'\1', text) # Replace bullet points text = re.sub(r'^\* (.*?)$', r'
  • \1
  • ', text, flags=re.MULTILINE) text = re.sub(r'^ \* (.*?)$', r'
  • \1
  • ', text, flags=re.MULTILINE) # Wrap consecutive list items in ul tags text = re.sub(r'((?:\s*)*)', r'', text, flags=re.DOTALL) # Replace numbered lists text = re.sub(r'^(\d+)\.\s+(.*?)$', r'
    \1. \2
    ', text, flags=re.MULTILINE) # Convert paragraphs (double newlines) paragraphs = text.split('\n\n') formatted_paragraphs = [] for para in paragraphs: para = para.strip() if para: # Skip if it's already wrapped in HTML tags if not (para.startswith('<') or para.endswith('>')): para = f'

    {para}

    ' formatted_paragraphs.append(para) return '\n'.join(formatted_paragraphs) def combined_analysis(image): """Combined function for UI that returns both outputs""" classification, gemini_analysis = classify_and_analyze_wound(image) formatted_analysis = format_gemini_analysis(gemini_analysis) return classification, formatted_analysis # Define path and file ID checkpoint_dir = "checkpoints" os.makedirs(checkpoint_dir, exist_ok=True) model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth") gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5" # Download if not already present if not os.path.exists(model_file): print("Downloading model from Google Drive...") gdown.download(gdrive_url, model_file, quiet=False) # --- TensorFlow: Check GPU Availability --- gpus = tf.config.list_physical_devices('GPU') if gpus: print("TensorFlow is using GPU") else: print("TensorFlow is using CPU") # --- Load Actual Wound Segmentation Model --- class WoundSegmentationModel: def __init__(self): self.input_dim_x = 224 self.input_dim_y = 224 self.model = None self.load_model() def load_model(self): """Load the trained wound segmentation model""" try: # Try to load the most recent model weight_file_name = '2025-08-07_16-25-27.hdf5' model_path = f'./training_history/{weight_file_name}' self.model = load_model(model_path, custom_objects={ 'recall': recall, 'precision': precision, 'dice_coef': dice_coef, 'relu6': relu6, 'DepthwiseConv2D': DepthwiseConv2D, 'BilinearUpsampling': BilinearUpsampling }) print(f"Segmentation model loaded successfully from {model_path}") except Exception as e: print(f"Error loading segmentation model: {e}") # Fallback to the older model try: weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5' model_path = f'./training_history/{weight_file_name}' self.model = load_model(model_path, custom_objects={ 'recall': recall, 'precision': precision, 'dice_coef': dice_coef, 'relu6': relu6, 'DepthwiseConv2D': DepthwiseConv2D, 'BilinearUpsampling': BilinearUpsampling }) print(f"Segmentation model loaded successfully from {model_path}") except Exception as e2: print(f"Error loading fallback segmentation model: {e2}") self.model = None def preprocess_image(self, image): """Preprocess the uploaded image for model input""" if image is None: return None # Convert to RGB if needed if len(image.shape) == 3 and image.shape[2] == 3: # Convert BGR to RGB if needed image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Resize to model input size image = cv2.resize(image, (self.input_dim_x, self.input_dim_y)) # Normalize the image image = image.astype(np.float32) / 255.0 # Add batch dimension image = np.expand_dims(image, axis=0) return image def postprocess_prediction(self, prediction): """Postprocess the model prediction""" # Remove batch dimension prediction = prediction[0] # Apply threshold to get binary mask threshold = 0.5 binary_mask = (prediction > threshold).astype(np.uint8) * 255 return binary_mask def segment_wound(self, input_image): """Main function to segment wound from uploaded image""" if self.model is None: return None, "Error: Segmentation model not loaded. Please check the model files." if input_image is None: return None, "Please upload an image." try: # Preprocess the image processed_image = self.preprocess_image(input_image) if processed_image is None: return None, "Error processing image." # Make prediction prediction = self.model.predict(processed_image, verbose=0) # Postprocess the prediction segmented_mask = self.postprocess_prediction(prediction) return segmented_mask, "Segmentation completed successfully!" except Exception as e: return None, f"Error during segmentation: {str(e)}" # Initialize the segmentation model segmentation_model = WoundSegmentationModel() # --- PyTorch: Set Device and Load Depth Model --- map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu") print(f"Using PyTorch device: {map_device}") model_configs = { 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} } encoder = 'vitl' depth_model = DepthAnythingV2(**model_configs[encoder]) state_dict = torch.load( f'checkpoints/depth_anything_v2_{encoder}.pth', map_location=map_device ) depth_model.load_state_dict(state_dict) depth_model = depth_model.to(map_device).eval() # --- Custom CSS for unified dark theme --- css = """ .gradio-container { font-family: 'Segoe UI', sans-serif; background-color: #121212; color: #ffffff; padding: 20px; } .gr-button { background-color: #2c3e50; color: white; border-radius: 10px; } .gr-button:hover { background-color: #34495e; } .gr-html, .gr-html div { white-space: normal !important; overflow: visible !important; text-overflow: unset !important; word-break: break-word !important; } #img-display-container { max-height: 100vh; } #img-display-input { max-height: 80vh; } #img-display-output { max-height: 80vh; } #download { height: 62px; } h1 { text-align: center; font-size: 3rem; font-weight: bold; margin: 2rem 0; color: #ffffff; } h2 { color: #ffffff; text-align: center; margin: 1rem 0; } .gr-tabs { background-color: #1e1e1e; border-radius: 10px; padding: 10px; } .gr-tab-nav { background-color: #2c3e50; border-radius: 8px; } .gr-tab-nav button { color: #ffffff !important; } .gr-tab-nav button.selected { background-color: #34495e !important; } /* Card styling for consistent heights */ .wound-card { min-height: 200px !important; display: flex !important; flex-direction: column !important; justify-content: space-between !important; } .wound-card-content { flex-grow: 1 !important; display: flex !important; flex-direction: column !important; justify-content: center !important; } /* Loading animation */ .loading-spinner { display: inline-block; width: 20px; height: 20px; border: 3px solid #f3f3f3; border-top: 3px solid #3498db; border-radius: 50%; animation: spin 1s linear infinite; } @keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } } """ # --- Enhanced Wound Severity Estimation Functions --- def compute_enhanced_depth_statistics(depth_map, mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0): """ Enhanced depth analysis with proper calibration and medical standards Based on wound depth classification standards: - Superficial: 0-2mm (epidermis only) - Partial thickness: 2-4mm (epidermis + partial dermis) - Full thickness: 4-6mm (epidermis + full dermis) - Deep: >6mm (involving subcutaneous tissue) """ # Convert pixel spacing to mm pixel_spacing_mm = float(pixel_spacing_mm) # Calculate pixel area in cm² pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2 # Extract wound region (binary mask) wound_mask = (mask > 127).astype(np.uint8) # Apply morphological operations to clean the mask kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) wound_mask = cv2.morphologyEx(wound_mask, cv2.MORPH_CLOSE, kernel) # Get depth values only for wound region wound_depths = depth_map[wound_mask > 0] if len(wound_depths) == 0: return { 'total_area_cm2': 0, 'superficial_area_cm2': 0, 'partial_thickness_area_cm2': 0, 'full_thickness_area_cm2': 0, 'deep_area_cm2': 0, 'mean_depth_mm': 0, 'max_depth_mm': 0, 'depth_std_mm': 0, 'deep_ratio': 0, 'wound_volume_cm3': 0, 'depth_percentiles': {'25': 0, '50': 0, '75': 0} } # Normalize depth relative to nearest point in wound area normalized_depth_map, nearest_point_coords, max_relative_depth = normalize_depth_relative_to_nearest_point(depth_map, wound_mask) # Calibrate the normalized depth map for more accurate measurements calibrated_depth_map = calibrate_depth_map(normalized_depth_map, reference_depth_mm=depth_calibration_mm) # Get calibrated depth values for wound region wound_depths_mm = calibrated_depth_map[wound_mask > 0] # Medical depth classification superficial_mask = wound_depths_mm < 2.0 partial_thickness_mask = (wound_depths_mm >= 2.0) & (wound_depths_mm < 4.0) full_thickness_mask = (wound_depths_mm >= 4.0) & (wound_depths_mm < 6.0) deep_mask = wound_depths_mm >= 6.0 # Calculate areas total_pixels = np.sum(wound_mask > 0) total_area_cm2 = total_pixels * pixel_area_cm2 superficial_area_cm2 = np.sum(superficial_mask) * pixel_area_cm2 partial_thickness_area_cm2 = np.sum(partial_thickness_mask) * pixel_area_cm2 full_thickness_area_cm2 = np.sum(full_thickness_mask) * pixel_area_cm2 deep_area_cm2 = np.sum(deep_mask) * pixel_area_cm2 # Calculate depth statistics mean_depth_mm = np.mean(wound_depths_mm) max_depth_mm = np.max(wound_depths_mm) depth_std_mm = np.std(wound_depths_mm) # Calculate depth percentiles depth_percentiles = { '25': np.percentile(wound_depths_mm, 25), '50': np.percentile(wound_depths_mm, 50), '75': np.percentile(wound_depths_mm, 75) } # Calculate depth distribution statistics depth_distribution = { 'shallow_ratio': np.sum(wound_depths_mm < 2.0) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0, 'moderate_ratio': np.sum((wound_depths_mm >= 2.0) & (wound_depths_mm < 5.0)) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0, 'deep_ratio': np.sum(wound_depths_mm >= 5.0) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0 } # Calculate wound volume (approximate) # Volume = area * average depth wound_volume_cm3 = total_area_cm2 * (mean_depth_mm / 10.0) # Deep tissue ratio deep_ratio = deep_area_cm2 / total_area_cm2 if total_area_cm2 > 0 else 0 # Calculate analysis quality metrics wound_pixel_count = len(wound_depths_mm) analysis_quality = "High" if wound_pixel_count > 1000 else "Medium" if wound_pixel_count > 500 else "Low" # Calculate depth consistency (lower std dev = more consistent) depth_consistency = "High" if depth_std_mm < 2.0 else "Medium" if depth_std_mm < 4.0 else "Low" return { 'total_area_cm2': total_area_cm2, 'superficial_area_cm2': superficial_area_cm2, 'partial_thickness_area_cm2': partial_thickness_area_cm2, 'full_thickness_area_cm2': full_thickness_area_cm2, 'deep_area_cm2': deep_area_cm2, 'mean_depth_mm': mean_depth_mm, 'max_depth_mm': max_depth_mm, 'depth_std_mm': depth_std_mm, 'deep_ratio': deep_ratio, 'wound_volume_cm3': wound_volume_cm3, 'depth_percentiles': depth_percentiles, 'depth_distribution': depth_distribution, 'analysis_quality': analysis_quality, 'depth_consistency': depth_consistency, 'wound_pixel_count': wound_pixel_count, 'nearest_point_coords': nearest_point_coords, 'max_relative_depth': max_relative_depth, 'normalized_depth_map': normalized_depth_map } def classify_wound_severity_by_enhanced_metrics(depth_stats): """ Enhanced wound severity classification based on medical standards Uses multiple criteria: depth, area, volume, and tissue involvement """ if depth_stats['total_area_cm2'] == 0: return "Unknown" # Extract key metrics total_area = depth_stats['total_area_cm2'] deep_area = depth_stats['deep_area_cm2'] full_thickness_area = depth_stats['full_thickness_area_cm2'] mean_depth = depth_stats['mean_depth_mm'] max_depth = depth_stats['max_depth_mm'] wound_volume = depth_stats['wound_volume_cm3'] deep_ratio = depth_stats['deep_ratio'] # Medical severity classification criteria severity_score = 0 # Criterion 1: Maximum depth if max_depth >= 10.0: severity_score += 3 # Very severe elif max_depth >= 6.0: severity_score += 2 # Severe elif max_depth >= 4.0: severity_score += 1 # Moderate # Criterion 2: Mean depth if mean_depth >= 5.0: severity_score += 2 elif mean_depth >= 3.0: severity_score += 1 # Criterion 3: Deep tissue involvement ratio if deep_ratio >= 0.5: severity_score += 3 # More than 50% deep tissue elif deep_ratio >= 0.25: severity_score += 2 # 25-50% deep tissue elif deep_ratio >= 0.1: severity_score += 1 # 10-25% deep tissue # Criterion 4: Total wound area if total_area >= 10.0: severity_score += 2 # Large wound (>10 cm²) elif total_area >= 5.0: severity_score += 1 # Medium wound (5-10 cm²) # Criterion 5: Wound volume if wound_volume >= 5.0: severity_score += 2 # High volume elif wound_volume >= 2.0: severity_score += 1 # Medium volume # Determine severity based on total score if severity_score >= 8: return "Very Severe" elif severity_score >= 6: return "Severe" elif severity_score >= 4: return "Moderate" elif severity_score >= 2: return "Mild" else: return "Superficial" def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0): """Enhanced wound severity analysis based on depth measurements""" if image is None or depth_map is None or wound_mask is None: return "❌ Please upload image, depth map, and wound mask." # Convert wound mask to grayscale if needed if len(wound_mask.shape) == 3: wound_mask = np.mean(wound_mask, axis=2) # Ensure depth map and mask have same dimensions if depth_map.shape[:2] != wound_mask.shape[:2]: # Resize mask to match depth map from PIL import Image mask_pil = Image.fromarray(wound_mask.astype(np.uint8)) mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0])) wound_mask = np.array(mask_pil) # Compute enhanced statistics with relative depth normalization stats = compute_enhanced_depth_statistics(depth_map, wound_mask, pixel_spacing_mm, depth_calibration_mm) # Get severity based on enhanced metrics severity_level = classify_wound_severity_by_enhanced_metrics(stats) severity_description = get_enhanced_severity_description(severity_level) # Get Gemini AI analysis based on depth data gemini_analysis = analyze_wound_depth_with_gemini(image, depth_map, stats) # Format Gemini analysis for display formatted_gemini_analysis = format_gemini_depth_analysis(gemini_analysis) # Create depth analysis visualization depth_visualization = create_depth_analysis_visualization( stats['normalized_depth_map'], wound_mask, stats['nearest_point_coords'], stats['max_relative_depth'] ) # Enhanced severity color coding severity_color = { "Superficial": "#4CAF50", # Green "Mild": "#8BC34A", # Light Green "Moderate": "#FF9800", # Orange "Severe": "#F44336", # Red "Very Severe": "#9C27B0" # Purple }.get(severity_level, "#9E9E9E") # Gray for unknown # Create comprehensive medical report report = f"""
    🩹 Enhanced Wound Severity Analysis
    📊 Depth & Quality Analysis
    � Basic Measurements
    �📏 Mean Relative Depth: {stats['mean_depth_mm']:.1f} mm
    📐 Max Relative Depth: {stats['max_depth_mm']:.1f} mm
    📊 Depth Std Dev: {stats['depth_std_mm']:.1f} mm
    📦 Wound Volume: {stats['wound_volume_cm3']:.2f} cm³
    🔥 Deep Tissue Ratio: {stats['deep_ratio']*100:.1f}%
    📈 Statistical Analysis
    25th Percentile: {stats['depth_percentiles']['25']:.1f} mm
    📊 Median (50th): {stats['depth_percentiles']['50']:.1f} mm
    📊 75th Percentile: {stats['depth_percentiles']['75']:.1f} mm
    📊 Shallow Areas: {stats['depth_distribution']['shallow_ratio']*100:.1f}%
    📊 Moderate Areas: {stats['depth_distribution']['moderate_ratio']*100:.1f}%
    🔍 Quality Metrics
    🔍 Analysis Quality: {stats['analysis_quality']}
    📏 Depth Consistency: {stats['depth_consistency']}
    📊 Data Points: {stats['wound_pixel_count']:,}
    📊 Deep Areas: {stats['depth_distribution']['deep_ratio']*100:.1f}%
    🎯 Reference Point: Nearest to camera
    📊 Medical Assessment Based on Depth Analysis
    {formatted_gemini_analysis}
    """ return report def normalize_depth_relative_to_nearest_point(depth_map, wound_mask): """ Normalize depth map relative to the nearest point in the wound area This assumes a top-down camera perspective where the closest point to camera = 0 depth Args: depth_map: Raw depth map wound_mask: Binary mask of wound region Returns: normalized_depth: Depth values relative to nearest point (0 = nearest, positive = deeper) nearest_point_coords: Coordinates of the nearest point max_relative_depth: Maximum relative depth in the wound """ if depth_map is None or wound_mask is None: return depth_map, None, 0 # Convert mask to binary binary_mask = (wound_mask > 127).astype(np.uint8) # Find wound region coordinates wound_coords = np.where(binary_mask > 0) if len(wound_coords[0]) == 0: return depth_map, None, 0 # Get depth values only for wound region wound_depths = depth_map[wound_coords] # Find the nearest point (minimum depth value in wound region) nearest_depth = np.min(wound_depths) nearest_indices = np.where(wound_depths == nearest_depth) # Get coordinates of the nearest point(s) nearest_point_coords = (wound_coords[0][nearest_indices[0][0]], wound_coords[1][nearest_indices[0][0]]) # Create normalized depth map (relative to nearest point) normalized_depth = depth_map.copy() normalized_depth = normalized_depth - nearest_depth # Ensure all values are non-negative (nearest point = 0, others = positive) normalized_depth = np.maximum(normalized_depth, 0) # Calculate maximum relative depth in wound region wound_normalized_depths = normalized_depth[wound_coords] max_relative_depth = np.max(wound_normalized_depths) return normalized_depth, nearest_point_coords, max_relative_depth def calibrate_depth_map(depth_map, reference_depth_mm=10.0): """ Calibrate depth map to real-world measurements using reference depth This helps convert normalized depth values to actual millimeters """ if depth_map is None: return depth_map # Find the maximum depth value in the depth map max_depth_value = np.max(depth_map) min_depth_value = np.min(depth_map) if max_depth_value == min_depth_value: return depth_map # Apply calibration to convert to millimeters # Assuming the maximum depth in the map corresponds to reference_depth_mm calibrated_depth = (depth_map - min_depth_value) / (max_depth_value - min_depth_value) * reference_depth_mm return calibrated_depth def create_depth_analysis_visualization(depth_map, wound_mask, nearest_point_coords, max_relative_depth): """ Create a visualization showing the depth analysis with nearest point and deepest point highlighted """ if depth_map is None or wound_mask is None: return None # Create a copy of the depth map for visualization vis_depth = depth_map.copy() # Apply colormap for better visualization normalized_depth = (vis_depth - np.min(vis_depth)) / (np.max(vis_depth) - np.min(vis_depth)) colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(normalized_depth)[:, :, :3] * 255).astype(np.uint8) # Convert to RGB if grayscale if len(colored_depth.shape) == 3 and colored_depth.shape[2] == 1: colored_depth = cv2.cvtColor(colored_depth, cv2.COLOR_GRAY2RGB) # Highlight the nearest point (reference point) with a red circle if nearest_point_coords is not None: y, x = nearest_point_coords cv2.circle(colored_depth, (x, y), 10, (255, 0, 0), 2) # Red circle for nearest point cv2.putText(colored_depth, "REF", (x+15, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) # Find and highlight the deepest point binary_mask = (wound_mask > 127).astype(np.uint8) wound_coords = np.where(binary_mask > 0) if len(wound_coords[0]) > 0: # Get depth values for wound region wound_depths = vis_depth[wound_coords] max_depth_idx = np.argmax(wound_depths) deepest_point_coords = (wound_coords[0][max_depth_idx], wound_coords[1][max_depth_idx]) # Highlight the deepest point with a blue circle y, x = deepest_point_coords cv2.circle(colored_depth, (x, y), 12, (0, 0, 255), 3) # Blue circle for deepest point cv2.putText(colored_depth, "DEEP", (x+15, y+5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) # Overlay wound mask outline contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(colored_depth, contours, -1, (0, 255, 0), 2) # Green outline for wound boundary return colored_depth def get_enhanced_severity_description(severity): """Get comprehensive medical description for severity level""" descriptions = { "Superficial": "Epidermis-only damage. Minimal tissue loss, typically heals within 1-2 weeks with basic wound care.", "Mild": "Superficial to partial thickness wound. Limited tissue involvement, good healing potential with proper care.", "Moderate": "Partial to full thickness involvement. Requires careful monitoring and may need advanced wound care techniques.", "Severe": "Full thickness with deep tissue involvement. High risk of complications, requires immediate medical attention.", "Very Severe": "Extensive deep tissue damage. Critical condition requiring immediate surgical intervention and specialized care.", "Unknown": "Unable to determine severity due to insufficient data or poor image quality." } return descriptions.get(severity, "Severity assessment unavailable.") def create_sample_wound_mask(image_shape, center=None, radius=50): """Create a sample circular wound mask for testing""" if center is None: center = (image_shape[1] // 2, image_shape[0] // 2) mask = np.zeros(image_shape[:2], dtype=np.uint8) y, x = np.ogrid[:image_shape[0], :image_shape[1]] # Create circular mask dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2) mask[dist_from_center <= radius] = 255 return mask def create_realistic_wound_mask(image_shape, method='elliptical'): """Create a more realistic wound mask with irregular shapes""" h, w = image_shape[:2] mask = np.zeros((h, w), dtype=np.uint8) if method == 'elliptical': # Create elliptical wound mask center = (w // 2, h // 2) radius_x = min(w, h) // 3 radius_y = min(w, h) // 4 y, x = np.ogrid[:h, :w] # Add some irregularity to make it more realistic ellipse = ((x - center[0])**2 / (radius_x**2) + (y - center[1])**2 / (radius_y**2)) <= 1 # Add some noise and irregularity noise = np.random.random((h, w)) > 0.8 mask = (ellipse | noise).astype(np.uint8) * 255 elif method == 'irregular': # Create irregular wound mask center = (w // 2, h // 2) radius = min(w, h) // 4 y, x = np.ogrid[:h, :w] base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius # Add irregular extensions extensions = np.zeros_like(base_circle) for i in range(3): angle = i * 2 * np.pi / 3 ext_x = int(center[0] + radius * 0.8 * np.cos(angle)) ext_y = int(center[1] + radius * 0.8 * np.sin(angle)) ext_radius = radius // 3 ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius extensions = extensions | ext_circle mask = (base_circle | extensions).astype(np.uint8) * 255 # Apply morphological operations to smooth the mask kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) return mask # --- Depth Estimation Functions --- def predict_depth(image): return depth_model.infer_image(image) def calculate_max_points(image): """Calculate maximum points based on image dimensions (3x pixel count)""" if image is None: return 10000 # Default value h, w = image.shape[:2] max_points = h * w * 3 # Ensure minimum and reasonable maximum values return max(1000, min(max_points, 300000)) def update_slider_on_image_upload(image): """Update the points slider when an image is uploaded""" max_points = calculate_max_points(image) default_value = min(10000, max_points // 10) # 10% of max points as default return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000, label=f"Number of 3D points (max: {max_points:,})") def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000): """Create a point cloud from depth map using camera intrinsics with high detail""" h, w = depth_map.shape # Use smaller step for higher detail (reduced downsampling) step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail # Create mesh grid for camera coordinates y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] # Convert to camera coordinates (normalized by focal length) x_cam = (x_coords - w / 2) / focal_length_x y_cam = (y_coords - h / 2) / focal_length_y # Get depth values depth_values = depth_map[::step, ::step] # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) x_3d = x_cam * depth_values y_3d = y_cam * depth_values z_3d = depth_values # Flatten arrays points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1) # Get corresponding image colors image_colors = image[::step, ::step, :] colors = image_colors.reshape(-1, 3) / 255.0 # Create Open3D point cloud pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) pcd.colors = o3d.utility.Vector3dVector(colors) return pcd def reconstruct_surface_mesh_from_point_cloud(pcd): """Convert point cloud to a mesh using Poisson reconstruction with very high detail.""" # Estimate and orient normals with high precision pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50)) pcd.orient_normals_consistent_tangent_plane(k=50) # Create surface mesh with maximum detail (depth=12 for very high resolution) mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12) # Return mesh without filtering low-density vertices return mesh def create_enhanced_3d_visualization(image, depth_map, max_points=10000): """Create an enhanced 3D visualization using proper camera projection""" h, w = depth_map.shape # Downsample to avoid too many points for performance step = max(1, int(np.sqrt(h * w / max_points))) # Create mesh grid for camera coordinates y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] # Convert to camera coordinates (normalized by focal length) focal_length = 470.4 # Default focal length x_cam = (x_coords - w / 2) / focal_length y_cam = (y_coords - h / 2) / focal_length # Get depth values depth_values = depth_map[::step, ::step] # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) x_3d = x_cam * depth_values y_3d = y_cam * depth_values z_3d = depth_values # Flatten arrays x_flat = x_3d.flatten() y_flat = y_3d.flatten() z_flat = z_3d.flatten() # Get corresponding image colors image_colors = image[::step, ::step, :] colors_flat = image_colors.reshape(-1, 3) # Create 3D scatter plot with proper camera projection fig = go.Figure(data=[go.Scatter3d( x=x_flat, y=y_flat, z=z_flat, mode='markers', marker=dict( size=1.5, color=colors_flat, opacity=0.9 ), hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
    ' + 'Depth: %{z:.2f}
    ' + '' )]) fig.update_layout( title="3D Point Cloud Visualization (Camera Projection)", scene=dict( xaxis_title="X (meters)", yaxis_title="Y (meters)", zaxis_title="Z (meters)", camera=dict( eye=dict(x=2.0, y=2.0, z=2.0), center=dict(x=0, y=0, z=0), up=dict(x=0, y=0, z=1) ), aspectmode='data' ), width=700, height=600 ) return fig def on_depth_submit(image, num_points, focal_x, focal_y): original_image = image.copy() h, w = image.shape[:2] # Predict depth using the model depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed # Save raw 16-bit depth raw_depth = Image.fromarray(depth.astype('uint16')) tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) raw_depth.save(tmp_raw_depth.name) # Normalize and convert to grayscale for display norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 norm_depth = norm_depth.astype(np.uint8) colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8) gray_depth = Image.fromarray(norm_depth) tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) gray_depth.save(tmp_gray_depth.name) # Create point cloud pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points) # Reconstruct mesh from point cloud mesh = reconstruct_surface_mesh_from_point_cloud(pcd) # Save mesh with faces as .ply tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False) o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh) # Create enhanced 3D scatter plot visualization depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points) return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d] # --- Actual Wound Segmentation Functions --- def create_automatic_wound_mask(image, method='deep_learning'): """ Automatically generate wound mask from image using the actual deep learning model Args: image: Input image (numpy array) method: Segmentation method (currently only 'deep_learning' supported) Returns: mask: Binary wound mask """ if image is None: return None # Use the actual deep learning model for segmentation if method == 'deep_learning': mask, _ = segmentation_model.segment_wound(image) return mask else: # Fallback to deep learning if method not recognized mask, _ = segmentation_model.segment_wound(image) return mask def post_process_wound_mask(mask, min_area=100): """Post-process the wound mask to remove noise and small objects""" if mask is None: return None # Convert to binary if needed if mask.dtype != np.uint8: mask = mask.astype(np.uint8) # Apply morphological operations to clean up kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) # Remove small objects using OpenCV contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) mask_clean = np.zeros_like(mask) for contour in contours: area = cv2.contourArea(contour) if area >= min_area: cv2.fillPoly(mask_clean, [contour], 255) # Fill holes mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel) return mask_clean def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='deep_learning'): """Analyze wound severity with automatic mask generation using actual segmentation model""" if image is None or depth_map is None: return "❌ Please provide both image and depth map." # Generate automatic wound mask using the actual model auto_mask = create_automatic_wound_mask(image, method=segmentation_method) if auto_mask is None: return "❌ Failed to generate automatic wound mask. Please check if the segmentation model is loaded." # Post-process the mask processed_mask = post_process_wound_mask(auto_mask, min_area=500) if processed_mask is None or np.sum(processed_mask > 0) == 0: return "❌ No wound region detected by the segmentation model. Try uploading a different image or use manual mask." # Analyze severity using the automatic mask return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm) # --- Main Gradio Interface --- with gr.Blocks(css=css, title="Wound Analysis System") as demo: gr.HTML("

    Wound Analysis System

    ") #gr.Markdown("### Complete workflow: Classification → Depth Estimation → Wound Severity Analysis") # Shared states shared_image = gr.State() shared_depth_map = gr.State() with gr.Tabs(): # Tab 1: Wound Classification with gr.Tab("1. 🔍 Wound Classification & Initial Analysis"): gr.Markdown("### Step 1: Classify wound type and get initial AI analysis") #gr.Markdown("Upload an image to identify the wound type and receive detailed analysis from our Vision AI.") with gr.Row(): # Left Column - Image Upload with gr.Column(scale=1): gr.HTML('

    Upload Wound Image

    ') classification_image_input = gr.Image( label="", type="pil", height=400 ) # Place Clear and Analyse buttons side by side with gr.Row(): classify_clear_btn = gr.Button( "Clear", variant="secondary", size="lg", scale=1 ) analyse_btn = gr.Button( "Analyse", variant="primary", size="lg", scale=1 ) # Right Column - Classification Results with gr.Column(scale=1): gr.HTML('

    Classification Results

    ') classification_output = gr.Label( label="", num_top_classes=5, show_label=False ) # Second Row - Full Width AI Analysis with gr.Row(): with gr.Column(scale=1): gr.HTML('

    Wound Visual Analysis

    ') gemini_output = gr.HTML( value="""
    Upload an image to get AI-powered wound analysis
    """ ) # Event handlers for classification tab classify_clear_btn.click( fn=lambda: (None, None, """
    Upload an image to get AI-powered wound analysis
    """), inputs=None, outputs=[classification_image_input, classification_output, gemini_output] ) # Only run classification on image upload def classify_and_store(image): result = classify_wound(image) return result classification_image_input.change( fn=classify_and_store, inputs=classification_image_input, outputs=classification_output ) # Store image in shared state for next tabs def store_shared_image(image): return image classification_image_input.change( fn=store_shared_image, inputs=classification_image_input, outputs=shared_image ) # Run Gemini analysis only when Analyse button is clicked def run_gemini_on_click(image, classification): # Get top label if isinstance(classification, dict) and classification: top_label = max(classification.items(), key=lambda x: x[1])[0] else: top_label = "Unknown" gemini_analysis = analyze_wound_with_gemini(image, top_label) formatted_analysis = format_gemini_analysis(gemini_analysis) return formatted_analysis analyse_btn.click( fn=run_gemini_on_click, inputs=[classification_image_input, classification_output], outputs=gemini_output ) # Tab 2: Depth Estimation with gr.Tab("2. 📏 Depth Estimation & 3D Visualization"): gr.Markdown("### Step 2: Generate depth maps and 3D visualizations") gr.Markdown("This module creates depth maps and 3D point clouds from your images.") with gr.Row(): load_from_classification_btn = gr.Button("🔄 Load Image from Classification Tab", variant="secondary") with gr.Row(): depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output') with gr.Row(): depth_submit = gr.Button(value="Compute Depth", variant="primary") points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000, label="Number of 3D points (upload image to update max)") with gr.Row(): focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, label="Focal Length X (pixels)") focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, label="Focal Length Y (pixels)") # Reorganized layout: 2 columns - 3D visualization on left, file outputs stacked on right with gr.Row(): with gr.Column(scale=2): # 3D Visualization gr.Markdown("### 3D Point Cloud Visualization") gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.") depth_3d_plot = gr.Plot(label="3D Point Cloud") with gr.Column(scale=1): gr.Markdown("### Download Files") gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download") point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download") # Tab 3: Wound Severity Analysis with gr.Tab("3. 🩹 Wound Severity Analysis"): gr.Markdown("### Step 3: Analyze wound severity using depth maps") gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.") with gr.Row(): # Load depth map from previous tab load_depth_btn = gr.Button("🔄 Load Depth Map from Tab 2", variant="secondary") with gr.Row(): severity_input_image = gr.Image(label="Original Image", type='numpy') severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy') with gr.Row(): wound_mask_input = gr.Image(label="Auto-Generated Wound Mask", type='numpy') with gr.Row(): severity_output = gr.HTML( label="🤖 AI-Powered Medical Assessment", value="""
    🩹 Wound Severity Analysis
    ⏳ Waiting for Input...
    Please upload an image and depth map, then click "🤖 Analyze Severity with Auto-Generated Mask" to begin AI-powered medical assessment.
    """ ) gr.Markdown("**Note:** The deep learning segmentation model will automatically generate a wound mask when you upload an image or load a depth map.") with gr.Row(): auto_severity_button = gr.Button("🤖 Analyze Severity with Auto-Generated Mask", variant="primary", size="lg") pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1, label="Pixel Spacing (mm/pixel)") depth_calibration_slider = gr.Slider(minimum=5.0, maximum=30.0, value=15.0, step=1.0, label="Depth Calibration (mm)", info="Adjust based on expected maximum wound depth") #gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.") #gr.Markdown("**Depth Calibration:** Adjust the maximum expected wound depth to improve measurement accuracy. For shallow wounds use 5-10mm, for deep wounds use 15-30mm.") #gr.Markdown("**Note:** When you load a depth map or upload an image, the segmentation model will automatically generate a wound mask.") # Update slider when image is uploaded depth_input_image.change( fn=update_slider_on_image_upload, inputs=[depth_input_image], outputs=[points_slider] ) # Modified depth submit function to store depth map def on_depth_submit_with_state(image, num_points, focal_x, focal_y): results = on_depth_submit(image, num_points, focal_x, focal_y) # Extract depth map from results for severity analysis depth_map = None if image is not None: depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed # Normalize depth for severity analysis norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 depth_map = norm_depth.astype(np.uint8) return results + [depth_map] depth_submit.click(on_depth_submit_with_state, inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y], outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, shared_depth_map]) # Function to load image from classification to depth tab def load_image_from_classification(shared_img): if shared_img is None: return None, "❌ No image available from classification tab. Please upload an image in Tab 1 first." # Convert PIL image to numpy array for depth estimation if hasattr(shared_img, 'convert'): # It's a PIL image, convert to numpy img_array = np.array(shared_img) return img_array, "✅ Image loaded from classification tab successfully!" else: # Already numpy array return shared_img, "✅ Image loaded from classification tab successfully!" # Connect the load button load_from_classification_btn.click( fn=load_image_from_classification, inputs=shared_image, outputs=[depth_input_image, gr.HTML()] ) # Load depth map to severity tab and auto-generate mask def load_depth_to_severity(depth_map, original_image): if depth_map is None: return None, None, None, "❌ No depth map available. Please compute depth in Tab 2 first." # Auto-generate wound mask using segmentation model if original_image is not None: auto_mask, _ = segmentation_model.segment_wound(original_image) if auto_mask is not None: # Post-process the mask processed_mask = post_process_wound_mask(auto_mask, min_area=500) if processed_mask is not None and np.sum(processed_mask > 0) > 0: return depth_map, original_image, processed_mask, "✅ Depth map loaded and wound mask auto-generated!" else: return depth_map, original_image, None, "✅ Depth map loaded but no wound detected. Try uploading a different image." else: return depth_map, original_image, None, "✅ Depth map loaded but segmentation failed. Try uploading a different image." else: return depth_map, original_image, None, "✅ Depth map loaded successfully!" load_depth_btn.click( fn=load_depth_to_severity, inputs=[shared_depth_map, depth_input_image], outputs=[severity_depth_map, severity_input_image, wound_mask_input, gr.HTML()] ) # Loading state function def show_loading_state(): return """
    🩹 Wound Severity Analysis
    🔄 AI Analysis in Progress...
    • Generating wound mask with deep learning model
    • Computing depth measurements and statistics
    • Analyzing wound characteristics with Gemini AI
    • Preparing comprehensive medical assessment
    """ # Automatic severity analysis function def run_auto_severity_analysis(image, depth_map, pixel_spacing, depth_calibration): if depth_map is None: return """
    ❌ Error
    Please load depth map from Tab 1 first.
    """ # Generate automatic wound mask using the actual model auto_mask = create_automatic_wound_mask(image, method='deep_learning') if auto_mask is None: return """
    ❌ Error
    Failed to generate automatic wound mask. Please check if the segmentation model is loaded.
    """ # Post-process the mask with fixed minimum area processed_mask = post_process_wound_mask(auto_mask, min_area=500) if processed_mask is None or np.sum(processed_mask > 0) == 0: return """
    ⚠️ No Wound Detected
    No wound region detected by the segmentation model. Try uploading a different image or use manual mask.
    """ # Analyze severity using the automatic mask return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing, depth_calibration) # Connect event handler with loading state auto_severity_button.click( fn=show_loading_state, inputs=[], outputs=[severity_output] ).then( fn=run_auto_severity_analysis, inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, depth_calibration_slider], outputs=[severity_output] ) # Auto-generate mask when image is uploaded def auto_generate_mask_on_image_upload(image): if image is None: return None, "❌ No image uploaded." # Generate automatic wound mask using segmentation model auto_mask, _ = segmentation_model.segment_wound(image) if auto_mask is not None: # Post-process the mask processed_mask = post_process_wound_mask(auto_mask, min_area=500) if processed_mask is not None and np.sum(processed_mask > 0) > 0: return processed_mask, "✅ Wound mask auto-generated using deep learning model!" else: return None, "✅ Image uploaded but no wound detected. Try uploading a different image." else: return None, "✅ Image uploaded but segmentation failed. Try uploading a different image." # Load shared image from classification tab def load_shared_image(shared_img): if shared_img is None: return gr.Image(), "❌ No image available from classification tab" # Convert PIL image to numpy array for depth estimation if hasattr(shared_img, 'convert'): # It's a PIL image, convert to numpy img_array = np.array(shared_img) return img_array, "✅ Image loaded from classification tab" else: # Already numpy array return shared_img, "✅ Image loaded from classification tab" # Auto-generate mask when image is uploaded to severity tab severity_input_image.change( fn=auto_generate_mask_on_image_upload, inputs=[severity_input_image], outputs=[wound_mask_input, gr.HTML()] ) if __name__ == '__main__': demo.queue().launch( server_name="0.0.0.0", server_port=7860, share=True )