import torch import ptp_utils import cv2 import os import numpy as np from diffusers import AutoencoderKL, DDPMScheduler from transformers import CLIPTextModel, CLIPTokenizer from model.unet import UNet2D, get_feature_dic, clear_feature_dic from model.segment.transformer_decoder import seg_decorder from model.segment.transformer_decoder_semantic import seg_decorder_open_word from store import AttentionStore import torch.nn.functional as F from config import opt from visualize import visualize_segmentation import gradio as gr # Model loading device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer") vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae") vae.to(device) vae.eval() unet = UNet2D.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") unet.to(device) unet.eval() text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder") text_encoder.to(device) text_encoder.eval() scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") # seg_model = seg_decorder().to(device) num_classes = 80 num_queries = 100 seg_model=seg_decorder_open_word(num_classes=num_classes, num_queries=num_queries).to(device) base_weights = torch.load(opt.get('grounding_ckpt'), map_location="cpu") seg_model.load_state_dict(base_weights, strict=True) def freeze_params(params): for param in params: param.requires_grad = False freeze_params(vae.parameters()) freeze_params(unet.parameters()) freeze_params(text_encoder.parameters()) freeze_params(seg_model.parameters()) def semantic_inference(mask_cls, mask_pred): mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] mask_pred = mask_pred.sigmoid() semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) for i in range(1, semseg.shape[0]): if (semseg[i] * (semseg[i] > 0.5)).sum() < 5000: semseg[i] = 0 return semseg def inference(prompt): prompts = [prompt] LOW_RESOURCE = False NUM_DIFFUSION_STEPS = 20 outpath = opt.get('outdir') image_path = os.path.join(outpath, "Image") mask_path = os.path.join(outpath, "Mask") vis_path = os.path.join(outpath, "Vis") os.makedirs(image_path, exist_ok=True) os.makedirs(mask_path, exist_ok=True) os.makedirs(vis_path, exist_ok=True) controller = AttentionStore() ptp_utils.register_attention_control(unet, controller) query_text = 'road' text_input = tokenizer( query_text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = text_encoder(text_input.input_ids.to(unet.device))[0] class_embedding=text_embeddings if class_embedding.size()[1] > 1: class_embedding = torch.unsqueeze(class_embedding.mean(1),1) seed = 1 with torch.no_grad(): clear_feature_dic() controller.reset() g_cpu = torch.Generator().manual_seed(seed) image_out, x_t = ptp_utils.text2image(unet, vae, tokenizer, text_encoder, scheduler, prompts, controller, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=5, generator=g_cpu, low_resource=LOW_RESOURCE, Train=False) print("image_out ", image_out.shape) image_file = f"{image_path}/image.jpg" ptp_utils.save_images(image_out, out_put=image_file) diffusion_features = get_feature_dic() outputs=seg_model(diffusion_features,controller,prompts,tokenizer,class_embedding) # outputs = seg_model(diffusion_features, controller, prompts, tokenizer) mask_cls_results = outputs["pred_logits"] mask_pred_results = outputs["pred_masks"] mask_pred_results = F.interpolate(mask_pred_results, size=(512, 512), mode="bilinear", align_corners=False) for mask_cls_result, mask_pred_result in zip(mask_cls_results, mask_pred_results): label_pred_prob = semantic_inference(mask_cls_result, mask_pred_result) label_pred_prob = torch.argmax(label_pred_prob, axis=0) label_pred_prob = label_pred_prob.cpu().numpy() print("mask ", label_pred_prob.shape) mask_file = f"{mask_path}/mask.png" cv2.imwrite(mask_file, label_pred_prob) vis_file = f"{vis_path}/visual.png" vis_map = visualize_segmentation(image_file, mask_file, vis_file) print("vis_map ", vis_map.shape) return image_out[0], vis_map iface = gr.Interface( fn=inference, inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."), outputs=[ gr.Image(label="Generated Image"), gr.Image(label="Segmentation Map") ], title="PromptPix: Image Generation & Segmentation", description="Enter a prompt to generate an image and its Segmask" ) if __name__ == "__main__": iface.launch()