File size: 7,542 Bytes
74a9b84
 
 
 
 
 
 
 
 
 
 
 
ab7538b
74a9b84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab7538b
74a9b84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import torch
import ptp_utils
import cv2
import os
import numpy as np
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from model.unet import UNet2D, get_feature_dic, clear_feature_dic
from model.segment.transformer_decoder import seg_decorder
from model.segment.transformer_decoder_semantic import seg_decorder_open_word
from store import AttentionStore
import torch.nn.functional as F
from config import opt, classes
from scipy.special import softmax
from visualize import visualize_segmentation
from detectron2.structures import Boxes, ImageList, Instances, BitMasks
import gradio as gr

# Model loading
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
vae.to(device)
vae.eval()

unet = UNet2D.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
unet.to(device)
unet.eval()

text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
text_encoder.to(device)
text_encoder.eval()

scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
# seg_model = seg_decorder().to(device)
num_classes = 80
num_queries = 100
seg_model=seg_decorder_open_word(num_classes=num_classes, 
                           num_queries=num_queries).to(device)

base_weights = torch.load(opt.get('grounding_ckpt'), map_location="cpu")
seg_model.load_state_dict(base_weights, strict=True)

def freeze_params(params):
    for param in params:
        param.requires_grad = False

freeze_params(vae.parameters())
freeze_params(unet.parameters())
freeze_params(text_encoder.parameters())
freeze_params(seg_model.parameters())

def semantic_inference(mask_cls, mask_pred):
    mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
    mask_pred = mask_pred.sigmoid()
    semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
    for i in range(1, semseg.shape[0]):
        if (semseg[i] * (semseg[i] > 0.5)).sum() < 5000:
            semseg[i] = 0
    return semseg

def instance_inference(mask_cls, mask_pred,class_n = 2,test_topk_per_image=20,query_n = 100):
    # mask_pred is already processed to have the same shape as original input
    image_size = mask_pred.shape[-2:]

    # [Q, K]
    scores = F.softmax(mask_cls, dim=-1)[:, :-1]
    labels = torch.arange(class_n , device=mask_cls.device).unsqueeze(0).repeat(query_n, 1).flatten(0, 1)
    # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
    scores_per_image, topk_indices = scores.flatten(0, 1).topk(test_topk_per_image, sorted=False)
    labels_per_image = labels[topk_indices]

    topk_indices = topk_indices // class_n
    # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
#     print(topk_indices)
    mask_pred = mask_pred[topk_indices]


    result = Instances(image_size)
    # mask (before sigmoid)
    result.pred_masks = (mask_pred > 0).float()
    result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
    # Uncomment the following to get boxes from masks (this is slow)
    # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()

    # calculate average mask prob
    mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
    result.scores = scores_per_image * mask_scores_per_image
    result.pred_classes = labels_per_image
    return result

def inference(prompt):
    prompts = [prompt]
    LOW_RESOURCE = False
    NUM_DIFFUSION_STEPS = 20

    outpath = opt.get('outdir')
    image_path = os.path.join(outpath, "Image")
    mask_path = os.path.join(outpath, "Mask")
    vis_path = os.path.join(outpath, "Vis")
    os.makedirs(image_path, exist_ok=True)
    os.makedirs(mask_path, exist_ok=True)
    os.makedirs(vis_path, exist_ok=True)

    controller = AttentionStore()
    ptp_utils.register_attention_control(unet, controller)
    full_arr = np.zeros((81, 512,512), np.float32)
    full_arr[0]=0.5
    seed = 100
    with torch.no_grad():
        clear_feature_dic()
        controller.reset()
        g_cpu = torch.Generator().manual_seed(seed)
        image_out, x_t = ptp_utils.text2image(unet, vae, tokenizer, text_encoder, scheduler, prompts, controller, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=5, generator=g_cpu, low_resource=LOW_RESOURCE, Train=False)
        print("image_out ", image_out.shape)
        image_file = f"{image_path}/image.jpg"
        ptp_utils.save_images(image_out, out_put=image_file)
        for idxx in classes:
            if idxx==0:
                continue
            class_name = classes[idxx]
            # if class_name not in classes_check[idx]:
            #     continue 
            query_text = class_name            
            text_input = tokenizer(
                    query_text,
                    padding="max_length",
                    max_length=tokenizer.model_max_length,
                    truncation=True,
                    return_tensors="pt",
                )
            text_embeddings = text_encoder(text_input.input_ids.to(unet.device))[0]
            class_embedding=text_embeddings           
            if class_embedding.size()[1] > 1:
                class_embedding = torch.unsqueeze(class_embedding.mean(1),1)
            seed = 1
            diffusion_features = get_feature_dic()
            outputs=seg_model(diffusion_features,controller,prompts,tokenizer,class_embedding)
            # outputs = seg_model(diffusion_features, controller, prompts, tokenizer)
            mask_cls_results = outputs["pred_logits"]
            mask_pred_results = outputs["pred_masks"]
            mask_pred_results = F.interpolate(mask_pred_results, size=(512, 512), mode="bilinear", align_corners=False)

            for mask_cls_result, mask_pred_result in zip(mask_cls_results, mask_pred_results):
                instance_r = instance_inference(mask_cls_result, mask_pred_result,class_n = 80,test_topk_per_image=3,query_n = 100)
                pred_masks = instance_r.pred_masks.cpu().numpy().astype(np.uint8)
                pred_boxes = instance_r.pred_boxes
                scores = instance_r.scores 
                pred_classes = instance_r.pred_classes
                import heapq
                topk_idx = heapq.nlargest(1, range(len(scores)), scores.__getitem__)
                mask_instance = (pred_masks[topk_idx[0]]>0.5 * 1).astype(np.uint8) 
                full_arr[idxx] = np.array(mask_instance)   
        full_arr = softmax(full_arr, axis=0)
        mask = np.argmax(full_arr, axis=0)           
        print("mask ", mask.shape)
        mask_file = f"{mask_path}/mask.png"
        cv2.imwrite(mask_file, mask)
        vis_file = f"{vis_path}/visual.png"
        vis_map = visualize_segmentation(image_file, mask_file, vis_file)
        print("vis_map ", vis_map.shape)
        return image_out[0], vis_map

iface = gr.Interface(
    fn=inference,
    inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
    outputs=[
        gr.Image(label="Generated Image"),
        gr.Image(label="Segmentation Map")
    ],
    title="PromptPix: Image Generation & Segmentation",
    description="Enter a prompt to generate an image and its Segmask"
)

if __name__ == "__main__":
    iface.launch(share=True)