|
import torch |
|
import ptp_utils |
|
import cv2 |
|
import os |
|
import numpy as np |
|
from diffusers import AutoencoderKL, DDPMScheduler |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from model.unet import UNet2D, get_feature_dic, clear_feature_dic |
|
from model.segment.transformer_decoder import seg_decorder |
|
from model.segment.transformer_decoder_semantic import seg_decorder_open_word |
|
from store import AttentionStore |
|
import torch.nn.functional as F |
|
from config import opt, classes |
|
from scipy.special import softmax |
|
from visualize import visualize_segmentation |
|
from detectron2.structures import Boxes, ImageList, Instances, BitMasks |
|
import gradio as gr |
|
|
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer") |
|
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae") |
|
vae.to(device) |
|
vae.eval() |
|
|
|
unet = UNet2D.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") |
|
unet.to(device) |
|
unet.eval() |
|
|
|
text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder") |
|
text_encoder.to(device) |
|
text_encoder.eval() |
|
|
|
scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") |
|
|
|
num_classes = 80 |
|
num_queries = 100 |
|
seg_model=seg_decorder_open_word(num_classes=num_classes, |
|
num_queries=num_queries).to(device) |
|
|
|
base_weights = torch.load(opt.get('grounding_ckpt'), map_location="cpu") |
|
seg_model.load_state_dict(base_weights, strict=True) |
|
|
|
def freeze_params(params): |
|
for param in params: |
|
param.requires_grad = False |
|
|
|
freeze_params(vae.parameters()) |
|
freeze_params(unet.parameters()) |
|
freeze_params(text_encoder.parameters()) |
|
freeze_params(seg_model.parameters()) |
|
|
|
def semantic_inference(mask_cls, mask_pred): |
|
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] |
|
mask_pred = mask_pred.sigmoid() |
|
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) |
|
for i in range(1, semseg.shape[0]): |
|
if (semseg[i] * (semseg[i] > 0.5)).sum() < 5000: |
|
semseg[i] = 0 |
|
return semseg |
|
|
|
def instance_inference(mask_cls, mask_pred,class_n = 2,test_topk_per_image=20,query_n = 100): |
|
|
|
image_size = mask_pred.shape[-2:] |
|
|
|
|
|
scores = F.softmax(mask_cls, dim=-1)[:, :-1] |
|
labels = torch.arange(class_n , device=mask_cls.device).unsqueeze(0).repeat(query_n, 1).flatten(0, 1) |
|
|
|
scores_per_image, topk_indices = scores.flatten(0, 1).topk(test_topk_per_image, sorted=False) |
|
labels_per_image = labels[topk_indices] |
|
|
|
topk_indices = topk_indices // class_n |
|
|
|
|
|
mask_pred = mask_pred[topk_indices] |
|
|
|
|
|
result = Instances(image_size) |
|
|
|
result.pred_masks = (mask_pred > 0).float() |
|
result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4)) |
|
|
|
|
|
|
|
|
|
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6) |
|
result.scores = scores_per_image * mask_scores_per_image |
|
result.pred_classes = labels_per_image |
|
return result |
|
|
|
def inference(prompt): |
|
prompts = [prompt] |
|
LOW_RESOURCE = False |
|
NUM_DIFFUSION_STEPS = 20 |
|
|
|
outpath = opt.get('outdir') |
|
image_path = os.path.join(outpath, "Image") |
|
mask_path = os.path.join(outpath, "Mask") |
|
vis_path = os.path.join(outpath, "Vis") |
|
os.makedirs(image_path, exist_ok=True) |
|
os.makedirs(mask_path, exist_ok=True) |
|
os.makedirs(vis_path, exist_ok=True) |
|
|
|
controller = AttentionStore() |
|
ptp_utils.register_attention_control(unet, controller) |
|
full_arr = np.zeros((81, 512,512), np.float32) |
|
full_arr[0]=0.5 |
|
seed = 100 |
|
with torch.no_grad(): |
|
clear_feature_dic() |
|
controller.reset() |
|
g_cpu = torch.Generator().manual_seed(seed) |
|
image_out, x_t = ptp_utils.text2image(unet, vae, tokenizer, text_encoder, scheduler, prompts, controller, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=5, generator=g_cpu, low_resource=LOW_RESOURCE, Train=False) |
|
print("image_out ", image_out.shape) |
|
image_file = f"{image_path}/image.jpg" |
|
ptp_utils.save_images(image_out, out_put=image_file) |
|
for idxx in classes: |
|
if idxx==0: |
|
continue |
|
class_name = classes[idxx] |
|
|
|
|
|
query_text = class_name |
|
text_input = tokenizer( |
|
query_text, |
|
padding="max_length", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_embeddings = text_encoder(text_input.input_ids.to(unet.device))[0] |
|
class_embedding=text_embeddings |
|
if class_embedding.size()[1] > 1: |
|
class_embedding = torch.unsqueeze(class_embedding.mean(1),1) |
|
seed = 1 |
|
diffusion_features = get_feature_dic() |
|
outputs=seg_model(diffusion_features,controller,prompts,tokenizer,class_embedding) |
|
|
|
mask_cls_results = outputs["pred_logits"] |
|
mask_pred_results = outputs["pred_masks"] |
|
mask_pred_results = F.interpolate(mask_pred_results, size=(512, 512), mode="bilinear", align_corners=False) |
|
|
|
for mask_cls_result, mask_pred_result in zip(mask_cls_results, mask_pred_results): |
|
instance_r = instance_inference(mask_cls_result, mask_pred_result,class_n = 80,test_topk_per_image=3,query_n = 100) |
|
pred_masks = instance_r.pred_masks.cpu().numpy().astype(np.uint8) |
|
pred_boxes = instance_r.pred_boxes |
|
scores = instance_r.scores |
|
pred_classes = instance_r.pred_classes |
|
import heapq |
|
topk_idx = heapq.nlargest(1, range(len(scores)), scores.__getitem__) |
|
mask_instance = (pred_masks[topk_idx[0]]>0.5 * 1).astype(np.uint8) |
|
full_arr[idxx] = np.array(mask_instance) |
|
full_arr = softmax(full_arr, axis=0) |
|
mask = np.argmax(full_arr, axis=0) |
|
print("mask ", mask.shape) |
|
mask_file = f"{mask_path}/mask.png" |
|
cv2.imwrite(mask_file, mask) |
|
vis_file = f"{vis_path}/visual.png" |
|
vis_map = visualize_segmentation(image_file, mask_file, vis_file) |
|
print("vis_map ", vis_map.shape) |
|
return image_out[0], vis_map |
|
|
|
iface = gr.Interface( |
|
fn=inference, |
|
inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."), |
|
outputs=[ |
|
gr.Image(label="Generated Image"), |
|
gr.Image(label="Segmentation Map") |
|
], |
|
title="PromptPix: Image Generation & Segmentation", |
|
description="Enter a prompt to generate an image and its Segmask" |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch(share=True) |
|
|