import torch import ptp_utils import cv2 import os import numpy as np from diffusers import AutoencoderKL, DDPMScheduler from transformers import CLIPTextModel, CLIPTokenizer from model.unet import UNet2D, get_feature_dic, clear_feature_dic from model.segment.transformer_decoder import seg_decorder from model.segment.transformer_decoder_semantic import seg_decorder_open_word from store import AttentionStore import torch.nn.functional as F from config import opt, classes from scipy.special import softmax from visualize import visualize_segmentation from detectron2.structures import Boxes, ImageList, Instances, BitMasks import gradio as gr # Model loading device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer") vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae") vae.to(device) vae.eval() unet = UNet2D.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") unet.to(device) unet.eval() text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder") text_encoder.to(device) text_encoder.eval() scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") # seg_model = seg_decorder().to(device) num_classes = 80 num_queries = 100 seg_model=seg_decorder_open_word(num_classes=num_classes, num_queries=num_queries).to(device) base_weights = torch.load(opt.get('grounding_ckpt'), map_location="cpu") seg_model.load_state_dict(base_weights, strict=True) def freeze_params(params): for param in params: param.requires_grad = False freeze_params(vae.parameters()) freeze_params(unet.parameters()) freeze_params(text_encoder.parameters()) freeze_params(seg_model.parameters()) def semantic_inference(mask_cls, mask_pred): mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] mask_pred = mask_pred.sigmoid() semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) for i in range(1, semseg.shape[0]): if (semseg[i] * (semseg[i] > 0.5)).sum() < 5000: semseg[i] = 0 return semseg def instance_inference(mask_cls, mask_pred,class_n = 2,test_topk_per_image=20,query_n = 100): # mask_pred is already processed to have the same shape as original input image_size = mask_pred.shape[-2:] # [Q, K] scores = F.softmax(mask_cls, dim=-1)[:, :-1] labels = torch.arange(class_n , device=mask_cls.device).unsqueeze(0).repeat(query_n, 1).flatten(0, 1) # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False) scores_per_image, topk_indices = scores.flatten(0, 1).topk(test_topk_per_image, sorted=False) labels_per_image = labels[topk_indices] topk_indices = topk_indices // class_n # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1) # print(topk_indices) mask_pred = mask_pred[topk_indices] result = Instances(image_size) # mask (before sigmoid) result.pred_masks = (mask_pred > 0).float() result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4)) # Uncomment the following to get boxes from masks (this is slow) # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes() # calculate average mask prob mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6) result.scores = scores_per_image * mask_scores_per_image result.pred_classes = labels_per_image return result def inference(prompt): prompts = [prompt] LOW_RESOURCE = False NUM_DIFFUSION_STEPS = 20 outpath = opt.get('outdir') image_path = os.path.join(outpath, "Image") mask_path = os.path.join(outpath, "Mask") vis_path = os.path.join(outpath, "Vis") os.makedirs(image_path, exist_ok=True) os.makedirs(mask_path, exist_ok=True) os.makedirs(vis_path, exist_ok=True) controller = AttentionStore() ptp_utils.register_attention_control(unet, controller) full_arr = np.zeros((81, 512,512), np.float32) full_arr[0]=0.5 seed = 100 with torch.no_grad(): clear_feature_dic() controller.reset() g_cpu = torch.Generator().manual_seed(seed) image_out, x_t = ptp_utils.text2image(unet, vae, tokenizer, text_encoder, scheduler, prompts, controller, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=5, generator=g_cpu, low_resource=LOW_RESOURCE, Train=False) print("image_out ", image_out.shape) image_file = f"{image_path}/image.jpg" ptp_utils.save_images(image_out, out_put=image_file) for idxx in classes: if idxx==0: continue class_name = classes[idxx] # if class_name not in classes_check[idx]: # continue query_text = class_name text_input = tokenizer( query_text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = text_encoder(text_input.input_ids.to(unet.device))[0] class_embedding=text_embeddings if class_embedding.size()[1] > 1: class_embedding = torch.unsqueeze(class_embedding.mean(1),1) seed = 1 diffusion_features = get_feature_dic() outputs=seg_model(diffusion_features,controller,prompts,tokenizer,class_embedding) # outputs = seg_model(diffusion_features, controller, prompts, tokenizer) mask_cls_results = outputs["pred_logits"] mask_pred_results = outputs["pred_masks"] mask_pred_results = F.interpolate(mask_pred_results, size=(512, 512), mode="bilinear", align_corners=False) for mask_cls_result, mask_pred_result in zip(mask_cls_results, mask_pred_results): instance_r = instance_inference(mask_cls_result, mask_pred_result,class_n = 80,test_topk_per_image=3,query_n = 100) pred_masks = instance_r.pred_masks.cpu().numpy().astype(np.uint8) pred_boxes = instance_r.pred_boxes scores = instance_r.scores pred_classes = instance_r.pred_classes import heapq topk_idx = heapq.nlargest(1, range(len(scores)), scores.__getitem__) mask_instance = (pred_masks[topk_idx[0]]>0.5 * 1).astype(np.uint8) full_arr[idxx] = np.array(mask_instance) full_arr = softmax(full_arr, axis=0) mask = np.argmax(full_arr, axis=0) print("mask ", mask.shape) mask_file = f"{mask_path}/mask.png" cv2.imwrite(mask_file, mask) vis_file = f"{vis_path}/visual.png" vis_map = visualize_segmentation(image_file, mask_file, vis_file) print("vis_map ", vis_map.shape) return image_out[0], vis_map iface = gr.Interface( fn=inference, inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."), outputs=[ gr.Image(label="Generated Image"), gr.Image(label="Segmentation Map") ], title="PromptPix: Image Generation & Segmentation", description="Enter a prompt to generate an image and its Segmask" ) if __name__ == "__main__": iface.launch(share=True)