|
import torch |
|
import ptp_utils |
|
import cv2 |
|
import os |
|
import numpy as np |
|
from diffusers import AutoencoderKL, DDPMScheduler |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from model.unet import UNet2D, get_feature_dic, clear_feature_dic |
|
from model.segment.transformer_decoder import seg_decorder |
|
from model.segment.transformer_decoder_semantic import seg_decorder_open_word |
|
from store import AttentionStore |
|
import torch.nn.functional as F |
|
from config import opt |
|
from visualize import visualize_segmentation |
|
import gradio as gr |
|
|
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer") |
|
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae") |
|
vae.to(device) |
|
vae.eval() |
|
|
|
unet = UNet2D.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") |
|
unet.to(device) |
|
unet.eval() |
|
|
|
text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder") |
|
text_encoder.to(device) |
|
text_encoder.eval() |
|
|
|
scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") |
|
|
|
num_classes = 80 |
|
num_queries = 100 |
|
seg_model=seg_decorder_open_word(num_classes=num_classes, |
|
num_queries=num_queries).to(device) |
|
|
|
base_weights = torch.load(opt.get('grounding_ckpt'), map_location="cpu") |
|
seg_model.load_state_dict(base_weights, strict=True) |
|
|
|
def freeze_params(params): |
|
for param in params: |
|
param.requires_grad = False |
|
|
|
freeze_params(vae.parameters()) |
|
freeze_params(unet.parameters()) |
|
freeze_params(text_encoder.parameters()) |
|
freeze_params(seg_model.parameters()) |
|
|
|
def semantic_inference(mask_cls, mask_pred): |
|
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] |
|
mask_pred = mask_pred.sigmoid() |
|
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) |
|
for i in range(1, semseg.shape[0]): |
|
if (semseg[i] * (semseg[i] > 0.5)).sum() < 5000: |
|
semseg[i] = 0 |
|
return semseg |
|
|
|
def inference(prompt): |
|
prompts = [prompt] |
|
LOW_RESOURCE = False |
|
NUM_DIFFUSION_STEPS = 20 |
|
|
|
outpath = opt.get('outdir') |
|
image_path = os.path.join(outpath, "Image") |
|
mask_path = os.path.join(outpath, "Mask") |
|
vis_path = os.path.join(outpath, "Vis") |
|
os.makedirs(image_path, exist_ok=True) |
|
os.makedirs(mask_path, exist_ok=True) |
|
os.makedirs(vis_path, exist_ok=True) |
|
|
|
controller = AttentionStore() |
|
ptp_utils.register_attention_control(unet, controller) |
|
query_text = 'road' |
|
text_input = tokenizer( |
|
query_text, |
|
padding="max_length", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_embeddings = text_encoder(text_input.input_ids.to(unet.device))[0] |
|
class_embedding=text_embeddings |
|
if class_embedding.size()[1] > 1: |
|
class_embedding = torch.unsqueeze(class_embedding.mean(1),1) |
|
seed = 1 |
|
with torch.no_grad(): |
|
clear_feature_dic() |
|
controller.reset() |
|
g_cpu = torch.Generator().manual_seed(seed) |
|
image_out, x_t = ptp_utils.text2image(unet, vae, tokenizer, text_encoder, scheduler, prompts, controller, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=5, generator=g_cpu, low_resource=LOW_RESOURCE, Train=False) |
|
print("image_out ", image_out.shape) |
|
image_file = f"{image_path}/image.jpg" |
|
ptp_utils.save_images(image_out, out_put=image_file) |
|
|
|
diffusion_features = get_feature_dic() |
|
outputs=seg_model(diffusion_features,controller,prompts,tokenizer,class_embedding) |
|
|
|
mask_cls_results = outputs["pred_logits"] |
|
mask_pred_results = outputs["pred_masks"] |
|
mask_pred_results = F.interpolate(mask_pred_results, size=(512, 512), mode="bilinear", align_corners=False) |
|
|
|
for mask_cls_result, mask_pred_result in zip(mask_cls_results, mask_pred_results): |
|
label_pred_prob = semantic_inference(mask_cls_result, mask_pred_result) |
|
label_pred_prob = torch.argmax(label_pred_prob, axis=0) |
|
label_pred_prob = label_pred_prob.cpu().numpy() |
|
print("mask ", label_pred_prob.shape) |
|
mask_file = f"{mask_path}/mask.png" |
|
cv2.imwrite(mask_file, label_pred_prob) |
|
vis_file = f"{vis_path}/visual.png" |
|
vis_map = visualize_segmentation(image_file, mask_file, vis_file) |
|
print("vis_map ", vis_map.shape) |
|
return image_out[0], vis_map |
|
|
|
iface = gr.Interface( |
|
fn=inference, |
|
inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."), |
|
outputs=[ |
|
gr.Image(label="Generated Image"), |
|
gr.Image(label="Segmentation Map") |
|
], |
|
title="PromptPix: Image Generation & Segmentation", |
|
description="Enter a prompt to generate an image and its Segmask" |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|