cellpose / app.py
mouseland's picture
Update app.py
7a96d4b verified
import numpy as np
import gradio as gr
import spaces
import cv2
from cellpose import models
from matplotlib.colors import hsv_to_rgb
import matplotlib.pyplot as plt
import os, io, base64
from PIL import Image
from cellpose.io import imread, imsave
import glob
from huggingface_hub import hf_hub_download
img = np.zeros((96, 128), dtype = np.uint8)
fp0 = Image.fromarray(img)
#fp0 = "0.png"
#imsave(fp0, img)
# data retrieval
def download_weights():
return hf_hub_download(repo_id="mouseland/cellpose-sam", filename="cpsam")
#os.system("wget -q https://huggingface.co/mouseland/cellpose-sam/resolve/main/cpsam")
def download_weights_old():
import os, requests
fname = ['cpsam']
url = ["https://osf.io/d7c8e/download"]
for j in range(len(url)):
if not os.path.isfile(fname[j]):
ntries = 0
while ntries<10:
try:
r = requests.get(url[j])
except:
print("!!! Failed to download data !!!")
ntries += 1
print(ntries)
if r.status_code != requests.codes.ok:
print("!!! Failed to download data !!!")
else:
with open(fname[j], "wb") as fid:
fid.write(r.content)
try:
fpath = download_weights()
model = models.CellposeModel(gpu=True, pretrained_model = fpath)
except Exception as e:
print(f"Error loading model: {e}")
exit(1)
def plot_flows(y):
Y = (np.clip(normalize99(y[0][0]),0,1) - 0.5) * 2
X = (np.clip(normalize99(y[1][0]),0,1) - 0.5) * 2
H = (np.arctan2(Y, X) + np.pi) / (2*np.pi)
S = normalize99(y[0][0]**2 + y[1][0]**2)
HSV = np.concatenate((H[:,:,np.newaxis], S[:,:,np.newaxis], S[:,:,np.newaxis]), axis=-1)
HSV = np.clip(HSV, 0.0, 1.0)
flow = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
return flow
def plot_outlines(img, masks):
img = normalize99(img)
img = np.clip(img, 0, 1)
outpix = []
contours, hierarchy = cv2.findContours(masks.astype(np.int32), mode=cv2.RETR_FLOODFILL, method=cv2.CHAIN_APPROX_SIMPLE)
for c in range(len(contours)):
pix = contours[c].astype(int).squeeze()
if len(pix)>4:
peri = cv2.arcLength(contours[c], True)
approx = cv2.approxPolyDP(contours[c], 0.001, True)[:,0,:]
outpix.append(approx)
figsize = (6,6)
if img.shape[0]>img.shape[1]:
figsize = (6*img.shape[1]/img.shape[0], 6)
else:
figsize = (6, 6*img.shape[0]/img.shape[1])
fig = plt.figure(figsize=figsize, facecolor='k')
ax = fig.add_axes([0.0,0.0,1,1])
ax.set_xlim([0,img.shape[1]])
ax.set_ylim([0,img.shape[0]])
ax.imshow(img[::-1], origin='upper', aspect = 'auto')
if outpix is not None:
for o in outpix:
ax.plot(o[:,0], img.shape[0]-o[:,1], color=[1,0,0], lw=1)
ax.axis('off')
#bytes_image = io.BytesIO()
#plt.savefig(bytes_image, format='png', facecolor=fig.get_facecolor(), edgecolor='none')
#bytes_image.seek(0)
#img_arr = np.frombuffer(bytes_image.getvalue(), dtype=np.uint8)
#bytes_image.close()
#img = cv2.imdecode(img_arr, 1)
#img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#del bytes_image
#fig.clf()
#plt.close(fig)
buf = io.BytesIO()
fig.savefig(buf, bbox_inches='tight')
buf.seek(0)
pil_img = Image.open(buf)
plt.close(fig)
return pil_img
def plot_overlay(img, masks):
if img.ndim>2:
img_gray = img.astype(np.float32).mean(axis=-1)
else:
img_gray = img.astype(np.float32)
img = normalize99(img_gray)
#img = np.clip(img, 0, 1)
HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
HSV[:,:,2] = np.clip(img*1.5, 0, 1.0)
for n in range(int(masks.max())):
ipix = (masks==n+1).nonzero()
HSV[ipix[0],ipix[1],0] = np.random.rand()
HSV[ipix[0],ipix[1],1] = 1.0
RGB = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
return RGB
def normalize99(img):
X = img.copy()
X = (X - np.percentile(X, 1)) / (1e-10 + np.percentile(X, 99) - np.percentile(X, 1))
return X
def image_resize(img, resize=400):
ny,nx = img.shape[:2]
if np.array(img.shape).max() > resize:
if ny>nx:
nx = int(nx/ny * resize)
ny = resize
else:
ny = int(ny/nx * resize)
nx = resize
shape = (nx,ny)
img = cv2.resize(img, shape)
img = img.astype(np.uint8)
return img
@spaces.GPU(duration=10)
def run_model_gpu(img, max_iter, flow_threshold, cellprob_threshold):
masks, flows, _ = model.eval(img, niter = max_iter, flow_threshold = flow_threshold, cellprob_threshold = cellprob_threshold)
return masks, flows
@spaces.GPU(duration=60)
def run_model_gpu60(img, max_iter, flow_threshold, cellprob_threshold):
masks, flows, _ = model.eval(img, niter = max_iter, flow_threshold = flow_threshold, cellprob_threshold = cellprob_threshold)
return masks, flows
@spaces.GPU(duration=240)
def run_model_gpu240(img, max_iter, flow_threshold, cellprob_threshold):
masks, flows, _ = model.eval(img, niter = max_iter, flow_threshold = flow_threshold, cellprob_threshold = cellprob_threshold)
return masks, flows
import datetime
from zipfile import ZipFile
def cellpose_segment(filepath, resize = 1000,max_iter = 250, flow_threshold= 0.4, cellprob_threshold = 0):
zip_path = os.path.splitext(filepath[-1])[0]+"_masks.zip"
#zip_path = 'masks.zip'
with ZipFile(zip_path, 'w') as myzip:
for j in range((len(filepath))):
now = datetime.datetime.now()
formatted_now = now.strftime("%Y-%m-%d %H:%M:%S")
img_input = imread(filepath[j])
#img_input = np.array(img_pil)
img = image_resize(img_input, resize = resize)
maxsize = np.max(img.shape)
if maxsize<=1000:
masks, flows = run_model_gpu(img, max_iter, flow_threshold, cellprob_threshold)
elif maxsize < 5000:
masks, flows = run_model_gpu60(img, max_iter, flow_threshold, cellprob_threshold)
elif maxsize < 20000:
masks, flows = run_model_gpu240(img, max_iter, flow_threshold, cellprob_threshold)
else:
raise ValueError("Image size must be less than 20,000")
print(formatted_now, j, masks.max(), os.path.split(filepath[j])[-1])
target_size = (img_input.shape[1], img_input.shape[0])
if (target_size[0]!=img.shape[1] or target_size[1]!=img.shape[0]):
# scale it back to keep the orignal size
masks_rsz = cv2.resize(masks.astype('uint16'), target_size, interpolation=cv2.INTER_NEAREST).astype('uint16')
else:
masks_rsz = masks.copy()
fname_masks = os.path.splitext(filepath[j])[0]+"_masks.tif"
imsave(fname_masks, masks_rsz)
myzip.write(fname_masks, arcname = os.path.split(fname_masks)[-1])
#masks, flows, _ = model.eval(img, channels=[0,0])
flows = flows[0]
# masks = np.zeros(img.shape[:2])
# flows = np.zeros_like(img)
outpix = plot_outlines(img, masks)
#overlay = plot_overlay(img, masks)
#crand = .2 + .8 * np.random.rand(np.max(masks.flatten()).astype('int')+1,).astype('float32')
#crand[0] = 0
#overlay = Image.fromarray(overlay)
flows = Image.fromarray(flows)
Ly, Lx = img.shape[:2]
outpix = outpix.resize((Lx, Ly), resample = Image.BICUBIC)
#overlay = overlay.resize((Lx, Ly), resample = Image.BICUBIC)
flows = flows.resize((Lx, Ly), resample = Image.BICUBIC)
fname_out = os.path.splitext(filepath[-1])[0]+"_outlines.png"
outpix.save(fname_out) #"outlines.png")
#fname_flows = os.path.splitext(filepath[-1])[0]+"_flows.png"
#flows.save(fname_flows) #"outlines.png")
if len(filepath)>1:
b1 = gr.DownloadButton(visible=True, value = zip_path)
else:
b1 = gr.DownloadButton(visible=True, value = fname_masks)
b2 = gr.DownloadButton(visible=True, value = fname_out) #"outlines.png")
return outpix, flows, b1, b2
def download_function():
b1 = gr.DownloadButton("Download masks as TIFF", visible=False)
b2 = gr.DownloadButton("Download outline image as PNG", visible=False)
return b1, b2
def tif_view(filepath):
fpath, fext = os.path.splitext(filepath)
if fext in ['tiff', 'tif']:
img = imread(filepath[-1])
if img.ndim==2:
img = np.tile(img[:,:,np.newxis], [1,1,3])
elif img.ndim==3:
imin = np.argmin(img.shape)
if imin<2:
img = np.tranpose(img, [2, imin])
else:
raise ValueError("TIF cannot have more than three dimensions")
Ly, Lx, nchan = img.shape
imgi = np.zeros((Ly, Lx, 3))
nn = np.minimum(3, img.shape[-1])
imgi[:,:,:nn] = img[:,:,:nn]
#filepath = fpath+'.png'
imsave(filepath, imgi)
return filepath
def norm_path(filepath):
img = imread(filepath)
img = normalize99(img)
img = np.clip(img, 0, 1)
fpath, fext = os.path.splitext(filepath)
filepath = fpath +'.png'
pil_image = Image.fromarray((255. * img).astype(np.uint8))
pil_image.save(filepath)
#imsave(filepath, pil_image)
return filepath
def update_image(filepath):
for f in filepath:
f = tif_view(f)
filepath_show = norm_path(filepath[-1])
return filepath_show, filepath, fp0, fp0
def update_button(filepath):
filepath = tif_view(filepath)
filepath_show = norm_path(filepath)
return filepath_show, [filepath], fp0, fp0
with gr.Blocks(title = "Hello",
css=".gradio-container {background:purple;}") as demo:
#filepath = ""
with gr.Row():
with gr.Column(scale=2):
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:20pt; font-weight:bold; text-align:center; color:white;">Cellpose-SAM for cellular
segmentation <a style="color:#cfe7fe; font-size:14pt;" href="https://www.biorxiv.org/content/10.1101/2025.04.28.651001v1" target="_blank">[paper]</a>
<a style="color:white; font-size:14pt;" href="https://github.com/MouseLand/cellpose" target="_blank">[github]</a>
<a style="color:white; font-size:14pt;" href="https://www.youtube.com/watch?v=KIdYXgQemcI" target="_blank">[talk]</a>
</div>""")
gr.HTML("""<h4 style="color:white;">You may need to login/refresh for 5 minutes of free GPU compute per day (enough to process hundreds of images). </h4>""")
input_image = gr.Image(label = "Input", type = "filepath")
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
resize = gr.Number(label = 'max resize', value = 1000)
max_iter = gr.Number(label = 'max iterations', value = 250)
flow_threshold = gr.Number(label = 'flow threshold', value = 0.4)
cellprob_threshold = gr.Number(label = 'cellprob threshold', value = 0)
up_btn = gr.UploadButton("Multi-file upload (png, jpg, tif etc)", visible=True, file_count = "multiple")
#gr.HTML("""<h4 style="color:white;"> Note2: Only the first image of a tif will display the segmentations, but you can download segmentations for all planes. </h4>""")
with gr.Column(scale=1):
send_btn = gr.Button("Run Cellpose-SAM")
down_btn = gr.DownloadButton("Download masks (TIF)", visible=False)
down_btn2 = gr.DownloadButton("Download outlines (PNG)", visible=False)
with gr.Column(scale=2):
outlines = gr.Image(label = "Outlines", type = "pil", format = 'png', value = fp0) #, width = "50vw", height = "20vw")
#img_overlay = gr.Image(label = "Overlay", type = "pil", format = 'png') #, width = "50vw", height = "20vw")
flows = gr.Image(label = "Cellpose flows", type = "pil", format = 'png', value = fp0) #, width = "50vw", height = "20vw")
sample_list = glob.glob("samples/*.png")
#sample_list = []
#for j in range(23):
# sample_list.append("samples/img%0.2d.png"%j)
gr.Examples(sample_list, fn = update_button, inputs=input_image, outputs = [input_image, up_btn, outlines, flows], examples_per_page=50, label = "Click on an example to try it")
input_image.upload(update_button, input_image, [input_image, up_btn, outlines, flows])
up_btn.upload(update_image, up_btn, [input_image, up_btn, outlines, flows])
send_btn.click(cellpose_segment, [up_btn, resize, max_iter, flow_threshold, cellprob_threshold], [outlines, flows, down_btn, down_btn2])
#down_btn.click(download_function, None, [down_btn, down_btn2])
gr.HTML("""<h4 style="color:white;"> Notes:<br>
<li>you can load and process 2D, multi-channel tifs.
<li>the smallest dimension of a tif --> channels
<li>you can upload multiple files and download a zip of the segmentations
<li>install Cellpose-SAM locally for full functionality.
</h4>""")
demo.launch()