File size: 7,542 Bytes
74a9b84 ab7538b 74a9b84 ab7538b 74a9b84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import torch
import ptp_utils
import cv2
import os
import numpy as np
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from model.unet import UNet2D, get_feature_dic, clear_feature_dic
from model.segment.transformer_decoder import seg_decorder
from model.segment.transformer_decoder_semantic import seg_decorder_open_word
from store import AttentionStore
import torch.nn.functional as F
from config import opt, classes
from scipy.special import softmax
from visualize import visualize_segmentation
from detectron2.structures import Boxes, ImageList, Instances, BitMasks
import gradio as gr
# Model loading
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
vae.to(device)
vae.eval()
unet = UNet2D.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
unet.to(device)
unet.eval()
text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
text_encoder.to(device)
text_encoder.eval()
scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
# seg_model = seg_decorder().to(device)
num_classes = 80
num_queries = 100
seg_model=seg_decorder_open_word(num_classes=num_classes,
num_queries=num_queries).to(device)
base_weights = torch.load(opt.get('grounding_ckpt'), map_location="cpu")
seg_model.load_state_dict(base_weights, strict=True)
def freeze_params(params):
for param in params:
param.requires_grad = False
freeze_params(vae.parameters())
freeze_params(unet.parameters())
freeze_params(text_encoder.parameters())
freeze_params(seg_model.parameters())
def semantic_inference(mask_cls, mask_pred):
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
mask_pred = mask_pred.sigmoid()
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
for i in range(1, semseg.shape[0]):
if (semseg[i] * (semseg[i] > 0.5)).sum() < 5000:
semseg[i] = 0
return semseg
def instance_inference(mask_cls, mask_pred,class_n = 2,test_topk_per_image=20,query_n = 100):
# mask_pred is already processed to have the same shape as original input
image_size = mask_pred.shape[-2:]
# [Q, K]
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
labels = torch.arange(class_n , device=mask_cls.device).unsqueeze(0).repeat(query_n, 1).flatten(0, 1)
# scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
scores_per_image, topk_indices = scores.flatten(0, 1).topk(test_topk_per_image, sorted=False)
labels_per_image = labels[topk_indices]
topk_indices = topk_indices // class_n
# mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
# print(topk_indices)
mask_pred = mask_pred[topk_indices]
result = Instances(image_size)
# mask (before sigmoid)
result.pred_masks = (mask_pred > 0).float()
result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
# Uncomment the following to get boxes from masks (this is slow)
# result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
# calculate average mask prob
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
result.scores = scores_per_image * mask_scores_per_image
result.pred_classes = labels_per_image
return result
def inference(prompt):
prompts = [prompt]
LOW_RESOURCE = False
NUM_DIFFUSION_STEPS = 20
outpath = opt.get('outdir')
image_path = os.path.join(outpath, "Image")
mask_path = os.path.join(outpath, "Mask")
vis_path = os.path.join(outpath, "Vis")
os.makedirs(image_path, exist_ok=True)
os.makedirs(mask_path, exist_ok=True)
os.makedirs(vis_path, exist_ok=True)
controller = AttentionStore()
ptp_utils.register_attention_control(unet, controller)
full_arr = np.zeros((81, 512,512), np.float32)
full_arr[0]=0.5
seed = 100
with torch.no_grad():
clear_feature_dic()
controller.reset()
g_cpu = torch.Generator().manual_seed(seed)
image_out, x_t = ptp_utils.text2image(unet, vae, tokenizer, text_encoder, scheduler, prompts, controller, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=5, generator=g_cpu, low_resource=LOW_RESOURCE, Train=False)
print("image_out ", image_out.shape)
image_file = f"{image_path}/image.jpg"
ptp_utils.save_images(image_out, out_put=image_file)
for idxx in classes:
if idxx==0:
continue
class_name = classes[idxx]
# if class_name not in classes_check[idx]:
# continue
query_text = class_name
text_input = tokenizer(
query_text,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = text_encoder(text_input.input_ids.to(unet.device))[0]
class_embedding=text_embeddings
if class_embedding.size()[1] > 1:
class_embedding = torch.unsqueeze(class_embedding.mean(1),1)
seed = 1
diffusion_features = get_feature_dic()
outputs=seg_model(diffusion_features,controller,prompts,tokenizer,class_embedding)
# outputs = seg_model(diffusion_features, controller, prompts, tokenizer)
mask_cls_results = outputs["pred_logits"]
mask_pred_results = outputs["pred_masks"]
mask_pred_results = F.interpolate(mask_pred_results, size=(512, 512), mode="bilinear", align_corners=False)
for mask_cls_result, mask_pred_result in zip(mask_cls_results, mask_pred_results):
instance_r = instance_inference(mask_cls_result, mask_pred_result,class_n = 80,test_topk_per_image=3,query_n = 100)
pred_masks = instance_r.pred_masks.cpu().numpy().astype(np.uint8)
pred_boxes = instance_r.pred_boxes
scores = instance_r.scores
pred_classes = instance_r.pred_classes
import heapq
topk_idx = heapq.nlargest(1, range(len(scores)), scores.__getitem__)
mask_instance = (pred_masks[topk_idx[0]]>0.5 * 1).astype(np.uint8)
full_arr[idxx] = np.array(mask_instance)
full_arr = softmax(full_arr, axis=0)
mask = np.argmax(full_arr, axis=0)
print("mask ", mask.shape)
mask_file = f"{mask_path}/mask.png"
cv2.imwrite(mask_file, mask)
vis_file = f"{vis_path}/visual.png"
vis_map = visualize_segmentation(image_file, mask_file, vis_file)
print("vis_map ", vis_map.shape)
return image_out[0], vis_map
iface = gr.Interface(
fn=inference,
inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
outputs=[
gr.Image(label="Generated Image"),
gr.Image(label="Segmentation Map")
],
title="PromptPix: Image Generation & Segmentation",
description="Enter a prompt to generate an image and its Segmask"
)
if __name__ == "__main__":
iface.launch(share=True)
|