Spaces:
Running
Running
import glob | |
import gradio as gr | |
import matplotlib | |
import numpy as np | |
from PIL import Image | |
import torch | |
import tempfile | |
from gradio_imageslider import ImageSlider | |
import plotly.graph_objects as go | |
import plotly.express as px | |
import open3d as o3d | |
from depth_anything_v2.dpt import DepthAnythingV2 | |
import os | |
import tensorflow as tf | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing import image as keras_image | |
import base64 | |
from io import BytesIO | |
import gdown | |
import spaces | |
import cv2 | |
# Import actual segmentation model components | |
from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling | |
from utils.learning.metrics import dice_coef, precision, recall | |
from utils.io.data import normalize | |
# Define path and file ID | |
checkpoint_dir = "checkpoints" | |
os.makedirs(checkpoint_dir, exist_ok=True) | |
model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth") | |
gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5" | |
# Download if not already present | |
if not os.path.exists(model_file): | |
print("Downloading model from Google Drive...") | |
gdown.download(gdrive_url, model_file, quiet=False) | |
# --- TensorFlow: Check GPU Availability --- | |
gpus = tf.config.list_physical_devices('GPU') | |
if gpus: | |
print("TensorFlow is using GPU") | |
else: | |
print("TensorFlow is using CPU") | |
# --- Load Wound Classification Model and Class Labels --- | |
wound_model = load_model("keras_model.h5") | |
with open("labels.txt", "r") as f: | |
class_labels = [line.strip().split(maxsplit=1)[1] for line in f] | |
# --- Load Actual Wound Segmentation Model --- | |
class WoundSegmentationModel: | |
def __init__(self): | |
self.input_dim_x = 224 | |
self.input_dim_y = 224 | |
self.model = None | |
self.load_model() | |
def load_model(self): | |
"""Load the trained wound segmentation model""" | |
try: | |
# Try to load the most recent model | |
weight_file_name = '2025-08-07_16-25-27.hdf5' | |
model_path = f'./training_history/{weight_file_name}' | |
self.model = load_model(model_path, | |
custom_objects={ | |
'recall': recall, | |
'precision': precision, | |
'dice_coef': dice_coef, | |
'relu6': relu6, | |
'DepthwiseConv2D': DepthwiseConv2D, | |
'BilinearUpsampling': BilinearUpsampling | |
}) | |
print(f"Segmentation model loaded successfully from {model_path}") | |
except Exception as e: | |
print(f"Error loading segmentation model: {e}") | |
# Fallback to the older model | |
try: | |
weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5' | |
model_path = f'./training_history/{weight_file_name}' | |
self.model = load_model(model_path, | |
custom_objects={ | |
'recall': recall, | |
'precision': precision, | |
'dice_coef': dice_coef, | |
'relu6': relu6, | |
'DepthwiseConv2D': DepthwiseConv2D, | |
'BilinearUpsampling': BilinearUpsampling | |
}) | |
print(f"Segmentation model loaded successfully from {model_path}") | |
except Exception as e2: | |
print(f"Error loading fallback segmentation model: {e2}") | |
self.model = None | |
def preprocess_image(self, image): | |
"""Preprocess the uploaded image for model input""" | |
if image is None: | |
return None | |
# Convert to RGB if needed | |
if len(image.shape) == 3 and image.shape[2] == 3: | |
# Convert BGR to RGB if needed | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# Resize to model input size | |
image = cv2.resize(image, (self.input_dim_x, self.input_dim_y)) | |
# Normalize the image | |
image = image.astype(np.float32) / 255.0 | |
# Add batch dimension | |
image = np.expand_dims(image, axis=0) | |
return image | |
def postprocess_prediction(self, prediction): | |
"""Postprocess the model prediction""" | |
# Remove batch dimension | |
prediction = prediction[0] | |
# Apply threshold to get binary mask | |
threshold = 0.5 | |
binary_mask = (prediction > threshold).astype(np.uint8) * 255 | |
return binary_mask | |
def segment_wound(self, input_image): | |
"""Main function to segment wound from uploaded image""" | |
if self.model is None: | |
return None, "Error: Segmentation model not loaded. Please check the model files." | |
if input_image is None: | |
return None, "Please upload an image." | |
try: | |
# Preprocess the image | |
processed_image = self.preprocess_image(input_image) | |
if processed_image is None: | |
return None, "Error processing image." | |
# Make prediction | |
prediction = self.model.predict(processed_image, verbose=0) | |
# Postprocess the prediction | |
segmented_mask = self.postprocess_prediction(prediction) | |
return segmented_mask, "Segmentation completed successfully!" | |
except Exception as e: | |
return None, f"Error during segmentation: {str(e)}" | |
# Initialize the segmentation model | |
segmentation_model = WoundSegmentationModel() | |
# --- PyTorch: Set Device and Load Depth Model --- | |
map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu") | |
print(f"Using PyTorch device: {map_device}") | |
model_configs = { | |
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, | |
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, | |
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, | |
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} | |
} | |
encoder = 'vitl' | |
depth_model = DepthAnythingV2(**model_configs[encoder]) | |
state_dict = torch.load( | |
f'checkpoints/depth_anything_v2_{encoder}.pth', | |
map_location=map_device | |
) | |
depth_model.load_state_dict(state_dict) | |
depth_model = depth_model.to(map_device).eval() | |
# --- Custom CSS for unified dark theme --- | |
css = """ | |
.gradio-container { | |
font-family: 'Segoe UI', sans-serif; | |
background-color: #121212; | |
color: #ffffff; | |
padding: 20px; | |
} | |
.gr-button { | |
background-color: #2c3e50; | |
color: white; | |
border-radius: 10px; | |
} | |
.gr-button:hover { | |
background-color: #34495e; | |
} | |
.gr-html, .gr-html div { | |
white-space: normal !important; | |
overflow: visible !important; | |
text-overflow: unset !important; | |
word-break: break-word !important; | |
} | |
#img-display-container { | |
max-height: 100vh; | |
} | |
#img-display-input { | |
max-height: 80vh; | |
} | |
#img-display-output { | |
max-height: 80vh; | |
} | |
#download { | |
height: 62px; | |
} | |
h1 { | |
text-align: center; | |
font-size: 3rem; | |
font-weight: bold; | |
margin: 2rem 0; | |
color: #ffffff; | |
} | |
h2 { | |
color: #ffffff; | |
text-align: center; | |
margin: 1rem 0; | |
} | |
.gr-tabs { | |
background-color: #1e1e1e; | |
border-radius: 10px; | |
padding: 10px; | |
} | |
.gr-tab-nav { | |
background-color: #2c3e50; | |
border-radius: 8px; | |
} | |
.gr-tab-nav button { | |
color: #ffffff !important; | |
} | |
.gr-tab-nav button.selected { | |
background-color: #34495e !important; | |
} | |
""" | |
# --- Wound Classification Functions --- | |
def preprocess_input(img): | |
img = img.resize((224, 224)) | |
arr = keras_image.img_to_array(img) | |
arr = arr / 255.0 | |
return np.expand_dims(arr, axis=0) | |
def get_reasoning_from_gemini(img, prediction): | |
try: | |
# For now, return a simple explanation without Gemini API to avoid typing issues | |
# In production, you would implement the proper Gemini API call here | |
explanations = { | |
"Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.", | |
"Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.", | |
"Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.", | |
"Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.", | |
"Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues." | |
} | |
return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.") | |
except Exception as e: | |
return f"(Reasoning unavailable: {str(e)})" | |
@spaces.GPU | |
def classify_wound_image(img): | |
if img is None: | |
return "<div style='color:#ff5252; font-size:18px;'>No image provided</div>", "" | |
img_array = preprocess_input(img) | |
predictions = wound_model.predict(img_array, verbose=0)[0] | |
pred_idx = int(np.argmax(predictions)) | |
pred_class = class_labels[pred_idx] | |
# Get reasoning from Gemini | |
reasoning_text = get_reasoning_from_gemini(img, pred_class) | |
# Prediction Card | |
predicted_card = f""" | |
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px; | |
box-shadow: 0 0 10px rgba(0,0,0,0.5);'> | |
<div style='font-size: 22px; font-weight: bold; color: orange; margin-bottom: 10px;'> | |
Predicted Wound Type | |
</div> | |
<div style='font-size: 26px; color: white;'> | |
{pred_class} | |
</div> | |
</div> | |
""" | |
# Reasoning Card | |
reasoning_card = f""" | |
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px; | |
box-shadow: 0 0 10px rgba(0,0,0,0.5);'> | |
<div style='font-size: 22px; font-weight: bold; color: orange; margin-bottom: 10px;'> | |
Reasoning | |
</div> | |
<div style='font-size: 16px; color: white; min-height: 80px;'> | |
{reasoning_text} | |
</div> | |
</div> | |
""" | |
return predicted_card, reasoning_card | |
# --- Enhanced Wound Severity Estimation Functions --- | |
@spaces.GPU | |
def compute_enhanced_depth_statistics(depth_map, mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0): | |
""" | |
Enhanced depth analysis with proper calibration and medical standards | |
Based on wound depth classification standards: | |
- Superficial: 0-2mm (epidermis only) | |
- Partial thickness: 2-4mm (epidermis + partial dermis) | |
- Full thickness: 4-6mm (epidermis + full dermis) | |
- Deep: >6mm (involving subcutaneous tissue) | |
""" | |
# Convert pixel spacing to mm | |
pixel_spacing_mm = float(pixel_spacing_mm) | |
# Calculate pixel area in cmΒ² | |
pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2 | |
# Extract wound region (binary mask) | |
wound_mask = (mask > 127).astype(np.uint8) | |
# Apply morphological operations to clean the mask | |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) | |
wound_mask = cv2.morphologyEx(wound_mask, cv2.MORPH_CLOSE, kernel) | |
# Get depth values only for wound region | |
wound_depths = depth_map[wound_mask > 0] | |
if len(wound_depths) == 0: | |
return { | |
'total_area_cm2': 0, | |
'superficial_area_cm2': 0, | |
'partial_thickness_area_cm2': 0, | |
'full_thickness_area_cm2': 0, | |
'deep_area_cm2': 0, | |
'mean_depth_mm': 0, | |
'max_depth_mm': 0, | |
'depth_std_mm': 0, | |
'deep_ratio': 0, | |
'wound_volume_cm3': 0, | |
'depth_percentiles': {'25': 0, '50': 0, '75': 0} | |
} | |
# Calibrate depth map for more accurate measurements | |
calibrated_depth_map = calibrate_depth_map(depth_map, reference_depth_mm=depth_calibration_mm) | |
# Get calibrated depth values for wound region | |
wound_depths_mm = calibrated_depth_map[wound_mask > 0] | |
# Medical depth classification | |
superficial_mask = wound_depths_mm < 2.0 | |
partial_thickness_mask = (wound_depths_mm >= 2.0) & (wound_depths_mm < 4.0) | |
full_thickness_mask = (wound_depths_mm >= 4.0) & (wound_depths_mm < 6.0) | |
deep_mask = wound_depths_mm >= 6.0 | |
# Calculate areas | |
total_pixels = np.sum(wound_mask > 0) | |
total_area_cm2 = total_pixels * pixel_area_cm2 | |
superficial_area_cm2 = np.sum(superficial_mask) * pixel_area_cm2 | |
partial_thickness_area_cm2 = np.sum(partial_thickness_mask) * pixel_area_cm2 | |
full_thickness_area_cm2 = np.sum(full_thickness_mask) * pixel_area_cm2 | |
deep_area_cm2 = np.sum(deep_mask) * pixel_area_cm2 | |
# Calculate depth statistics | |
mean_depth_mm = np.mean(wound_depths_mm) | |
max_depth_mm = np.max(wound_depths_mm) | |
depth_std_mm = np.std(wound_depths_mm) | |
# Calculate depth percentiles | |
depth_percentiles = { | |
'25': np.percentile(wound_depths_mm, 25), | |
'50': np.percentile(wound_depths_mm, 50), | |
'75': np.percentile(wound_depths_mm, 75) | |
} | |
# Calculate wound volume (approximate) | |
# Volume = area * average depth | |
wound_volume_cm3 = total_area_cm2 * (mean_depth_mm / 10.0) | |
# Deep tissue ratio | |
deep_ratio = deep_area_cm2 / total_area_cm2 if total_area_cm2 > 0 else 0 | |
# Calculate analysis quality metrics | |
wound_pixel_count = len(wound_depths_mm) | |
analysis_quality = "High" if wound_pixel_count > 1000 else "Medium" if wound_pixel_count > 500 else "Low" | |
# Calculate depth consistency (lower std dev = more consistent) | |
depth_consistency = "High" if depth_std_mm < 2.0 else "Medium" if depth_std_mm < 4.0 else "Low" | |
return { | |
'total_area_cm2': total_area_cm2, | |
'superficial_area_cm2': superficial_area_cm2, | |
'partial_thickness_area_cm2': partial_thickness_area_cm2, | |
'full_thickness_area_cm2': full_thickness_area_cm2, | |
'deep_area_cm2': deep_area_cm2, | |
'mean_depth_mm': mean_depth_mm, | |
'max_depth_mm': max_depth_mm, | |
'depth_std_mm': depth_std_mm, | |
'deep_ratio': deep_ratio, | |
'wound_volume_cm3': wound_volume_cm3, | |
'depth_percentiles': depth_percentiles, | |
'analysis_quality': analysis_quality, | |
'depth_consistency': depth_consistency, | |
'wound_pixel_count': wound_pixel_count | |
} | |
def classify_wound_severity_by_enhanced_metrics(depth_stats): | |
""" | |
Enhanced wound severity classification based on medical standards | |
Uses multiple criteria: depth, area, volume, and tissue involvement | |
""" | |
if depth_stats['total_area_cm2'] == 0: | |
return "Unknown" | |
# Extract key metrics | |
total_area = depth_stats['total_area_cm2'] | |
deep_area = depth_stats['deep_area_cm2'] | |
full_thickness_area = depth_stats['full_thickness_area_cm2'] | |
mean_depth = depth_stats['mean_depth_mm'] | |
max_depth = depth_stats['max_depth_mm'] | |
wound_volume = depth_stats['wound_volume_cm3'] | |
deep_ratio = depth_stats['deep_ratio'] | |
# Medical severity classification criteria | |
severity_score = 0 | |
# Criterion 1: Maximum depth | |
if max_depth >= 10.0: | |
severity_score += 3 # Very severe | |
elif max_depth >= 6.0: | |
severity_score += 2 # Severe | |
elif max_depth >= 4.0: | |
severity_score += 1 # Moderate | |
# Criterion 2: Mean depth | |
if mean_depth >= 5.0: | |
severity_score += 2 | |
elif mean_depth >= 3.0: | |
severity_score += 1 | |
# Criterion 3: Deep tissue involvement ratio | |
if deep_ratio >= 0.5: | |
severity_score += 3 # More than 50% deep tissue | |
elif deep_ratio >= 0.25: | |
severity_score += 2 # 25-50% deep tissue | |
elif deep_ratio >= 0.1: | |
severity_score += 1 # 10-25% deep tissue | |
# Criterion 4: Total wound area | |
if total_area >= 10.0: | |
severity_score += 2 # Large wound (>10 cmΒ²) | |
elif total_area >= 5.0: | |
severity_score += 1 # Medium wound (5-10 cmΒ²) | |
# Criterion 5: Wound volume | |
if wound_volume >= 5.0: | |
severity_score += 2 # High volume | |
elif wound_volume >= 2.0: | |
severity_score += 1 # Medium volume | |
# Determine severity based on total score | |
if severity_score >= 8: | |
return "Very Severe" | |
elif severity_score >= 6: | |
return "Severe" | |
elif severity_score >= 4: | |
return "Moderate" | |
elif severity_score >= 2: | |
return "Mild" | |
else: | |
return "Superficial" | |
def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0): | |
"""Enhanced wound severity analysis with medical-grade metrics""" | |
if image is None or depth_map is None or wound_mask is None: | |
return "β Please upload image, depth map, and wound mask." | |
# Convert wound mask to grayscale if needed | |
if len(wound_mask.shape) == 3: | |
wound_mask = np.mean(wound_mask, axis=2) | |
# Ensure depth map and mask have same dimensions | |
if depth_map.shape[:2] != wound_mask.shape[:2]: | |
# Resize mask to match depth map | |
from PIL import Image | |
mask_pil = Image.fromarray(wound_mask.astype(np.uint8)) | |
mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0])) | |
wound_mask = np.array(mask_pil) | |
# Compute enhanced statistics | |
stats = compute_enhanced_depth_statistics(depth_map, wound_mask, pixel_spacing_mm, depth_calibration_mm) | |
severity = classify_wound_severity_by_enhanced_metrics(stats) | |
# Enhanced severity color coding | |
severity_color = { | |
"Superficial": "#4CAF50", # Green | |
"Mild": "#8BC34A", # Light Green | |
"Moderate": "#FF9800", # Orange | |
"Severe": "#F44336", # Red | |
"Very Severe": "#9C27B0" # Purple | |
}.get(severity, "#9E9E9E") # Gray for unknown | |
# Create comprehensive medical report | |
report = f""" | |
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5);'> | |
<div style='font-size: 24px; font-weight: bold; color: {severity_color}; margin-bottom: 15px;'> | |
π©Ή Enhanced Wound Severity Analysis | |
</div> | |
<div style='display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin-bottom: 20px;'> | |
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px;'> | |
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'> | |
π Tissue Involvement Analysis | |
</div> | |
<div style='color: #cccccc; line-height: 1.6;'> | |
<div>π’ <b>Superficial (0-2mm):</b> {stats['superficial_area_cm2']:.2f} cmΒ²</div> | |
<div>π‘ <b>Partial Thickness (2-4mm):</b> {stats['partial_thickness_area_cm2']:.2f} cmΒ²</div> | |
<div>π <b>Full Thickness (4-6mm):</b> {stats['full_thickness_area_cm2']:.2f} cmΒ²</div> | |
<div>π₯ <b>Deep (>6mm):</b> {stats['deep_area_cm2']:.2f} cmΒ²</div> | |
<div>π <b>Total Area:</b> {stats['total_area_cm2']:.2f} cmΒ²</div> | |
</div> | |
</div> | |
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px;'> | |
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'> | |
π Depth Statistics | |
</div> | |
<div style='color: #cccccc; line-height: 1.6;'> | |
<div>π <b>Mean Depth:</b> {stats['mean_depth_mm']:.1f} mm</div> | |
<div>π <b>Max Depth:</b> {stats['max_depth_mm']:.1f} mm</div> | |
<div>π <b>Depth Std Dev:</b> {stats['depth_std_mm']:.1f} mm</div> | |
<div>π¦ <b>Wound Volume:</b> {stats['wound_volume_cm3']:.2f} cmΒ³</div> | |
<div>π₯ <b>Deep Tissue Ratio:</b> {stats['deep_ratio']*100:.1f}%</div> | |
</div> | |
</div> | |
</div> | |
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px; margin-bottom: 20px;'> | |
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'> | |
π Depth Percentiles & Quality Metrics | |
</div> | |
<div style='color: #cccccc; line-height: 1.6; display: grid; grid-template-columns: 1fr 1fr; gap: 15px;'> | |
<div> | |
<div>π <b>25th Percentile:</b> {stats['depth_percentiles']['25']:.1f} mm</div> | |
<div>π <b>Median (50th):</b> {stats['depth_percentiles']['50']:.1f} mm</div> | |
<div>π <b>75th Percentile:</b> {stats['depth_percentiles']['75']:.1f} mm</div> | |
</div> | |
<div> | |
<div>π <b>Analysis Quality:</b> {stats['analysis_quality']}</div> | |
<div>π <b>Depth Consistency:</b> {stats['depth_consistency']}</div> | |
<div>π <b>Data Points:</b> {stats['wound_pixel_count']:,}</div> | |
</div> | |
</div> | |
</div> | |
<div style='text-align: center; padding: 15px; background-color: #2c2c2c; border-radius: 8px; border-left: 4px solid {severity_color};'> | |
<div style='font-size: 20px; font-weight: bold; color: {severity_color};'> | |
π― Medical Severity Assessment: {severity} | |
</div> | |
<div style='font-size: 14px; color: #cccccc; margin-top: 5px;'> | |
{get_enhanced_severity_description(severity)} | |
</div> | |
</div> | |
</div> | |
""" | |
return report | |
def calibrate_depth_map(depth_map, reference_depth_mm=10.0): | |
""" | |
Calibrate depth map to real-world measurements using reference depth | |
This helps convert normalized depth values to actual millimeters | |
""" | |
if depth_map is None: | |
return depth_map | |
# Find the maximum depth value in the depth map | |
max_depth_value = np.max(depth_map) | |
min_depth_value = np.min(depth_map) | |
if max_depth_value == min_depth_value: | |
return depth_map | |
# Apply calibration to convert to millimeters | |
# Assuming the maximum depth in the map corresponds to reference_depth_mm | |
calibrated_depth = (depth_map - min_depth_value) / (max_depth_value - min_depth_value) * reference_depth_mm | |
return calibrated_depth | |
def get_enhanced_severity_description(severity): | |
"""Get comprehensive medical description for severity level""" | |
descriptions = { | |
"Superficial": "Epidermis-only damage. Minimal tissue loss, typically heals within 1-2 weeks with basic wound care.", | |
"Mild": "Superficial to partial thickness wound. Limited tissue involvement, good healing potential with proper care.", | |
"Moderate": "Partial to full thickness involvement. Requires careful monitoring and may need advanced wound care techniques.", | |
"Severe": "Full thickness with deep tissue involvement. High risk of complications, requires immediate medical attention.", | |
"Very Severe": "Extensive deep tissue damage. Critical condition requiring immediate surgical intervention and specialized care.", | |
"Unknown": "Unable to determine severity due to insufficient data or poor image quality." | |
} | |
return descriptions.get(severity, "Severity assessment unavailable.") | |
def create_sample_wound_mask(image_shape, center=None, radius=50): | |
"""Create a sample circular wound mask for testing""" | |
if center is None: | |
center = (image_shape[1] // 2, image_shape[0] // 2) | |
mask = np.zeros(image_shape[:2], dtype=np.uint8) | |
y, x = np.ogrid[:image_shape[0], :image_shape[1]] | |
# Create circular mask | |
dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2) | |
mask[dist_from_center <= radius] = 255 | |
return mask | |
def create_realistic_wound_mask(image_shape, method='elliptical'): | |
"""Create a more realistic wound mask with irregular shapes""" | |
h, w = image_shape[:2] | |
mask = np.zeros((h, w), dtype=np.uint8) | |
if method == 'elliptical': | |
# Create elliptical wound mask | |
center = (w // 2, h // 2) | |
radius_x = min(w, h) // 3 | |
radius_y = min(w, h) // 4 | |
y, x = np.ogrid[:h, :w] | |
# Add some irregularity to make it more realistic | |
ellipse = ((x - center[0])**2 / (radius_x**2) + | |
(y - center[1])**2 / (radius_y**2)) <= 1 | |
# Add some noise and irregularity | |
noise = np.random.random((h, w)) > 0.8 | |
mask = (ellipse | noise).astype(np.uint8) * 255 | |
elif method == 'irregular': | |
# Create irregular wound mask | |
center = (w // 2, h // 2) | |
radius = min(w, h) // 4 | |
y, x = np.ogrid[:h, :w] | |
base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius | |
# Add irregular extensions | |
extensions = np.zeros_like(base_circle) | |
for i in range(3): | |
angle = i * 2 * np.pi / 3 | |
ext_x = int(center[0] + radius * 0.8 * np.cos(angle)) | |
ext_y = int(center[1] + radius * 0.8 * np.sin(angle)) | |
ext_radius = radius // 3 | |
ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius | |
extensions = extensions | ext_circle | |
mask = (base_circle | extensions).astype(np.uint8) * 255 | |
# Apply morphological operations to smooth the mask | |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
return mask | |
# --- Depth Estimation Functions --- | |
@spaces.GPU | |
def predict_depth(image): | |
return depth_model.infer_image(image) | |
def calculate_max_points(image): | |
"""Calculate maximum points based on image dimensions (3x pixel count)""" | |
if image is None: | |
return 10000 # Default value | |
h, w = image.shape[:2] | |
max_points = h * w * 3 | |
# Ensure minimum and reasonable maximum values | |
return max(1000, min(max_points, 300000)) | |
def update_slider_on_image_upload(image): | |
"""Update the points slider when an image is uploaded""" | |
max_points = calculate_max_points(image) | |
default_value = min(10000, max_points // 10) # 10% of max points as default | |
return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000, | |
label=f"Number of 3D points (max: {max_points:,})") | |
@spaces.GPU | |
def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000): | |
"""Create a point cloud from depth map using camera intrinsics with high detail""" | |
h, w = depth_map.shape | |
# Use smaller step for higher detail (reduced downsampling) | |
step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail | |
# Create mesh grid for camera coordinates | |
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] | |
# Convert to camera coordinates (normalized by focal length) | |
x_cam = (x_coords - w / 2) / focal_length_x | |
y_cam = (y_coords - h / 2) / focal_length_y | |
# Get depth values | |
depth_values = depth_map[::step, ::step] | |
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth) | |
x_3d = x_cam * depth_values | |
y_3d = y_cam * depth_values | |
z_3d = depth_values | |
# Flatten arrays | |
points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1) | |
# Get corresponding image colors | |
image_colors = image[::step, ::step, :] | |
colors = image_colors.reshape(-1, 3) / 255.0 | |
# Create Open3D point cloud | |
pcd = o3d.geometry.PointCloud() | |
pcd.points = o3d.utility.Vector3dVector(points) | |
pcd.colors = o3d.utility.Vector3dVector(colors) | |
return pcd | |
@spaces.GPU | |
def reconstruct_surface_mesh_from_point_cloud(pcd): | |
"""Convert point cloud to a mesh using Poisson reconstruction with very high detail.""" | |
# Estimate and orient normals with high precision | |
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50)) | |
pcd.orient_normals_consistent_tangent_plane(k=50) | |
# Create surface mesh with maximum detail (depth=12 for very high resolution) | |
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12) | |
# Return mesh without filtering low-density vertices | |
return mesh | |
@spaces.GPU | |
def create_enhanced_3d_visualization(image, depth_map, max_points=10000): | |
"""Create an enhanced 3D visualization using proper camera projection""" | |
h, w = depth_map.shape | |
# Downsample to avoid too many points for performance | |
step = max(1, int(np.sqrt(h * w / max_points))) | |
# Create mesh grid for camera coordinates | |
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] | |
# Convert to camera coordinates (normalized by focal length) | |
focal_length = 470.4 # Default focal length | |
x_cam = (x_coords - w / 2) / focal_length | |
y_cam = (y_coords - h / 2) / focal_length | |
# Get depth values | |
depth_values = depth_map[::step, ::step] | |
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth) | |
x_3d = x_cam * depth_values | |
y_3d = y_cam * depth_values | |
z_3d = depth_values | |
# Flatten arrays | |
x_flat = x_3d.flatten() | |
y_flat = y_3d.flatten() | |
z_flat = z_3d.flatten() | |
# Get corresponding image colors | |
image_colors = image[::step, ::step, :] | |
colors_flat = image_colors.reshape(-1, 3) | |
# Create 3D scatter plot with proper camera projection | |
fig = go.Figure(data=[go.Scatter3d( | |
x=x_flat, | |
y=y_flat, | |
z=z_flat, | |
mode='markers', | |
marker=dict( | |
size=1.5, | |
color=colors_flat, | |
opacity=0.9 | |
), | |
hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' + | |
'<b>Depth:</b> %{z:.2f}<br>' + | |
'<extra></extra>' | |
)]) | |
fig.update_layout( | |
title="3D Point Cloud Visualization (Camera Projection)", | |
scene=dict( | |
xaxis_title="X (meters)", | |
yaxis_title="Y (meters)", | |
zaxis_title="Z (meters)", | |
camera=dict( | |
eye=dict(x=2.0, y=2.0, z=2.0), | |
center=dict(x=0, y=0, z=0), | |
up=dict(x=0, y=0, z=1) | |
), | |
aspectmode='data' | |
), | |
width=700, | |
height=600 | |
) | |
return fig | |
def on_depth_submit(image, num_points, focal_x, focal_y): | |
original_image = image.copy() | |
h, w = image.shape[:2] | |
# Predict depth using the model | |
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed | |
# Save raw 16-bit depth | |
raw_depth = Image.fromarray(depth.astype('uint16')) | |
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) | |
raw_depth.save(tmp_raw_depth.name) | |
# Normalize and convert to grayscale for display | |
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 | |
norm_depth = norm_depth.astype(np.uint8) | |
colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8) | |
gray_depth = Image.fromarray(norm_depth) | |
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) | |
gray_depth.save(tmp_gray_depth.name) | |
# Create point cloud | |
pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points) | |
# Reconstruct mesh from point cloud | |
mesh = reconstruct_surface_mesh_from_point_cloud(pcd) | |
# Save mesh with faces as .ply | |
tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False) | |
o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh) | |
# Create enhanced 3D scatter plot visualization | |
depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points) | |
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d] | |
# --- Actual Wound Segmentation Functions --- | |
def create_automatic_wound_mask(image, method='deep_learning'): | |
""" | |
Automatically generate wound mask from image using the actual deep learning model | |
Args: | |
image: Input image (numpy array) | |
method: Segmentation method (currently only 'deep_learning' supported) | |
Returns: | |
mask: Binary wound mask | |
""" | |
if image is None: | |
return None | |
# Use the actual deep learning model for segmentation | |
if method == 'deep_learning': | |
mask, _ = segmentation_model.segment_wound(image) | |
return mask | |
else: | |
# Fallback to deep learning if method not recognized | |
mask, _ = segmentation_model.segment_wound(image) | |
return mask | |
def post_process_wound_mask(mask, min_area=100): | |
"""Post-process the wound mask to remove noise and small objects""" | |
if mask is None: | |
return None | |
# Convert to binary if needed | |
if mask.dtype != np.uint8: | |
mask = mask.astype(np.uint8) | |
# Apply morphological operations to clean up | |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) | |
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) | |
# Remove small objects using OpenCV | |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
mask_clean = np.zeros_like(mask) | |
for contour in contours: | |
area = cv2.contourArea(contour) | |
if area >= min_area: | |
cv2.fillPoly(mask_clean, [contour], 255) | |
# Fill holes | |
mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel) | |
return mask_clean | |
def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='deep_learning'): | |
"""Analyze wound severity with automatic mask generation using actual segmentation model""" | |
if image is None or depth_map is None: | |
return "β Please provide both image and depth map." | |
# Generate automatic wound mask using the actual model | |
auto_mask = create_automatic_wound_mask(image, method=segmentation_method) | |
if auto_mask is None: | |
return "β Failed to generate automatic wound mask. Please check if the segmentation model is loaded." | |
# Post-process the mask | |
processed_mask = post_process_wound_mask(auto_mask, min_area=500) | |
if processed_mask is None or np.sum(processed_mask > 0) == 0: | |
return "β No wound region detected by the segmentation model. Try uploading a different image or use manual mask." | |
# Analyze severity using the automatic mask | |
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm) | |
# --- Main Gradio Interface --- | |
with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo: | |
gr.HTML("<h1>Wound Analysis & Depth Estimation System</h1>") | |
gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities") | |
# Shared image state | |
shared_image = gr.State() | |
with gr.Tabs(): | |
# Tab 1: Wound Classification | |
with gr.Tab("1. Wound Classification"): | |
gr.Markdown("### Step 1: Upload and classify your wound image") | |
gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350) | |
with gr.Column(scale=1): | |
wound_prediction_box = gr.HTML() | |
wound_reasoning_box = gr.HTML() | |
# Button to pass image to depth estimation | |
with gr.Row(): | |
pass_to_depth_btn = gr.Button("π Pass Image to Depth Analysis", variant="secondary", size="lg") | |
pass_status = gr.HTML("") | |
wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input, | |
outputs=[wound_prediction_box, wound_reasoning_box]) | |
# Store image when uploaded for classification | |
wound_image_input.change( | |
fn=lambda img: img, | |
inputs=[wound_image_input], | |
outputs=[shared_image] | |
) | |
# Tab 2: Depth Estimation | |
with gr.Tab("2. Depth Estimation & 3D Visualization"): | |
gr.Markdown("### Step 2: Generate depth maps and 3D visualizations") | |
gr.Markdown("This module creates depth maps and 3D point clouds from your images.") | |
with gr.Row(): | |
depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') | |
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output') | |
with gr.Row(): | |
depth_submit = gr.Button(value="Compute Depth", variant="primary") | |
load_shared_btn = gr.Button("π Load Image from Classification", variant="secondary") | |
points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000, | |
label="Number of 3D points (upload image to update max)") | |
with gr.Row(): | |
focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, | |
label="Focal Length X (pixels)") | |
focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, | |
label="Focal Length Y (pixels)") | |
with gr.Row(): | |
gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") | |
raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download") | |
point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download") | |
# 3D Visualization | |
gr.Markdown("### 3D Point Cloud Visualization") | |
gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.") | |
depth_3d_plot = gr.Plot(label="3D Point Cloud") | |
# Store depth map for severity analysis | |
depth_map_state = gr.State() | |
# Tab 3: Wound Severity Analysis | |
with gr.Tab("3. π©Ή Wound Severity Analysis"): | |
gr.Markdown("### Step 3: Analyze wound severity using depth maps") | |
gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.") | |
with gr.Row(): | |
severity_input_image = gr.Image(label="Original Image", type='numpy') | |
severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy') | |
with gr.Row(): | |
wound_mask_input = gr.Image(label="Auto-Generated Wound Mask", type='numpy') | |
severity_output = gr.HTML(label="Severity Analysis Report") | |
gr.Markdown("**Note:** The deep learning segmentation model will automatically generate a wound mask when you upload an image or load a depth map.") | |
with gr.Row(): | |
auto_severity_button = gr.Button("π€ Analyze Severity with Auto-Generated Mask", variant="primary", size="lg") | |
manual_severity_button = gr.Button("π Manual Mask Analysis", variant="secondary", size="lg") | |
pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1, | |
label="Pixel Spacing (mm/pixel)") | |
depth_calibration_slider = gr.Slider(minimum=5.0, maximum=30.0, value=15.0, step=1.0, | |
label="Depth Calibration (mm)", | |
info="Adjust based on expected maximum wound depth") | |
gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.") | |
gr.Markdown("**Depth Calibration:** Adjust the maximum expected wound depth to improve measurement accuracy. For shallow wounds use 5-10mm, for deep wounds use 15-30mm.") | |
with gr.Row(): | |
# Load depth map from previous tab | |
load_depth_btn = gr.Button("π Load Depth Map from Tab 2", variant="secondary") | |
gr.Markdown("**Note:** When you load a depth map or upload an image, the segmentation model will automatically generate a wound mask.") | |
# Update slider when image is uploaded | |
depth_input_image.change( | |
fn=update_slider_on_image_upload, | |
inputs=[depth_input_image], | |
outputs=[points_slider] | |
) | |
# Modified depth submit function to store depth map | |
def on_depth_submit_with_state(image, num_points, focal_x, focal_y): | |
results = on_depth_submit(image, num_points, focal_x, focal_y) | |
# Extract depth map from results for severity analysis | |
depth_map = None | |
if image is not None: | |
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed | |
# Normalize depth for severity analysis | |
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 | |
depth_map = norm_depth.astype(np.uint8) | |
return results + [depth_map] | |
depth_submit.click(on_depth_submit_with_state, | |
inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y], | |
outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state]) | |
# Load depth map to severity tab and auto-generate mask | |
def load_depth_to_severity(depth_map, original_image): | |
if depth_map is None: | |
return None, None, None, "β No depth map available. Please compute depth in Tab 2 first." | |
# Auto-generate wound mask using segmentation model | |
if original_image is not None: | |
auto_mask, _ = segmentation_model.segment_wound(original_image) | |
if auto_mask is not None: | |
# Post-process the mask | |
processed_mask = post_process_wound_mask(auto_mask, min_area=500) | |
if processed_mask is not None and np.sum(processed_mask > 0) > 0: | |
return depth_map, original_image, processed_mask, "β Depth map loaded and wound mask auto-generated!" | |
else: | |
return depth_map, original_image, None, "β Depth map loaded but no wound detected. Try uploading a different image." | |
else: | |
return depth_map, original_image, None, "β Depth map loaded but segmentation failed. Try uploading a different image." | |
else: | |
return depth_map, original_image, None, "β Depth map loaded successfully!" | |
load_depth_btn.click( | |
fn=load_depth_to_severity, | |
inputs=[depth_map_state, depth_input_image], | |
outputs=[severity_depth_map, severity_input_image, wound_mask_input, gr.HTML()] | |
) | |
# Automatic severity analysis function | |
def run_auto_severity_analysis(image, depth_map, pixel_spacing, depth_calibration): | |
if depth_map is None: | |
return "β Please load depth map from Tab 2 first." | |
# Generate automatic wound mask using the actual model | |
auto_mask = create_automatic_wound_mask(image, method='deep_learning') | |
if auto_mask is None: | |
return "β Failed to generate automatic wound mask. Please check if the segmentation model is loaded." | |
# Post-process the mask with fixed minimum area | |
processed_mask = post_process_wound_mask(auto_mask, min_area=500) | |
if processed_mask is None or np.sum(processed_mask > 0) == 0: | |
return "β No wound region detected by the segmentation model. Try uploading a different image or use manual mask." | |
# Analyze severity using the automatic mask | |
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing, depth_calibration) | |
# Manual severity analysis function | |
def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing, depth_calibration): | |
if depth_map is None: | |
return "β Please load depth map from Tab 2 first." | |
if wound_mask is None: | |
return "β Please upload a wound mask (binary image where white pixels represent the wound area)." | |
return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing, depth_calibration) | |
# Connect event handlers | |
auto_severity_button.click( | |
fn=run_auto_severity_analysis, | |
inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, depth_calibration_slider], | |
outputs=[severity_output] | |
) | |
manual_severity_button.click( | |
fn=run_manual_severity_analysis, | |
inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider, depth_calibration_slider], | |
outputs=[severity_output] | |
) | |
# Auto-generate mask when image is uploaded | |
def auto_generate_mask_on_image_upload(image): | |
if image is None: | |
return None, "β No image uploaded." | |
# Generate automatic wound mask using segmentation model | |
auto_mask, _ = segmentation_model.segment_wound(image) | |
if auto_mask is not None: | |
# Post-process the mask | |
processed_mask = post_process_wound_mask(auto_mask, min_area=500) | |
if processed_mask is not None and np.sum(processed_mask > 0) > 0: | |
return processed_mask, "β Wound mask auto-generated using deep learning model!" | |
else: | |
return None, "β Image uploaded but no wound detected. Try uploading a different image." | |
else: | |
return None, "β Image uploaded but segmentation failed. Try uploading a different image." | |
# Load shared image from classification tab | |
def load_shared_image(shared_img): | |
if shared_img is None: | |
return gr.Image(), "β No image available from classification tab" | |
# Convert PIL image to numpy array for depth estimation | |
if hasattr(shared_img, 'convert'): | |
# It's a PIL image, convert to numpy | |
img_array = np.array(shared_img) | |
return img_array, "β Image loaded from classification tab" | |
else: | |
# Already numpy array | |
return shared_img, "β Image loaded from classification tab" | |
# Auto-generate mask when image is uploaded to severity tab | |
severity_input_image.change( | |
fn=auto_generate_mask_on_image_upload, | |
inputs=[severity_input_image], | |
outputs=[wound_mask_input, gr.HTML()] | |
) | |
load_shared_btn.click( | |
fn=load_shared_image, | |
inputs=[shared_image], | |
outputs=[depth_input_image, gr.HTML()] | |
) | |
# Pass image to depth tab function | |
def pass_image_to_depth(img): | |
if img is None: | |
return "β No image uploaded in classification tab" | |
return "β Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'" | |
pass_to_depth_btn.click( | |
fn=pass_image_to_depth, | |
inputs=[shared_image], | |
outputs=[pass_status] | |
) | |
if __name__ == '__main__': | |
demo.queue().launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |