ZO1_Network / zo1_core.py
Pbagnaninchi's picture
main app files and file needed
f3c2f08 verified
#!/usr/bin/env python3
"""
ZO-1 Network Analysis Tool - Core Functions
Core analysis logic and classes for ZO-1 network quantification
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Circle
import cv2
from PIL import Image
import pandas as pd
from io import StringIO, BytesIO
import base64
import traceback
import os
from pathlib import Path
# Cellpose imports
from cellpose import models
from skimage.segmentation import find_boundaries
# PyTorch for GPU detection
import torch
# AI validation imports
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
# Global variables for state management
global_state = {
'img_gray': None,
'masks': None,
'membrane_mask': None,
'quantifier': None,
'analysis_geometry': 'Circles (RIS - recommended)',
'image_basename': None
}
# Set matplotlib backend for non-interactive use
plt.switch_backend('Agg')
def _sanitize_hits_xy(hits_like):
"""Return a clean N×2 int32 array of hit coordinates from arbitrary nested inputs."""
try:
arr = np.asarray(hits_like, dtype=object)
# Fast path: already a proper 2D numeric array
if isinstance(arr, np.ndarray) and arr.ndim == 2 and arr.shape[1] == 2 and arr.dtype != object:
return arr.astype(np.int32, copy=False)
# Build list of valid coordinate pairs
cleaned = []
for item in arr:
try:
if isinstance(item, (list, tuple, np.ndarray)) and len(item) == 2:
y_val, x_val = item
if np.isscalar(y_val) and np.isscalar(x_val):
cleaned.append([int(y_val), int(x_val)])
except Exception:
continue
return np.asarray(cleaned, dtype=np.int32).reshape(-1, 2) if cleaned else np.zeros((0, 2), dtype=np.int32)
except Exception:
return np.zeros((0, 2), dtype=np.int32)
def create_visualization(img_gray, masks, quantifier, analysis_geometry, show_contours=False, show_rectangles=True, show_cross_sections=True):
"""Create visualization with overlays"""
if img_gray is None or masks is None:
return None
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
# Main image with overlays - ensure same dimensions
if img_gray.shape != masks.shape:
img_gray_resized = cv2.resize(img_gray, (masks.shape[1], masks.shape[0]), interpolation=cv2.INTER_LINEAR)
else:
img_gray_resized = img_gray
ax.imshow(img_gray_resized, cmap='gray')
if analysis_geometry == "Circles (RIS - recommended)":
ax.set_title('ZO-1 Network with RIS Analysis Overlays', fontsize=14, fontweight='bold')
else:
ax.set_title('ZO-1 Network with TiJOR Analysis Overlays', fontsize=14, fontweight='bold')
ax.axis('off')
# Draw contours if requested
if show_contours and global_state['membrane_mask'] is not None:
validated_membrane_mask = global_state['membrane_mask']
# Ensure same dimensions
if validated_membrane_mask.shape != img_gray_resized.shape:
validated_membrane_mask = cv2.resize(validated_membrane_mask, (img_gray_resized.shape[1], img_gray_resized.shape[0]), interpolation=cv2.INTER_NEAREST)
# Contours with 2-pixel thickness
# Use yellow contours for both RIS and TiJOR
ax.contour(validated_membrane_mask, [0.5], colors='yellow', linewidths=2, alpha=0.7)
# Draw analysis overlays based on geometry type
if analysis_geometry == "Circles (RIS - recommended)" and quantifier and hasattr(quantifier, 'results'):
# Draw concentric circles and scatter hits for RIS analysis
if show_rectangles and 'radii' in quantifier.results:
center_x = img_gray.shape[1] / 2
center_y = img_gray.shape[0] / 2
radii = quantifier.results['radii']
colors = plt.cm.Blues(np.linspace(0.3, 1, len(radii)))
for i, r in enumerate(radii):
circle = plt.Circle((center_x, center_y), r,
linewidth=2, edgecolor=colors[i],
facecolor='none', linestyle='--', alpha=0.7)
ax.add_patch(circle)
# Plot crossing points (hits) if requested
if show_cross_sections and 'hits_xy' in quantifier.results:
hits = _sanitize_hits_xy(quantifier.results['hits_xy'])
if hits.size > 0 and hits.ndim == 2 and hits.shape[1] == 2:
ax.scatter(hits[:, 1], hits[:, 0], # Note: y, x order for matplotlib
c='red', s=30, alpha=1.0, edgecolors='darkred', linewidth=2,
label=f'Crossings ({hits.shape[0]})')
ax.legend(loc='upper right', fontsize=10)
elif quantifier and hasattr(quantifier, 'results') and 'rectangle_sizes' in quantifier.results:
# Draw rectangles and cross-sections for TiJOR analysis
center_x = img_gray.shape[1] / 2
center_y = img_gray.shape[0] / 2
colors = plt.cm.Reds(np.linspace(0.3, 1, len(quantifier.results['rectangle_sizes'])))
for i, size in enumerate(quantifier.results['rectangle_sizes']):
# Force scalar float to avoid inhomogeneous shape issues
try:
size_val = float(np.asarray(size).reshape(-1)[0])
except Exception:
continue
half_side = size_val / 2.0
if show_rectangles:
rect = Rectangle(
(float(center_x - half_side), float(center_y - half_side)),
float(size_val), float(size_val),
linewidth=2,
edgecolor=colors[i],
facecolor='none',
linestyle='--',
alpha=0.7
)
ax.add_patch(rect)
# Plot TiJOR cross-section points
if show_cross_sections and 'hits_xy' in quantifier.results:
hits = _sanitize_hits_xy(quantifier.results['hits_xy'])
if hits.size > 0 and hits.ndim == 2 and hits.shape[1] == 2:
ax.scatter(hits[:, 1], hits[:, 0], # Note: y, x order for matplotlib
c='red', s=25, alpha=1.0, edgecolors='darkred', linewidth=1.5,
label=f'Cross-sections ({hits.shape[0]})', zorder=10)
ax.legend(loc='upper right', fontsize=10)
# Convert plot to numpy array for Gradio
buf = BytesIO()
fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
# Convert to PIL Image then to numpy array
pil_image = Image.open(buf)
numpy_array = np.array(pil_image)
plt.close(fig)
buf.close()
return numpy_array
def create_visualization_with_masks(img_gray, masks):
"""Create visualization showing the segmentation masks"""
if img_gray is None or masks is None:
return None
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
# Show original image
ax.imshow(img_gray, cmap='gray')
# Overlay masks with different colors
if masks.max() > 0:
# Create colored mask overlay
colored_masks = np.zeros((*masks.shape, 3), dtype=np.uint8)
for i in range(1, int(masks.max()) + 1):
mask = (masks == i)
color = np.random.randint(0, 255, 3)
colored_masks[mask] = color
# Overlay with transparency
ax.imshow(colored_masks, alpha=0.3)
ax.set_title('ZO-1 Segmentation Results', fontsize=14, fontweight='bold')
ax.axis('off')
# Convert to numpy array for Gradio
buf = BytesIO()
fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
# Convert to PIL Image then to numpy array
pil_image = Image.open(buf)
numpy_array = np.array(pil_image)
plt.close(fig)
buf.close()
return numpy_array
def validate_contours_with_ai(contours, image, method="K-means clustering", dilation_pixels=4):
"""Validate Cellpose contours using AI-powered methods"""
try:
print(f"[validate_contours_with_ai] method={method}, image.shape={image.shape}, image.dtype={image.dtype}")
except Exception:
pass
if method == "K-means clustering":
pixels = image.reshape(-1, 1).astype(np.float32)
kmeans = KMeans(n_clusters=2, n_init=10, max_iter=300, random_state=42)
labels = kmeans.fit_predict(pixels)
cluster_centers = kmeans.cluster_centers_.flatten()
foreground_cluster = np.argmax(cluster_centers)
ai_mask = (labels == foreground_cluster).reshape(image.shape).astype(np.uint8) * 255
elif method == "Gaussian Mixture Model (GMM)":
pixels = image.reshape(-1, 1).astype(np.float32)
gmm = GaussianMixture(n_components=2, n_init=10, max_iter=300, random_state=42)
labels = gmm.fit_predict(pixels)
cluster_centers = gmm.means_.flatten()
foreground_cluster = np.argmax(cluster_centers)
ai_mask = (labels == foreground_cluster).reshape(image.shape).astype(np.uint8) * 255
else: # Otsu
otsu_threshold, _ = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
ai_mask = (image > otsu_threshold).astype(np.uint8) * 255
# Dilate mask for tolerance
kernel = np.ones((dilation_pixels, dilation_pixels), np.uint8)
dilated_mask = cv2.dilate(ai_mask, kernel, iterations=1)
# Combine with contours
validated_mask = np.logical_and(contours > 0, dilated_mask > 0).astype(np.uint8)
return validated_mask
def run_segmentation_only(img_gray, diam, scale, enable_contour_validation, validation_method):
"""Run only the AI-powered segmentation step"""
if img_gray is None:
return None, None, "No image provided"
try:
print(f"[run_segmentation_only] img_gray.shape={img_gray.shape}, dtype={img_gray.dtype}, diam={diam}, scale={scale}, enable_validation={enable_contour_validation}, method={validation_method}")
# Initialize Cellpose model
# Use GPU if available, fallback to CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cp_model = models.CellposeModel(gpu=(device=='cuda'), model_type='cyto')
print(f"🔧 Using device: {device}")
if device == 'cuda':
print(f"🚀 GPU: {torch.cuda.get_device_name(0)}")
else:
print("⚠️ Running on CPU - this will be slower")
# Downsample if requested
h, w = img_gray.shape
if scale < 1.0:
small = cv2.resize(img_gray, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)
diam_small = max(1, int(diam*scale))
else:
small = img_gray
diam_small = diam
# Run segmentation
masks_small, flows, styles = cp_model.eval(
small,
diameter=diam_small,
channels=[0, 0],
flow_threshold=0.4,
batch_size=4,
resample=True,
augment=True
)
# Upsample masks if needed
if scale < 1.0:
masks = cv2.resize(masks_small.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST)
else:
masks = masks_small
# Create membrane mask for network analysis
contours = find_boundaries(masks, mode='inner')
# Apply AI-powered validation
if enable_contour_validation:
membrane_mask = validate_contours_with_ai(contours, img_gray, validation_method, 4)
else:
membrane_mask = contours.astype(np.uint8)
# Update global state
global_state['masks'] = masks
global_state['membrane_mask'] = membrane_mask
guidance = (
f"Segmentation complete! Found {int(masks.max())} cells.\n"
"If not happy with segmentation, adjust the Cell Diameter estimate and rerun.\n"
"If satisfied, proceed to the Analysis tab."
)
return masks, membrane_mask, guidance
except Exception as e:
return None, None, f"Segmentation failed: {str(e)}"
class ZO1TiJORQuantifier:
"""TiJOR quantifier for rectangular analysis"""
def __init__(self, initial_size=10, max_size=100, num_steps=10, min_distance=5):
self.initial_size = initial_size
self.max_size = max_size
self.num_steps = num_steps
self.min_distance = min_distance
self.results = {}
def analyze(self, membrane_mask):
"""Analyze membrane network using expanding rectangles"""
if membrane_mask is None:
return False
# Generate rectangle sizes as percentage of image size
img_size = min(membrane_mask.shape) # Use smaller dimension
min_size = img_size * self.initial_size / 100
max_size = img_size * self.max_size / 100
sizes = np.linspace(min_size, max_size, self.num_steps).astype(np.float32)
self.results['rectangle_sizes'] = sizes
try:
print(f"[TiJOR.analyze] img_size={img_size:.2f}, sizes.shape={sizes.shape}, sizes[:3]={sizes[:3] if sizes.size>=3 else sizes}")
except Exception:
pass
# Calculate cross-sections for each size
tijor_values = []
cross_section_counts = []
filtered_cross_section_counts = []
all_hits_xy = [] # Store all intersection points
center_y, center_x = np.array(membrane_mask.shape) / 2
for size in sizes:
half_side = size / 2
# Define rectangle boundaries
y1 = max(0, int(center_y - half_side))
y2 = min(membrane_mask.shape[0], int(center_y + half_side))
x1 = max(0, int(center_x - half_side))
x2 = min(membrane_mask.shape[1], int(center_x + half_side))
# Create rectangle boundary mask (2-pixel thickness like RIS) without reduce()
y, x = np.ogrid[:membrane_mask.shape[0], :membrane_mask.shape[1]]
tolerance = 2
# Outer rectangle (expanded by tolerance)
within_outer = (
(x >= (x1 - tolerance)) & (x <= (x2 + tolerance)) &
(y >= (y1 - tolerance)) & (y <= (y2 + tolerance))
)
# Inner rectangle (shrunk by tolerance)
within_inner = (
(x > (x1 + tolerance)) & (x < (x2 - tolerance)) &
(y > (y1 + tolerance)) & (y < (y2 - tolerance))
)
# Boundary is the ring between outer and inner
boundary_mask = within_outer & (~within_inner)
# Find intersections: membrane pixels near rectangle boundary (boolean-safe)
intersection_region = (membrane_mask > 0) & boundary_mask
# Apply minimum separation filtering (like RIS)
total_cross_sections = 0
if np.any(intersection_region):
y_coords, x_coords = np.where(intersection_region)
# Filter points with minimum separation
filtered_points = []
for y_coord, x_coord in zip(y_coords, x_coords):
# Check if this point is far enough from existing points
is_unique = True
for existing_y, existing_x in filtered_points:
distance = np.sqrt((y_coord - existing_y)**2 + (x_coord - existing_x)**2)
if distance < self.min_distance:
is_unique = False
break
if is_unique:
yi = int(y_coord)
xi = int(x_coord)
filtered_points.append([yi, xi])
all_hits_xy.append([yi, xi])
total_cross_sections = len(filtered_points)
cross_section_counts.append(total_cross_sections)
filtered_count = total_cross_sections
filtered_cross_section_counts.append(filtered_count)
# Calculate TiJOR (crossings per pixel length; rectangle perimeter)
perimeter = 4.0 * float(size)
tijor = (filtered_count / perimeter) if perimeter > 0 else 0.0
tijor_values.append(float(tijor))
self.results['tijor_values'] = np.asarray(tijor_values, dtype=np.float32)
self.results['cross_section_counts'] = np.asarray(cross_section_counts, dtype=np.int32)
self.results['filtered_cross_section_counts'] = np.asarray(filtered_cross_section_counts, dtype=np.int32)
try:
print(f"[TiJOR.analyze] crossings={self.results['cross_section_counts']}, tijor_values={self.results['tijor_values']}")
except Exception:
pass
# Safely create hits_xy array
if len(all_hits_xy) > 0:
try:
self.results['hits_xy'] = np.asarray(all_hits_xy, dtype=np.int32).reshape(-1, 2)
except Exception:
valid_pairs = []
for pair in all_hits_xy:
if isinstance(pair, (list, tuple, np.ndarray)) and len(pair) == 2:
y_val, x_val = pair
if np.isscalar(y_val) and np.isscalar(x_val):
valid_pairs.append([int(y_val), int(x_val)])
self.results['hits_xy'] = np.asarray(valid_pairs, dtype=np.int32).reshape(-1, 2) if len(valid_pairs) > 0 else np.zeros((0, 2), dtype=np.int32)
else:
self.results['hits_xy'] = np.zeros((0, 2), dtype=np.int32)
return True
def _filter_by_distance(self, coords, min_distance):
"""Filter coordinates by minimum distance"""
if len(coords) <= 1:
return coords
filtered = [coords[0]]
for coord in coords[1:]:
distances = [np.linalg.norm(coord - f) for f in filtered]
if min(distances) >= min_distance:
filtered.append(coord)
return np.array(filtered)
def get_summary_stats(self):
"""Get summary statistics"""
if not self.results:
return {}
tijor_values = self.results['tijor_values']
filtered_counts = self.results['filtered_cross_section_counts']
return {
'mean_tijor': np.mean(tijor_values),
'std_tijor': np.std(tijor_values),
'max_tijor': np.max(tijor_values),
'min_tijor': np.min(tijor_values),
'total_cross_sections': np.sum(filtered_counts),
'mean_cross_sections_per_rectangle': np.mean(filtered_counts)
}
class ZO1RISQuantifier:
"""Circular (Sholl-style) crossings-per-circumference for ZO-1 networks"""
def __init__(self, packing_factor=1.5, min_radius_percent=10, max_radius_percent=80, num_circles=15, min_separation=5):
self.kappa = float(packing_factor)
self.min_radius_percent = float(min_radius_percent)
self.max_radius_percent = float(max_radius_percent)
self.num_circles = int(num_circles)
self.min_separation = int(min_separation)
self.results = {}
def analyze(self, membrane_mask, d_eff_pixels, scale_factor=1.0):
"""Analyze membrane network using concentric circles with proper intersection detection"""
if membrane_mask is None:
return False
# Calculate reference density
d_ref = self.kappa / float(d_eff_pixels)
# Generate radii for concentric circles with percentage-based spacing
img_size = min(membrane_mask.shape) / 2 # Half of smaller dimension
min_radius = img_size * self.min_radius_percent / 100
max_radius = img_size * self.max_radius_percent / 100
radii = np.linspace(min_radius, max_radius, self.num_circles)
self.results['radii'] = radii
# Calculate crossings for each radius using proper intersection detection
crossings = []
hits_xy = []
center_y, center_x = np.array(membrane_mask.shape) / 2
for radius in radii:
# Create circular mask
y, x = np.ogrid[:membrane_mask.shape[0], :membrane_mask.shape[1]]
circle_mask = (x - center_x)**2 + (y - center_y)**2 <= radius**2
# Find membrane pixels near the circle (within 2 pixels tolerance)
circle_boundary = np.logical_and(
(x - center_x)**2 + (y - center_y)**2 <= (radius + 2)**2,
(x - center_x)**2 + (y - center_y)**2 >= (radius - 2)**2
)
# Count intersections: membrane pixels near circle boundary (boolean-safe)
intersection_region = (membrane_mask > 0) & circle_boundary
# Apply minimum separation filtering to avoid counting clustered pixels
if np.sum(intersection_region > 0) > 0:
y_coords, x_coords = np.where(intersection_region > 0)
# Filter points with minimum separation
filtered_points = []
for y_coord, x_coord in zip(y_coords, x_coords):
# Check if this point is far enough from existing points
is_unique = True
for existing_y, existing_x in filtered_points:
distance = np.sqrt((y_coord - existing_y)**2 + (x_coord - existing_x)**2)
if distance < self.min_separation:
is_unique = False
break
if is_unique:
filtered_points.append([y_coord, x_coord])
crossing_count = len(filtered_points)
crossings.append(crossing_count)
# Store filtered hit coordinates for visualization
hits_xy.extend(filtered_points)
else:
crossings.append(0)
self.results['crossings'] = np.array(crossings)
# Safely create hits_xy array
if len(hits_xy) > 0:
try:
self.results['hits_xy'] = np.asarray(hits_xy, dtype=np.int32).reshape(-1, 2)
except Exception:
# Fallback: filter only valid pairs
valid_pairs = [(int(y), int(x)) for y, x in hits_xy if isinstance(y, (int, np.integer)) and isinstance(x, (int, np.integer))]
self.results['hits_xy'] = np.asarray(valid_pairs, dtype=np.int32).reshape(-1, 2) if valid_pairs else np.zeros((0, 2), dtype=np.int32)
else:
self.results['hits_xy'] = np.zeros((0, 2), dtype=np.int32)
# Calculate RIS metrics as crossings per pixel length, averaged over all circles
radii = self.results['radii']
circle_lengths = 2.0 * np.pi * np.maximum(radii, 1e-6)
crossings_array = self.results['crossings'].astype(np.float32)
densities = crossings_array / circle_lengths # crossings per pixel length
self.results['crossings_per_length'] = densities
if len(densities) > 0:
d_mean = float(np.mean(densities))
d_peak = float(np.max(densities))
ris = (d_mean / d_ref) if d_ref > 0 else 0.0
ris_peak = (d_peak / d_ref) if d_ref > 0 else 0.0
self.results['RIS'] = float(ris)
self.results['RIS_peak'] = float(ris_peak)
self.results['d_mean'] = d_mean
self.results['d_peak'] = d_peak
self.results['d_ref'] = float(d_ref)
return True
def get_summary_stats(self):
"""Get summary statistics for RIS analysis"""
if not self.results:
return {}
crossings = self.results.get('crossings', np.array([]))
densities = self.results.get('crossings_per_length', np.array([]))
return {
'RIS': self.results.get('RIS', np.nan),
'RIS_peak': self.results.get('RIS_peak', np.nan),
'd_mean': self.results.get('d_mean', np.nan), # mean crossings per pixel length
'd_peak': self.results.get('d_peak', np.nan), # peak crossings per pixel length
'd_ref': self.results.get('d_ref', np.nan),
'total_crossings': float(np.sum(crossings)) if crossings.size else 0.0,
'mean_crossings_per_px': float(np.mean(densities)) if densities.size else np.nan,
'packing_factor': self.kappa
}
def run_analysis(analysis_geometry, initial_size, max_size, num_steps, min_distance, packing_factor, min_radius_percent, max_radius_percent, num_circles, min_separation, show_contours=True, show_rectangles=True, show_cross_sections=True):
"""Run the selected analysis method"""
if global_state['membrane_mask'] is None:
return "No membrane mask available. Please run segmentation first.", None, None
try:
# Persist current analysis mode for downstream export/reporting
global_state['analysis_geometry'] = analysis_geometry
print(f"[run_analysis] mode={analysis_geometry}, init_size={initial_size}, max_size={max_size}, steps={num_steps}, min_dist={min_distance}, kappa={packing_factor}, minR%={min_radius_percent}, maxR%={max_radius_percent}, circles={num_circles}, min_sep={min_separation}")
if analysis_geometry == "Circles (RIS - recommended)":
# RIS analysis with user-configurable parameters
quantifier = ZO1RISQuantifier(
packing_factor=packing_factor,
min_radius_percent=min_radius_percent,
max_radius_percent=max_radius_percent,
num_circles=num_circles,
min_separation=min_separation
)
success = quantifier.analyze(global_state['membrane_mask'], 20, scale_factor=1.0)
if success:
global_state['quantifier'] = quantifier
summary = quantifier.get_summary_stats()
print(f"[run_analysis][RIS] summary keys={list(summary.keys())}")
# Create results display
ris = summary.get('RIS', np.nan)
ris_peak = summary.get('RIS_peak', np.nan)
mean_crossings = summary.get('d_mean', np.nan)
total_crossings = summary.get('total_crossings', np.nan)
d_ref = summary.get('d_ref', np.nan)
results_text = f"""
🔵 **RIS Analysis Results**
RIS Score: {ris:.4f}
RIS Peak: {ris_peak:.4f}
Mean Crossings/px length: {mean_crossings:.4f}
Total Crossings: {total_crossings:.0f}
Reference Density (per px length): {d_ref:.4f}
Parameters: κ={packing_factor:.1f}, Min Radius={min_radius_percent}%, Max Radius={max_radius_percent}%, Circles={num_circles}, Min Sep={min_separation}px
"""
# Create visualization
viz = create_visualization(
global_state['img_gray'],
global_state['masks'],
quantifier,
analysis_geometry,
show_contours=show_contours,
show_rectangles=show_rectangles,
show_cross_sections=show_cross_sections
)
return results_text, viz, summary
else:
return "RIS analysis failed.", None, None
else:
# TiJOR analysis
quantifier = ZO1TiJORQuantifier(initial_size, max_size, num_steps, min_distance)
success = quantifier.analyze(global_state['membrane_mask'])
if success:
global_state['quantifier'] = quantifier
# Defensive normalization of results to avoid shape errors downstream
try:
if 'hits_xy' in quantifier.results:
quantifier.results['hits_xy'] = _sanitize_hits_xy(quantifier.results['hits_xy'])
if 'rectangle_sizes' in quantifier.results:
sizes_arr = np.asarray(quantifier.results['rectangle_sizes']).reshape(-1)
quantifier.results['rectangle_sizes'] = sizes_arr.astype(np.float32, copy=False)
except Exception:
quantifier.results['hits_xy'] = np.zeros((0, 2), dtype=np.int32)
summary = quantifier.get_summary_stats()
print(f"[run_analysis][TiJOR] sizes.shape={quantifier.results.get('rectangle_sizes', np.array([])).shape}, hits.shape={quantifier.results.get('hits_xy', np.zeros((0,2))).shape}, summary keys={list(summary.keys())}")
# Create results display
mean_tijor = float(summary.get('mean_tijor', np.nan))
total_cross_sections = float(summary.get('total_cross_sections', np.nan))
cells_detected = int(global_state['masks'].max()) if global_state['masks'] is not None else 0
results_text = f"""
📊 **TiJOR Analysis Results**
Mean TiJOR: {mean_tijor:.4f}
Total Cross-sections: {total_cross_sections:.0f}
Cells Detected: {cells_detected}
"""
# Create visualization
viz = create_visualization(
global_state['img_gray'],
global_state['masks'],
quantifier,
analysis_geometry,
show_contours=show_contours,
show_rectangles=show_rectangles,
show_cross_sections=show_cross_sections
)
return results_text, viz, summary
else:
return "TiJOR analysis failed.", None, None
except Exception as e:
tb = traceback.format_exc()
print(f"[run_analysis][EXCEPTION] {e}\n{tb}")
return f"Analysis failed: {str(e)}\n\nTraceback:\n{tb}", None, None
def process_image(image, cell_diameter, scale_factor, enable_validation, validation_method):
"""Process uploaded image and run segmentation"""
if image is None:
return "No image uploaded", None, None
try:
# Accept either numpy array or filepath; store basename for exports
if isinstance(image, str):
try:
global_state['image_basename'] = Path(image).stem
except Exception:
global_state['image_basename'] = 'results'
# Read with unchanged flag to preserve bit depth
img = cv2.imread(image, cv2.IMREAD_UNCHANGED)
if img is None:
return "Failed to read image from path", None, None
# Convert to grayscale
if len(img.shape) == 3:
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
else:
img_gray = img
else:
# Numpy array path
global_state['image_basename'] = global_state.get('image_basename') or 'results'
# Handle different image formats and data types (especially TIFF)
if len(image.shape) == 3:
# Convert to RGB first, then grayscale
if image.shape[2] == 4: # RGBA
image = image[:, :, :3] # Remove alpha channel
img_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
img_gray = image
# Ensure image is 8-bit (TIFF files might be 16-bit or float)
if img_gray.dtype != np.uint8:
imin = float(img_gray.min())
imax = float(img_gray.max())
if imax > imin:
# Normalize to 0-255 range
img_gray = ((img_gray - imin) / (imax - imin) * 255).astype(np.uint8)
else:
# Flat image; return zeros to avoid NaNs
img_gray = np.zeros_like(img_gray, dtype=np.uint8)
global_state['img_gray'] = img_gray
# Run segmentation
masks, membrane_mask, message = run_segmentation_only(
img_gray, cell_diameter, scale_factor, enable_validation, validation_method
)
if masks is not None:
# Create initial visualization with masks
viz = create_visualization_with_masks(img_gray, masks)
return message, viz, masks
else:
return message, None, None
except Exception as e:
return f"Image processing failed: {str(e)}", None, None
def export_results(format_type):
"""Export analysis results; returns (text_content, file_path)."""
if global_state['quantifier'] is None:
return "No analysis results to export", None
try:
# Determine filename base and suffix
base = global_state.get('image_basename') or 'results'
suffix = '_RIS' if global_state.get('analysis_geometry') == "Circles (RIS - recommended)" else '_TiJOR'
if format_type == "CSV":
if global_state['analysis_geometry'] == "Circles (RIS - recommended)":
# Export RIS results
results_data = []
if hasattr(global_state['quantifier'], 'results') and 'crossings' in global_state['quantifier'].results:
for i, crossing in enumerate(global_state['quantifier'].results['crossings']):
results_data.append({
'Circle': i+1,
'Crossings': int(crossing)
})
df = pd.DataFrame(results_data)
csv_data = df.to_csv(index=False)
file_path = f"{base}{suffix}.csv"
with open(file_path, 'w', newline='') as f:
f.write(csv_data)
return csv_data, file_path
else:
# Export TiJOR results
results_data = []
if hasattr(global_state['quantifier'], 'results') and 'tijor_values' in global_state['quantifier'].results:
for i, (size, tijor) in enumerate(zip(
global_state['quantifier'].results['rectangle_sizes'],
global_state['quantifier'].results['tijor_values']
)):
results_data.append({
'Step': i+1,
'Size (px)': f'{size:.1f}',
'TiJOR': f'{tijor:.4f}'
})
df = pd.DataFrame(results_data)
csv_data = df.to_csv(index=False)
file_path = f"{base}{suffix}.csv"
with open(file_path, 'w', newline='') as f:
f.write(csv_data)
return csv_data, file_path
else:
# Export as text report
summary = global_state['quantifier'].get_summary_stats()
if global_state['analysis_geometry'] == "Circles (RIS - recommended)":
ris = summary.get('RIS', np.nan)
ris_peak = summary.get('RIS_peak', np.nan)
mean_crossings = summary.get('d_mean', np.nan)
total_crossings = summary.get('total_crossings', np.nan)
d_ref = summary.get('d_ref', np.nan)
report = f"""ZO-1 RIS Network Analysis Report
{'='*50}
RIS Score: {ris:.4f}
RIS Peak: {ris_peak:.4f}
Mean Crossings/px length: {mean_crossings:.4f}
Total Crossings: {total_crossings:.0f}
Reference Density (per px length): {d_ref:.4f}
"""
else:
mean_tijor = summary.get('mean_tijor', np.nan)
total_cross_sections = summary.get('total_cross_sections', np.nan)
cells_detected = int(global_state['masks'].max()) if global_state['masks'] is not None else 0
report = f"""ZO-1 TiJOR Network Analysis Report
{'='*50}
Mean TiJOR: {mean_tijor:.4f}
Total Cross-sections: {total_cross_sections:.0f}
Cells Detected: {cells_detected}
"""
file_path = f"{base}{suffix}.txt"
with open(file_path, 'w', encoding='utf-8') as f:
f.write(report)
return report, file_path
except Exception as e:
return f"Export failed: {str(e)}", None