from typing import List from torch import _validate_compressed_sparse_indices from torchvision.utils import save_image from videogen_hub import MODEL_PATH from with_mask_sample import * class SEINEPipeline(): def __init__(self, seine_path: str = os.path.join(MODEL_PATH, "SEINE", "seine.pt"), pretrained_model_path: str = os.path.join(MODEL_PATH, "SEINE", "stable-diffusion-v1-4"), config_path: str = "src/videogen_hub/pipelines/seine/sample_i2v.yaml"): """ Load the configuration file and set the paths of models. Args: seine_path: The path of the downloaded seine pretrained model. pretrained_model_path: The path of the downloaded stable diffusion pretrained model. config_path: The path of the configuration file. """ self.config = OmegaConf.load(config_path) self.config.ckpt = seine_path self.config.pretrained_model_path = pretrained_model_path def infer_one_video(self, input_image, text_prompt: List = [], output_size: List = [240, 560], num_frames: int = 16, num_sampling_steps: int = 250, seed: int = 42, save_video: bool = False): """ Generate video based on provided input_image and text_prompt. Args: input_image: The input image to generate video. text_prompt: The text prompt to generate video. output_size: The size of the generated video. Defaults to [240, 560]. num_frames: number of frames of the generated video. Defaults to 16. num_sampling_steps: number of sampling steps to generate the video. Defaults to 250. seed: The random seed for video generation. Defaults to 42. save_video: save the video to the path in config if it is True. Not save if it is False. Defaults to False. Returns: The generated video as tensor with shape (num_frames, channels, height, width). """ self.config.image_size = output_size self.config.num_frames = num_frames self.config.num_sampling_steps = num_sampling_steps self.config.seed = seed self.config.text_prompt = text_prompt print(input_image, type(input_image) == str) if type(input_image) == str: self.config.input_path = input_image else: assert torch.is_tensor(input_image) assert len(input_image.shape) == 3 assert input_image.shape[0] == 3 save_image(input_image, "src/videogen_hub/pipelines/seine/input_image.png") args = self.config # Setup PyTorch: if args.seed: torch.manual_seed(args.seed) torch.set_grad_enabled(False) device = "cuda" if torch.cuda.is_available() else "cpu" # device = "cpu" if args.ckpt is None: raise ValueError("Please specify a checkpoint path using --ckpt ") # Load model: latent_h = args.image_size[0] // 8 latent_w = args.image_size[1] // 8 args.image_h = args.image_size[0] args.image_w = args.image_size[1] args.latent_h = latent_h args.latent_w = latent_w print('loading model') model = get_models(args).to(device) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): model.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") # load model ckpt_path = args.ckpt state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema'] model.load_state_dict(state_dict) print('loading succeed') model.eval() pretrained_model_path = args.pretrained_model_path diffusion = create_diffusion(str(args.num_sampling_steps)) vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) text_encoder = TextEmbedder(pretrained_model_path).to(device) if args.use_fp16: print('Warnning: using half percision for inferencing!') vae.to(dtype=torch.float16) model.to(dtype=torch.float16) text_encoder.to(dtype=torch.float16) # prompt: prompt = args.text_prompt if prompt is None or prompt == []: prompt = args.input_path.split('/')[-1].split('.')[0].replace('_', ' ') else: prompt = prompt[0] prompt_base = prompt.replace(' ', '_') prompt = prompt + args.additional_prompt if save_video: if not os.path.exists(os.path.join(args.save_path)): os.makedirs(os.path.join(args.save_path)) video_input, researve_frames = get_input(args) # f,c,h,w video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w masked_video = video_input * (mask == 0) video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device, ) video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) if save_video: save_video_path = os.path.join(args.save_path, prompt_base + '.mp4') torchvision.io.write_video(save_video_path, video_, fps=8) print(f'save in {save_video_path}') return video_.permute(0, 3, 1, 2)