Spaces:
Configuration error
Configuration error
| import torch | |
| import sys | |
| try: | |
| import utils | |
| from diffusion import create_diffusion | |
| except: | |
| sys.path.append(os.path.split(sys.path[0])[0]) | |
| import utils | |
| import gradio as gr | |
| from gradio.themes.utils import colors, fonts, sizes | |
| import argparse | |
| from omegaconf import OmegaConf | |
| import os | |
| from models import get_models | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from vlogger.STEB.model_transform import tca_transform_model, ip_scale_set, ip_transform_model | |
| from diffusers.models import AutoencoderKL | |
| from models.clip import TextEmbedder | |
| sys.path.append("..") | |
| from datasets import video_transforms | |
| from torchvision import transforms | |
| from utils import mask_generation_before | |
| from backend import auto_inpainting | |
| from einops import rearrange | |
| import torchvision | |
| from PIL import Image | |
| from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | |
| from transformers.image_transforms import convert_to_rgb | |
| def auto_inpainting(video_input, masked_video, mask, prompt, image, vae, text_encoder, image_encoder, diffusion, model, device, cfg_scale, img_cfg_scale, negative_prompt=""): | |
| global use_fp16 | |
| image_prompt_embeds = None | |
| if prompt is None: | |
| prompt = "" | |
| if image is not None: | |
| clip_image = clip_image_processor(images=image, return_tensors="pt").pixel_values | |
| clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds | |
| uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device) | |
| image_prompt_embeds = torch.cat([clip_image_embeds, uncond_clip_image_embeds], dim=0) | |
| image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=2).contiguous() | |
| model = ip_scale_set(model, img_cfg_scale) | |
| if use_fp16: | |
| image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16) | |
| b, f, c, h, w = video_input.shape | |
| latent_h = video_input.shape[-2] // 8 | |
| latent_w = video_input.shape[-1] // 8 | |
| if use_fp16: | |
| z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w | |
| masked_video = masked_video.to(dtype=torch.float16) | |
| mask = mask.to(dtype=torch.float16) | |
| else: | |
| z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w | |
| masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous() | |
| masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215) | |
| masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous() | |
| mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1) | |
| masked_video = torch.cat([masked_video] * 2) | |
| mask = torch.cat([mask] * 2) | |
| z = torch.cat([z] * 2) | |
| prompt_all = [prompt] + [negative_prompt] | |
| text_prompt = text_encoder(text_prompts=prompt_all, train=False) | |
| model_kwargs = dict(encoder_hidden_states=text_prompt, | |
| class_labels=None, | |
| cfg_scale=cfg_scale, | |
| use_fp16=use_fp16, | |
| ip_hidden_states=image_prompt_embeds) | |
| # Sample images: | |
| samples = diffusion.ddim_sample_loop( | |
| model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \ | |
| mask=mask, x_start=masked_video, use_concat=True | |
| ) | |
| samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32] | |
| if use_fp16: | |
| samples = samples.to(dtype=torch.float16) | |
| video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32] | |
| video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256] | |
| return video_clip | |
| def auto_inpainting_temp_split(video_input, masked_video, mask, prompt, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, tcfg_scale, img_cfg_scale, negative_prompt=""): | |
| global use_fp16 | |
| image_prompt_embeds = None | |
| if prompt is None: | |
| prompt = "" | |
| if image is not None: | |
| clip_image = clip_image_processor(images=image, return_tensors="pt").pixel_values | |
| clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds | |
| uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device) | |
| image_prompt_embeds = torch.cat([clip_image_embeds, clip_image_embeds, uncond_clip_image_embeds], dim=0) | |
| image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=3).contiguous() | |
| model = ip_scale_set(model, img_cfg_scale) | |
| if use_fp16: | |
| image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16) | |
| b, f, c, h, w = video_input.shape | |
| latent_h = video_input.shape[-2] // 8 | |
| latent_w = video_input.shape[-1] // 8 | |
| if use_fp16: | |
| z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w | |
| masked_video = masked_video.to(dtype=torch.float16) | |
| mask = mask.to(dtype=torch.float16) | |
| else: | |
| z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w | |
| masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous() | |
| masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215) | |
| masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous() | |
| mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1) | |
| masked_video = torch.cat([masked_video] * 3) | |
| mask = torch.cat([mask] * 3) | |
| z = torch.cat([z] * 3) | |
| prompt_all = [prompt] + [prompt] + [negative_prompt] | |
| prompt_temp = [prompt] + [""] + [""] | |
| text_prompt = text_encoder(text_prompts=prompt_all, train=False) | |
| temporal_text_prompt = text_encoder(text_prompts=prompt_temp, train=False) | |
| model_kwargs = dict(encoder_hidden_states=text_prompt, | |
| class_labels=None, | |
| scfg_scale=scfg_scale, | |
| tcfg_scale=tcfg_scale, | |
| use_fp16=use_fp16, | |
| ip_hidden_states=image_prompt_embeds, | |
| encoder_temporal_hidden_states=temporal_text_prompt) | |
| # Sample images: | |
| samples = diffusion.ddim_sample_loop( | |
| model.forward_with_cfg_temp_split, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \ | |
| mask=mask, x_start=masked_video, use_concat=True | |
| ) | |
| samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32] | |
| if use_fp16: | |
| samples = samples.to(dtype=torch.float16) | |
| video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32] | |
| video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256] | |
| return video_clip | |
| # ======================================== | |
| # Model Initialization | |
| # ======================================== | |
| device = None | |
| output_path = None | |
| use_fp16 = False | |
| model = None | |
| vae = None | |
| text_encoder = None | |
| image_encoder = None | |
| clip_image_processor = None | |
| def init_model(): | |
| global device | |
| global output_path | |
| global use_fp16 | |
| global model | |
| global diffusion | |
| global vae | |
| global text_encoder | |
| global image_encoder | |
| global clip_image_processor | |
| print('Initializing ShowMaker', flush=True) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default="./configs/with_mask_ref_sample.yaml") | |
| args = parser.parse_args() | |
| args = OmegaConf.load(args.config) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| output_path = args.save_img_path | |
| # Load model: | |
| latent_h = args.image_size[0] // 8 | |
| latent_w = args.image_size[1] // 8 | |
| args.image_h = args.image_size[0] | |
| args.image_w = args.image_size[1] | |
| args.latent_h = latent_h | |
| args.latent_w = latent_w | |
| print('loading model') | |
| model = get_models(True, args).to(device) | |
| model = tca_transform_model(model).to(device) | |
| model = ip_transform_model(model).to(device) | |
| if args.use_compile: | |
| model = torch.compile(model) | |
| if args.enable_xformers_memory_efficient_attention: | |
| if is_xformers_available(): | |
| model.enable_xformers_memory_efficient_attention() | |
| print("xformer!") | |
| else: | |
| raise ValueError("xformers is not available. Make sure it is installed correctly") | |
| ckpt_path = args.ckpt | |
| state_dict = state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema'] | |
| model.load_state_dict(state_dict) | |
| print('loading succeed') | |
| model.eval() # important! | |
| pretrained_model_path = args.pretrained_model_path | |
| vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) | |
| text_encoder = TextEmbedder(tokenizer_path=pretrained_model_path + "tokenizer", | |
| encoder_path=pretrained_model_path + "text_encoder").to(device) | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path).to(device) | |
| clip_image_processor = CLIPImageProcessor() | |
| if args.use_fp16: | |
| print('Warnning: using half percision for inferencing!') | |
| vae.to(dtype=torch.float16) | |
| model.to(dtype=torch.float16) | |
| text_encoder.to(dtype=torch.float16) | |
| image_encoder.to(dtype=torch.float16) | |
| use_fp16 = True | |
| print('Initialization Finished') | |
| init_model() | |
| # ======================================== | |
| # Video Generation | |
| # ======================================== | |
| def video_generation(text, image, scfg_scale, tcfg_scale, img_cfg_scale, diffusion): | |
| with torch.no_grad(): | |
| print("begin generation", flush=True) | |
| transform_video = transforms.Compose([ | |
| video_transforms.ToTensorVideo(), # TCHW | |
| video_transforms.ResizeVideo((320, 512)), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) | |
| ]) | |
| video_frames = torch.zeros(16, 3, 320, 512, dtype=torch.uint8) | |
| video_frames = transform_video(video_frames) | |
| video_input = video_frames.to(device).unsqueeze(0) # b,f,c,h,w | |
| mask = mask_generation_before("all", video_input.shape, video_input.dtype, device) | |
| masked_video = video_input * (mask == 0) | |
| if image is not None: | |
| print(image.shape, flush=True) | |
| # image = Image.open(image) | |
| if scfg_scale == tcfg_scale: | |
| video_clip = auto_inpainting(video_input, masked_video, mask, text, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, img_cfg_scale) | |
| else: | |
| video_clip = auto_inpainting_temp_split(video_input, masked_video, mask, text, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, tcfg_scale, img_cfg_scale) | |
| video_clip = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) | |
| video_path = os.path.join(output_path, 'video.mp4') | |
| torchvision.io.write_video(video_path, video_clip, fps=8) | |
| return video_path | |
| # ======================================== | |
| # Video Prediction | |
| # ======================================== | |
| def video_prediction(text, image, scfg_scale, tcfg_scale, img_cfg_scale, preframe, diffusion): | |
| with torch.no_grad(): | |
| print("begin generation", flush=True) | |
| transform_video = transforms.Compose([ | |
| video_transforms.ToTensorVideo(), # TCHW | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) | |
| ]) | |
| preframe = torch.as_tensor(convert_to_rgb(preframe)).unsqueeze(0) | |
| zeros = torch.zeros_like(preframe) | |
| video_frames = torch.cat([preframe] + [zeros] * 15, dim=0).permute(0, 3, 1, 2) | |
| H_scale = 320 / video_frames.shape[2] | |
| W_scale = 512 / video_frames.shape[3] | |
| scale_ = H_scale | |
| if W_scale < H_scale: | |
| scale_ = W_scale | |
| video_frames = torch.nn.functional.interpolate(video_frames, scale_factor=scale_, mode="bilinear", align_corners=False) | |
| video_frames = transform_video(video_frames) | |
| video_input = video_frames.to(device).unsqueeze(0) # b,f,c,h,w | |
| mask = mask_generation_before("first1", video_input.shape, video_input.dtype, device) | |
| masked_video = video_input * (mask == 0) | |
| if image is not None: | |
| print(image.shape, flush=True) | |
| if scfg_scale == tcfg_scale: | |
| video_clip = auto_inpainting(video_input, masked_video, mask, text, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, img_cfg_scale) | |
| else: | |
| video_clip = auto_inpainting_temp_split(video_input, masked_video, mask, text, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, tcfg_scale, img_cfg_scale) | |
| video_clip = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) | |
| video_path = os.path.join(output_path, 'video.mp4') | |
| torchvision.io.write_video(video_path, video_clip, fps=8) | |
| return video_path | |
| # ======================================== | |
| # Judge Generation or Prediction | |
| # ======================================== | |
| def gen_or_pre(text_input, image_input, scfg_scale, tcfg_scale, img_cfg_scale, preframe_input, diffusion_step): | |
| default_step = [25, 40, 50, 100, 125, 200, 250] | |
| difference = [abs(item - diffusion_step) for item in default_step] | |
| diffusion_step = default_step[difference.index(min(difference))] | |
| diffusion = create_diffusion(str(diffusion_step)) | |
| if preframe_input is None: | |
| return video_generation(text_input, image_input, scfg_scale, tcfg_scale, img_cfg_scale, diffusion) | |
| else: | |
| return video_prediction(text_input, image_input, scfg_scale, tcfg_scale, img_cfg_scale, preframe_input, diffusion) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(visible=True) as input_raws: | |
| with gr.Row(): | |
| with gr.Column(scale=1.0): | |
| text_input = gr.Textbox(show_label=True, interactive=True, label="Text prompt").style(container=False) | |
| with gr.Row(): | |
| with gr.Column(scale=0.5): | |
| image_input = gr.Image(show_label=True, interactive=True, label="Reference image").style(container=False) | |
| with gr.Column(scale=0.5): | |
| preframe_input = gr.Image(show_label=True, interactive=True, label="First frame").style(container=False) | |
| with gr.Row(): | |
| with gr.Column(scale=1.0): | |
| scfg_scale = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=8, | |
| step=0.1, | |
| interactive=True, | |
| label="Spatial Text Guidence Scale", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1.0): | |
| tcfg_scale = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=6.5, | |
| step=0.1, | |
| interactive=True, | |
| label="Temporal Text Guidence Scale", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1.0): | |
| img_cfg_scale = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.3, | |
| step=0.005, | |
| interactive=True, | |
| label="Image Guidence Scale", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1.0): | |
| diffusion_step = gr.Slider( | |
| minimum=20, | |
| maximum=250, | |
| value=100, | |
| step=1, | |
| interactive=True, | |
| label="Diffusion Step", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=0.5, min_width=0): | |
| run = gr.Button("💭Send") | |
| with gr.Column(scale=0.5, min_width=0): | |
| clear = gr.Button("🔄Clear️") | |
| with gr.Column(scale=0.5, visible=True) as video_upload: | |
| output_video = gr.Video(interactive=False, include_audio=True, elem_id="输出的视频")#.style(height=360) | |
| # with gr.Column(elem_id="image", scale=0.5) as img_part: | |
| # with gr.Tab("Video", elem_id='video_tab'): | |
| # with gr.Tab("Image", elem_id='image_tab'): | |
| # up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload").style(height=360) | |
| # upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") | |
| clear = gr.Button("Restart") | |
| run.click(gen_or_pre, [text_input, image_input, scfg_scale, tcfg_scale, img_cfg_scale, preframe_input, diffusion_step], [output_video]) | |
| demo.launch(share=True, enable_queue=True) | |
| # demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True) | |