Rakhi-2025's picture
Upload 95 files
911c613 verified
import glob
import gradio as gr
import matplotlib
import numpy as np
from PIL import Image
import torch
import tempfile
from gradio_imageslider import ImageSlider
import plotly.graph_objects as go
import plotly.express as px
import open3d as o3d
from depth_anything_v2.dpt import DepthAnythingV2
import os
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image as keras_image
import base64
from io import BytesIO
import gdown
import spaces
# Define path and file ID
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth")
gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5"
# Download if not already present
if not os.path.exists(model_file):
print("Downloading model from Google Drive...")
gdown.download(gdrive_url, model_file, quiet=False)
# --- TensorFlow: Check GPU Availability ---
gpus = tf.config.list_physical_devices('GPU')
if gpus:
print("TensorFlow is using GPU")
else:
print("TensorFlow is using CPU")
# --- Load Wound Classification Model and Class Labels ---
wound_model = load_model("/home/user/app/keras_model.h5")
with open("/home/user/app/labels.txt", "r") as f:
class_labels = [line.strip().split(maxsplit=1)[1] for line in f]
# --- PyTorch: Set Device and Load Depth Model ---
map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
print(f"Using PyTorch device: {map_device}")
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
encoder = 'vitl'
depth_model = DepthAnythingV2(**model_configs[encoder])
state_dict = torch.load(
f'/home/user/app/checkpoints/depth_anything_v2_{encoder}.pth',
map_location=map_device
)
depth_model.load_state_dict(state_dict)
depth_model = depth_model.to(map_device).eval()
# --- Custom CSS for unified dark theme ---
css = """
.gradio-container {
font-family: 'Segoe UI', sans-serif;
background-color: #121212;
color: #ffffff;
padding: 20px;
}
.gr-button {
background-color: #2c3e50;
color: white;
border-radius: 10px;
}
.gr-button:hover {
background-color: #34495e;
}
.gr-html, .gr-html div {
white-space: normal !important;
overflow: visible !important;
text-overflow: unset !important;
word-break: break-word !important;
}
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
#download {
height: 62px;
}
h1 {
text-align: center;
font-size: 3rem;
font-weight: bold;
margin: 2rem 0;
color: #ffffff;
}
h2 {
color: #ffffff;
text-align: center;
margin: 1rem 0;
}
.gr-tabs {
background-color: #1e1e1e;
border-radius: 10px;
padding: 10px;
}
.gr-tab-nav {
background-color: #2c3e50;
border-radius: 8px;
}
.gr-tab-nav button {
color: #ffffff !important;
}
.gr-tab-nav button.selected {
background-color: #34495e !important;
}
"""
# --- Wound Classification Functions ---
def preprocess_input(img):
img = img.resize((224, 224))
arr = keras_image.img_to_array(img)
arr = arr / 255.0
return np.expand_dims(arr, axis=0)
def get_reasoning_from_gemini(img, prediction):
try:
# For now, return a simple explanation without Gemini API to avoid typing issues
# In production, you would implement the proper Gemini API call here
explanations = {
"Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.",
"Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.",
"Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.",
"Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.",
"Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues."
}
return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.")
except Exception as e:
return f"(Reasoning unavailable: {str(e)})"
@spaces.GPU
def classify_wound_image(img):
if img is None:
return "<div style='color:#ff5252; font-size:18px;'>No image provided</div>", ""
img_array = preprocess_input(img)
predictions = wound_model.predict(img_array, verbose=0)[0]
pred_idx = int(np.argmax(predictions))
pred_class = class_labels[pred_idx]
# Get reasoning from Gemini
reasoning_text = get_reasoning_from_gemini(img, pred_class)
# Prediction Card
predicted_card = f"""
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px;
box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
<div style='font-size: 22px; font-weight: bold; color: orange; margin-bottom: 10px;'>
Predicted Wound Type
</div>
<div style='font-size: 26px; color: white;'>
{pred_class}
</div>
</div>
"""
# Reasoning Card
reasoning_card = f"""
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px;
box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
<div style='font-size: 22px; font-weight: bold; color: orange; margin-bottom: 10px;'>
Reasoning
</div>
<div style='font-size: 16px; color: white; min-height: 80px;'>
{reasoning_text}
</div>
</div>
"""
return predicted_card, reasoning_card
# --- Wound Severity Estimation Functions ---
@spaces.GPU
def compute_depth_area_statistics(depth_map, mask, pixel_spacing_mm=0.5):
"""Compute area statistics for different depth regions"""
pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2
# Extract only wound region
wound_mask = (mask > 127)
wound_depths = depth_map[wound_mask]
total_area = np.sum(wound_mask) * pixel_area_cm2
# Categorize depth regions
shallow = wound_depths < 3
moderate = (wound_depths >= 3) & (wound_depths < 6)
deep = wound_depths >= 6
shallow_area = np.sum(shallow) * pixel_area_cm2
moderate_area = np.sum(moderate) * pixel_area_cm2
deep_area = np.sum(deep) * pixel_area_cm2
deep_ratio = deep_area / total_area if total_area > 0 else 0
return {
'total_area_cm2': total_area,
'shallow_area_cm2': shallow_area,
'moderate_area_cm2': moderate_area,
'deep_area_cm2': deep_area,
'deep_ratio': deep_ratio,
'max_depth': np.max(wound_depths) if len(wound_depths) > 0 else 0
}
def classify_wound_severity_by_area(depth_stats):
"""Classify wound severity based on area and depth distribution"""
total = depth_stats['total_area_cm2']
deep = depth_stats['deep_area_cm2']
moderate = depth_stats['moderate_area_cm2']
if total == 0:
return "Unknown"
# Severity classification rules
if deep > 2 or (deep / total) > 0.3:
return "Severe"
elif moderate > 1.5 or (moderate / total) > 0.4:
return "Moderate"
else:
return "Mild"
def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5):
"""Analyze wound severity from depth map and wound mask"""
if image is None or depth_map is None or wound_mask is None:
return "❌ Please upload image, depth map, and wound mask."
# Convert wound mask to grayscale if needed
if len(wound_mask.shape) == 3:
wound_mask = np.mean(wound_mask, axis=2)
# Ensure depth map and mask have same dimensions
if depth_map.shape[:2] != wound_mask.shape[:2]:
# Resize mask to match depth map
from PIL import Image
mask_pil = Image.fromarray(wound_mask.astype(np.uint8))
mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0]))
wound_mask = np.array(mask_pil)
# Compute statistics
stats = compute_depth_area_statistics(depth_map, wound_mask, pixel_spacing_mm)
severity = classify_wound_severity_by_area(stats)
# Create severity report with color coding
severity_color = {
"Mild": "#4CAF50", # Green
"Moderate": "#FF9800", # Orange
"Severe": "#F44336" # Red
}.get(severity, "#9E9E9E") # Gray for unknown
report = f"""
<div style='padding: 20px; background-color: #1e1e1e; border-radius: 12px; box-shadow: 0 0 10px rgba(0,0,0,0.5);'>
<div style='font-size: 24px; font-weight: bold; color: {severity_color}; margin-bottom: 15px;'>
🩹 Wound Severity Analysis
</div>
<div style='display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin-bottom: 20px;'>
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px;'>
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'>
πŸ“ Area Measurements
</div>
<div style='color: #cccccc; line-height: 1.6;'>
<div>🟒 <b>Total Area:</b> {stats['total_area_cm2']:.2f} cm²</div>
<div>🟩 <b>Shallow (0-3mm):</b> {stats['shallow_area_cm2']:.2f} cm²</div>
<div>🟨 <b>Moderate (3-6mm):</b> {stats['moderate_area_cm2']:.2f} cm²</div>
<div>πŸŸ₯ <b>Deep (>6mm):</b> {stats['deep_area_cm2']:.2f} cmΒ²</div>
</div>
</div>
<div style='background-color: #2c2c2c; padding: 15px; border-radius: 8px;'>
<div style='font-size: 18px; font-weight: bold; color: #ffffff; margin-bottom: 10px;'>
πŸ“Š Depth Analysis
</div>
<div style='color: #cccccc; line-height: 1.6;'>
<div>πŸ”₯ <b>Deep Coverage:</b> {stats['deep_ratio']*100:.1f}%</div>
<div>πŸ“ <b>Max Depth:</b> {stats['max_depth']:.1f} mm</div>
<div>⚑ <b>Pixel Spacing:</b> {pixel_spacing_mm} mm</div>
</div>
</div>
</div>
<div style='text-align: center; padding: 15px; background-color: #2c2c2c; border-radius: 8px; border-left: 4px solid {severity_color};'>
<div style='font-size: 20px; font-weight: bold; color: {severity_color};'>
🎯 Predicted Severity: {severity}
</div>
<div style='font-size: 14px; color: #cccccc; margin-top: 5px;'>
{get_severity_description(severity)}
</div>
</div>
</div>
"""
return report
def get_severity_description(severity):
"""Get description for severity level"""
descriptions = {
"Mild": "Superficial wound with minimal tissue damage. Usually heals well with basic care.",
"Moderate": "Moderate tissue involvement requiring careful monitoring and proper treatment.",
"Severe": "Deep tissue damage requiring immediate medical attention and specialized care.",
"Unknown": "Unable to determine severity due to insufficient data."
}
return descriptions.get(severity, "Severity assessment unavailable.")
def create_sample_wound_mask(image_shape, center=None, radius=50):
"""Create a sample circular wound mask for testing"""
if center is None:
center = (image_shape[1] // 2, image_shape[0] // 2)
mask = np.zeros(image_shape[:2], dtype=np.uint8)
y, x = np.ogrid[:image_shape[0], :image_shape[1]]
# Create circular mask
dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
mask[dist_from_center <= radius] = 255
return mask
def create_realistic_wound_mask(image_shape, method='elliptical'):
"""Create a more realistic wound mask with irregular shapes"""
h, w = image_shape[:2]
mask = np.zeros((h, w), dtype=np.uint8)
if method == 'elliptical':
# Create elliptical wound mask
center = (w // 2, h // 2)
radius_x = min(w, h) // 3
radius_y = min(w, h) // 4
y, x = np.ogrid[:h, :w]
# Add some irregularity to make it more realistic
ellipse = ((x - center[0])**2 / (radius_x**2) +
(y - center[1])**2 / (radius_y**2)) <= 1
# Add some noise and irregularity
noise = np.random.random((h, w)) > 0.8
mask = (ellipse | noise).astype(np.uint8) * 255
elif method == 'irregular':
# Create irregular wound mask
center = (w // 2, h // 2)
radius = min(w, h) // 4
y, x = np.ogrid[:h, :w]
base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius
# Add irregular extensions
extensions = np.zeros_like(base_circle)
for i in range(3):
angle = i * 2 * np.pi / 3
ext_x = int(center[0] + radius * 0.8 * np.cos(angle))
ext_y = int(center[1] + radius * 0.8 * np.sin(angle))
ext_radius = radius // 3
ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius
extensions = extensions | ext_circle
mask = (base_circle | extensions).astype(np.uint8) * 255
# Apply morphological operations to smooth the mask
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
return mask
# --- Depth Estimation Functions ---
@spaces.GPU
def predict_depth(image):
return depth_model.infer_image(image)
def calculate_max_points(image):
"""Calculate maximum points based on image dimensions (3x pixel count)"""
if image is None:
return 10000 # Default value
h, w = image.shape[:2]
max_points = h * w * 3
# Ensure minimum and reasonable maximum values
return max(1000, min(max_points, 300000))
def update_slider_on_image_upload(image):
"""Update the points slider when an image is uploaded"""
max_points = calculate_max_points(image)
default_value = min(10000, max_points // 10) # 10% of max points as default
return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
label=f"Number of 3D points (max: {max_points:,})")
@spaces.GPU
def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000):
"""Create a point cloud from depth map using camera intrinsics with high detail"""
h, w = depth_map.shape
# Use smaller step for higher detail (reduced downsampling)
step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail
# Create mesh grid for camera coordinates
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
# Convert to camera coordinates (normalized by focal length)
x_cam = (x_coords - w / 2) / focal_length_x
y_cam = (y_coords - h / 2) / focal_length_y
# Get depth values
depth_values = depth_map[::step, ::step]
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
x_3d = x_cam * depth_values
y_3d = y_cam * depth_values
z_3d = depth_values
# Flatten arrays
points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
# Get corresponding image colors
image_colors = image[::step, ::step, :]
colors = image_colors.reshape(-1, 3) / 255.0
# Create Open3D point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd.colors = o3d.utility.Vector3dVector(colors)
return pcd
@spaces.GPU
def reconstruct_surface_mesh_from_point_cloud(pcd):
"""Convert point cloud to a mesh using Poisson reconstruction with very high detail."""
# Estimate and orient normals with high precision
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50))
pcd.orient_normals_consistent_tangent_plane(k=50)
# Create surface mesh with maximum detail (depth=12 for very high resolution)
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12)
# Return mesh without filtering low-density vertices
return mesh
@spaces.GPU
def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
"""Create an enhanced 3D visualization using proper camera projection"""
h, w = depth_map.shape
# Downsample to avoid too many points for performance
step = max(1, int(np.sqrt(h * w / max_points)))
# Create mesh grid for camera coordinates
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
# Convert to camera coordinates (normalized by focal length)
focal_length = 470.4 # Default focal length
x_cam = (x_coords - w / 2) / focal_length
y_cam = (y_coords - h / 2) / focal_length
# Get depth values
depth_values = depth_map[::step, ::step]
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
x_3d = x_cam * depth_values
y_3d = y_cam * depth_values
z_3d = depth_values
# Flatten arrays
x_flat = x_3d.flatten()
y_flat = y_3d.flatten()
z_flat = z_3d.flatten()
# Get corresponding image colors
image_colors = image[::step, ::step, :]
colors_flat = image_colors.reshape(-1, 3)
# Create 3D scatter plot with proper camera projection
fig = go.Figure(data=[go.Scatter3d(
x=x_flat,
y=y_flat,
z=z_flat,
mode='markers',
marker=dict(
size=1.5,
color=colors_flat,
opacity=0.9
),
hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' +
'<b>Depth:</b> %{z:.2f}<br>' +
'<extra></extra>'
)])
fig.update_layout(
title="3D Point Cloud Visualization (Camera Projection)",
scene=dict(
xaxis_title="X (meters)",
yaxis_title="Y (meters)",
zaxis_title="Z (meters)",
camera=dict(
eye=dict(x=2.0, y=2.0, z=2.0),
center=dict(x=0, y=0, z=0),
up=dict(x=0, y=0, z=1)
),
aspectmode='data'
),
width=700,
height=600
)
return fig
def on_depth_submit(image, num_points, focal_x, focal_y):
original_image = image.copy()
h, w = image.shape[:2]
# Predict depth using the model
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
# Save raw 16-bit depth
raw_depth = Image.fromarray(depth.astype('uint16'))
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
raw_depth.save(tmp_raw_depth.name)
# Normalize and convert to grayscale for display
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
norm_depth = norm_depth.astype(np.uint8)
colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8)
gray_depth = Image.fromarray(norm_depth)
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
gray_depth.save(tmp_gray_depth.name)
# Create point cloud
pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points)
# Reconstruct mesh from point cloud
mesh = reconstruct_surface_mesh_from_point_cloud(pcd)
# Save mesh with faces as .ply
tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh)
# Create enhanced 3D scatter plot visualization
depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points)
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
# --- Automatic Wound Mask Generation Functions ---
import cv2
from skimage import filters, morphology, measure
from skimage.segmentation import clear_border
def create_automatic_wound_mask(image, method='adaptive'):
"""
Automatically generate wound mask from image using various segmentation methods
Args:
image: Input image (numpy array)
method: Segmentation method ('adaptive', 'otsu', 'color', 'combined')
Returns:
mask: Binary wound mask
"""
if image is None:
return None
# Convert to grayscale if needed
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image.copy()
# Apply different segmentation methods
if method == 'adaptive':
mask = adaptive_threshold_segmentation(gray)
elif method == 'otsu':
mask = otsu_threshold_segmentation(gray)
elif method == 'color':
mask = color_based_segmentation(image)
elif method == 'combined':
mask = combined_segmentation(image, gray)
else:
mask = adaptive_threshold_segmentation(gray)
return mask
def adaptive_threshold_segmentation(gray):
"""Use adaptive thresholding for wound segmentation"""
# Apply Gaussian blur to reduce noise
blurred = cv2.GaussianBlur(gray, (15, 15), 0)
# Adaptive thresholding with larger block size
thresh = cv2.adaptiveThreshold(
blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 25, 5
)
# Morphological operations to clean up the mask
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
# Find contours and keep only the largest ones
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Create a new mask with only large contours
mask_clean = np.zeros_like(mask)
for contour in contours:
area = cv2.contourArea(contour)
if area > 1000: # Minimum area threshold
cv2.fillPoly(mask_clean, [contour], 255)
return mask_clean
def otsu_threshold_segmentation(gray):
"""Use Otsu's thresholding for wound segmentation"""
# Apply Gaussian blur
blurred = cv2.GaussianBlur(gray, (15, 15), 0)
# Otsu's thresholding
_, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# Morphological operations
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
# Find contours and keep only the largest ones
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Create a new mask with only large contours
mask_clean = np.zeros_like(mask)
for contour in contours:
area = cv2.contourArea(contour)
if area > 800: # Minimum area threshold
cv2.fillPoly(mask_clean, [contour], 255)
return mask_clean
def color_based_segmentation(image):
"""Use color-based segmentation for wound detection"""
# Convert to different color spaces
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
# Create masks for different color ranges (wound-like colors)
# Reddish/brownish wound colors in HSV - broader ranges
lower_red1 = np.array([0, 30, 30])
upper_red1 = np.array([15, 255, 255])
lower_red2 = np.array([160, 30, 30])
upper_red2 = np.array([180, 255, 255])
mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
red_mask = mask1 + mask2
# Yellowish wound colors - broader range
lower_yellow = np.array([15, 30, 30])
upper_yellow = np.array([35, 255, 255])
yellow_mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
# Brownish wound colors
lower_brown = np.array([10, 50, 20])
upper_brown = np.array([20, 255, 200])
brown_mask = cv2.inRange(hsv, lower_brown, upper_brown)
# Combine color masks
color_mask = red_mask + yellow_mask + brown_mask
# Clean up the mask with larger kernels
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_CLOSE, kernel)
color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_OPEN, kernel)
# Find contours and keep only the largest ones
contours, _ = cv2.findContours(color_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Create a new mask with only large contours
mask_clean = np.zeros_like(color_mask)
for contour in contours:
area = cv2.contourArea(contour)
if area > 600: # Minimum area threshold
cv2.fillPoly(mask_clean, [contour], 255)
return mask_clean
def combined_segmentation(image, gray):
"""Combine multiple segmentation methods for better results"""
# Get masks from different methods
adaptive_mask = adaptive_threshold_segmentation(gray)
otsu_mask = otsu_threshold_segmentation(gray)
color_mask = color_based_segmentation(image)
# Combine masks (union)
combined_mask = cv2.bitwise_or(adaptive_mask, otsu_mask)
combined_mask = cv2.bitwise_or(combined_mask, color_mask)
# Apply additional morphological operations to clean up
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel)
# Find contours and keep only the largest ones
contours, _ = cv2.findContours(combined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Create a new mask with only large contours
mask_clean = np.zeros_like(combined_mask)
for contour in contours:
area = cv2.contourArea(contour)
if area > 500: # Minimum area threshold
cv2.fillPoly(mask_clean, [contour], 255)
# If no large contours found, create a realistic wound mask
if np.sum(mask_clean) == 0:
mask_clean = create_realistic_wound_mask(combined_mask.shape, method='elliptical')
return mask_clean
def post_process_wound_mask(mask, min_area=100):
"""Post-process the wound mask to remove noise and small objects"""
if mask is None:
return None
# Convert to binary if needed
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
# Apply morphological operations to clean up
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
# Remove small objects using OpenCV
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
mask_clean = np.zeros_like(mask)
for contour in contours:
area = cv2.contourArea(contour)
if area >= min_area:
cv2.fillPoly(mask_clean, [contour], 255)
# Fill holes
mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
return mask_clean
def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='combined'):
"""Analyze wound severity with automatic mask generation"""
if image is None or depth_map is None:
return "❌ Please provide both image and depth map."
# Generate automatic wound mask
auto_mask = create_automatic_wound_mask(image, method=segmentation_method)
if auto_mask is None:
return "❌ Failed to generate automatic wound mask."
# Post-process the mask
processed_mask = post_process_wound_mask(auto_mask, min_area=500)
if processed_mask is None or np.sum(processed_mask > 0) == 0:
return "❌ No wound region detected. Try adjusting segmentation parameters or upload a manual mask."
# Analyze severity using the automatic mask
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm)
# --- Main Gradio Interface ---
with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo:
gr.HTML("<h1>Wound Analysis & Depth Estimation System</h1>")
gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities")
# Shared image state
shared_image = gr.State()
with gr.Tabs():
# Tab 1: Wound Classification
with gr.Tab("1. Wound Classification"):
gr.Markdown("### Step 1: Upload and classify your wound image")
gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.")
with gr.Row():
with gr.Column(scale=1):
wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350)
with gr.Column(scale=1):
wound_prediction_box = gr.HTML()
wound_reasoning_box = gr.HTML()
# Button to pass image to depth estimation
with gr.Row():
pass_to_depth_btn = gr.Button("πŸ“Š Pass Image to Depth Analysis", variant="secondary", size="lg")
pass_status = gr.HTML("")
wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input,
outputs=[wound_prediction_box, wound_reasoning_box])
# Store image when uploaded for classification
wound_image_input.change(
fn=lambda img: img,
inputs=[wound_image_input],
outputs=[shared_image]
)
# Tab 2: Depth Estimation
with gr.Tab("2. Depth Estimation & 3D Visualization"):
gr.Markdown("### Step 2: Generate depth maps and 3D visualizations")
gr.Markdown("This module creates depth maps and 3D point clouds from your images.")
with gr.Row():
depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
with gr.Row():
depth_submit = gr.Button(value="Compute Depth", variant="primary")
load_shared_btn = gr.Button("πŸ”„ Load Image from Classification", variant="secondary")
points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
label="Number of 3D points (upload image to update max)")
with gr.Row():
focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
label="Focal Length X (pixels)")
focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
label="Focal Length Y (pixels)")
with gr.Row():
gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
# 3D Visualization
gr.Markdown("### 3D Point Cloud Visualization")
gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
depth_3d_plot = gr.Plot(label="3D Point Cloud")
# Store depth map for severity analysis
depth_map_state = gr.State()
# Tab 3: Wound Severity Analysis
with gr.Tab("3. 🩹 Wound Severity Analysis"):
gr.Markdown("### Step 3: Analyze wound severity using depth maps")
gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.")
with gr.Row():
severity_input_image = gr.Image(label="Original Image", type='numpy')
severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy')
with gr.Row():
wound_mask_input = gr.Image(label="Wound Mask (Optional)", type='numpy')
severity_output = gr.HTML(label="Severity Analysis Report")
gr.Markdown("**Note:** You can either upload a manual mask or use automatic mask generation.")
with gr.Row():
auto_severity_button = gr.Button("πŸ€– Auto-Analyze Severity", variant="primary", size="lg")
manual_severity_button = gr.Button("πŸ” Manual Mask Analysis", variant="secondary", size="lg")
pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1,
label="Pixel Spacing (mm/pixel)")
gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.")
with gr.Row():
segmentation_method = gr.Dropdown(
choices=["combined", "adaptive", "otsu", "color"],
value="combined",
label="Segmentation Method",
info="Choose automatic segmentation method"
)
min_area_slider = gr.Slider(minimum=100, maximum=2000, value=500, step=100,
label="Minimum Area (pixels)",
info="Minimum wound area to detect")
with gr.Row():
# Load depth map from previous tab
load_depth_btn = gr.Button("πŸ”„ Load Depth Map from Tab 2", variant="secondary")
sample_mask_btn = gr.Button("🎯 Generate Sample Mask", variant="secondary")
realistic_mask_btn = gr.Button("πŸ₯ Generate Realistic Mask", variant="secondary")
preview_mask_btn = gr.Button("πŸ‘οΈ Preview Auto Mask", variant="secondary")
gr.Markdown("**Options:** Load depth map, generate sample mask, or preview automatic segmentation.")
# Generate sample mask function
def generate_sample_mask(image):
if image is None:
return None, "❌ Please load an image first."
sample_mask = create_sample_wound_mask(image.shape)
return sample_mask, "βœ… Sample circular wound mask generated!"
# Generate realistic mask function
def generate_realistic_mask(image):
if image is None:
return None, "❌ Please load an image first."
realistic_mask = create_realistic_wound_mask(image.shape, method='elliptical')
return realistic_mask, "βœ… Realistic elliptical wound mask generated!"
sample_mask_btn.click(
fn=generate_sample_mask,
inputs=[severity_input_image],
outputs=[wound_mask_input, gr.HTML()]
)
realistic_mask_btn.click(
fn=generate_realistic_mask,
inputs=[severity_input_image],
outputs=[wound_mask_input, gr.HTML()]
)
# Update slider when image is uploaded
depth_input_image.change(
fn=update_slider_on_image_upload,
inputs=[depth_input_image],
outputs=[points_slider]
)
# Modified depth submit function to store depth map
def on_depth_submit_with_state(image, num_points, focal_x, focal_y):
results = on_depth_submit(image, num_points, focal_x, focal_y)
# Extract depth map from results for severity analysis
depth_map = None
if image is not None:
depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
# Normalize depth for severity analysis
norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth_map = norm_depth.astype(np.uint8)
return results + [depth_map]
depth_submit.click(on_depth_submit_with_state,
inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y],
outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state])
# Load depth map to severity tab
def load_depth_to_severity(depth_map, original_image):
if depth_map is None:
return None, None, "❌ No depth map available. Please compute depth in Tab 2 first."
return depth_map, original_image, "βœ… Depth map loaded successfully!"
load_depth_btn.click(
fn=load_depth_to_severity,
inputs=[depth_map_state, depth_input_image],
outputs=[severity_depth_map, severity_input_image, gr.HTML()]
)
# Automatic severity analysis function
def run_auto_severity_analysis(image, depth_map, pixel_spacing, seg_method, min_area):
if depth_map is None:
return "❌ Please load depth map from Tab 2 first."
# Update post-processing with user-defined minimum area
def post_process_with_area(mask):
return post_process_wound_mask(mask, min_area=min_area)
# Generate automatic wound mask
auto_mask = create_automatic_wound_mask(image, method=seg_method)
if auto_mask is None:
return "❌ Failed to generate automatic wound mask."
# Post-process the mask
processed_mask = post_process_with_area(auto_mask)
if processed_mask is None or np.sum(processed_mask > 0) == 0:
return "❌ No wound region detected. Try adjusting segmentation parameters or use manual mask."
# Analyze severity using the automatic mask
return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing)
# Manual severity analysis function
def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing):
if depth_map is None:
return "❌ Please load depth map from Tab 2 first."
if wound_mask is None:
return "❌ Please upload a wound mask (binary image where white pixels represent the wound area)."
return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing)
# Preview automatic mask function
def preview_auto_mask(image, seg_method, min_area):
if image is None:
return None, "❌ Please load an image first."
# Generate automatic wound mask
auto_mask = create_automatic_wound_mask(image, method=seg_method)
if auto_mask is None:
return None, "❌ Failed to generate automatic wound mask."
# Post-process the mask
processed_mask = post_process_wound_mask(auto_mask, min_area=min_area)
if processed_mask is None or np.sum(processed_mask > 0) == 0:
return None, "❌ No wound region detected. Try adjusting parameters."
return processed_mask, f"βœ… Auto mask generated using {seg_method} method!"
# Connect event handlers
auto_severity_button.click(
fn=run_auto_severity_analysis,
inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider,
segmentation_method, min_area_slider],
outputs=[severity_output]
)
manual_severity_button.click(
fn=run_manual_severity_analysis,
inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider],
outputs=[severity_output]
)
preview_mask_btn.click(
fn=preview_auto_mask,
inputs=[severity_input_image, segmentation_method, min_area_slider],
outputs=[wound_mask_input, gr.HTML()]
)
# Load shared image from classification tab
def load_shared_image(shared_img):
if shared_img is None:
return gr.Image(), "❌ No image available from classification tab"
# Convert PIL image to numpy array for depth estimation
if hasattr(shared_img, 'convert'):
# It's a PIL image, convert to numpy
img_array = np.array(shared_img)
return img_array, "βœ… Image loaded from classification tab"
else:
# Already numpy array
return shared_img, "βœ… Image loaded from classification tab"
load_shared_btn.click(
fn=load_shared_image,
inputs=[shared_image],
outputs=[depth_input_image, gr.HTML()]
)
# Pass image to depth tab function
def pass_image_to_depth(img):
if img is None:
return "❌ No image uploaded in classification tab"
return "βœ… Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'"
pass_to_depth_btn.click(
fn=pass_image_to_depth,
inputs=[shared_image],
outputs=[pass_status]
)
if __name__ == '__main__':
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)