File size: 5,080 Bytes
4258b21
 
 
 
 
 
 
 
 
43ab69e
4258b21
 
 
 
 
 
eabc840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43ab69e
 
 
 
 
eabc840
 
 
 
4258b21
 
 
 
eabc840
 
 
 
 
4258b21
 
 
 
eabc840
 
4258b21
 
 
 
 
 
eabc840
4258b21
 
 
 
 
 
 
 
 
 
 
43ab69e
 
 
 
 
 
 
 
 
 
 
 
4258b21
 
 
 
 
 
 
b73fa31
4258b21
 
 
43ab69e
 
4258b21
 
 
 
 
 
 
 
 
b73fa31
4258b21
b73fa31
4258b21
 
 
 
 
 
 
eabc840
b73fa31
 
eabc840
 
 
4258b21
 
 
320c59f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import ptp_utils
import cv2
import os
import numpy as np
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from model.unet import UNet2D, get_feature_dic, clear_feature_dic
from model.segment.transformer_decoder import seg_decorder
from model.segment.transformer_decoder_semantic import seg_decorder_open_word
from store import AttentionStore
import torch.nn.functional as F
from config import opt
from visualize import visualize_segmentation
import gradio as gr

# Model loading
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
vae.to(device)
vae.eval()

unet = UNet2D.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
unet.to(device)
unet.eval()

text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
text_encoder.to(device)
text_encoder.eval()

scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
# seg_model = seg_decorder().to(device)
num_classes = 80
num_queries = 100
seg_model=seg_decorder_open_word(num_classes=num_classes, 
                           num_queries=num_queries).to(device)

base_weights = torch.load(opt.get('grounding_ckpt'), map_location="cpu")
seg_model.load_state_dict(base_weights, strict=True)

def freeze_params(params):
    for param in params:
        param.requires_grad = False

freeze_params(vae.parameters())
freeze_params(unet.parameters())
freeze_params(text_encoder.parameters())
freeze_params(seg_model.parameters())

def semantic_inference(mask_cls, mask_pred):
    mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
    mask_pred = mask_pred.sigmoid()
    semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
    for i in range(1, semseg.shape[0]):
        if (semseg[i] * (semseg[i] > 0.5)).sum() < 5000:
            semseg[i] = 0
    return semseg

def inference(prompt):
    prompts = [prompt]
    LOW_RESOURCE = False
    NUM_DIFFUSION_STEPS = 20

    outpath = opt.get('outdir')
    image_path = os.path.join(outpath, "Image")
    mask_path = os.path.join(outpath, "Mask")
    vis_path = os.path.join(outpath, "Vis")
    os.makedirs(image_path, exist_ok=True)
    os.makedirs(mask_path, exist_ok=True)
    os.makedirs(vis_path, exist_ok=True)

    controller = AttentionStore()
    ptp_utils.register_attention_control(unet, controller)
    query_text = 'road'            
    text_input = tokenizer(
            query_text,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
    text_embeddings = text_encoder(text_input.input_ids.to(unet.device))[0]
    class_embedding=text_embeddings           
    if class_embedding.size()[1] > 1:
        class_embedding = torch.unsqueeze(class_embedding.mean(1),1)
    seed = 1
    with torch.no_grad():
        clear_feature_dic()
        controller.reset()
        g_cpu = torch.Generator().manual_seed(seed)
        image_out, x_t = ptp_utils.text2image(unet, vae, tokenizer, text_encoder, scheduler, prompts, controller, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=5, generator=g_cpu, low_resource=LOW_RESOURCE, Train=False)
        print("image_out ", image_out.shape)
        image_file = f"{image_path}/image.jpg"
        ptp_utils.save_images(image_out, out_put=image_file)

        diffusion_features = get_feature_dic()
        outputs=seg_model(diffusion_features,controller,prompts,tokenizer,class_embedding)
        # outputs = seg_model(diffusion_features, controller, prompts, tokenizer)
        mask_cls_results = outputs["pred_logits"]
        mask_pred_results = outputs["pred_masks"]
        mask_pred_results = F.interpolate(mask_pred_results, size=(512, 512), mode="bilinear", align_corners=False)

        for mask_cls_result, mask_pred_result in zip(mask_cls_results, mask_pred_results):
            label_pred_prob = semantic_inference(mask_cls_result, mask_pred_result)
            label_pred_prob = torch.argmax(label_pred_prob, axis=0)
            label_pred_prob = label_pred_prob.cpu().numpy()
            print("mask ", label_pred_prob.shape)
            mask_file = f"{mask_path}/mask.png"
            cv2.imwrite(mask_file, label_pred_prob)
            vis_file = f"{vis_path}/visual.png"
            vis_map = visualize_segmentation(image_file, mask_file, vis_file)
        print("vis_map ", vis_map.shape)
        return image_out[0], vis_map

iface = gr.Interface(
    fn=inference,
    inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
    outputs=[
        gr.Image(label="Generated Image"),
        gr.Image(label="Segmentation Map")
    ],
    title="PromptPix: Image Generation & Segmentation",
    description="Enter a prompt to generate an image and its Segmask"
)

if __name__ == "__main__":
    iface.launch()