OrganoidCounter / Organoid_Analyzer_AI_zstack2_HF.py
ReyaLabColumbia's picture
Upload Organoid_Analyzer_AI_zstack2_HF.py
8f635c1 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 20 14:23:27 2025
@author: mattc
"""
import os
import cv2
#this is the huggingface version
from PIL import Image
def pad(img_np, tw=2048, th=1536):
"""
Pads a numpy image (grayscale or RGB) to 2048x1536 (width x height) with white pixels.
Pads at the bottom and right as needed.
"""
height, width = img_np.shape[:2]
pad_bottom = max(0, th - height)
pad_right = max(0, tw - width)
# Padding: (top, bottom, left, right)
if img_np.ndim == 3:
# Color image (H, W, 3)
border_value = [255, 255, 255]
else:
# Grayscale image (H, W)
border_value = 255
padded = cv2.copyMakeBorder(
img_np,
top=0, bottom=pad_bottom,
left=0, right=pad_right,
borderType=cv2.BORDER_CONSTANT,
value=border_value
)
return padded
#this is the huggingface version
import numpy as np
from PIL import Image
def cut_img(img, patch_size=512):
img_map = {}
width, height = img.size
i_num = height // patch_size
j_num = width // patch_size
count = 1
for i in range(i_num):
for j in range(j_num):
cropped_img = img.crop((
patch_size * j,
patch_size * i,
patch_size * (j + 1),
patch_size * (i + 1)
))
img_map[count] = cropped_img
count += 1
return img_map, i_num, j_num # Return rows and cols for stitching
import numpy as np
import numpy as np
from PIL import Image
def stitch(img_map, i_num, j_num, min_width=2048, min_height=1536):
tiles = []
count = 1
for i in range(i_num):
row_tiles = []
for j in range(j_num):
tile = np.array(img_map[count])
row_tiles.append(tile)
count += 1
row_img = np.hstack(row_tiles)
tiles.append(row_img)
stitched = np.vstack(tiles)
# Pad the stitched image if it's less than min_width/min_height
h, w = stitched.shape[:2]
pad_h = max(0, min_height - h)
pad_w = max(0, min_width - w)
if pad_h > 0 or pad_w > 0:
# Pad as (top, bottom), (left, right), (channels)
if stitched.ndim == 3:
stitched = np.pad(stitched, ((0, pad_h), (0, pad_w), (0, 0)), 'constant')
else:
stitched = np.pad(stitched, ((0, pad_h), (0, pad_w)), 'constant')
return stitched
from PIL import Image
import matplotlib.pyplot as plt
def visualize_segmentation(mask, image=0):
plt.figure(figsize=(10, 5))
if(not np.isscalar(image)):
# Show original image if it is entered
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")
# Show segmentation mask
plt.subplot(1, 2, 2)
plt.imshow(mask, cmap="gray") # Show as grayscale
plt.title("Segmentation Mask")
plt.axis("off")
plt.show()
import torch
from transformers import SegformerForSemanticSegmentation
# Load fine-tuned model
#ReyaLabColumbia/Segformer_Colony_Counter
#ReyaLabColumbia/OrganoidCounter
model = SegformerForSemanticSegmentation.from_pretrained("ReyaLabColumbia/Segformer_Organoid_Counter_GP") # Adjust path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval() # Set to evaluation mode
# Load image processor
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
image_processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b3-finetuned-cityscapes-1024-1024")
def preprocess_image(image):
image = image.convert("RGB") # Open and convert to RGB
inputs = image_processor(image, return_tensors="pt") # Preprocess for model
return image, inputs["pixel_values"]
def postprocess_mask(logits):
mask = torch.argmax(logits, dim=1) # Take argmax across the class dimension
return mask.squeeze().cpu().numpy() # Convert to NumPy array
def eval_img(image):
# Load and preprocess image
image, pixel_values = preprocess_image(image)
pixel_values = pixel_values.to(device)
with torch.no_grad(): # No gradient calculation for inference
outputs = model(pixel_values=pixel_values) # Run model
logits = outputs.logits
# Convert logits to segmentation mask
segmentation_mask = postprocess_mask(logits)
#visualize_segmentation(segmentation_mask,image)
segmentation_mask = cv2.resize(segmentation_mask, (512, 512), interpolation=cv2.INTER_LINEAR_EXACT)
return(segmentation_mask)
def find_colonies(mask, size_cutoff, circ_cutoff):
binary_mask = np.where(mask == 1, 255, 0).astype(np.uint8)
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contoursf = []
areas = []
for x in contours:
area = cv2.contourArea(x)
if (area < size_cutoff):
continue
perimeter = cv2.arcLength(x, True)
# Avoid division by zero
if perimeter == 0:
continue
# Calculate circularity
circularity = (4 * np.pi * area) / (perimeter ** 2)
if circularity >= circ_cutoff:
contoursf.append(x)
areas.append(area)
return(contoursf, areas)
def find_necrosis(mask):
binary_mask = np.where(mask == 2, 255, 0).astype(np.uint8)
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
return(contours)
# contour_image = np.zeros_like(p)
# contours = find_necrosis(p)
# cv2.drawContours(contour_image, contours, -1, (255), 2)
# visualize_segmentation(contour_image)
import pandas as pd
def compute_centroid(contour):
M = cv2.moments(contour)
if M["m00"] == 0: # Avoid division by zero
return None
cx = int(M["m10"] / M["m00"])
cy = int(M["m01"] / M["m00"])
return (cx, cy)
def contours_overlap_using_mask(contour1, contour2, image_shape=(1536, 2048)):
"""Check if two contours overlap using a bitwise AND mask."""
import numpy as np
import cv2
mask1 = np.zeros(image_shape, dtype=np.uint8)
mask2 = np.zeros(image_shape, dtype=np.uint8)
# Draw each contour as a white shape on its respective mask
cv2.drawContours(mask1, [contour1], -1, 255, thickness=cv2.FILLED)
cv2.drawContours(mask2, [contour2], -1, 255, thickness=cv2.FILLED)
# Compute bitwise AND to find overlapping regions
overlap = cv2.bitwise_and(mask1, mask2)
return np.any(overlap)
def analyze_colonies(mask, size_cutoff, circ_cutoff, img):
colonies,areas = find_colonies(mask, size_cutoff, circ_cutoff)
necrosis = find_necrosis(mask)
data = []
for x in range(len(colonies)):
colony = colonies[x]
colony_area = areas[x]
centroid = compute_centroid(colony)
mask = np.zeros(img.shape, np.uint8)
cv2.drawContours(mask, [colony], -1, 255, cv2.FILLED)
pix = img[mask == 255]
# Check if any necrosis contour is inside the colony
necrosis_area = 0
nec_list =[]
for nec in necrosis:
# Check if the first point of the necrosis contour is inside the colony
if contours_overlap_using_mask(colony, nec):
nec_area = cv2.contourArea(nec)
necrosis_area += nec_area
nec_list.append(nec)
data.append({
"organoid_area": colony_area,
"necrotic_area": necrosis_area,
"centroid": centroid,
"percent_necrotic": necrosis_area/colony_area,
"contour": colony,
"nec_contours": nec_list,
'mean_pixel_value':np.mean(pix)
})
# Convert results to a DataFrame
df = pd.DataFrame(data)
df.index = range(1,len(df.index)+1)
return(df)
def contour_overlap(contour1, contour2, centroid1, centroid2, area1, area2, centroid_thresh=30, area_thresh = .4, img_shape = (1536, 2048)):
"""
Determines the overlap between two contours.
Returns:
0: No overlap
1: Overlap but does not meet strict conditions
2: Overlap >= 80% of the larger contour and centroids are close
"""
# Create blank images
img1 = np.zeros(img_shape, dtype=np.uint8)
img2 = np.zeros(img_shape, dtype=np.uint8)
# Draw filled contours
cv2.drawContours(img1, [contour1], -1, 255, thickness=cv2.FILLED)
cv2.drawContours(img2, [contour2], -1, 255, thickness=cv2.FILLED)
# Compute overlap
intersection = cv2.bitwise_and(img1, img2)
intersection_area = np.count_nonzero(intersection)
if intersection_area == 0:
return 0 # No overlap
# Compute centroid distance
centroid_distance = float(np.sqrt(abs(centroid1[0]-centroid2[0])**2 + abs(centroid1[1]-centroid2[1])**2))
# Check percentage overlap relative to the larger contour
overlap_ratio = intersection_area/max(area1, area2)
if overlap_ratio >= area_thresh and centroid_distance <= centroid_thresh:
if area1 > area2:
return(2)
else:
return(3)
else:
return 1 # Some overlap but not meeting strict criteria
def compare_frames(frame1, frame2, centroid_dist=30, overlap_area=.4):
for i in range(1, len(frame1)+1):
if frame1.loc[i,"exclude"] == True:
continue
for j in range(1, len(frame2)+1):
if frame2.loc[j,"exclude"] == True:
continue
temp = contour_overlap(frame1.loc[i, "contour"], frame2.loc[j, "contour"], frame1.loc[i, "centroid"], frame2.loc[j, "centroid"], frame1.loc[i, "organoid_area"], frame2.loc[j, "organoid_area"],centroid_dist, overlap_area)
if temp ==2:
frame2.loc[j,"exclude"] = True
elif temp ==3:
frame1.loc[i, "exclude"] = True
break
frame1 = frame1[frame1["exclude"]==False]
frame2 = frame2[frame2["exclude"]==False]
df = pd.concat([frame1, frame2], axis=0)
df.index = range(1,len(df.index)+1)
return(df)
def main(args):
min_size = args[1]
min_circ = args[2]
do_necrosis = args[5]
colonies = {}
files = args[0]
for idx,x in enumerate(files):
img_map, i_num, j_num = cut_img(Image.fromarray(pad(np.array(files[idx]),512,512)))
for z in img_map:
img_map[z] = eval_img(img_map[z])
del z
p = stitch(img_map, i_num, j_num)
frame = analyze_colonies(p, min_size, min_circ, np.array(files[idx]))
frame["source"] = idx
frame["exclude"] = False
if isinstance(colonies, dict):
colonies = frame
else:
colonies = compare_frames(frame, colonies, args[3], args[4])
if len(colonies) <=0:
img = pad(np.array(files[0]))
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
caption = np.ones((150, 2048, 3), dtype=np.uint8) * 255 # Multiply by 255 to make it white
cv2.putText(caption, 'No organoids detected.', (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
cv2.imwrite('results.png', np.vstack((img, caption)))
colonies = pd.DataFrame({"organoid_number":[], 'organoid_volume':[], "organoid_area":[],'mean_pixel_value':[], "centroid":[], "necrotic_area":[],"percent_necrotic":[]})
with pd.ExcelWriter('results.xlsx') as writer:
colonies.to_excel(writer, sheet_name="Organoid data", index=False)
return(np.vstack((img, caption)), 'results.png', 'results.xlsx')
counts = {}
for x in range(len(files)):
counts[x] = list(colonies["source"]).count(x)
best = [x, counts[x]]
del x
for x in counts:
if counts[x] > best[1]:
best[0] = x
best[1] = counts[x]
del x, counts
best = best[0]
img = pad(np.array(files[best]))
for x in range(len(files)):
if x == best:
continue
#mask = np.zeros_like(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY))
mask = np.zeros_like(img)
contours = colonies[colonies["source"]==x]
contours = list(contours["contour"])
cv2.drawContours(mask, contours, -1, 255, thickness=cv2.FILLED)
# Extract all ROIs from the source image at once
src_image = pad(np.array(files[x]))
roi = cv2.bitwise_and(src_image, src_image, mask=mask)
# Paste the extracted regions onto the destination image
#print("ROI Shape:", roi.shape)
#print("img Shape:", img.shape)
#print("mask Shape:", mask.shape)
#print("mask Shape fixed:", mask.shape)
np.copyto(img, roi, where=(mask== 255))
try:
del x, mask, src_image, roi, best, contours
except:
pass
img = cv2.copyMakeBorder(img,top=0, bottom=10,left=0,right=10, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
colonies = colonies.sort_values(by=["organoid_area"], ascending=False)
colonies = colonies[colonies["organoid_area"]>= min_size]
colonies.index = range(1,len(colonies.index)+1)
#nearby is a boolean list of whether a colony has overlapping colonies. If so, labelling positions change
nearby = [False]*len(colonies)
areas = list(colonies["organoid_area"])
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
for i in range(len(colonies)):
cv2.drawContours(img, [list(colonies["contour"])[i]], -1, (0, 255, 0), 2)
if do_necrosis == True:
cv2.drawContours(img, list(colonies['nec_contours'])[i], -1, (0, 0, 255), 2)
coords = list(list(colonies["centroid"])[i])
if coords[0] > 1950:
#if a colony is too close to the right edge, makes the label move to left
coords[0] = 1950
for j in range(len(colonies)):
if j == i:
continue
coords2 = list(list(colonies["centroid"])[j])
if ((abs(coords[0] - coords2[0]) + abs(coords[1] - coords2[1])) <= 40):
nearby[i] = True
break
if nearby[i] ==True:
#If the colony has nearby colonies, this adjusts the labels so they are smaller and are positioned based on the approximate radius of the colony
# a random number is generated, and based on that, the label is put at the top or bottom, left or right
radius= int(np.sqrt(areas[i]/3.1415)*.9)
n = np.random.random()
if n >.75:
new_x = min(coords[0] + radius, 2000)
new_y = min(coords[1] + radius, 1480)
elif n >.5:
new_x = min(coords[0] + radius, 2000)
new_y = max(coords[1] - radius, 50)
elif n >.25:
new_x = max(coords[0] - radius, 0)
new_y = min(coords[1] + radius, 1480)
else:
new_x = max(coords[0] - radius, 0)
new_y = max(coords[1] - radius, 50)
cv2.putText(img, str(colonies.index[i]), (new_x,new_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
del n, radius, new_x, new_y
else:
cv2.putText(img, str(colonies.index[i]), coords, cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 2)
del nearby, areas
colonies = colonies.drop('contour', axis=1)
colonies = colonies.drop('nec_contours', axis=1)
colonies = colonies.drop('exclude', axis=1)
img = cv2.copyMakeBorder(img,top=10, bottom=0,left=10,right=0, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
colonies.insert(loc=0, column="organoid_number", value=[str(x) for x in range(1, len(colonies)+1)])
total_area_dark = sum(colonies['necrotic_area'])
total_area_light = sum(colonies['organoid_area'])
ratio = total_area_dark/(abs(total_area_light)+1)
radii = [np.sqrt(x/3.1415) for x in list(colonies['organoid_area'])]
volumes = [4.189*(x**3) for x in radii]
colonies['organoid_volume'] = volumes
del radii, volumes
meanpix = sum(colonies['mean_pixel_value'] * colonies['organoid_area'])/total_area_light
colonies = colonies[["organoid_number", 'organoid_volume', "organoid_area",'mean_pixel_value', "centroid", "necrotic_area","percent_necrotic", "source"]]
colonies.loc[len(colonies)+1] = ["Total", sum(colonies['organoid_volume']), total_area_light, meanpix, None, total_area_dark, ratio, None]
del meanpix
Parameters = pd.DataFrame({"Minimum organoid size in pixels":[min_size], "Minimum organoid circularity":[min_circ]})
if do_necrosis == False:
colonies = colonies.drop('necrotic_area', axis=1)
colonies = colonies.drop('percent_necrotic', axis=1)
with pd.ExcelWriter("Group_analysis_results.xlsx") as writer:
colonies.to_excel(writer, sheet_name="Organoid data", index=False)
Parameters.to_excel(writer, sheet_name="Parameters", index=False)
caption = np.ones((150, 2068, 3), dtype=np.uint8) * 255 # Multiply by 255 to make it white
if do_necrosis == True:
cv2.putText(caption, "Total area necrotic: "+str(total_area_dark)+ ", Total area living: "+str(total_area_light)+", Ratio: "+str(ratio), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
else:
cv2.putText(caption, "Total area: "+str(total_area_light), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
cv2.putText(caption, "Total number of organoids: "+str(len(colonies)-1), (40, 110), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
print('img ndim: ' +str(img.ndim))
print('caption ndim: ' +str(caption.ndim))
cv2.imwrite('Group_analysis_results.png', np.vstack((img, caption)))
return(np.vstack((img, caption)), 'Group_analysis_results.png', 'Group_analysis_results.xlsx')