Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
ZO-1 Network Analysis Tool - Core Functions | |
Core analysis logic and classes for ZO-1 network quantification | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.patches import Rectangle, Circle | |
import cv2 | |
from PIL import Image | |
import pandas as pd | |
from io import StringIO, BytesIO | |
import base64 | |
import traceback | |
import os | |
from pathlib import Path | |
# Cellpose imports | |
from cellpose import models | |
from skimage.segmentation import find_boundaries | |
# PyTorch for GPU detection | |
import torch | |
# AI validation imports | |
from sklearn.cluster import KMeans | |
from sklearn.mixture import GaussianMixture | |
# Global variables for state management | |
global_state = { | |
'img_gray': None, | |
'masks': None, | |
'membrane_mask': None, | |
'quantifier': None, | |
'analysis_geometry': 'Circles (RIS - recommended)', | |
'image_basename': None | |
} | |
# Set matplotlib backend for non-interactive use | |
plt.switch_backend('Agg') | |
def _sanitize_hits_xy(hits_like): | |
"""Return a clean N×2 int32 array of hit coordinates from arbitrary nested inputs.""" | |
try: | |
arr = np.asarray(hits_like, dtype=object) | |
# Fast path: already a proper 2D numeric array | |
if isinstance(arr, np.ndarray) and arr.ndim == 2 and arr.shape[1] == 2 and arr.dtype != object: | |
return arr.astype(np.int32, copy=False) | |
# Build list of valid coordinate pairs | |
cleaned = [] | |
for item in arr: | |
try: | |
if isinstance(item, (list, tuple, np.ndarray)) and len(item) == 2: | |
y_val, x_val = item | |
if np.isscalar(y_val) and np.isscalar(x_val): | |
cleaned.append([int(y_val), int(x_val)]) | |
except Exception: | |
continue | |
return np.asarray(cleaned, dtype=np.int32).reshape(-1, 2) if cleaned else np.zeros((0, 2), dtype=np.int32) | |
except Exception: | |
return np.zeros((0, 2), dtype=np.int32) | |
def create_visualization(img_gray, masks, quantifier, analysis_geometry, show_contours=False, show_rectangles=True, show_cross_sections=True): | |
"""Create visualization with overlays""" | |
if img_gray is None or masks is None: | |
return None | |
fig, ax = plt.subplots(1, 1, figsize=(10, 8)) | |
# Main image with overlays - ensure same dimensions | |
if img_gray.shape != masks.shape: | |
img_gray_resized = cv2.resize(img_gray, (masks.shape[1], masks.shape[0]), interpolation=cv2.INTER_LINEAR) | |
else: | |
img_gray_resized = img_gray | |
ax.imshow(img_gray_resized, cmap='gray') | |
if analysis_geometry == "Circles (RIS - recommended)": | |
ax.set_title('ZO-1 Network with RIS Analysis Overlays', fontsize=14, fontweight='bold') | |
else: | |
ax.set_title('ZO-1 Network with TiJOR Analysis Overlays', fontsize=14, fontweight='bold') | |
ax.axis('off') | |
# Draw contours if requested | |
if show_contours and global_state['membrane_mask'] is not None: | |
validated_membrane_mask = global_state['membrane_mask'] | |
# Ensure same dimensions | |
if validated_membrane_mask.shape != img_gray_resized.shape: | |
validated_membrane_mask = cv2.resize(validated_membrane_mask, (img_gray_resized.shape[1], img_gray_resized.shape[0]), interpolation=cv2.INTER_NEAREST) | |
# Contours with 2-pixel thickness | |
# Use yellow contours for both RIS and TiJOR | |
ax.contour(validated_membrane_mask, [0.5], colors='yellow', linewidths=2, alpha=0.7) | |
# Draw analysis overlays based on geometry type | |
if analysis_geometry == "Circles (RIS - recommended)" and quantifier and hasattr(quantifier, 'results'): | |
# Draw concentric circles and scatter hits for RIS analysis | |
if show_rectangles and 'radii' in quantifier.results: | |
center_x = img_gray.shape[1] / 2 | |
center_y = img_gray.shape[0] / 2 | |
radii = quantifier.results['radii'] | |
colors = plt.cm.Blues(np.linspace(0.3, 1, len(radii))) | |
for i, r in enumerate(radii): | |
circle = plt.Circle((center_x, center_y), r, | |
linewidth=2, edgecolor=colors[i], | |
facecolor='none', linestyle='--', alpha=0.7) | |
ax.add_patch(circle) | |
# Plot crossing points (hits) if requested | |
if show_cross_sections and 'hits_xy' in quantifier.results: | |
hits = _sanitize_hits_xy(quantifier.results['hits_xy']) | |
if hits.size > 0 and hits.ndim == 2 and hits.shape[1] == 2: | |
ax.scatter(hits[:, 1], hits[:, 0], # Note: y, x order for matplotlib | |
c='red', s=30, alpha=1.0, edgecolors='darkred', linewidth=2, | |
label=f'Crossings ({hits.shape[0]})') | |
ax.legend(loc='upper right', fontsize=10) | |
elif quantifier and hasattr(quantifier, 'results') and 'rectangle_sizes' in quantifier.results: | |
# Draw rectangles and cross-sections for TiJOR analysis | |
center_x = img_gray.shape[1] / 2 | |
center_y = img_gray.shape[0] / 2 | |
colors = plt.cm.Reds(np.linspace(0.3, 1, len(quantifier.results['rectangle_sizes']))) | |
for i, size in enumerate(quantifier.results['rectangle_sizes']): | |
# Force scalar float to avoid inhomogeneous shape issues | |
try: | |
size_val = float(np.asarray(size).reshape(-1)[0]) | |
except Exception: | |
continue | |
half_side = size_val / 2.0 | |
if show_rectangles: | |
rect = Rectangle( | |
(float(center_x - half_side), float(center_y - half_side)), | |
float(size_val), float(size_val), | |
linewidth=2, | |
edgecolor=colors[i], | |
facecolor='none', | |
linestyle='--', | |
alpha=0.7 | |
) | |
ax.add_patch(rect) | |
# Plot TiJOR cross-section points | |
if show_cross_sections and 'hits_xy' in quantifier.results: | |
hits = _sanitize_hits_xy(quantifier.results['hits_xy']) | |
if hits.size > 0 and hits.ndim == 2 and hits.shape[1] == 2: | |
ax.scatter(hits[:, 1], hits[:, 0], # Note: y, x order for matplotlib | |
c='red', s=25, alpha=1.0, edgecolors='darkred', linewidth=1.5, | |
label=f'Cross-sections ({hits.shape[0]})', zorder=10) | |
ax.legend(loc='upper right', fontsize=10) | |
# Convert plot to numpy array for Gradio | |
buf = BytesIO() | |
fig.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
buf.seek(0) | |
# Convert to PIL Image then to numpy array | |
pil_image = Image.open(buf) | |
numpy_array = np.array(pil_image) | |
plt.close(fig) | |
buf.close() | |
return numpy_array | |
def create_visualization_with_masks(img_gray, masks): | |
"""Create visualization showing the segmentation masks""" | |
if img_gray is None or masks is None: | |
return None | |
fig, ax = plt.subplots(1, 1, figsize=(10, 8)) | |
# Show original image | |
ax.imshow(img_gray, cmap='gray') | |
# Overlay masks with different colors | |
if masks.max() > 0: | |
# Create colored mask overlay | |
colored_masks = np.zeros((*masks.shape, 3), dtype=np.uint8) | |
for i in range(1, int(masks.max()) + 1): | |
mask = (masks == i) | |
color = np.random.randint(0, 255, 3) | |
colored_masks[mask] = color | |
# Overlay with transparency | |
ax.imshow(colored_masks, alpha=0.3) | |
ax.set_title('ZO-1 Segmentation Results', fontsize=14, fontweight='bold') | |
ax.axis('off') | |
# Convert to numpy array for Gradio | |
buf = BytesIO() | |
fig.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
buf.seek(0) | |
# Convert to PIL Image then to numpy array | |
pil_image = Image.open(buf) | |
numpy_array = np.array(pil_image) | |
plt.close(fig) | |
buf.close() | |
return numpy_array | |
def validate_contours_with_ai(contours, image, method="K-means clustering", dilation_pixels=4): | |
"""Validate Cellpose contours using AI-powered methods""" | |
try: | |
print(f"[validate_contours_with_ai] method={method}, image.shape={image.shape}, image.dtype={image.dtype}") | |
except Exception: | |
pass | |
if method == "K-means clustering": | |
pixels = image.reshape(-1, 1).astype(np.float32) | |
kmeans = KMeans(n_clusters=2, n_init=10, max_iter=300, random_state=42) | |
labels = kmeans.fit_predict(pixels) | |
cluster_centers = kmeans.cluster_centers_.flatten() | |
foreground_cluster = np.argmax(cluster_centers) | |
ai_mask = (labels == foreground_cluster).reshape(image.shape).astype(np.uint8) * 255 | |
elif method == "Gaussian Mixture Model (GMM)": | |
pixels = image.reshape(-1, 1).astype(np.float32) | |
gmm = GaussianMixture(n_components=2, n_init=10, max_iter=300, random_state=42) | |
labels = gmm.fit_predict(pixels) | |
cluster_centers = gmm.means_.flatten() | |
foreground_cluster = np.argmax(cluster_centers) | |
ai_mask = (labels == foreground_cluster).reshape(image.shape).astype(np.uint8) * 255 | |
else: # Otsu | |
otsu_threshold, _ = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
ai_mask = (image > otsu_threshold).astype(np.uint8) * 255 | |
# Dilate mask for tolerance | |
kernel = np.ones((dilation_pixels, dilation_pixels), np.uint8) | |
dilated_mask = cv2.dilate(ai_mask, kernel, iterations=1) | |
# Combine with contours | |
validated_mask = np.logical_and(contours > 0, dilated_mask > 0).astype(np.uint8) | |
return validated_mask | |
def run_segmentation_only(img_gray, diam, scale, enable_contour_validation, validation_method): | |
"""Run only the AI-powered segmentation step""" | |
if img_gray is None: | |
return None, None, "No image provided" | |
try: | |
print(f"[run_segmentation_only] img_gray.shape={img_gray.shape}, dtype={img_gray.dtype}, diam={diam}, scale={scale}, enable_validation={enable_contour_validation}, method={validation_method}") | |
# Initialize Cellpose model | |
# Use GPU if available, fallback to CPU | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
cp_model = models.CellposeModel(gpu=(device=='cuda'), model_type='cyto') | |
print(f"🔧 Using device: {device}") | |
if device == 'cuda': | |
print(f"🚀 GPU: {torch.cuda.get_device_name(0)}") | |
else: | |
print("⚠️ Running on CPU - this will be slower") | |
# Downsample if requested | |
h, w = img_gray.shape | |
if scale < 1.0: | |
small = cv2.resize(img_gray, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA) | |
diam_small = max(1, int(diam*scale)) | |
else: | |
small = img_gray | |
diam_small = diam | |
# Run segmentation | |
masks_small, flows, styles = cp_model.eval( | |
small, | |
diameter=diam_small, | |
channels=[0, 0], | |
flow_threshold=0.4, | |
batch_size=4, | |
resample=True, | |
augment=True | |
) | |
# Upsample masks if needed | |
if scale < 1.0: | |
masks = cv2.resize(masks_small.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST) | |
else: | |
masks = masks_small | |
# Create membrane mask for network analysis | |
contours = find_boundaries(masks, mode='inner') | |
# Apply AI-powered validation | |
if enable_contour_validation: | |
membrane_mask = validate_contours_with_ai(contours, img_gray, validation_method, 4) | |
else: | |
membrane_mask = contours.astype(np.uint8) | |
# Update global state | |
global_state['masks'] = masks | |
global_state['membrane_mask'] = membrane_mask | |
guidance = ( | |
f"Segmentation complete! Found {int(masks.max())} cells.\n" | |
"If not happy with segmentation, adjust the Cell Diameter estimate and rerun.\n" | |
"If satisfied, proceed to the Analysis tab." | |
) | |
return masks, membrane_mask, guidance | |
except Exception as e: | |
return None, None, f"Segmentation failed: {str(e)}" | |
class ZO1TiJORQuantifier: | |
"""TiJOR quantifier for rectangular analysis""" | |
def __init__(self, initial_size=10, max_size=100, num_steps=10, min_distance=5): | |
self.initial_size = initial_size | |
self.max_size = max_size | |
self.num_steps = num_steps | |
self.min_distance = min_distance | |
self.results = {} | |
def analyze(self, membrane_mask): | |
"""Analyze membrane network using expanding rectangles""" | |
if membrane_mask is None: | |
return False | |
# Generate rectangle sizes as percentage of image size | |
img_size = min(membrane_mask.shape) # Use smaller dimension | |
min_size = img_size * self.initial_size / 100 | |
max_size = img_size * self.max_size / 100 | |
sizes = np.linspace(min_size, max_size, self.num_steps).astype(np.float32) | |
self.results['rectangle_sizes'] = sizes | |
try: | |
print(f"[TiJOR.analyze] img_size={img_size:.2f}, sizes.shape={sizes.shape}, sizes[:3]={sizes[:3] if sizes.size>=3 else sizes}") | |
except Exception: | |
pass | |
# Calculate cross-sections for each size | |
tijor_values = [] | |
cross_section_counts = [] | |
filtered_cross_section_counts = [] | |
all_hits_xy = [] # Store all intersection points | |
center_y, center_x = np.array(membrane_mask.shape) / 2 | |
for size in sizes: | |
half_side = size / 2 | |
# Define rectangle boundaries | |
y1 = max(0, int(center_y - half_side)) | |
y2 = min(membrane_mask.shape[0], int(center_y + half_side)) | |
x1 = max(0, int(center_x - half_side)) | |
x2 = min(membrane_mask.shape[1], int(center_x + half_side)) | |
# Create rectangle boundary mask (2-pixel thickness like RIS) without reduce() | |
y, x = np.ogrid[:membrane_mask.shape[0], :membrane_mask.shape[1]] | |
tolerance = 2 | |
# Outer rectangle (expanded by tolerance) | |
within_outer = ( | |
(x >= (x1 - tolerance)) & (x <= (x2 + tolerance)) & | |
(y >= (y1 - tolerance)) & (y <= (y2 + tolerance)) | |
) | |
# Inner rectangle (shrunk by tolerance) | |
within_inner = ( | |
(x > (x1 + tolerance)) & (x < (x2 - tolerance)) & | |
(y > (y1 + tolerance)) & (y < (y2 - tolerance)) | |
) | |
# Boundary is the ring between outer and inner | |
boundary_mask = within_outer & (~within_inner) | |
# Find intersections: membrane pixels near rectangle boundary (boolean-safe) | |
intersection_region = (membrane_mask > 0) & boundary_mask | |
# Apply minimum separation filtering (like RIS) | |
total_cross_sections = 0 | |
if np.any(intersection_region): | |
y_coords, x_coords = np.where(intersection_region) | |
# Filter points with minimum separation | |
filtered_points = [] | |
for y_coord, x_coord in zip(y_coords, x_coords): | |
# Check if this point is far enough from existing points | |
is_unique = True | |
for existing_y, existing_x in filtered_points: | |
distance = np.sqrt((y_coord - existing_y)**2 + (x_coord - existing_x)**2) | |
if distance < self.min_distance: | |
is_unique = False | |
break | |
if is_unique: | |
yi = int(y_coord) | |
xi = int(x_coord) | |
filtered_points.append([yi, xi]) | |
all_hits_xy.append([yi, xi]) | |
total_cross_sections = len(filtered_points) | |
cross_section_counts.append(total_cross_sections) | |
filtered_count = total_cross_sections | |
filtered_cross_section_counts.append(filtered_count) | |
# Calculate TiJOR (crossings per pixel length; rectangle perimeter) | |
perimeter = 4.0 * float(size) | |
tijor = (filtered_count / perimeter) if perimeter > 0 else 0.0 | |
tijor_values.append(float(tijor)) | |
self.results['tijor_values'] = np.asarray(tijor_values, dtype=np.float32) | |
self.results['cross_section_counts'] = np.asarray(cross_section_counts, dtype=np.int32) | |
self.results['filtered_cross_section_counts'] = np.asarray(filtered_cross_section_counts, dtype=np.int32) | |
try: | |
print(f"[TiJOR.analyze] crossings={self.results['cross_section_counts']}, tijor_values={self.results['tijor_values']}") | |
except Exception: | |
pass | |
# Safely create hits_xy array | |
if len(all_hits_xy) > 0: | |
try: | |
self.results['hits_xy'] = np.asarray(all_hits_xy, dtype=np.int32).reshape(-1, 2) | |
except Exception: | |
valid_pairs = [] | |
for pair in all_hits_xy: | |
if isinstance(pair, (list, tuple, np.ndarray)) and len(pair) == 2: | |
y_val, x_val = pair | |
if np.isscalar(y_val) and np.isscalar(x_val): | |
valid_pairs.append([int(y_val), int(x_val)]) | |
self.results['hits_xy'] = np.asarray(valid_pairs, dtype=np.int32).reshape(-1, 2) if len(valid_pairs) > 0 else np.zeros((0, 2), dtype=np.int32) | |
else: | |
self.results['hits_xy'] = np.zeros((0, 2), dtype=np.int32) | |
return True | |
def _filter_by_distance(self, coords, min_distance): | |
"""Filter coordinates by minimum distance""" | |
if len(coords) <= 1: | |
return coords | |
filtered = [coords[0]] | |
for coord in coords[1:]: | |
distances = [np.linalg.norm(coord - f) for f in filtered] | |
if min(distances) >= min_distance: | |
filtered.append(coord) | |
return np.array(filtered) | |
def get_summary_stats(self): | |
"""Get summary statistics""" | |
if not self.results: | |
return {} | |
tijor_values = self.results['tijor_values'] | |
filtered_counts = self.results['filtered_cross_section_counts'] | |
return { | |
'mean_tijor': np.mean(tijor_values), | |
'std_tijor': np.std(tijor_values), | |
'max_tijor': np.max(tijor_values), | |
'min_tijor': np.min(tijor_values), | |
'total_cross_sections': np.sum(filtered_counts), | |
'mean_cross_sections_per_rectangle': np.mean(filtered_counts) | |
} | |
class ZO1RISQuantifier: | |
"""Circular (Sholl-style) crossings-per-circumference for ZO-1 networks""" | |
def __init__(self, packing_factor=1.5, min_radius_percent=10, max_radius_percent=80, num_circles=15, min_separation=5): | |
self.kappa = float(packing_factor) | |
self.min_radius_percent = float(min_radius_percent) | |
self.max_radius_percent = float(max_radius_percent) | |
self.num_circles = int(num_circles) | |
self.min_separation = int(min_separation) | |
self.results = {} | |
def analyze(self, membrane_mask, d_eff_pixels, scale_factor=1.0): | |
"""Analyze membrane network using concentric circles with proper intersection detection""" | |
if membrane_mask is None: | |
return False | |
# Calculate reference density | |
d_ref = self.kappa / float(d_eff_pixels) | |
# Generate radii for concentric circles with percentage-based spacing | |
img_size = min(membrane_mask.shape) / 2 # Half of smaller dimension | |
min_radius = img_size * self.min_radius_percent / 100 | |
max_radius = img_size * self.max_radius_percent / 100 | |
radii = np.linspace(min_radius, max_radius, self.num_circles) | |
self.results['radii'] = radii | |
# Calculate crossings for each radius using proper intersection detection | |
crossings = [] | |
hits_xy = [] | |
center_y, center_x = np.array(membrane_mask.shape) / 2 | |
for radius in radii: | |
# Create circular mask | |
y, x = np.ogrid[:membrane_mask.shape[0], :membrane_mask.shape[1]] | |
circle_mask = (x - center_x)**2 + (y - center_y)**2 <= radius**2 | |
# Find membrane pixels near the circle (within 2 pixels tolerance) | |
circle_boundary = np.logical_and( | |
(x - center_x)**2 + (y - center_y)**2 <= (radius + 2)**2, | |
(x - center_x)**2 + (y - center_y)**2 >= (radius - 2)**2 | |
) | |
# Count intersections: membrane pixels near circle boundary (boolean-safe) | |
intersection_region = (membrane_mask > 0) & circle_boundary | |
# Apply minimum separation filtering to avoid counting clustered pixels | |
if np.sum(intersection_region > 0) > 0: | |
y_coords, x_coords = np.where(intersection_region > 0) | |
# Filter points with minimum separation | |
filtered_points = [] | |
for y_coord, x_coord in zip(y_coords, x_coords): | |
# Check if this point is far enough from existing points | |
is_unique = True | |
for existing_y, existing_x in filtered_points: | |
distance = np.sqrt((y_coord - existing_y)**2 + (x_coord - existing_x)**2) | |
if distance < self.min_separation: | |
is_unique = False | |
break | |
if is_unique: | |
filtered_points.append([y_coord, x_coord]) | |
crossing_count = len(filtered_points) | |
crossings.append(crossing_count) | |
# Store filtered hit coordinates for visualization | |
hits_xy.extend(filtered_points) | |
else: | |
crossings.append(0) | |
self.results['crossings'] = np.array(crossings) | |
# Safely create hits_xy array | |
if len(hits_xy) > 0: | |
try: | |
self.results['hits_xy'] = np.asarray(hits_xy, dtype=np.int32).reshape(-1, 2) | |
except Exception: | |
# Fallback: filter only valid pairs | |
valid_pairs = [(int(y), int(x)) for y, x in hits_xy if isinstance(y, (int, np.integer)) and isinstance(x, (int, np.integer))] | |
self.results['hits_xy'] = np.asarray(valid_pairs, dtype=np.int32).reshape(-1, 2) if valid_pairs else np.zeros((0, 2), dtype=np.int32) | |
else: | |
self.results['hits_xy'] = np.zeros((0, 2), dtype=np.int32) | |
# Calculate RIS metrics as crossings per pixel length, averaged over all circles | |
radii = self.results['radii'] | |
circle_lengths = 2.0 * np.pi * np.maximum(radii, 1e-6) | |
crossings_array = self.results['crossings'].astype(np.float32) | |
densities = crossings_array / circle_lengths # crossings per pixel length | |
self.results['crossings_per_length'] = densities | |
if len(densities) > 0: | |
d_mean = float(np.mean(densities)) | |
d_peak = float(np.max(densities)) | |
ris = (d_mean / d_ref) if d_ref > 0 else 0.0 | |
ris_peak = (d_peak / d_ref) if d_ref > 0 else 0.0 | |
self.results['RIS'] = float(ris) | |
self.results['RIS_peak'] = float(ris_peak) | |
self.results['d_mean'] = d_mean | |
self.results['d_peak'] = d_peak | |
self.results['d_ref'] = float(d_ref) | |
return True | |
def get_summary_stats(self): | |
"""Get summary statistics for RIS analysis""" | |
if not self.results: | |
return {} | |
crossings = self.results.get('crossings', np.array([])) | |
densities = self.results.get('crossings_per_length', np.array([])) | |
return { | |
'RIS': self.results.get('RIS', np.nan), | |
'RIS_peak': self.results.get('RIS_peak', np.nan), | |
'd_mean': self.results.get('d_mean', np.nan), # mean crossings per pixel length | |
'd_peak': self.results.get('d_peak', np.nan), # peak crossings per pixel length | |
'd_ref': self.results.get('d_ref', np.nan), | |
'total_crossings': float(np.sum(crossings)) if crossings.size else 0.0, | |
'mean_crossings_per_px': float(np.mean(densities)) if densities.size else np.nan, | |
'packing_factor': self.kappa | |
} | |
def run_analysis(analysis_geometry, initial_size, max_size, num_steps, min_distance, packing_factor, min_radius_percent, max_radius_percent, num_circles, min_separation, show_contours=True, show_rectangles=True, show_cross_sections=True): | |
"""Run the selected analysis method""" | |
if global_state['membrane_mask'] is None: | |
return "No membrane mask available. Please run segmentation first.", None, None | |
try: | |
# Persist current analysis mode for downstream export/reporting | |
global_state['analysis_geometry'] = analysis_geometry | |
print(f"[run_analysis] mode={analysis_geometry}, init_size={initial_size}, max_size={max_size}, steps={num_steps}, min_dist={min_distance}, kappa={packing_factor}, minR%={min_radius_percent}, maxR%={max_radius_percent}, circles={num_circles}, min_sep={min_separation}") | |
if analysis_geometry == "Circles (RIS - recommended)": | |
# RIS analysis with user-configurable parameters | |
quantifier = ZO1RISQuantifier( | |
packing_factor=packing_factor, | |
min_radius_percent=min_radius_percent, | |
max_radius_percent=max_radius_percent, | |
num_circles=num_circles, | |
min_separation=min_separation | |
) | |
success = quantifier.analyze(global_state['membrane_mask'], 20, scale_factor=1.0) | |
if success: | |
global_state['quantifier'] = quantifier | |
summary = quantifier.get_summary_stats() | |
print(f"[run_analysis][RIS] summary keys={list(summary.keys())}") | |
# Create results display | |
ris = summary.get('RIS', np.nan) | |
ris_peak = summary.get('RIS_peak', np.nan) | |
mean_crossings = summary.get('d_mean', np.nan) | |
total_crossings = summary.get('total_crossings', np.nan) | |
d_ref = summary.get('d_ref', np.nan) | |
results_text = f""" | |
🔵 **RIS Analysis Results** | |
RIS Score: {ris:.4f} | |
RIS Peak: {ris_peak:.4f} | |
Mean Crossings/px length: {mean_crossings:.4f} | |
Total Crossings: {total_crossings:.0f} | |
Reference Density (per px length): {d_ref:.4f} | |
Parameters: κ={packing_factor:.1f}, Min Radius={min_radius_percent}%, Max Radius={max_radius_percent}%, Circles={num_circles}, Min Sep={min_separation}px | |
""" | |
# Create visualization | |
viz = create_visualization( | |
global_state['img_gray'], | |
global_state['masks'], | |
quantifier, | |
analysis_geometry, | |
show_contours=show_contours, | |
show_rectangles=show_rectangles, | |
show_cross_sections=show_cross_sections | |
) | |
return results_text, viz, summary | |
else: | |
return "RIS analysis failed.", None, None | |
else: | |
# TiJOR analysis | |
quantifier = ZO1TiJORQuantifier(initial_size, max_size, num_steps, min_distance) | |
success = quantifier.analyze(global_state['membrane_mask']) | |
if success: | |
global_state['quantifier'] = quantifier | |
# Defensive normalization of results to avoid shape errors downstream | |
try: | |
if 'hits_xy' in quantifier.results: | |
quantifier.results['hits_xy'] = _sanitize_hits_xy(quantifier.results['hits_xy']) | |
if 'rectangle_sizes' in quantifier.results: | |
sizes_arr = np.asarray(quantifier.results['rectangle_sizes']).reshape(-1) | |
quantifier.results['rectangle_sizes'] = sizes_arr.astype(np.float32, copy=False) | |
except Exception: | |
quantifier.results['hits_xy'] = np.zeros((0, 2), dtype=np.int32) | |
summary = quantifier.get_summary_stats() | |
print(f"[run_analysis][TiJOR] sizes.shape={quantifier.results.get('rectangle_sizes', np.array([])).shape}, hits.shape={quantifier.results.get('hits_xy', np.zeros((0,2))).shape}, summary keys={list(summary.keys())}") | |
# Create results display | |
mean_tijor = float(summary.get('mean_tijor', np.nan)) | |
total_cross_sections = float(summary.get('total_cross_sections', np.nan)) | |
cells_detected = int(global_state['masks'].max()) if global_state['masks'] is not None else 0 | |
results_text = f""" | |
📊 **TiJOR Analysis Results** | |
Mean TiJOR: {mean_tijor:.4f} | |
Total Cross-sections: {total_cross_sections:.0f} | |
Cells Detected: {cells_detected} | |
""" | |
# Create visualization | |
viz = create_visualization( | |
global_state['img_gray'], | |
global_state['masks'], | |
quantifier, | |
analysis_geometry, | |
show_contours=show_contours, | |
show_rectangles=show_rectangles, | |
show_cross_sections=show_cross_sections | |
) | |
return results_text, viz, summary | |
else: | |
return "TiJOR analysis failed.", None, None | |
except Exception as e: | |
tb = traceback.format_exc() | |
print(f"[run_analysis][EXCEPTION] {e}\n{tb}") | |
return f"Analysis failed: {str(e)}\n\nTraceback:\n{tb}", None, None | |
def process_image(image, cell_diameter, scale_factor, enable_validation, validation_method): | |
"""Process uploaded image and run segmentation""" | |
if image is None: | |
return "No image uploaded", None, None | |
try: | |
# Accept either numpy array or filepath; store basename for exports | |
if isinstance(image, str): | |
try: | |
global_state['image_basename'] = Path(image).stem | |
except Exception: | |
global_state['image_basename'] = 'results' | |
# Read with unchanged flag to preserve bit depth | |
img = cv2.imread(image, cv2.IMREAD_UNCHANGED) | |
if img is None: | |
return "Failed to read image from path", None, None | |
# Convert to grayscale | |
if len(img.shape) == 3: | |
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
else: | |
img_gray = img | |
else: | |
# Numpy array path | |
global_state['image_basename'] = global_state.get('image_basename') or 'results' | |
# Handle different image formats and data types (especially TIFF) | |
if len(image.shape) == 3: | |
# Convert to RGB first, then grayscale | |
if image.shape[2] == 4: # RGBA | |
image = image[:, :, :3] # Remove alpha channel | |
img_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
else: | |
img_gray = image | |
# Ensure image is 8-bit (TIFF files might be 16-bit or float) | |
if img_gray.dtype != np.uint8: | |
imin = float(img_gray.min()) | |
imax = float(img_gray.max()) | |
if imax > imin: | |
# Normalize to 0-255 range | |
img_gray = ((img_gray - imin) / (imax - imin) * 255).astype(np.uint8) | |
else: | |
# Flat image; return zeros to avoid NaNs | |
img_gray = np.zeros_like(img_gray, dtype=np.uint8) | |
global_state['img_gray'] = img_gray | |
# Run segmentation | |
masks, membrane_mask, message = run_segmentation_only( | |
img_gray, cell_diameter, scale_factor, enable_validation, validation_method | |
) | |
if masks is not None: | |
# Create initial visualization with masks | |
viz = create_visualization_with_masks(img_gray, masks) | |
return message, viz, masks | |
else: | |
return message, None, None | |
except Exception as e: | |
return f"Image processing failed: {str(e)}", None, None | |
def export_results(format_type): | |
"""Export analysis results; returns (text_content, file_path).""" | |
if global_state['quantifier'] is None: | |
return "No analysis results to export", None | |
try: | |
# Determine filename base and suffix | |
base = global_state.get('image_basename') or 'results' | |
suffix = '_RIS' if global_state.get('analysis_geometry') == "Circles (RIS - recommended)" else '_TiJOR' | |
if format_type == "CSV": | |
if global_state['analysis_geometry'] == "Circles (RIS - recommended)": | |
# Export RIS results | |
results_data = [] | |
if hasattr(global_state['quantifier'], 'results') and 'crossings' in global_state['quantifier'].results: | |
for i, crossing in enumerate(global_state['quantifier'].results['crossings']): | |
results_data.append({ | |
'Circle': i+1, | |
'Crossings': int(crossing) | |
}) | |
df = pd.DataFrame(results_data) | |
csv_data = df.to_csv(index=False) | |
file_path = f"{base}{suffix}.csv" | |
with open(file_path, 'w', newline='') as f: | |
f.write(csv_data) | |
return csv_data, file_path | |
else: | |
# Export TiJOR results | |
results_data = [] | |
if hasattr(global_state['quantifier'], 'results') and 'tijor_values' in global_state['quantifier'].results: | |
for i, (size, tijor) in enumerate(zip( | |
global_state['quantifier'].results['rectangle_sizes'], | |
global_state['quantifier'].results['tijor_values'] | |
)): | |
results_data.append({ | |
'Step': i+1, | |
'Size (px)': f'{size:.1f}', | |
'TiJOR': f'{tijor:.4f}' | |
}) | |
df = pd.DataFrame(results_data) | |
csv_data = df.to_csv(index=False) | |
file_path = f"{base}{suffix}.csv" | |
with open(file_path, 'w', newline='') as f: | |
f.write(csv_data) | |
return csv_data, file_path | |
else: | |
# Export as text report | |
summary = global_state['quantifier'].get_summary_stats() | |
if global_state['analysis_geometry'] == "Circles (RIS - recommended)": | |
ris = summary.get('RIS', np.nan) | |
ris_peak = summary.get('RIS_peak', np.nan) | |
mean_crossings = summary.get('d_mean', np.nan) | |
total_crossings = summary.get('total_crossings', np.nan) | |
d_ref = summary.get('d_ref', np.nan) | |
report = f"""ZO-1 RIS Network Analysis Report | |
{'='*50} | |
RIS Score: {ris:.4f} | |
RIS Peak: {ris_peak:.4f} | |
Mean Crossings/px length: {mean_crossings:.4f} | |
Total Crossings: {total_crossings:.0f} | |
Reference Density (per px length): {d_ref:.4f} | |
""" | |
else: | |
mean_tijor = summary.get('mean_tijor', np.nan) | |
total_cross_sections = summary.get('total_cross_sections', np.nan) | |
cells_detected = int(global_state['masks'].max()) if global_state['masks'] is not None else 0 | |
report = f"""ZO-1 TiJOR Network Analysis Report | |
{'='*50} | |
Mean TiJOR: {mean_tijor:.4f} | |
Total Cross-sections: {total_cross_sections:.0f} | |
Cells Detected: {cells_detected} | |
""" | |
file_path = f"{base}{suffix}.txt" | |
with open(file_path, 'w', encoding='utf-8') as f: | |
f.write(report) | |
return report, file_path | |
except Exception as e: | |
return f"Export failed: {str(e)}", None | |