File size: 5,902 Bytes
5d32408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from typing import List
from torch import _validate_compressed_sparse_indices
from torchvision.utils import save_image

from videogen_hub import MODEL_PATH
from with_mask_sample import *


class SEINEPipeline():
    def __init__(self, seine_path: str = os.path.join(MODEL_PATH, "SEINE", "seine.pt"),
                 pretrained_model_path: str = os.path.join(MODEL_PATH, "SEINE", "stable-diffusion-v1-4"),
                 config_path: str = "src/videogen_hub/pipelines/seine/sample_i2v.yaml"):
        """
        Load the configuration file and set the paths of models.
        Args:
            seine_path: The path of the downloaded seine pretrained model.
            pretrained_model_path: The path of the downloaded stable diffusion pretrained model.
            config_path: The path of the configuration file.
        """
        self.config = OmegaConf.load(config_path)
        self.config.ckpt = seine_path
        self.config.pretrained_model_path = pretrained_model_path

    def infer_one_video(self, input_image,
                        text_prompt: List = [],
                        output_size: List = [240, 560],
                        num_frames: int = 16,
                        num_sampling_steps: int = 250,
                        seed: int = 42,
                        save_video: bool = False):
        """
        Generate video based on provided input_image and text_prompt.
        Args:
            input_image: The input image to generate video.
            text_prompt: The text prompt to generate video.
            output_size: The size of the generated video. Defaults to [240, 560].
            num_frames: number of frames of the generated video. Defaults to 16.
            num_sampling_steps: number of sampling steps to generate the video. Defaults to 250.
            seed: The random seed for video generation. Defaults to 42.
            save_video: save the video to the path in config if it is True. Not save if it is False. Defaults to False.

        Returns:
            The generated video as tensor with shape (num_frames, channels, height, width).

        """

        self.config.image_size = output_size
        self.config.num_frames = num_frames
        self.config.num_sampling_steps = num_sampling_steps
        self.config.seed = seed
        self.config.text_prompt = text_prompt
        print(input_image, type(input_image) == str)
        if type(input_image) == str:
            self.config.input_path = input_image
        else:
            assert torch.is_tensor(input_image)
            assert len(input_image.shape) == 3
            assert input_image.shape[0] == 3
            save_image(input_image, "src/videogen_hub/pipelines/seine/input_image.png")

        args = self.config

        # Setup PyTorch:
        if args.seed:
            torch.manual_seed(args.seed)
        torch.set_grad_enabled(False)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        # device = "cpu"

        if args.ckpt is None:
            raise ValueError("Please specify a checkpoint path using --ckpt <path>")

        # Load model:
        latent_h = args.image_size[0] // 8
        latent_w = args.image_size[1] // 8
        args.image_h = args.image_size[0]
        args.image_w = args.image_size[1]
        args.latent_h = latent_h
        args.latent_w = latent_w
        print('loading model')
        model = get_models(args).to(device)

        if args.enable_xformers_memory_efficient_attention:
            if is_xformers_available():
                model.enable_xformers_memory_efficient_attention()
            else:
                raise ValueError("xformers is not available. Make sure it is installed correctly")

        # load model
        ckpt_path = args.ckpt
        state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
        model.load_state_dict(state_dict)
        print('loading succeed')

        model.eval()
        pretrained_model_path = args.pretrained_model_path
        diffusion = create_diffusion(str(args.num_sampling_steps))
        vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
        text_encoder = TextEmbedder(pretrained_model_path).to(device)
        if args.use_fp16:
            print('Warnning: using half percision for inferencing!')
            vae.to(dtype=torch.float16)
            model.to(dtype=torch.float16)
            text_encoder.to(dtype=torch.float16)

        # prompt:
        prompt = args.text_prompt
        if prompt is None or prompt == []:
            prompt = args.input_path.split('/')[-1].split('.')[0].replace('_', ' ')
        else:
            prompt = prompt[0]
        prompt_base = prompt.replace(' ', '_')
        prompt = prompt + args.additional_prompt

        if save_video:
            if not os.path.exists(os.path.join(args.save_path)):
                os.makedirs(os.path.join(args.save_path))

        video_input, researve_frames = get_input(args)  # f,c,h,w
        video_input = video_input.to(device).unsqueeze(0)  # b,f,c,h,w

        mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device)  # b,f,c,h,w
        masked_video = video_input * (mask == 0)

        video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model,
                                     device, )
        video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3,
                                                                                                               1)

        if save_video:
            save_video_path = os.path.join(args.save_path, prompt_base + '.mp4')
            torchvision.io.write_video(save_video_path, video_, fps=8)
            print(f'save in {save_video_path}')

        return video_.permute(0, 3, 1, 2)