Rakhi-2025's picture
Update app.py
70224c0 verified
import glob
import gradio as gr
import matplotlib
import numpy as np
from PIL import Image
import torch
import tempfile
from gradio_imageslider import ImageSlider
import plotly.graph_objects as go
import plotly.express as px
import open3d as o3d
from depth_anything_v2.dpt import DepthAnythingV2
import os
import tensorflow as tf
from tensorflow.keras.models import load_model
# Classification imports
from transformers import AutoImageProcessor, AutoModelForImageClassification
import google.generativeai as genai
import gdown
import spaces
import cv2
# Import actual segmentation model components
from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling
from utils.learning.metrics import dice_coef, precision, recall
from utils.io.data import normalize
# --- Classification Model Setup ---
# Load classification model and processor
classification_processor = AutoImageProcessor.from_pretrained("Hemg/Wound-classification")
classification_model = AutoModelForImageClassification.from_pretrained("Hemg/Wound-classification")
# Configure Gemini AI
try:
# Try to get API key from Hugging Face secrets
gemini_api_key = os.getenv("GOOGLE_API_KEY")
if not gemini_api_key:
raise ValueError("GEMINI_API_KEY not found in environment variables")
genai.configure(api_key=gemini_api_key)
gemini_model = genai.GenerativeModel("gemini-2.5-pro")
print("βœ… Gemini AI configured successfully with API key from secrets")
except Exception as e:
print(f"❌ Error configuring Gemini AI: {e}")
print("Please make sure GEMINI_API_KEY is set in your Hugging Face Space secrets")
gemini_model = None
# --- Classification Functions ---
def analyze_wound_with_gemini(image, predicted_label):
"""
Analyze wound image using Gemini AI with classification context
Args:
image: PIL Image
predicted_label: The predicted wound type from classification model
Returns:
str: Gemini AI analysis
"""
if image is None:
return "No image provided for analysis."
if gemini_model is None:
return "Gemini AI is not available. Please check that GEMINI_API_KEY is properly configured in your Hugging Face Space secrets."
try:
# Ensure image is in RGB format
if image.mode != 'RGB':
image = image.convert('RGB')
# Create prompt that includes the classification result
prompt = f"""You are assisting in a medical education and research task.
Based on the wound classification model, this image has been identified as: {predicted_label}
Please provide an educational analysis of this wound image focusing on:
1. Visible characteristics of the wound (size, color, texture, edges, surrounding tissue)
2. Educational explanation about this type of wound based on the classification: {predicted_label}
3. General wound healing stages if applicable
4. Key features that are typically associated with this wound type
Important guidelines:
- This is for educational and research purposes only
- Do not provide medical advice or diagnosis
- Keep the analysis objective and educational
- Focus on visible features and general wound characteristics
- Do not recommend treatments or medical interventions
Please provide a comprehensive educational analysis."""
response = gemini_model.generate_content([prompt, image])
return response.text
except Exception as e:
return f"Error analyzing image with Gemini: {str(e)}"
def analyze_wound_depth_with_gemini(image, depth_map, depth_stats):
"""
Analyze wound depth and severity using Gemini AI with depth analysis context
Args:
image: Original wound image (PIL Image or numpy array)
depth_map: Depth map (numpy array)
depth_stats: Dictionary containing depth analysis statistics
Returns:
str: Gemini AI medical assessment based on depth analysis
"""
if image is None or depth_map is None:
return "No image or depth map provided for analysis."
if gemini_model is None:
return "Gemini AI is not available. Please check that GEMINI_API_KEY is properly configured in your Hugging Face Space secrets."
try:
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Ensure image is in RGB format
if image.mode != 'RGB':
image = image.convert('RGB')
# Convert depth map to PIL Image for Gemini
if isinstance(depth_map, np.ndarray):
# Normalize depth map for visualization
norm_depth = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0
depth_image = Image.fromarray(norm_depth.astype(np.uint8))
else:
depth_image = depth_map
# Create detailed prompt with depth statistics
prompt = f"""You are a medical AI assistant specializing in wound assessment. Analyze this wound using both the original image and depth map data.
DEPTH ANALYSIS DATA PROVIDED:
- Total Wound Area: {depth_stats['total_area_cm2']:.2f} cmΒ²
- Mean Depth: {depth_stats['mean_depth_mm']:.1f} mm
- Maximum Depth: {depth_stats['max_depth_mm']:.1f} mm
- Depth Standard Deviation: {depth_stats['depth_std_mm']:.1f} mm
- Wound Volume: {depth_stats['wound_volume_cm3']:.2f} cmΒ³
- Deep Tissue Involvement: {depth_stats['deep_ratio']*100:.1f}%
- Analysis Quality: {depth_stats['analysis_quality']}
- Depth Consistency: {depth_stats['depth_consistency']}
TISSUE DEPTH DISTRIBUTION:
- Superficial Areas (0-2mm): {depth_stats['superficial_area_cm2']:.2f} cmΒ²
- Partial Thickness (2-4mm): {depth_stats['partial_thickness_area_cm2']:.2f} cmΒ²
- Full Thickness (4-6mm): {depth_stats['full_thickness_area_cm2']:.2f} cmΒ²
- Deep Areas (>6mm): {depth_stats['deep_area_cm2']:.2f} cmΒ²
STATISTICAL DEPTH ANALYSIS:
- 25th Percentile Depth: {depth_stats['depth_percentiles']['25']:.1f} mm
- Median Depth: {depth_stats['depth_percentiles']['50']:.1f} mm
- 75th Percentile Depth: {depth_stats['depth_percentiles']['75']:.1f} mm
Please provide a comprehensive medical assessment focusing on:
1. **WOUND CHARACTERISTICS ANALYSIS**
- Visible wound features from the original image
- Correlation between visual appearance and depth measurements
- Tissue quality assessment based on color, texture, and depth data
2. **DEPTH-BASED SEVERITY ASSESSMENT**
- Clinical significance of the measured depths
- Tissue layer involvement based on depth measurements
- Risk assessment based on deep tissue involvement percentage
3. **HEALING PROGNOSIS**
- Expected healing timeline based on depth and area measurements
- Factors that may affect healing based on depth distribution
- Complexity assessment based on wound volume and depth variation
4. **CLINICAL CONSIDERATIONS**
- Significance of depth consistency/inconsistency
- Areas of particular concern based on depth analysis
- Educational insights about this type of wound presentation
5. **MEASUREMENT INTERPRETATION**
- Clinical relevance of the statistical depth measurements
- What the depth distribution tells us about wound progression
- Comparison to typical wound depth classifications
IMPORTANT GUIDELINES:
- This is for educational and research purposes only
- Do not provide specific medical advice or treatment recommendations
- Focus on objective analysis of the provided measurements
- Correlate visual findings with quantitative depth data
- Maintain educational and clinical terminology
- Emphasize the relationship between depth measurements and clinical significance
Provide a detailed, structured medical assessment that integrates both visual and quantitative depth analysis."""
# Send both images to Gemini for analysis
response = gemini_model.generate_content([prompt, image, depth_image])
return response.text
except Exception as e:
return f"Error analyzing wound with Gemini AI: {str(e)}"
def classify_wound(image):
"""
Classify wound type from uploaded image
Args:
image: PIL Image or numpy array
Returns:
dict: Classification results with confidence scores
"""
if image is None:
return "Please upload an image"
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Ensure image is in RGB format
if image.mode != 'RGB':
image = image.convert('RGB')
try:
# Process the image
inputs = classification_processor(images=image, return_tensors="pt")
# Get model predictions
with torch.no_grad():
outputs = classification_model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits[0], dim=-1)
# Get the predicted class labels and confidence scores
confidence_scores = predictions.numpy()
# Create results dictionary
results = {}
for i, score in enumerate(confidence_scores):
# Get class name from model config
class_name = classification_model.config.id2label[i] if hasattr(classification_model.config, 'id2label') else f"Class {i}"
results[class_name] = float(score)
return results
except Exception as e:
return f"Error processing image: {str(e)}"
def classify_and_analyze_wound(image):
"""
Combined function to classify wound and get Gemini analysis
Args:
image: PIL Image or numpy array
Returns:
tuple: (classification_results, gemini_analysis)
"""
if image is None:
return "Please upload an image", "Please upload an image for analysis"
# Get classification results
classification_results = classify_wound(image)
# Get the top predicted label for Gemini analysis
if isinstance(classification_results, dict) and classification_results:
# Get the label with highest confidence
top_label = max(classification_results.items(), key=lambda x: x[1])[0]
# Get Gemini analysis
gemini_analysis = analyze_wound_with_gemini(image, top_label)
else:
top_label = "Unknown"
gemini_analysis = "Unable to analyze due to classification error"
return classification_results, gemini_analysis
def format_gemini_analysis(analysis):
"""Format Gemini analysis as properly structured HTML"""
if not analysis or "Error" in analysis:
return f"""
<div style="
background-color: #fee2e2;
border-radius: 12px;
padding: 16px;
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
font-family: Arial, sans-serif;
min-height: 300px;
border-left: 4px solid #ef4444;
">
<h4 style="color: #dc2626; margin-top: 0;">Analysis Error</h4>
<p style="color: #991b1b;">{analysis}</p>
</div>
"""
# Parse the markdown-style response and convert to HTML
formatted_analysis = parse_markdown_to_html(analysis)
return f"""
<div style="
border-radius: 12px;
padding: 25px;
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
min-height: 300px;
border-left: 4px solid #d97706;
max-height: 600px;
overflow-y: auto;
">
<h3 style="color: #d97706; margin-top: 0; margin-bottom: 20px; display: flex; align-items: center; gap: 8px;">
Initial Wound Analysis
</h3>
<div style="color: white; line-height: 1.7;">
{formatted_analysis}
</div>
</div>
"""
def format_gemini_depth_analysis(analysis):
"""Format Gemini depth analysis as properly structured HTML for medical assessment"""
if not analysis or "Error" in analysis:
return f"""
<div style="color: #ffffff; line-height: 1.6;">
<div style="font-size: 16px; font-weight: bold; margin-bottom: 10px; color: #f44336;">
❌ AI Analysis Error
</div>
<div style="color: #cccccc;">
{analysis}
</div>
</div>
"""
# Parse the markdown-style response and convert to HTML
formatted_analysis = parse_markdown_to_html(analysis)
return f"""
<div style="color: #ffffff; line-height: 1.6;">
<div style="font-size: 16px; font-weight: bold; margin-bottom: 15px; color: #4CAF50;">
πŸ€– AI-Powered Medical Assessment
</div>
<div style="color: #cccccc; max-height: 400px; overflow-y: auto; padding-right: 10px;">
{formatted_analysis}
</div>
</div>
"""
def parse_markdown_to_html(text):
"""Convert markdown-style text to HTML"""
import re
# Replace markdown headers
text = re.sub(r'^### \*\*(.*?)\*\*$', r'<h4 style="color: #d97706; margin: 20px 0 10px 0; font-weight: bold;">\1</h4>', text, flags=re.MULTILINE)
text = re.sub(r'^#### \*\*(.*?)\*\*$', r'<h5 style="color: #f59e0b; margin: 15px 0 8px 0; font-weight: bold;">\1</h5>', text, flags=re.MULTILINE)
text = re.sub(r'^### (.*?)$', r'<h4 style="color: #d97706; margin: 20px 0 10px 0; font-weight: bold;">\1</h4>', text, flags=re.MULTILINE)
text = re.sub(r'^#### (.*?)$', r'<h5 style="color: #f59e0b; margin: 15px 0 8px 0; font-weight: bold;">\1</h5>', text, flags=re.MULTILINE)
# Replace bold text
text = re.sub(r'\*\*(.*?)\*\*', r'<strong style="color: #fbbf24;">\1</strong>', text)
# Replace italic text
text = re.sub(r'\*(.*?)\*', r'<em style="color: #fde68a;">\1</em>', text)
# Replace bullet points
text = re.sub(r'^\* (.*?)$', r'<li style="margin: 5px 0; color: white;">\1</li>', text, flags=re.MULTILINE)
text = re.sub(r'^ \* (.*?)$', r'<li style="margin: 3px 0; margin-left: 20px; color: white;">\1</li>', text, flags=re.MULTILINE)
# Wrap consecutive list items in ul tags
text = re.sub(r'(<li.*?</li>(?:\s*<li.*?</li>)*)', r'<ul style="margin: 10px 0; padding-left: 20px;">\1</ul>', text, flags=re.DOTALL)
# Replace numbered lists
text = re.sub(r'^(\d+)\.\s+(.*?)$', r'<div style="margin: 8px 0; color: white;"><strong style="color: #d97706;">\1.</strong> \2</div>', text, flags=re.MULTILINE)
# Convert paragraphs (double newlines)
paragraphs = text.split('\n\n')
formatted_paragraphs = []
for para in paragraphs:
para = para.strip()
if para:
# Skip if it's already wrapped in HTML tags
if not (para.startswith('<') or para.endswith('>')):
para = f'<p style="margin: 12px 0; color: white; text-align: justify;">{para}</p>'
formatted_paragraphs.append(para)
return '\n'.join(formatted_paragraphs)
def combined_analysis(image):
"""Combined function for UI that returns both outputs"""
classification, gemini_analysis = classify_and_analyze_wound(image)
formatted_analysis = format_gemini_analysis(gemini_analysis)
return classification, formatted_analysis
# Define path and file ID
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth")
gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5"
# Download if not already present
if not os.path.exists(model_file):
print("Downloading model from Google Drive...")
gdown.download(gdrive_url, model_file, quiet=False)
# --- TensorFlow: Check GPU Availability ---
gpus = tf.config.list_physical_devices('GPU')
if gpus:
print("TensorFlow is using GPU")
else:
print("TensorFlow is using CPU")
# --- Load Actual Wound Segmentation Model ---
class WoundSegmentationModel:
def __init__(self):
self.input_dim_x = 224
self.input_dim_y = 224
self.model = None
self.load_model()
def load_model(self):
"""Load the trained wound segmentation model"""
try:
# Try to load the most recent model
weight_file_name = '2025-08-07_16-25-27.hdf5'
model_path = f'./training_history/{weight_file_name}'
self.model = load_model(model_path,
custom_objects={
'recall': recall,
'precision': precision,
'dice_coef': dice_coef,
'relu6': relu6,
'DepthwiseConv2D': DepthwiseConv2D,
'BilinearUpsampling': BilinearUpsampling
})
print(f"Segmentation model loaded successfully from {model_path}")
except Exception as e:
print(f"Error loading segmentation model: {e}")
# Fallback to the older model
try:
weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5'
model_path = f'./training_history/{weight_file_name}'
self.model = load_model(model_path,
custom_objects={
'recall': recall,
'precision': precision,
'dice_coef': dice_coef,
'relu6': relu6,
'DepthwiseConv2D': DepthwiseConv2D,
'BilinearUpsampling': BilinearUpsampling
})
print(f"Segmentation model loaded successfully from {model_path}")
except Exception as e2:
print(f"Error loading fallback segmentation model: {e2}")
self.model = None
def preprocess_image(self, image):
"""Preprocess the uploaded image for model input"""
if image is None:
return None
# Convert to RGB if needed
if len(image.shape) == 3 and image.shape[2] == 3:
# Convert BGR to RGB if needed
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Resize to model input size
image = cv2.resize(image, (self.input_dim_x, self.input_dim_y))
# Normalize the image
image = image.astype(np.float32) / 255.0
# Add batch dimension
image = np.expand_dims(image, axis=0)
return image
def postprocess_prediction(self, prediction):
"""Postprocess the model prediction"""
# Remove batch dimension
prediction = prediction[0]
# Apply threshold to get binary mask
threshold = 0.5
binary_mask = (prediction > threshold).astype(np.uint8) * 255
return binary_mask
def segment_wound(self, input_image):
"""Main function to segment wound from uploaded image"""
if self.model is None:
return None, "Error: Segmentation model not loaded. Please check the model files."
if input_image is None:
return None, "Please upload an image."
try:
# Preprocess the image
processed_image = self.preprocess_image(input_image)
if processed_image is None:
return None, "Error processing image."
# Make prediction
prediction = self.model.predict(processed_image, verbose=0)
# Postprocess the prediction
segmented_mask = self.postprocess_prediction(prediction)
return segmented_mask, "Segmentation completed successfully!"
except Exception as e:
return None, f"Error during segmentation: {str(e)}"
# Initialize the segmentation model
segmentation_model = WoundSegmentationModel()
# --- PyTorch: Set Device and Load Depth Model ---
map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
print(f"Using PyTorch device: {map_device}")
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
encoder = 'vitl'
depth_model = DepthAnythingV2(**model_configs[encoder])
state_dict = torch.load(
f'checkpoints/depth_anything_v2_{encoder}.pth',
map_location=map_device
)
depth_model.load_state_dict(state_dict)
depth_model = depth_model.to(map_device).eval()
# --- Custom CSS for unified dark theme ---
css = """
.gradio-container {
font-family: 'Segoe UI', sans-serif;
background-color: #121212;
color: #ffffff;
padding: 20px;
}
.gr-button {
background-color: #2c3e50;
color: white;
border-radius: 10px;
}
.gr-button:hover {
background-color: #34495e;
}
.gr-html, .gr-html div {
white-space: normal !important;
overflow: visible !important;
text-overflow: unset !important;
word-break: break-word !important;
}
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
#download {
height: 62px;
}
h1 {
text-align: center;
font-size: 3rem;
font-weight: bold;
margin: 2rem 0;
color: #ffffff;
}
h2 {
color: #ffffff;
text-align: center;
margin: 1rem 0;
}
.gr-tabs {
background-color: #1e1e1e;
border-radius: 10px;
padding: 10px;
}
.gr-tab-nav {
background-color: #2c3e50;
border-radius: 8px;
}
.gr-tab-nav button {
color: #ffffff !important;
}
.gr-tab-nav button.selected {
background-color: #34495e !important;
}
/* Card styling for consistent heights */
.wound-card {
min-height: 200px !important;
display: flex !important;
flex-direction: column !important;
justify-content: space-between !important;
}
.wound-card-content {
flex-grow: 1 !important;
display: flex !important;
flex-direction: column !important;
justify-content: center !important;
}
/* Loading animation */
.loading-spinner {
display: inline-block;
width: 20px;
height: 20px;
border: 3px solid #f3f3f3;
border-top: 3px solid #3498db;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
"""
# --- Enhanced Wound Severity Estimation Functions ---
def compute_enhanced_depth_statistics(depth_map, mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0):
"""
Enhanced depth analysis with proper calibration and medical standards
Based on wound depth classification standards:
- Superficial: 0-2mm (epidermis only)
- Partial thickness: 2-4mm (epidermis + partial dermis)
- Full thickness: 4-6mm (epidermis + full dermis)
- Deep: >6mm (involving subcutaneous tissue)
"""
# Convert pixel spacing to mm
pixel_spacing_mm = float(pixel_spacing_mm)
# Calculate pixel area in cmΒ²
pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2
# Extract wound region (binary mask)
wound_mask = (mask > 127).astype(np.uint8)
# Apply morphological operations to clean the mask
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
wound_mask = cv2.morphologyEx(wound_mask, cv2.MORPH_CLOSE, kernel)
# Get depth values only for wound region
wound_depths = depth_map[wound_mask > 0]
if len(wound_depths) == 0:
return {
'total_area_cm2': 0,
'superficial_area_cm2': 0,
'partial_thickness_area_cm2': 0,
'full_thickness_area_cm2': 0,
'deep_area_cm2': 0,
'mean_depth_mm': 0,
'max_depth_mm': 0,
'depth_std_mm': 0,
'deep_ratio': 0,
'wound_volume_cm3': 0,
'depth_percentiles': {'25': 0, '50': 0, '75': 0}
}
# Normalize depth relative to nearest point in wound area
normalized_depth_map, nearest_point_coords, max_relative_depth = normalize_depth_relative_to_nearest_point(depth_map, wound_mask)
# Calibrate the normalized depth map for more accurate measurements
calibrated_depth_map = calibrate_depth_map(normalized_depth_map, reference_depth_mm=depth_calibration_mm)
# Get calibrated depth values for wound region
wound_depths_mm = calibrated_depth_map[wound_mask > 0]
# Medical depth classification
superficial_mask = wound_depths_mm < 2.0
partial_thickness_mask = (wound_depths_mm >= 2.0) & (wound_depths_mm < 4.0)
full_thickness_mask = (wound_depths_mm >= 4.0) & (wound_depths_mm < 6.0)
deep_mask = wound_depths_mm >= 6.0
# Calculate areas
total_pixels = np.sum(wound_mask > 0)
total_area_cm2 = total_pixels * pixel_area_cm2
superficial_area_cm2 = np.sum(superficial_mask) * pixel_area_cm2
partial_thickness_area_cm2 = np.sum(partial_thickness_mask) * pixel_area_cm2
full_thickness_area_cm2 = np.sum(full_thickness_mask) * pixel_area_cm2
deep_area_cm2 = np.sum(deep_mask) * pixel_area_cm2
# Calculate depth statistics
mean_depth_mm = np.mean(wound_depths_mm)
max_depth_mm = np.max(wound_depths_mm)
depth_std_mm = np.std(wound_depths_mm)
# Calculate depth percentiles
depth_percentiles = {
'25': np.percentile(wound_depths_mm, 25),
'50': np.percentile(wound_depths_mm, 50),
'75': np.percentile(wound_depths_mm, 75)
}
# Calculate depth distribution statistics
depth_distribution = {
'shallow_ratio': np.sum(wound_depths_mm < 2.0) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0,
'moderate_ratio': np.sum((wound_depths_mm >= 2.0) & (wound_depths_mm < 5.0)) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0,
'deep_ratio': np.sum(wound_depths_mm >= 5.0) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0
}
# Calculate wound volume (approximate)
# Volume = area * average depth
wound_volume_cm3 = total_area_cm2 * (mean_depth_mm / 10.0)
# Deep tissue ratio
deep_ratio = deep_area_cm2 / total_area_cm2 if total_area_cm2 > 0 else 0
# Calculate analysis quality metrics
wound_pixel_count = len(wound_depths_mm)
analysis_quality = "High" if wound_pixel_count > 1000 else "Medium" if wound_pixel_count > 500 else "Low"
# Calculate depth consistency (lower std dev = more consistent)
depth_consistency = "High" if depth_std_mm < 2.0 else "Medium" if depth_std_mm < 4.0 else "Low"
return {
'total_area_cm2': total_area_cm2,
'superficial_area_cm2': superficial_area_cm2,
'partial_thickness_area_cm2': partial_thickness_area_cm2,
'full_thickness_area_cm2': full_thickness_area_cm2,
'deep_area_cm2': deep_area_cm2,
'mean_depth_mm': mean_depth_mm,
'max_depth_mm': max_depth_mm,
'depth_std_mm': depth_std_mm,
'deep_ratio': deep_ratio,
'wound_volume_cm3': wound_volume_cm3,
'depth_percentiles': depth_percentiles,
'depth_distribution': depth_distribution,
'analysis_quality': analysis_quality,
'depth_consistency': depth_consistency,
'wound_pixel_count': wound_pixel_count,
'nearest_point_coords': nearest_point_coords,
'max_relative_depth': max_relative_depth,
'normalized_depth_map': normalized_depth_map
}
def classify_wound_severity_by_enhanced_metrics(depth_stats):
"""
Enhanced wound severity classification based on medical standards
Uses multiple criteria: depth, area, volume, and tissue involvement
"""
if depth_stats['total_area_cm2'] == 0:
return "Unknown"
# Extract key metrics
total_area = depth_stats['total_area_cm2']
deep_area = depth_stats['deep_area_cm2']
full_thickness_area = depth_stats['full_thickness_area_cm2']
mean_depth = depth_stats['mean_depth_mm']
max_depth = depth_stats['max_depth_mm']
wound_volume = depth_stats['wound_volume_cm3']
deep_ratio = depth_stats['deep_ratio']
# Medical severity classification criteria
severity_score = 0
# Criterion 1: Maximum depth
if max_depth >= 10.0:
severity_score += 3 # Very severe
elif max_depth >= 6.0:
severity_score += 2 # Severe
elif max_depth >= 4.0:
severity_score += 1 # Moderate
# Criterion 2: Mean depth
if mean_depth >= 5.0:
severity_score += 2
elif mean_depth >= 3.0:
severity_score += 1
# Criterion 3: Deep tissue involvement ratio
if deep_ratio >= 0.5:
severity_score += 3 # More than 50% deep tissue
elif deep_ratio >= 0.25:
severity_score += 2 # 25-50% deep tissue
elif deep_ratio >= 0.1:
severity_score += 1 # 10-25% deep tissue
# Criterion 4: Total wound area
if total_area >= 10.0:
severity_score += 2 # Large wound (>10 cmΒ²)
elif total_area >= 5.0:
severity_score += 1 # Medium wound (5-10 cmΒ²)
# Criterion 5: Wound volume
if wound_volume >= 5.0:
severity_score += 2 # High volume
elif wound_volume >= 2.0:
severity_score += 1 # Medium volume
# Determine severity based on total score
if severity_score >= 8:
return "Very Severe"
elif severity_score >= 6:
return "Severe"
elif severity_score >= 4:
return "Moderate"
elif severity_score >= 2:
return "Mild"
else:
return "Superficial"
def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0):
"""Enhanced wound severity analysis based on depth measurements"""
if image is None or depth_map is None or wound_mask is None:
return "❌ Please upload image, depth map, and wound mask."
# Convert wound mask to grayscale if needed
if len(wound_mask.shape) == 3:
wound_mask = np.mean(wound_mask, axis=2)
# Ensure depth map and mask have same dimensions
if depth_map.shape[:2] != wound_mask.shape[:2]:
# Resize mask to match depth map
from PIL import Image
mask_pil = Image.fromarray(wound_mask.astype(np.uint8))
mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0]))
wound_mask = np.array(mask_pil)
# Compute enhanced statistics with relative depth normalization
stats = compute_enhanced_depth_statistics(depth_map, wound_mask, pixel_spacing_mm, depth_calibration_mm)
# Get severity based on enhanced metrics
severity_level = classify_wound_severity_by_enhanced_metrics(stats)
severity_description = get_enhanced_severity_description(severity_level)
# Get Gemini AI analysis based on depth data
gemini_analysis = analyze_wound_depth_with_gemini(image, depth_map, stats)
# Format Gemini analysis for display
formatted_gemini_analysis = format_gemini_depth_analysis(gemini_analysis)
# Create depth analysis visualization
depth_visualization = create_depth_analysis_visualization(
stats['normalized_depth_map'], wound_mask,
stats['nearest_point_coords'], stats['max_relative_depth']
)
# Enhanced severity color coding
severity_color = {
"Superficial": "#4CAF50", # Green
"Mild": "#8BC34A", # Light Green
"Moderate": "#FF9800", # Orange
"Severe": "#F44336", # Red
"Very Severe": "#9C27B0" # Purple
}.get(severity_level, "#9E9E9E") # Gray for unknown
# Create comprehensive medical report
report = f"""
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
<div style='font-size: 24px; font-weight: bold; color: {severity_color}; margin-bottom: 15px;'>
🩹 Enhanced Wound Severity Analysis
</div>
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px; margin-bottom: 20px;'>
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 15px; text-align: center;'>
πŸ“Š Depth & Quality Analysis
</div>
<div style='color: #cccccc; line-height: 1.6; display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 20px;'>
<div>
<div style='font-size: 16px; font-weight: bold; color: #ff9800; margin-bottom: 8px;'>οΏ½ Basic Measurements</div>
<div>οΏ½πŸ“ <b>Mean Relative Depth:</b> {stats['mean_depth_mm']:.1f} mm</div>
<div>πŸ“ <b>Max Relative Depth:</b> {stats['max_depth_mm']:.1f} mm</div>
<div>πŸ“Š <b>Depth Std Dev:</b> {stats['depth_std_mm']:.1f} mm</div>
<div>πŸ“¦ <b>Wound Volume:</b> {stats['wound_volume_cm3']:.2f} cmΒ³</div>
<div>πŸ”₯ <b>Deep Tissue Ratio:</b> {stats['deep_ratio']*100:.1f}%</div>
</div>
<div>
<div style='font-size: 16px; font-weight: bold; color: #4CAF50; margin-bottom: 8px;'>πŸ“ˆ Statistical Analysis</div>
<div>οΏ½ <b>25th Percentile:</b> {stats['depth_percentiles']['25']:.1f} mm</div>
<div>πŸ“Š <b>Median (50th):</b> {stats['depth_percentiles']['50']:.1f} mm</div>
<div>πŸ“Š <b>75th Percentile:</b> {stats['depth_percentiles']['75']:.1f} mm</div>
<div>πŸ“Š <b>Shallow Areas:</b> {stats['depth_distribution']['shallow_ratio']*100:.1f}%</div>
<div>πŸ“Š <b>Moderate Areas:</b> {stats['depth_distribution']['moderate_ratio']*100:.1f}%</div>
</div>
<div>
<div style='font-size: 16px; font-weight: bold; color: #2196F3; margin-bottom: 8px;'>πŸ” Quality Metrics</div>
<div>πŸ” <b>Analysis Quality:</b> {stats['analysis_quality']}</div>
<div>πŸ“ <b>Depth Consistency:</b> {stats['depth_consistency']}</div>
<div>πŸ“Š <b>Data Points:</b> {stats['wound_pixel_count']:,}</div>
<div>πŸ“Š <b>Deep Areas:</b> {stats['depth_distribution']['deep_ratio']*100:.1f}%</div>
<div>🎯 <b>Reference Point:</b> Nearest to camera</div>
</div>
</div>
</div>
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px; margin-bottom: 20px; border-left: 4px solid {severity_color};'>
<div style='font-size: 18px; font-weight: bold; color: {severity_color}; margin-bottom: 10px;'>
πŸ“Š Medical Assessment Based on Depth Analysis
</div>
{formatted_gemini_analysis}
</div>
</div>
"""
return report
def normalize_depth_relative_to_nearest_point(depth_map, wound_mask):
"""
Normalize depth map relative to the nearest point in the wound area
This assumes a top-down camera perspective where the closest point to camera = 0 depth
Args:
depth_map: Raw depth map
wound_mask: Binary mask of wound region
Returns:
normalized_depth: Depth values relative to nearest point (0 = nearest, positive = deeper)
nearest_point_coords: Coordinates of the nearest point
max_relative_depth: Maximum relative depth in the wound
"""
if depth_map is None or wound_mask is None:
return depth_map, None, 0
# Convert mask to binary
binary_mask = (wound_mask > 127).astype(np.uint8)
# Find wound region coordinates
wound_coords = np.where(binary_mask > 0)
if len(wound_coords[0]) == 0:
return depth_map, None, 0
# Get depth values only for wound region
wound_depths = depth_map[wound_coords]
# Find the nearest point (minimum depth value in wound region)
nearest_depth = np.min(wound_depths)
nearest_indices = np.where(wound_depths == nearest_depth)
# Get coordinates of the nearest point(s)
nearest_point_coords = (wound_coords[0][nearest_indices[0][0]],
wound_coords[1][nearest_indices[0][0]])
# Create normalized depth map (relative to nearest point)
normalized_depth = depth_map.copy()
normalized_depth = normalized_depth - nearest_depth
# Ensure all values are non-negative (nearest point = 0, others = positive)
normalized_depth = np.maximum(normalized_depth, 0)
# Calculate maximum relative depth in wound region
wound_normalized_depths = normalized_depth[wound_coords]
max_relative_depth = np.max(wound_normalized_depths)
return normalized_depth, nearest_point_coords, max_relative_depth
def calibrate_depth_map(depth_map, reference_depth_mm=10.0):
"""
Calibrate depth map to real-world measurements using reference depth
This helps convert normalized depth values to actual millimeters
"""
if depth_map is None:
return depth_map
# Find the maximum depth value in the depth map
max_depth_value = np.max(depth_map)
min_depth_value = np.min(depth_map)
if max_depth_value == min_depth_value:
return depth_map
# Apply calibration to convert to millimeters
# Assuming the maximum depth in the map corresponds to reference_depth_mm
calibrated_depth = (depth_map - min_depth_value) / (max_depth_value - min_depth_value) * reference_depth_mm
return calibrated_depth
def create_depth_analysis_visualization(depth_map, wound_mask, nearest_point_coords, max_relative_depth):
"""
Create a visualization showing the depth analysis with nearest point and deepest point highlighted
"""
if depth_map is None or wound_mask is None:
return None
# Create a copy of the depth map for visualization
vis_depth = depth_map.copy()
# Apply colormap for better visualization
normalized_depth = (vis_depth - np.min(vis_depth)) / (np.max(vis_depth) - np.min(vis_depth))
colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(normalized_depth)[:, :, :3] * 255).astype(np.uint8)
# Convert to RGB if grayscale
if len(colored_depth.shape) == 3 and colored_depth.shape[2] == 1:
colored_depth = cv2.cvtColor(colored_depth, cv2.COLOR_GRAY2RGB)
# Highlight the nearest point (reference point) with a red circle
if nearest_point_coords is not None:
y, x = nearest_point_coords
cv2.circle(colored_depth, (x, y), 10, (255, 0, 0), 2) # Red circle for nearest point
cv2.putText(colored_depth, "REF", (x+15, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
# Find and highlight the deepest point
binary_mask = (wound_mask > 127).astype(np.uint8)
wound_coords = np.where(binary_mask > 0)
if len(wound_coords[0]) > 0:
# Get depth values for wound region
wound_depths = vis_depth[wound_coords]
max_depth_idx = np.argmax(wound_depths)
deepest_point_coords = (wound_coords[0][max_depth_idx], wound_coords[1][max_depth_idx])
# Highlight the deepest point with a blue circle
y, x = deepest_point_coords
cv2.circle(colored_depth, (x, y), 12, (0, 0, 255), 3) # Blue circle for deepest point
cv2.putText(colored_depth, "DEEP", (x+15, y+5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
# Overlay wound mask outline
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(colored_depth, contours, -1, (0, 255, 0), 2) # Green outline for wound boundary
return colored_depth
def get_enhanced_severity_description(severity):
"""Get comprehensive medical description for severity level"""
descriptions = {
"Superficial": "Epidermis-only damage. Minimal tissue loss, typically heals within 1-2 weeks with basic wound care.",
"Mild": "Superficial to partial thickness wound. Limited tissue involvement, good healing potential with proper care.",
"Moderate": "Partial to full thickness involvement. Requires careful monitoring and may need advanced wound care techniques.",
"Severe": "Full thickness with deep tissue involvement. High risk of complications, requires immediate medical attention.",
"Very Severe": "Extensive deep tissue damage. Critical condition requiring immediate surgical intervention and specialized care.",
"Unknown": "Unable to determine severity due to insufficient data or poor image quality."
}
return descriptions.get(severity, "Severity assessment unavailable.")
def create_sample_wound_mask(image_shape, center=None, radius=50):
"""Create a sample circular wound mask for testing"""
if center is None:
center = (image_shape[1] // 2, image_shape[0] // 2)
mask = np.zeros(image_shape[:2], dtype=np.uint8)
y, x = np.ogrid[:image_shape[0], :image_shape[1]]
# Create circular mask
dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
mask[dist_from_center <= radius] = 255
return mask
def create_realistic_wound_mask(image_shape, method='elliptical'):
"""Create a more realistic wound mask with irregular shapes"""
h, w = image_shape[:2]
mask = np.zeros((h, w), dtype=np.uint8)
if method == 'elliptical':
# Create elliptical wound mask
center = (w // 2, h // 2)
radius_x = min(w, h) // 3
radius_y = min(w, h) // 4
y, x = np.ogrid[:h, :w]
# Add some irregularity to make it more realistic
ellipse = ((x - center[0])**2 / (radius_x**2) +
(y - center[1])**2 / (radius_y**2)) <= 1
# Add some noise and irregularity
noise = np.random.random((h, w)) > 0.8
mask = (ellipse | noise).astype(np.uint8) * 255
elif method == 'irregular':
# Create irregular wound mask
center = (w // 2, h // 2)
radius = min(w, h) // 4
y, x = np.ogrid[:h, :w]
base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius
# Add irregular extensions
extensions = np.zeros_like(base_circle)
for i in range(3):
angle = i * 2 * np.pi / 3
ext_x = int(center[0] + radius * 0.8 * np.cos(angle))
ext_y = int(center[1] + radius * 0.8 * np.sin(angle))
ext_radius = radius // 3
ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius
extensions = extensions | ext_circle
mask = (base_circle | extensions).astype(np.uint8) * 255
# Apply morphological operations to smooth the mask
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
return mask
# --- Depth Estimation Functions ---
def predict_depth(image):
return depth_model.infer_image(image)
def calculate_max_points(image):
"""Calculate maximum points based on image dimensions (3x pixel count)"""
if image is None:
return 10000 # Default value
h, w = image.shape[:2]
max_points = h * w * 3
# Ensure minimum and reasonable maximum values
return max(1000, min(max_points, 300000))
def update_slider_on_image_upload(image):
"""Update the points slider when an image is uploaded"""
max_points = calculate_max_points(image)
default_value = min(10000, max_points // 10) # 10% of max points as default
return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
label=f"Number of 3D points (max: {max_points:,})")
def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000):
"""Create a point cloud from depth map using camera intrinsics with high detail"""
h, w = depth_map.shape
# Use smaller step for higher detail (reduced downsampling)
step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail
# Create mesh grid for camera coordinates
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
# Convert to camera coordinates (normalized by focal length)
x_cam = (x_coords - w / 2) / focal_length_x
y_cam = (y_coords - h / 2) / focal_length_y
# Get depth values
depth_values = depth_map[::step, ::step]
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
x_3d = x_cam * depth_values
y_3d = y_cam * depth_values
z_3d = depth_values
# Flatten arrays
points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
# Get corresponding image colors
image_colors = image[::step, ::step, :]
colors = image_colors.reshape(-1, 3) / 255.0
# Create Open3D point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd.colors = o3d.utility.Vector3dVector(colors)
return pcd
def reconstruct_surface_mesh_from_point_cloud(pcd):
"""Convert point cloud to a mesh using Poisson reconstruction with very high detail."""
# Estimate and orient normals with high precision
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50))
pcd.orient_normals_consistent_tangent_plane(k=50)
# Create surface mesh with maximum detail (depth=12 for very high resolution)
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12)
# Return mesh without filtering low-density vertices
return mesh
def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
"""Create an enhanced 3D visualization using proper camera projection"""
h, w = depth_map.shape
# Downsample to avoid too many points for performance
step = max(1, int(np.sqrt(h * w / max_points)))
# Create mesh grid for camera coordinates
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
# Convert to camera coordinates (normalized by focal length)
focal_length = 470.4 # Default focal length
x_cam = (x_coords - w / 2) / focal_length
y_cam = (y_coords - h / 2) / focal_length
# Get depth values
depth_values = depth_map[::step, ::step]
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
x_3d = x_cam * depth_values
y_3d = y_cam * depth_values
z_3d = depth_values
# Flatten arrays
x_flat = x_3d.flatten()
y_flat = y_3d.flatten()
z_flat = z_3d.flatten()
# Get corresponding image colors
image_colors = image[::step, ::step, :]
colors_flat = image_colors.reshape(-1, 3)
# Create 3D scatter plot with proper camera projection
fig = go.Figure(data=[go.Scatter3d(
x=x_flat,
y=y_flat,
z=z_flat,
mode='markers',
marker=dict(
size=1.5,
color=colors_flat,
opacity=0.9
),
hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' +
'<b>Depth:</b> %{z:.2f}<br>' +
'<extra></extra>'
)])
fig.update_layout(
title="3D Point Cloud Visualization (Camera Projection)",
scene=dict(
xaxis_title="X (meters)",
yaxis_title="Y (meters)",
zaxis_title="Z (meters)",
camera=dict(
eye=dict(x=2.0, y=2.0, z=2.0),
center=dict(x=0, y=0, z=0),
up=dict(x=0, y=0, z=1)
),
aspectmode='data'
),
width=700,
height=600
)
return fig
def on_depth_submit(image, num_points, focal_x, focal_y):
original_image = image.copy()
h, w = image.shape[:2]
# Predict depth using the model
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
# Save raw 16-bit depth
raw_depth = Image.fromarray(depth.astype('uint16'))
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
raw_depth.save(tmp_raw_depth.name)
# Normalize and convert to grayscale for display
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
norm_depth = norm_depth.astype(np.uint8)
colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8)
gray_depth = Image.fromarray(norm_depth)
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
gray_depth.save(tmp_gray_depth.name)
# Create point cloud
pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points)
# Reconstruct mesh from point cloud
mesh = reconstruct_surface_mesh_from_point_cloud(pcd)
# Save mesh with faces as .ply
tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh)
# Create enhanced 3D scatter plot visualization
depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points)
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
# --- Actual Wound Segmentation Functions ---
def create_automatic_wound_mask(image, method='deep_learning'):
"""
Automatically generate wound mask from image using the actual deep learning model
Args:
image: Input image (numpy array)
method: Segmentation method (currently only 'deep_learning' supported)
Returns:
mask: Binary wound mask
"""
if image is None:
return None
# Use the actual deep learning model for segmentation
if method == 'deep_learning':
mask, _ = segmentation_model.segment_wound(image)
return mask
else:
# Fallback to deep learning if method not recognized
mask, _ = segmentation_model.segment_wound(image)
return mask
def post_process_wound_mask(mask, min_area=100):
"""Post-process the wound mask to remove noise and small objects"""
if mask is None:
return None
# Convert to binary if needed
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
# Apply morphological operations to clean up
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
# Remove small objects using OpenCV
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
mask_clean = np.zeros_like(mask)
for contour in contours:
area = cv2.contourArea(contour)
if area >= min_area:
cv2.fillPoly(mask_clean, [contour], 255)
# Fill holes
mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
return mask_clean
def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='deep_learning'):
"""Analyze wound severity with automatic mask generation using actual segmentation model"""
if image is None or depth_map is None:
return "❌ Please provide both image and depth map."
# Generate automatic wound mask using the actual model
auto_mask = create_automatic_wound_mask(image, method=segmentation_method)
if auto_mask is None:
return "❌ Failed to generate automatic wound mask. Please check if the segmentation model is loaded."
# Post-process the mask
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
if processed_mask is None or np.sum(processed_mask > 0) == 0:
return "❌ No wound region detected by the segmentation model. Try uploading a different image or use manual mask."
# Analyze severity using the automatic mask
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm)
# --- Main Gradio Interface ---
with gr.Blocks(css=css, title="Wound Analysis System") as demo:
gr.HTML("<h1>Wound Analysis System</h1>")
#gr.Markdown("### Complete workflow: Classification β†’ Depth Estimation β†’ Wound Severity Analysis")
# Shared states
shared_image = gr.State()
shared_depth_map = gr.State()
with gr.Tabs():
# Tab 1: Wound Classification
with gr.Tab("1. πŸ” Wound Classification & Initial Analysis"):
gr.Markdown("### Step 1: Classify wound type and get initial AI analysis")
#gr.Markdown("Upload an image to identify the wound type and receive detailed analysis from our Vision AI.")
with gr.Row():
# Left Column - Image Upload
with gr.Column(scale=1):
gr.HTML('<h2 style="text-align: left; color: #d97706; margin-top: 0; font-weight: bold; font-size: 1.8rem;">Upload Wound Image</h2>')
classification_image_input = gr.Image(
label="",
type="pil",
height=400
)
# Place Clear and Analyse buttons side by side
with gr.Row():
classify_clear_btn = gr.Button(
"Clear",
variant="secondary",
size="lg",
scale=1
)
analyse_btn = gr.Button(
"Analyse",
variant="primary",
size="lg",
scale=1
)
# Right Column - Classification Results
with gr.Column(scale=1):
gr.HTML('<h2 style="text-align: left; color: #d97706; margin-top: 0; font-weight: bold; font-size: 1.8rem;">Classification Results</h2>')
classification_output = gr.Label(
label="",
num_top_classes=5,
show_label=False
)
# Second Row - Full Width AI Analysis
with gr.Row():
with gr.Column(scale=1):
gr.HTML('<h2 style="text-align: left; color: #d97706; margin-top: 2rem; margin-bottom: 1rem; font-weight: bold; font-size: 1.8rem;">Wound Visual Analysis</h2>')
gemini_output = gr.HTML(
value="""
<div style="
border-radius: 12px;
padding: 20px;
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
font-family: Arial, sans-serif;
min-height: 200px;
display: flex;
align-items: center;
justify-content: center;
color: white;
width: 100%;
border-left: 4px solid #d97706;
font-weight: bold;
">
Upload an image to get AI-powered wound analysis
</div>
"""
)
# Event handlers for classification tab
classify_clear_btn.click(
fn=lambda: (None, None, """
<div style="
border-radius: 12px;
padding: 20px;
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
font-family: Arial, sans-serif;
min-height: 200px;
display: flex;
align-items: center;
justify-content: center;
color: white;
width: 100%;
border-left: 4px solid #d97706;
font-weight: bold;
">
Upload an image to get AI-powered wound analysis
</div>
"""),
inputs=None,
outputs=[classification_image_input, classification_output, gemini_output]
)
# Only run classification on image upload
def classify_and_store(image):
result = classify_wound(image)
return result
classification_image_input.change(
fn=classify_and_store,
inputs=classification_image_input,
outputs=classification_output
)
# Store image in shared state for next tabs
def store_shared_image(image):
return image
classification_image_input.change(
fn=store_shared_image,
inputs=classification_image_input,
outputs=shared_image
)
# Run Gemini analysis only when Analyse button is clicked
def run_gemini_on_click(image, classification):
# Get top label
if isinstance(classification, dict) and classification:
top_label = max(classification.items(), key=lambda x: x[1])[0]
else:
top_label = "Unknown"
gemini_analysis = analyze_wound_with_gemini(image, top_label)
formatted_analysis = format_gemini_analysis(gemini_analysis)
return formatted_analysis
analyse_btn.click(
fn=run_gemini_on_click,
inputs=[classification_image_input, classification_output],
outputs=gemini_output
)
# Tab 2: Depth Estimation
with gr.Tab("2. πŸ“ Depth Estimation & 3D Visualization"):
gr.Markdown("### Step 2: Generate depth maps and 3D visualizations")
gr.Markdown("This module creates depth maps and 3D point clouds from your images.")
with gr.Row():
load_from_classification_btn = gr.Button("πŸ”„ Load Image from Classification Tab", variant="secondary")
with gr.Row():
depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
with gr.Row():
depth_submit = gr.Button(value="Compute Depth", variant="primary")
points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
label="Number of 3D points (upload image to update max)")
with gr.Row():
focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
label="Focal Length X (pixels)")
focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
label="Focal Length Y (pixels)")
# Reorganized layout: 2 columns - 3D visualization on left, file outputs stacked on right
with gr.Row():
with gr.Column(scale=2):
# 3D Visualization
gr.Markdown("### 3D Point Cloud Visualization")
gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
depth_3d_plot = gr.Plot(label="3D Point Cloud")
with gr.Column(scale=1):
gr.Markdown("### Download Files")
gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
# Tab 3: Wound Severity Analysis
with gr.Tab("3. 🩹 Wound Severity Analysis"):
gr.Markdown("### Step 3: Analyze wound severity using depth maps")
gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.")
with gr.Row():
# Load depth map from previous tab
load_depth_btn = gr.Button("πŸ”„ Load Depth Map from Tab 2", variant="secondary")
with gr.Row():
severity_input_image = gr.Image(label="Original Image", type='numpy')
severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy')
with gr.Row():
wound_mask_input = gr.Image(label="Auto-Generated Wound Mask", type='numpy')
with gr.Row():
severity_output = gr.HTML(
label="πŸ€– AI-Powered Medical Assessment",
value="""
<div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'>
<div style='font-size: 24px; font-weight: bold; color: #ff9800; margin-bottom: 15px;'>
🩹 Wound Severity Analysis
</div>
<div style='font-size: 18px; color: #cccccc; margin-bottom: 20px;'>
⏳ Waiting for Input...
</div>
<div style='color: #888888; font-size: 14px;'>
Please upload an image and depth map, then click "πŸ€– Analyze Severity with Auto-Generated Mask" to begin AI-powered medical assessment.
</div>
</div>
"""
)
gr.Markdown("**Note:** The deep learning segmentation model will automatically generate a wound mask when you upload an image or load a depth map.")
with gr.Row():
auto_severity_button = gr.Button("πŸ€– Analyze Severity with Auto-Generated Mask", variant="primary", size="lg")
pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1,
label="Pixel Spacing (mm/pixel)")
depth_calibration_slider = gr.Slider(minimum=5.0, maximum=30.0, value=15.0, step=1.0,
label="Depth Calibration (mm)",
info="Adjust based on expected maximum wound depth")
#gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.")
#gr.Markdown("**Depth Calibration:** Adjust the maximum expected wound depth to improve measurement accuracy. For shallow wounds use 5-10mm, for deep wounds use 15-30mm.")
#gr.Markdown("**Note:** When you load a depth map or upload an image, the segmentation model will automatically generate a wound mask.")
# Update slider when image is uploaded
depth_input_image.change(
fn=update_slider_on_image_upload,
inputs=[depth_input_image],
outputs=[points_slider]
)
# Modified depth submit function to store depth map
def on_depth_submit_with_state(image, num_points, focal_x, focal_y):
results = on_depth_submit(image, num_points, focal_x, focal_y)
# Extract depth map from results for severity analysis
depth_map = None
if image is not None:
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
# Normalize depth for severity analysis
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth_map = norm_depth.astype(np.uint8)
return results + [depth_map]
depth_submit.click(on_depth_submit_with_state,
inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y],
outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, shared_depth_map])
# Function to load image from classification to depth tab
def load_image_from_classification(shared_img):
if shared_img is None:
return None, "❌ No image available from classification tab. Please upload an image in Tab 1 first."
# Convert PIL image to numpy array for depth estimation
if hasattr(shared_img, 'convert'):
# It's a PIL image, convert to numpy
img_array = np.array(shared_img)
return img_array, "βœ… Image loaded from classification tab successfully!"
else:
# Already numpy array
return shared_img, "βœ… Image loaded from classification tab successfully!"
# Connect the load button
load_from_classification_btn.click(
fn=load_image_from_classification,
inputs=shared_image,
outputs=[depth_input_image, gr.HTML()]
)
# Load depth map to severity tab and auto-generate mask
def load_depth_to_severity(depth_map, original_image):
if depth_map is None:
return None, None, None, "❌ No depth map available. Please compute depth in Tab 2 first."
# Auto-generate wound mask using segmentation model
if original_image is not None:
auto_mask, _ = segmentation_model.segment_wound(original_image)
if auto_mask is not None:
# Post-process the mask
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
if processed_mask is not None and np.sum(processed_mask > 0) > 0:
return depth_map, original_image, processed_mask, "βœ… Depth map loaded and wound mask auto-generated!"
else:
return depth_map, original_image, None, "βœ… Depth map loaded but no wound detected. Try uploading a different image."
else:
return depth_map, original_image, None, "βœ… Depth map loaded but segmentation failed. Try uploading a different image."
else:
return depth_map, original_image, None, "βœ… Depth map loaded successfully!"
load_depth_btn.click(
fn=load_depth_to_severity,
inputs=[shared_depth_map, depth_input_image],
outputs=[severity_depth_map, severity_input_image, wound_mask_input, gr.HTML()]
)
# Loading state function
def show_loading_state():
return """
<div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'>
<div style='font-size: 24px; font-weight: bold; color: #ff9800; margin-bottom: 15px;'>
🩹 Wound Severity Analysis
</div>
<div style='font-size: 18px; color: #4CAF50; margin-bottom: 20px;'>
πŸ”„ AI Analysis in Progress...
</div>
<div style='color: #cccccc; font-size: 14px; margin-bottom: 15px;'>
β€’ Generating wound mask with deep learning model<br>
β€’ Computing depth measurements and statistics<br>
β€’ Analyzing wound characteristics with Gemini AI<br>
β€’ Preparing comprehensive medical assessment
</div>
<div style='display: inline-block; width: 30px; height: 30px; border: 3px solid #f3f3f3; border-top: 3px solid #4CAF50; border-radius: 50%; animation: spin 1s linear infinite;'></div>
<style>
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
</style>
</div>
"""
# Automatic severity analysis function
def run_auto_severity_analysis(image, depth_map, pixel_spacing, depth_calibration):
if depth_map is None:
return """
<div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'>
<div style='font-size: 24px; font-weight: bold; color: #f44336; margin-bottom: 15px;'>
❌ Error
</div>
<div style='font-size: 16px; color: #cccccc;'>
Please load depth map from Tab 1 first.
</div>
</div>
"""
# Generate automatic wound mask using the actual model
auto_mask = create_automatic_wound_mask(image, method='deep_learning')
if auto_mask is None:
return """
<div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'>
<div style='font-size: 24px; font-weight: bold; color: #f44336; margin-bottom: 15px;'>
❌ Error
</div>
<div style='font-size: 16px; color: #cccccc;'>
Failed to generate automatic wound mask. Please check if the segmentation model is loaded.
</div>
</div>
"""
# Post-process the mask with fixed minimum area
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
if processed_mask is None or np.sum(processed_mask > 0) == 0:
return """
<div style='padding: 30px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5); text-align: center;'>
<div style='font-size: 24px; font-weight: bold; color: #ff9800; margin-bottom: 15px;'>
⚠️ No Wound Detected
</div>
<div style='font-size: 16px; color: #cccccc;'>
No wound region detected by the segmentation model. Try uploading a different image or use manual mask.
</div>
</div>
"""
# Analyze severity using the automatic mask
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing, depth_calibration)
# Connect event handler with loading state
auto_severity_button.click(
fn=show_loading_state,
inputs=[],
outputs=[severity_output]
).then(
fn=run_auto_severity_analysis,
inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, depth_calibration_slider],
outputs=[severity_output]
)
# Auto-generate mask when image is uploaded
def auto_generate_mask_on_image_upload(image):
if image is None:
return None, "❌ No image uploaded."
# Generate automatic wound mask using segmentation model
auto_mask, _ = segmentation_model.segment_wound(image)
if auto_mask is not None:
# Post-process the mask
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
if processed_mask is not None and np.sum(processed_mask > 0) > 0:
return processed_mask, "βœ… Wound mask auto-generated using deep learning model!"
else:
return None, "βœ… Image uploaded but no wound detected. Try uploading a different image."
else:
return None, "βœ… Image uploaded but segmentation failed. Try uploading a different image."
# Load shared image from classification tab
def load_shared_image(shared_img):
if shared_img is None:
return gr.Image(), "❌ No image available from classification tab"
# Convert PIL image to numpy array for depth estimation
if hasattr(shared_img, 'convert'):
# It's a PIL image, convert to numpy
img_array = np.array(shared_img)
return img_array, "βœ… Image loaded from classification tab"
else:
# Already numpy array
return shared_img, "βœ… Image loaded from classification tab"
# Auto-generate mask when image is uploaded to severity tab
severity_input_image.change(
fn=auto_generate_mask_on_image_upload,
inputs=[severity_input_image],
outputs=[wound_mask_input, gr.HTML()]
)
if __name__ == '__main__':
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)