promptpix / app2.py
RingL's picture
updated scripts
ab7538b
import torch
import ptp_utils
import cv2
import os
import numpy as np
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from model.unet import UNet2D, get_feature_dic, clear_feature_dic
from model.segment.transformer_decoder import seg_decorder
from model.segment.transformer_decoder_semantic import seg_decorder_open_word
from store import AttentionStore
import torch.nn.functional as F
from config import opt, classes
from scipy.special import softmax
from visualize import visualize_segmentation
from detectron2.structures import Boxes, ImageList, Instances, BitMasks
import gradio as gr
# Model loading
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
vae.to(device)
vae.eval()
unet = UNet2D.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
unet.to(device)
unet.eval()
text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
text_encoder.to(device)
text_encoder.eval()
scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
# seg_model = seg_decorder().to(device)
num_classes = 80
num_queries = 100
seg_model=seg_decorder_open_word(num_classes=num_classes,
num_queries=num_queries).to(device)
base_weights = torch.load(opt.get('grounding_ckpt'), map_location="cpu")
seg_model.load_state_dict(base_weights, strict=True)
def freeze_params(params):
for param in params:
param.requires_grad = False
freeze_params(vae.parameters())
freeze_params(unet.parameters())
freeze_params(text_encoder.parameters())
freeze_params(seg_model.parameters())
def semantic_inference(mask_cls, mask_pred):
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
mask_pred = mask_pred.sigmoid()
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
for i in range(1, semseg.shape[0]):
if (semseg[i] * (semseg[i] > 0.5)).sum() < 5000:
semseg[i] = 0
return semseg
def instance_inference(mask_cls, mask_pred,class_n = 2,test_topk_per_image=20,query_n = 100):
# mask_pred is already processed to have the same shape as original input
image_size = mask_pred.shape[-2:]
# [Q, K]
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
labels = torch.arange(class_n , device=mask_cls.device).unsqueeze(0).repeat(query_n, 1).flatten(0, 1)
# scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
scores_per_image, topk_indices = scores.flatten(0, 1).topk(test_topk_per_image, sorted=False)
labels_per_image = labels[topk_indices]
topk_indices = topk_indices // class_n
# mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
# print(topk_indices)
mask_pred = mask_pred[topk_indices]
result = Instances(image_size)
# mask (before sigmoid)
result.pred_masks = (mask_pred > 0).float()
result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
# Uncomment the following to get boxes from masks (this is slow)
# result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
# calculate average mask prob
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
result.scores = scores_per_image * mask_scores_per_image
result.pred_classes = labels_per_image
return result
def inference(prompt):
prompts = [prompt]
LOW_RESOURCE = False
NUM_DIFFUSION_STEPS = 20
outpath = opt.get('outdir')
image_path = os.path.join(outpath, "Image")
mask_path = os.path.join(outpath, "Mask")
vis_path = os.path.join(outpath, "Vis")
os.makedirs(image_path, exist_ok=True)
os.makedirs(mask_path, exist_ok=True)
os.makedirs(vis_path, exist_ok=True)
controller = AttentionStore()
ptp_utils.register_attention_control(unet, controller)
full_arr = np.zeros((81, 512,512), np.float32)
full_arr[0]=0.5
seed = 100
with torch.no_grad():
clear_feature_dic()
controller.reset()
g_cpu = torch.Generator().manual_seed(seed)
image_out, x_t = ptp_utils.text2image(unet, vae, tokenizer, text_encoder, scheduler, prompts, controller, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=5, generator=g_cpu, low_resource=LOW_RESOURCE, Train=False)
print("image_out ", image_out.shape)
image_file = f"{image_path}/image.jpg"
ptp_utils.save_images(image_out, out_put=image_file)
for idxx in classes:
if idxx==0:
continue
class_name = classes[idxx]
# if class_name not in classes_check[idx]:
# continue
query_text = class_name
text_input = tokenizer(
query_text,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = text_encoder(text_input.input_ids.to(unet.device))[0]
class_embedding=text_embeddings
if class_embedding.size()[1] > 1:
class_embedding = torch.unsqueeze(class_embedding.mean(1),1)
seed = 1
diffusion_features = get_feature_dic()
outputs=seg_model(diffusion_features,controller,prompts,tokenizer,class_embedding)
# outputs = seg_model(diffusion_features, controller, prompts, tokenizer)
mask_cls_results = outputs["pred_logits"]
mask_pred_results = outputs["pred_masks"]
mask_pred_results = F.interpolate(mask_pred_results, size=(512, 512), mode="bilinear", align_corners=False)
for mask_cls_result, mask_pred_result in zip(mask_cls_results, mask_pred_results):
instance_r = instance_inference(mask_cls_result, mask_pred_result,class_n = 80,test_topk_per_image=3,query_n = 100)
pred_masks = instance_r.pred_masks.cpu().numpy().astype(np.uint8)
pred_boxes = instance_r.pred_boxes
scores = instance_r.scores
pred_classes = instance_r.pred_classes
import heapq
topk_idx = heapq.nlargest(1, range(len(scores)), scores.__getitem__)
mask_instance = (pred_masks[topk_idx[0]]>0.5 * 1).astype(np.uint8)
full_arr[idxx] = np.array(mask_instance)
full_arr = softmax(full_arr, axis=0)
mask = np.argmax(full_arr, axis=0)
print("mask ", mask.shape)
mask_file = f"{mask_path}/mask.png"
cv2.imwrite(mask_file, mask)
vis_file = f"{vis_path}/visual.png"
vis_map = visualize_segmentation(image_file, mask_file, vis_file)
print("vis_map ", vis_map.shape)
return image_out[0], vis_map
iface = gr.Interface(
fn=inference,
inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
outputs=[
gr.Image(label="Generated Image"),
gr.Image(label="Segmentation Map")
],
title="PromptPix: Image Generation & Segmentation",
description="Enter a prompt to generate an image and its Segmask"
)
if __name__ == "__main__":
iface.launch(share=True)