Diffusers
Safetensors
PixCellControlNetPipeline
File size: 3,191 Bytes
669d980
 
 
 
62b8f7f
 
 
 
292a14e
62b8f7f
 
669d980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5e2ae6
669d980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa7e6ae
669d980
fa7e6ae
669d980
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
---
license: apache-2.0
---

<img src="pixcell_256_cell_controlnet_banner.png" alt="pixcell_256_cell_controlnet_banner" width="500"/>

# PixCell: A generative foundation model for digital histopathology images

[[๐Ÿ“„ arXiv]](https://arxiv.org/abs/2506.05127)[[๐Ÿ”ฌ PixCell-1024]](https://huggingface.co/StonyBrook-CVLab/PixCell-1024) [[๐Ÿ”ฌ PixCell-256]](https://huggingface.co/StonyBrook-CVLab/PixCell-256) [[๐Ÿ”ฌ Pixcell-256-Cell-ControlNet]](https://huggingface.co/StonyBrook-CVLab/PixCell-256-Cell-ControlNet) [[๐Ÿ’พ Synthetic SBU-1M]](https://huggingface.co/datasets/StonyBrook-CVLab/Synthetic-SBU-1M)


### Load PixCell-256-Cell-ControlNet model

```python
import torch

from diffusers import DiffusionPipeline
from diffusers import AutoencoderKL

device = torch.device('cuda')

# We do not host the weights of the SD3 VAE -- load it from StabilityAI
sd3_vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3.5-large", subfolder="vae")

pipeline = DiffusionPipeline.from_pretrained(
    "StonyBrook-CVLab/PixCell-256-Cell-ControlNet",
    vae=sd3_vae,
    custom_pipeline="StonyBrook-CVLab/PixCell-pipeline-ControlNet",
    trust_remote_code=True,
)

pipeline.to(device);
```

### Load [[UNI-2h]](https://huggingface.co/MahmoodLab/UNI2-h) for conditioning
```python
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

timm_kwargs = {
            'img_size': 224, 
            'patch_size': 14, 
            'depth': 24,
            'num_heads': 24,
            'init_values': 1e-5, 
            'embed_dim': 1536,
            'mlp_ratio': 2.66667*2,
            'num_classes': 0, 
            'no_embed_class': True,
            'mlp_layer': timm.layers.SwiGLUPacked, 
            'act_layer': torch.nn.SiLU, 
            'reg_tokens': 8, 
            'dynamic_img_size': True
        }
uni_model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
uni_transforms = create_transform(**resolve_data_config(uni_model.pretrained_cfg, model=uni_model))
uni_model.eval()
uni_model.to(device);
```


### Mask-conditioned generation
```python
# Load image
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

# This is an example image/mask pair we provide
image_path = hf_hub_download(repo_id="StonyBrook-CVLab/PixCell-256-Cell-ControlNet", filename="test_image.png")
mask_path = hf_hub_download(repo_id="StonyBrook-CVLab/PixCell-256-Cell-ControlNet", filename="test_mask.png")
image = Image.open(image_path).convert("RGB")
mask = np.asarray(Image.open(mask_path).convert("RGB"))

# Extract UNI embedding from the image
uni_inp = uni_transforms(image).unsqueeze(dim=0)
with torch.inference_mode():
    uni_emb = uni_model(uni_inp.to(device))

# reshape UNI to (bs, 1, D)
uni_emb = uni_emb.unsqueeze(1)
print("Extracted UNI:", uni_emb.shape)

# Get unconditional embedding for classifier-free guidance
uncond = pipeline.get_unconditional_embedding(uni_emb.shape[0])
# Generate new samples using the given mask
samples = pipeline(uni_embeds=uni_emb, controlnet_input=mask, negative_uni_embeds=uncond, guidance_scale=2.5, num_images_per_prompt=1).images
```