File size: 5,080 Bytes
4258b21 43ab69e 4258b21 eabc840 43ab69e eabc840 4258b21 eabc840 4258b21 eabc840 4258b21 eabc840 4258b21 43ab69e 4258b21 b73fa31 4258b21 43ab69e 4258b21 b73fa31 4258b21 b73fa31 4258b21 eabc840 b73fa31 eabc840 4258b21 320c59f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import ptp_utils
import cv2
import os
import numpy as np
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from model.unet import UNet2D, get_feature_dic, clear_feature_dic
from model.segment.transformer_decoder import seg_decorder
from model.segment.transformer_decoder_semantic import seg_decorder_open_word
from store import AttentionStore
import torch.nn.functional as F
from config import opt
from visualize import visualize_segmentation
import gradio as gr
# Model loading
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
vae.to(device)
vae.eval()
unet = UNet2D.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
unet.to(device)
unet.eval()
text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
text_encoder.to(device)
text_encoder.eval()
scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
# seg_model = seg_decorder().to(device)
num_classes = 80
num_queries = 100
seg_model=seg_decorder_open_word(num_classes=num_classes,
num_queries=num_queries).to(device)
base_weights = torch.load(opt.get('grounding_ckpt'), map_location="cpu")
seg_model.load_state_dict(base_weights, strict=True)
def freeze_params(params):
for param in params:
param.requires_grad = False
freeze_params(vae.parameters())
freeze_params(unet.parameters())
freeze_params(text_encoder.parameters())
freeze_params(seg_model.parameters())
def semantic_inference(mask_cls, mask_pred):
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
mask_pred = mask_pred.sigmoid()
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
for i in range(1, semseg.shape[0]):
if (semseg[i] * (semseg[i] > 0.5)).sum() < 5000:
semseg[i] = 0
return semseg
def inference(prompt):
prompts = [prompt]
LOW_RESOURCE = False
NUM_DIFFUSION_STEPS = 20
outpath = opt.get('outdir')
image_path = os.path.join(outpath, "Image")
mask_path = os.path.join(outpath, "Mask")
vis_path = os.path.join(outpath, "Vis")
os.makedirs(image_path, exist_ok=True)
os.makedirs(mask_path, exist_ok=True)
os.makedirs(vis_path, exist_ok=True)
controller = AttentionStore()
ptp_utils.register_attention_control(unet, controller)
query_text = 'road'
text_input = tokenizer(
query_text,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = text_encoder(text_input.input_ids.to(unet.device))[0]
class_embedding=text_embeddings
if class_embedding.size()[1] > 1:
class_embedding = torch.unsqueeze(class_embedding.mean(1),1)
seed = 1
with torch.no_grad():
clear_feature_dic()
controller.reset()
g_cpu = torch.Generator().manual_seed(seed)
image_out, x_t = ptp_utils.text2image(unet, vae, tokenizer, text_encoder, scheduler, prompts, controller, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=5, generator=g_cpu, low_resource=LOW_RESOURCE, Train=False)
print("image_out ", image_out.shape)
image_file = f"{image_path}/image.jpg"
ptp_utils.save_images(image_out, out_put=image_file)
diffusion_features = get_feature_dic()
outputs=seg_model(diffusion_features,controller,prompts,tokenizer,class_embedding)
# outputs = seg_model(diffusion_features, controller, prompts, tokenizer)
mask_cls_results = outputs["pred_logits"]
mask_pred_results = outputs["pred_masks"]
mask_pred_results = F.interpolate(mask_pred_results, size=(512, 512), mode="bilinear", align_corners=False)
for mask_cls_result, mask_pred_result in zip(mask_cls_results, mask_pred_results):
label_pred_prob = semantic_inference(mask_cls_result, mask_pred_result)
label_pred_prob = torch.argmax(label_pred_prob, axis=0)
label_pred_prob = label_pred_prob.cpu().numpy()
print("mask ", label_pred_prob.shape)
mask_file = f"{mask_path}/mask.png"
cv2.imwrite(mask_file, label_pred_prob)
vis_file = f"{vis_path}/visual.png"
vis_map = visualize_segmentation(image_file, mask_file, vis_file)
print("vis_map ", vis_map.shape)
return image_out[0], vis_map
iface = gr.Interface(
fn=inference,
inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
outputs=[
gr.Image(label="Generated Image"),
gr.Image(label="Segmentation Map")
],
title="PromptPix: Image Generation & Segmentation",
description="Enter a prompt to generate an image and its Segmask"
)
if __name__ == "__main__":
iface.launch()
|