import yaml import tempfile import gradio as gr import os import shutil import torch is_shared_ui = True if "fffiloni/Light-A-Video" in os.environ['SPACE_ID'] else False is_gpu_associated = torch.cuda.is_available() import imageio import argparse from types import MethodType import safetensors.torch as sf import torch.nn.functional as F from omegaconf import OmegaConf from transformers import CLIPTextModel, CLIPTokenizer from diffusers import MotionAdapter, EulerAncestralDiscreteScheduler, AutoencoderKL from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from diffusers.models.attention_processor import AttnProcessor2_0 from torch.hub import download_url_to_file from src.ic_light import BGSource from src.animatediff_pipe import AnimateDiffVideoToVideoPipeline from src.ic_light_pipe import StableDiffusionImg2ImgPipeline from utils.tools import read_video, set_all_seed from huggingface_hub import snapshot_download, hf_hub_download if not is_shared_ui and is_gpu_associated: hf_hub_download( repo_id='lllyasviel/ic-light', filename='iclight_sd15_fc.safetensors', local_dir='./models' ) snapshot_download( repo_id="stablediffusionapi/realistic-vision-v51", local_dir="./models/stablediffusionapi/realistic-vision-v51" ) snapshot_download( repo_id="guoyww/animatediff-motion-adapter-v1-5-3", local_dir="./models/guoyww/animatediff-motion-adapter-v1-5-3" ) def main(args): config = OmegaConf.load(args.config) device = torch.device('cuda') adopted_dtype = torch.float16 set_all_seed(42) ## vdm model adapter = MotionAdapter.from_pretrained(args.motion_adapter_model) ## pipeline pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(args.sd_model, motion_adapter=adapter) eul_scheduler = EulerAncestralDiscreteScheduler.from_pretrained( args.sd_model, subfolder="scheduler", beta_schedule="linear", ) pipe.scheduler = eul_scheduler pipe.enable_vae_slicing() pipe = pipe.to(device=device, dtype=adopted_dtype) pipe.vae.requires_grad_(False) pipe.unet.requires_grad_(False) ## ic-light model tokenizer = CLIPTokenizer.from_pretrained(args.sd_model, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(args.sd_model, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(args.sd_model, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(args.sd_model, subfolder="unet") with torch.no_grad(): new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) new_conv_in.weight.zero_() #torch.Size([320, 8, 3, 3]) new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) new_conv_in.bias = unet.conv_in.bias unet.conv_in = new_conv_in unet_original_forward = unet.forward def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample) c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) new_sample = torch.cat([sample, c_concat], dim=1) kwargs['cross_attention_kwargs'] = {} return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) unet.forward = hooked_unet_forward ## ic-light model loader if not os.path.exists(args.ic_light_model): download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=args.ic_light_model) sd_offset = sf.load_file(args.ic_light_model) sd_origin = unet.state_dict() sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()} unet.load_state_dict(sd_merged, strict=True) del sd_offset, sd_origin, sd_merged text_encoder = text_encoder.to(device=device, dtype=adopted_dtype) vae = vae.to(device=device, dtype=adopted_dtype) unet = unet.to(device=device, dtype=adopted_dtype) unet.set_attn_processor(AttnProcessor2_0()) vae.set_attn_processor(AttnProcessor2_0()) # Consistent light attention @torch.inference_mode() def custom_forward_CLA(self, hidden_states, gamma=config.get("gamma", 0.5), encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None ): batch_size, sequence_length, channel = hidden_states.shape residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) if attention_mask is not None: if attention_mask.shape[-1] != query.shape[1]: target_length = query.shape[1] attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) if encoder_hidden_states is None: encoder_hidden_states = hidden_states query = self.to_q(hidden_states) key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // self.heads query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) shape = query.shape # addition key and value mean_key = key.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True) mean_value = value.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True) mean_key = mean_key.expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3]) mean_value = mean_value.expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3]) add_hidden_state = F.scaled_dot_product_attention(query, mean_key, mean_value, attn_mask=None, dropout_p=0.0, is_causal=False) # mix hidden_states = (1-gamma)*hidden_states + gamma*add_hidden_state hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) hidden_states = hidden_states.to(query.dtype) hidden_states = self.to_out[0](hidden_states) hidden_states = self.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if self.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / self.rescale_output_factor return hidden_states ### attention @torch.inference_mode() def prep_unet_self_attention(unet): for name, module in unet.named_modules(): module_name = type(module).__name__ name_split_list = name.split(".") cond_1 = name_split_list[0] in "up_blocks" cond_2 = name_split_list[-1] in ('attn1') if "Attention" in module_name and cond_1 and cond_2: cond_3 = name_split_list[1] if cond_3 not in "3": module.forward = MethodType(custom_forward_CLA, module) return unet ## consistency light attention unet = prep_unet_self_attention(unet) ## ic-light-scheduler ic_light_scheduler = DPMSolverMultistepScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True, steps_offset=1 ) ic_light_pipe = StableDiffusionImg2ImgPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=ic_light_scheduler, safety_checker=None, requires_safety_checker=False, feature_extractor=None, image_encoder=None ) ic_light_pipe = ic_light_pipe.to(device) ############################# params ###################################### strength = config.get("strength", 0.5) num_step = config.get("num_step", 25) text_guide_scale = config.get("text_guide_scale", 2) seed = config.get("seed") image_width = config.get("width", 512) image_height = config.get("height", 512) n_prompt = config.get("n_prompt", "") relight_prompt = config.get("relight_prompt", "") video_path = config.get("video_path", "") bg_source = BGSource[config.get("bg_source")] save_path = config.get("save_path") ############################## infer ##################################### generator = torch.manual_seed(seed) video_name = os.path.basename(video_path) video_list, video_name = read_video(video_path, image_width, image_height) print("################## begin ##################") with torch.no_grad(): num_inference_steps = int(round(num_step / strength)) output = pipe( ic_light_pipe=ic_light_pipe, relight_prompt=relight_prompt, bg_source=bg_source, video=video_list, prompt=relight_prompt, strength=strength, negative_prompt=n_prompt, guidance_scale=text_guide_scale, num_inference_steps=num_inference_steps, height=image_height, width=image_width, generator=generator, ) frames = output.frames[0] results_path = f"{save_path}/relight_{video_name}" imageio.mimwrite(results_path, frames, fps=8) print(f"relight with bg generation! prompt:{relight_prompt}, light:{bg_source.value}, save in {results_path}.") return results_path def infer(n_prompt, relight_prompt, video_path, bg_source, width, height, strength, gamma, num_step, text_guide_scale, seed, progress=gr.Progress(track_tqdm=True)): save_path = "./output" # Ensure output folder is empty if os.path.exists(save_path): shutil.rmtree(save_path) os.makedirs(save_path, exist_ok=True) config_data = { "n_prompt": n_prompt, "relight_prompt": relight_prompt, "video_path": video_path, "bg_source": bg_source, "save_path": save_path, "width": width, "height": height, "strength": strength, "gamma": gamma, "num_step": num_step, "text_guide_scale": text_guide_scale, "seed": seed } temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") with open(temp_file.name, 'w') as file: yaml.dump(config_data, file, default_flow_style=False) config_path = temp_file.name class Args: def __init__(self): self.sd_model = "./models/stablediffusionapi/realistic-vision-v51" self.motion_adapter_model = "./models/guoyww/animatediff-motion-adapter-v1-5-3" self.ic_light_model = "./models/iclight_sd15_fc.safetensors" self.config = config_path args = Args() results_path= main(args) os.remove(config_path) return results_path css=""" div#col-container{ margin: 0 auto; max-width: 1200px; } div#warning-duplicate { background-color: #ebf5ff; padding: 0 16px 16px; margin: 0px 0; color: #030303!important; } div#warning-duplicate > .gr-prose > h2, div#warning-duplicate > .gr-prose > p { color: #0f4592!important; } div#warning-duplicate strong { color: #0f4592; } p.actions { display: flex; align-items: center; margin: 20px 0; } div#warning-duplicate .actions a { display: inline-block; margin-right: 10px; } div#warning-setgpu { background-color: #fff4eb; padding: 0 16px 16px; margin: 0px 0; color: #030303!important; } div#warning-setgpu > .gr-prose > h2, div#warning-setgpu > .gr-prose > p { color: #92220f!important; } div#warning-setgpu a, div#warning-setgpu b { color: #91230f; } div#warning-setgpu p.actions > a { display: inline-block; background: #1f1f23; border-radius: 40px; padding: 6px 24px; color: antiquewhite; text-decoration: none; font-weight: 600; font-size: 1.2em; } div#warning-ready { background-color: #ecfdf5; padding: 0 16px 16px; margin: 0px 0; color: #030303!important; } div#warning-ready > .gr-prose > h2, div#warning-ready > .gr-prose > p { color: #057857!important; } .custom-color { color: #030303 !important; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# Light-A-Video") gr.Markdown("Training-free Video Relighting via Progressive Light Fusion") gr.HTML("""
Duplicate this Space Follow me on HF
""") with gr.Row(): with gr.Column(): video_path = gr.Video(label="Video Path") with gr.Row(): relight_prompt = gr.Textbox(label="Relight Prompt", scale=3) bg_source = gr.Dropdown(["NONE", "LEFT", "RIGHT", "BOTTOM", "TOP"], label="Background Source", scale=1) with gr.Accordion(label="Advanced Settings", open=False): n_prompt = gr.Textbox(label="Negative Prompt", value="bad quality, worse quality") with gr.Row(): width = gr.Number(label="Width", value=512) height = gr.Number(label="Height", value=512) with gr.Row(): strength = gr.Slider(minimum=0.0, maximum=1.0, label="Strength", value=0.5) gamma = gr.Slider(minimum=0.0, maximum=1.0, label="Gamma", value=0.5) with gr.Row(): num_step = gr.Number(label="Number of Steps", value=25) text_guide_scale = gr.Number(label="Text Guide Scale", value=2) seed = gr.Number(label="Seed", value=2060) submit = gr.Button("Run", interactive=False if is_shared_ui else True) gr.Examples( examples=[ ["./input_animatediff/bear.mp4", "a bear walking on the rock, nature lighting, key light", "TOP"], ["./input_animatediff/boat.mp4", "a boat floating on the sea, sunset", "TOP"], ["./input_animatediff/car.mp4", "a car driving on the street, neon light", "RIGHT"], ["./input_animatediff/cat.mp4", "a cat, red and blue neon light", "LEFT"], ["./input_animatediff/cow.mp4", "a cow drinking water in the river, sunset", "RIGHT"], ["./input_animatediff/flowers.mp4", "A basket of flowers, sunshine, hard light", "LEFT"], ["./input_animatediff/fox.mp4", "a fox, sunlight filtering through trees, dappled light", "LEFT"], ["./input_animatediff/girl.mp4", "a girl, magic lit, sci-fi RGB glowing, key lighting", "BOTTOM"], ["./input_animatediff/girl2.mp4", "an anime girl, neon light", "RIGHT"], ["./input_animatediff/juice.mp4", "Pour juice into a glass, magic golden lit", "RIGHT"], ["./input_animatediff/man2.mp4", "handsome man with glasses, shadow from window, sunshine", "RIGHT"], ["./input_animatediff/man4.mp4", "handsome man with glasses, sunlight through the blinds", "LEFT"], ["./input_animatediff/plane.mp4", "a plane on the runway, bottom neon light", "BOTTOM"], ["./input_animatediff/toy.mp4", "a maneki-neko toy, cozy bedroom illumination", "RIGHT"], ["./input_animatediff/woman.mp4", "a woman with curly hair, natural lighting, warm atmosphere", "LEFT"], ], inputs=[video_path, relight_prompt, bg_source], examples_per_page=3 ) with gr.Column(): if is_shared_ui: top_description = gr.HTML(f'''

Attention: this Space need to be duplicated to work

To make it work, duplicate the Space and run it on your own profile using a private GPU (L40s recommended).
A L40s costs US$1.80/h.

Duplicate this Space to start experimenting with this demo

''', elem_id="warning-duplicate") else: if(is_gpu_associated): top_description = gr.HTML(f'''

You have successfully associated a GPU to this Space 🎉

You will be billed by the minute from when you activated the GPU until when it is turned off.

''', elem_id="warning-ready") else: top_description = gr.HTML(f'''

You have successfully duplicated the MimicMotion Space 🎉

There's only one step left before you can properly play with this demo: attribute a GPU to it (via the Settings tab) and run the app below. You will be billed by the minute from when you activate the GPU until when it is turned off.

🔥   Set recommended GPU

''', elem_id="warning-setgpu") output = gr.Video(label="Results Path") submit.click( fn=infer, inputs=[n_prompt, relight_prompt, video_path, bg_source, width, height, strength, gamma, num_step, text_guide_scale, seed], outputs=[output] ) demo.queue().launch(show_api=False, show_error=True, ssr_mode=False)