Spaces:
Runtime error
Runtime error
File size: 5,902 Bytes
5d32408 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from typing import List
from torch import _validate_compressed_sparse_indices
from torchvision.utils import save_image
from videogen_hub import MODEL_PATH
from with_mask_sample import *
class SEINEPipeline():
def __init__(self, seine_path: str = os.path.join(MODEL_PATH, "SEINE", "seine.pt"),
pretrained_model_path: str = os.path.join(MODEL_PATH, "SEINE", "stable-diffusion-v1-4"),
config_path: str = "src/videogen_hub/pipelines/seine/sample_i2v.yaml"):
"""
Load the configuration file and set the paths of models.
Args:
seine_path: The path of the downloaded seine pretrained model.
pretrained_model_path: The path of the downloaded stable diffusion pretrained model.
config_path: The path of the configuration file.
"""
self.config = OmegaConf.load(config_path)
self.config.ckpt = seine_path
self.config.pretrained_model_path = pretrained_model_path
def infer_one_video(self, input_image,
text_prompt: List = [],
output_size: List = [240, 560],
num_frames: int = 16,
num_sampling_steps: int = 250,
seed: int = 42,
save_video: bool = False):
"""
Generate video based on provided input_image and text_prompt.
Args:
input_image: The input image to generate video.
text_prompt: The text prompt to generate video.
output_size: The size of the generated video. Defaults to [240, 560].
num_frames: number of frames of the generated video. Defaults to 16.
num_sampling_steps: number of sampling steps to generate the video. Defaults to 250.
seed: The random seed for video generation. Defaults to 42.
save_video: save the video to the path in config if it is True. Not save if it is False. Defaults to False.
Returns:
The generated video as tensor with shape (num_frames, channels, height, width).
"""
self.config.image_size = output_size
self.config.num_frames = num_frames
self.config.num_sampling_steps = num_sampling_steps
self.config.seed = seed
self.config.text_prompt = text_prompt
print(input_image, type(input_image) == str)
if type(input_image) == str:
self.config.input_path = input_image
else:
assert torch.is_tensor(input_image)
assert len(input_image.shape) == 3
assert input_image.shape[0] == 3
save_image(input_image, "src/videogen_hub/pipelines/seine/input_image.png")
args = self.config
# Setup PyTorch:
if args.seed:
torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
if args.ckpt is None:
raise ValueError("Please specify a checkpoint path using --ckpt <path>")
# Load model:
latent_h = args.image_size[0] // 8
latent_w = args.image_size[1] // 8
args.image_h = args.image_size[0]
args.image_w = args.image_size[1]
args.latent_h = latent_h
args.latent_w = latent_w
print('loading model')
model = get_models(args).to(device)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
model.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# load model
ckpt_path = args.ckpt
state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
model.load_state_dict(state_dict)
print('loading succeed')
model.eval()
pretrained_model_path = args.pretrained_model_path
diffusion = create_diffusion(str(args.num_sampling_steps))
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
text_encoder = TextEmbedder(pretrained_model_path).to(device)
if args.use_fp16:
print('Warnning: using half percision for inferencing!')
vae.to(dtype=torch.float16)
model.to(dtype=torch.float16)
text_encoder.to(dtype=torch.float16)
# prompt:
prompt = args.text_prompt
if prompt is None or prompt == []:
prompt = args.input_path.split('/')[-1].split('.')[0].replace('_', ' ')
else:
prompt = prompt[0]
prompt_base = prompt.replace(' ', '_')
prompt = prompt + args.additional_prompt
if save_video:
if not os.path.exists(os.path.join(args.save_path)):
os.makedirs(os.path.join(args.save_path))
video_input, researve_frames = get_input(args) # f,c,h,w
video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w
mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w
masked_video = video_input * (mask == 0)
video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model,
device, )
video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3,
1)
if save_video:
save_video_path = os.path.join(args.save_path, prompt_base + '.mp4')
torchvision.io.write_video(save_video_path, video_, fps=8)
print(f'save in {save_video_path}')
return video_.permute(0, 3, 1, 2)
|