File size: 13,688 Bytes
49c6db7
 
8786ac3
49c6db7
 
 
f0821bf
 
44c9541
e16c5c4
e426c5a
3e66137
47accba
 
6f40774
980a5f7
 
 
f672193
1cffc39
47accba
266f336
47accba
 
465186b
 
3e66137
 
d874e72
3e66137
 
 
 
 
2fefa26
 
 
 
 
 
 
d51d5c4
2fefa26
 
 
 
 
 
3e66137
49c6db7
2949f09
 
49c6db7
 
 
 
3e66137
 
 
49c6db7
 
 
 
 
 
 
 
 
 
f0821bf
fd26ead
255910e
49c6db7
 
 
 
 
 
 
 
f0821bf
 
 
 
 
 
 
 
 
 
1c5339e
f0821bf
 
 
 
c034f55
 
 
 
 
 
 
 
 
 
 
 
 
736285e
c034f55
ea7f537
 
5e93cff
 
ea7f537
49c6db7
 
bf308b6
 
 
 
 
 
255910e
49c6db7
 
 
 
 
 
 
 
 
 
 
5e93cff
49c6db7
 
32b0ba4
49c6db7
 
 
 
 
 
 
 
 
 
 
 
 
eeb07d9
ded4361
7a96d4b
 
a887322
 
c08aa90
7a96d4b
 
c08aa90
 
2bae9a9
7a96d4b
 
c08aa90
 
7090db4
f9b1b13
7a96d4b
5ba71fa
aabd05b
 
4bc8e38
 
7090db4
 
 
4bc8e38
 
 
 
8ebcd09
f7e0b43
7a96d4b
8ebcd09
7a96d4b
8ebcd09
7a96d4b
4bc8e38
 
d9286e0
8bb1bfa
7090db4
4bc8e38
 
 
7c8c045
0079cc8
 
 
4bc8e38
7c8c045
5ba71fa
218cd58
4bc8e38
5ba71fa
a887322
49c6db7
 
 
18b4441
8a8ccfd
5ba71fa
958ea27
5ba71fa
18b4441
 
 
4c3c584
a08adb1
eecd9f2
7902217
 
 
5ba71fa
7902217
 
5ba71fa
e16c5c4
980a5f7
255910e
 
93bdd17
4bc8e38
 
 
 
e16c5c4
4c3c584
980a5f7
49c6db7
fd1e2f9
25641bf
 
 
 
f62e231
 
 
 
f453514
 
 
 
 
 
 
ecf2bd9
f453514
 
 
 
 
 
e3ba99d
f453514
f62e231
e3ba99d
 
 
 
64935e2
e3ba99d
64935e2
692738d
7b4fea1
 
e3ba99d
f62e231
a7916aa
f62e231
 
e3ba99d
 
9e13cfa
f62e231
 
ac8a141
e3ba99d
2d8800e
2c27168
 
 
c77528c
2c27168
c3abe48
95242a5
50f6c2b
95242a5
47ad03f
95242a5
03ae964
 
2e35a3d
1fe72e5
 
22bab81
d23bb1f
 
9a7b38a
7a96d4b
 
9a7b38a
 
f43ddd6
daeff40
e19ebf3
16b034a
f43ddd6
c0e39ef
 
03ae964
f43ddd6
980a5f7
f43ddd6
980a5f7
4d163cf
 
e426c5a
 
 
 
 
e3ba99d
44d9c6b
7beb8b9
 
e3ba99d
7a96d4b
a3dc09f
25641bf
2c27168
157bb22
3e49020
2949f09
3e49020
157bb22
 
2c27168
fe58ba1
2c27168
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import numpy as np
import gradio as gr
import spaces
import cv2
from cellpose import models
from matplotlib.colors import hsv_to_rgb
import matplotlib.pyplot as plt
import os, io, base64
from PIL import Image 
from cellpose.io import imread, imsave
import glob 

from huggingface_hub import hf_hub_download

img = np.zeros((96, 128), dtype = np.uint8)
fp0 = Image.fromarray(img)
#fp0 = "0.png"
#imsave(fp0, img)

# data  retrieval
def download_weights():    
    return hf_hub_download(repo_id="mouseland/cellpose-sam", filename="cpsam")
    
    #os.system("wget -q https://huggingface.co/mouseland/cellpose-sam/resolve/main/cpsam")

def download_weights_old():
    import os, requests
    
    fname = ['cpsam']
    
    url = ["https://osf.io/d7c8e/download"]
    
    for j in range(len(url)):
      if not os.path.isfile(fname[j]):
        ntries = 0
        while ntries<10:
            try:
              r = requests.get(url[j])
            except:
                print("!!! Failed to download data !!!")
                ntries += 1 
                print(ntries)
            
      if r.status_code != requests.codes.ok:
        print("!!! Failed to download data !!!")
      else:
        with open(fname[j], "wb") as fid:
          fid.write(r.content)

try:
    fpath = download_weights()
    model = models.CellposeModel(gpu=True, pretrained_model = fpath)
except Exception as e:
    print(f"Error loading model: {e}")
    exit(1)



            
def plot_flows(y):
    Y = (np.clip(normalize99(y[0][0]),0,1) - 0.5) * 2
    X = (np.clip(normalize99(y[1][0]),0,1) - 0.5) * 2
    H = (np.arctan2(Y, X) + np.pi) / (2*np.pi)
    S = normalize99(y[0][0]**2 + y[1][0]**2)
    HSV = np.concatenate((H[:,:,np.newaxis], S[:,:,np.newaxis], S[:,:,np.newaxis]), axis=-1)
    HSV = np.clip(HSV, 0.0, 1.0)
    flow = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
    return flow

def plot_outlines(img, masks):
    img = normalize99(img)
    img = np.clip(img, 0, 1)
    outpix = []
    contours, hierarchy = cv2.findContours(masks.astype(np.int32), mode=cv2.RETR_FLOODFILL, method=cv2.CHAIN_APPROX_SIMPLE)
    for c in range(len(contours)):
        pix = contours[c].astype(int).squeeze()
        if len(pix)>4:
            peri = cv2.arcLength(contours[c], True)
            approx = cv2.approxPolyDP(contours[c], 0.001, True)[:,0,:]
            outpix.append(approx)
    
    figsize = (6,6)
    if img.shape[0]>img.shape[1]:
        figsize = (6*img.shape[1]/img.shape[0], 6)
    else:
        figsize = (6, 6*img.shape[0]/img.shape[1])
    fig = plt.figure(figsize=figsize, facecolor='k')
    ax = fig.add_axes([0.0,0.0,1,1])
    ax.set_xlim([0,img.shape[1]])
    ax.set_ylim([0,img.shape[0]])
    ax.imshow(img[::-1], origin='upper', aspect = 'auto')
    if outpix is not None:
        for o in outpix:
            ax.plot(o[:,0], img.shape[0]-o[:,1], color=[1,0,0], lw=1)
    ax.axis('off')
    
    #bytes_image = io.BytesIO()
    #plt.savefig(bytes_image, format='png', facecolor=fig.get_facecolor(), edgecolor='none')
    #bytes_image.seek(0)
    #img_arr = np.frombuffer(bytes_image.getvalue(), dtype=np.uint8)
    #bytes_image.close()
    #img = cv2.imdecode(img_arr, 1)
    #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #del bytes_image
    #fig.clf()
    #plt.close(fig)

    buf = io.BytesIO()
    fig.savefig(buf, bbox_inches='tight')
    buf.seek(0)
    pil_img = Image.open(buf)

    plt.close(fig)

    return pil_img

def plot_overlay(img, masks):
    if img.ndim>2:
        img_gray = img.astype(np.float32).mean(axis=-1)
    else:
        img_gray = img.astype(np.float32)
        
    img = normalize99(img_gray)
    #img = np.clip(img, 0, 1)
    HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
    HSV[:,:,2] = np.clip(img*1.5, 0, 1.0)
    for n in range(int(masks.max())):
        ipix = (masks==n+1).nonzero()
        HSV[ipix[0],ipix[1],0] = np.random.rand()
        HSV[ipix[0],ipix[1],1] = 1.0
    RGB = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
    return RGB

def normalize99(img):
    X = img.copy()
    X = (X - np.percentile(X, 1)) / (1e-10 + np.percentile(X, 99) - np.percentile(X, 1))
    return X

def image_resize(img, resize=400):
    ny,nx = img.shape[:2]
    if np.array(img.shape).max() > resize:
        if ny>nx:
            nx = int(nx/ny * resize)
            ny = resize
        else:
            ny = int(ny/nx * resize)
            nx = resize
        shape = (nx,ny)
        img = cv2.resize(img, shape)
    img = img.astype(np.uint8)
    return img

    
@spaces.GPU(duration=10)
def run_model_gpu(img, max_iter, flow_threshold, cellprob_threshold):
    masks, flows, _ = model.eval(img, niter = max_iter, flow_threshold = flow_threshold, cellprob_threshold = cellprob_threshold)
    return masks, flows

@spaces.GPU(duration=60)
def run_model_gpu60(img, max_iter, flow_threshold, cellprob_threshold):
    masks, flows, _ = model.eval(img, niter = max_iter, flow_threshold = flow_threshold, cellprob_threshold = cellprob_threshold)
    return masks, flows

@spaces.GPU(duration=240)
def run_model_gpu240(img, max_iter, flow_threshold, cellprob_threshold):
    masks, flows, _ = model.eval(img, niter = max_iter, flow_threshold = flow_threshold, cellprob_threshold = cellprob_threshold)
    return masks, flows

import datetime
from zipfile import ZipFile
def cellpose_segment(filepath, resize = 1000,max_iter = 250, flow_threshold= 0.4, cellprob_threshold = 0):

    zip_path = os.path.splitext(filepath[-1])[0]+"_masks.zip"
    #zip_path = 'masks.zip'
    with ZipFile(zip_path, 'w') as myzip:
        for j in range((len(filepath))):
            now = datetime.datetime.now()
            formatted_now = now.strftime("%Y-%m-%d %H:%M:%S")            
            
            img_input = imread(filepath[j])
            #img_input = np.array(img_pil)
            img = image_resize(img_input, resize = resize)
            
            maxsize = np.max(img.shape)
            if maxsize<=1000:
                masks, flows = run_model_gpu(img, max_iter, flow_threshold, cellprob_threshold)
            elif maxsize < 5000:
                masks, flows = run_model_gpu60(img, max_iter, flow_threshold, cellprob_threshold)
            elif maxsize < 20000:
                masks, flows = run_model_gpu240(img, max_iter, flow_threshold, cellprob_threshold)
            else:
                raise ValueError("Image size must be less than 20,000")

            print(formatted_now, j, masks.max(), os.path.split(filepath[j])[-1])
            
            target_size = (img_input.shape[1], img_input.shape[0])
            if (target_size[0]!=img.shape[1] or target_size[1]!=img.shape[0]):
                # scale it back to keep the orignal size
                masks_rsz = cv2.resize(masks.astype('uint16'), target_size, interpolation=cv2.INTER_NEAREST).astype('uint16')
            else:
                masks_rsz = masks.copy()
                
            fname_masks = os.path.splitext(filepath[j])[0]+"_masks.tif"
            imsave(fname_masks, masks_rsz)
    
            myzip.write(fname_masks, arcname = os.path.split(fname_masks)[-1])
            
    
    #masks, flows, _ = model.eval(img, channels=[0,0])
    flows = flows[0]
    # masks = np.zeros(img.shape[:2])
    # flows = np.zeros_like(img)

    outpix = plot_outlines(img, masks)
    #overlay = plot_overlay(img, masks)
    
        
    
    #crand = .2 + .8 * np.random.rand(np.max(masks.flatten()).astype('int')+1,).astype('float32')
    #crand[0] = 0

    #overlay = Image.fromarray(overlay)
    flows = Image.fromarray(flows)

    Ly, Lx = img.shape[:2]
    outpix = outpix.resize((Lx, Ly), resample  = Image.BICUBIC)
    #overlay = overlay.resize((Lx, Ly), resample  = Image.BICUBIC)
    flows = flows.resize((Lx, Ly), resample  = Image.BICUBIC)

    fname_out  = os.path.splitext(filepath[-1])[0]+"_outlines.png"
    outpix.save(fname_out) #"outlines.png")
    
    #fname_flows  = os.path.splitext(filepath[-1])[0]+"_flows.png"
    #flows.save(fname_flows) #"outlines.png")

    if len(filepath)>1:
        b1 = gr.DownloadButton(visible=True, value = zip_path)
    else:
        b1 = gr.DownloadButton(visible=True, value = fname_masks)
    b2 = gr.DownloadButton(visible=True, value = fname_out) #"outlines.png")
    
    return outpix, flows, b1, b2

def download_function(): 
    b1 = gr.DownloadButton("Download masks as TIFF", visible=False)
    b2 = gr.DownloadButton("Download outline image as PNG", visible=False)
    return b1, b2

def tif_view(filepath):
    fpath, fext = os.path.splitext(filepath)
    if fext in ['tiff', 'tif']:
        img = imread(filepath[-1])
        if img.ndim==2:
            img = np.tile(img[:,:,np.newxis], [1,1,3])
        elif img.ndim==3:
            imin = np.argmin(img.shape)
            if imin<2:
                img = np.tranpose(img, [2, imin])
        else:
            raise ValueError("TIF cannot have more than three dimensions")

        Ly, Lx, nchan = img.shape
        imgi = np.zeros((Ly, Lx, 3))
        nn = np.minimum(3, img.shape[-1])
        imgi[:,:,:nn] = img[:,:,:nn]
        
        #filepath = fpath+'.png'
        imsave(filepath, imgi)
    return filepath

def norm_path(filepath):
    img = imread(filepath)
    img = normalize99(img)
    img = np.clip(img, 0, 1)
    fpath, fext = os.path.splitext(filepath)
    filepath = fpath +'.png'
    pil_image = Image.fromarray((255. * img).astype(np.uint8))
    pil_image.save(filepath)
    #imsave(filepath, pil_image)
    return filepath 
    
def update_image(filepath): 
    for f in filepath:
        f = tif_view(f)
    filepath_show = norm_path(filepath[-1])
    return filepath_show, filepath, fp0, fp0

def update_button(filepath):
    filepath = tif_view(filepath)
    filepath_show = norm_path(filepath)
    return filepath_show, [filepath], fp0, fp0
    
with gr.Blocks(title = "Hello", 
               css=".gradio-container {background:purple;}") as demo:

    #filepath = ""
    with gr.Row():
        with gr.Column(scale=2):
            gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:20pt; font-weight:bold; text-align:center; color:white;">Cellpose-SAM for cellular 
            segmentation <a style="color:#cfe7fe; font-size:14pt;" href="https://www.biorxiv.org/content/10.1101/2025.04.28.651001v1" target="_blank">[paper]</a> 
            <a style="color:white; font-size:14pt;" href="https://github.com/MouseLand/cellpose" target="_blank">[github]</a>
            <a style="color:white; font-size:14pt;" href="https://www.youtube.com/watch?v=KIdYXgQemcI" target="_blank">[talk]</a>                        
            </div>""")
            gr.HTML("""<h4 style="color:white;">You may need to login/refresh for 5 minutes of free GPU compute per day (enough to process hundreds of images). </h4>""")
            
            input_image = gr.Image(label = "Input", type = "filepath")

            with gr.Row():
                with gr.Column(scale=1):                    
                    with gr.Row():
                        resize = gr.Number(label = 'max resize', value = 1000)
                        max_iter = gr.Number(label = 'max iterations', value = 250)
                        flow_threshold = gr.Number(label = 'flow threshold', value = 0.4)
                        cellprob_threshold = gr.Number(label = 'cellprob threshold', value = 0)
                        
                    up_btn = gr.UploadButton("Multi-file upload (png, jpg, tif etc)", visible=True, file_count = "multiple")                        
                    
                    #gr.HTML("""<h4 style="color:white;"> Note2: Only the first image of a tif will display the segmentations, but you can download segmentations for all planes. </h4>""")
                    
                with gr.Column(scale=1):
                    send_btn = gr.Button("Run Cellpose-SAM")
                    down_btn = gr.DownloadButton("Download masks (TIF)", visible=False)            
                    down_btn2 = gr.DownloadButton("Download outlines (PNG)", visible=False)  
                    
        with gr.Column(scale=2):     
            outlines = gr.Image(label = "Outlines", type = "pil", format = 'png', value = fp0) #, width = "50vw", height = "20vw")
            #img_overlay = gr.Image(label = "Overlay", type = "pil", format = 'png') #, width = "50vw", height = "20vw")
            flows = gr.Image(label = "Cellpose flows", type = "pil", format = 'png', value = fp0) #, width = "50vw", height = "20vw")

            
    
    sample_list = glob.glob("samples/*.png")
    #sample_list = []
    #for j in range(23):
    #    sample_list.append("samples/img%0.2d.png"%j)
        
    gr.Examples(sample_list, fn = update_button, inputs=input_image, outputs = [input_image, up_btn, outlines, flows], examples_per_page=50, label = "Click on an example to try it")
    input_image.upload(update_button, input_image, [input_image, up_btn, outlines, flows])
    up_btn.upload(update_image, up_btn, [input_image, up_btn, outlines, flows])
    
    send_btn.click(cellpose_segment, [up_btn, resize, max_iter, flow_threshold, cellprob_threshold], [outlines, flows, down_btn, down_btn2])

    #down_btn.click(download_function, None, [down_btn, down_btn2])
        
    gr.HTML("""<h4 style="color:white;"> Notes:<br> 
                    <li>you can load and process 2D, multi-channel tifs.
                    <li>the smallest dimension of a tif --> channels
                    <li>you can upload multiple files and download a zip of the segmentations
                    <li>install Cellpose-SAM locally for full functionality.
                    </h4>""")
    
                    
demo.launch()