import torch from typing import Any, Dict, Optional, Tuple, Union, List, Callable from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoSingleTransformerBlock, HunyuanVideoTransformerBlock, HunyuanVideoTransformer3DModel from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers import HunyuanVideoPipeline from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers, logging, is_torch_xla_available logger = logging.get_logger(__name__) from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import retrieve_timesteps import numpy as np if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False DEFAULT_PROMPT_TEMPLATE = { "template": ( "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " "1. The main content and theme of the video." "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." "4. background environment, light, style and atmosphere." "5. camera angles, movements, and transitions used in the video:<|eot_id|>" "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" ), "crop_start": 95, } class HunyuanVideoSingleTransformerBlockSparse(HunyuanVideoSingleTransformerBlock): def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, timestep: Optional[torch.Tensor] = None, numeral_timestep: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.shape[1] hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) residual = hidden_states # 1. Input normalization norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) norm_hidden_states, norm_encoder_hidden_states = ( norm_hidden_states[:, :-text_seq_length, :], norm_hidden_states[:, -text_seq_length:, :], ) # 2. Attention attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, timestep=timestep, numeral_timestep=numeral_timestep, ) attn_output = torch.cat([attn_output, context_attn_output], dim=1) # 3. Modulation and residual connection hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states) hidden_states = hidden_states + residual hidden_states, encoder_hidden_states = ( hidden_states[:, :-text_seq_length, :], hidden_states[:, -text_seq_length:, :], ) return hidden_states, encoder_hidden_states class HunyuanVideoTransformerBlockSparse(HunyuanVideoTransformerBlock): def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, timestep: Optional[torch.Tensor] = None, numeral_timestep: Optional[torch.Tensor] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Input normalization norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) # 2. Joint attention attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, attention_mask=attention_mask, image_rotary_emb=freqs_cis, timestep=timestep, numeral_timestep=numeral_timestep, ) # 3. Modulation and residual connection hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) norm_hidden_states = self.norm2(hidden_states) norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] # 4. Feed-forward ff_output = self.ff(norm_hidden_states) context_ff_output = self.ff_context(norm_encoder_hidden_states) hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output return hidden_states, encoder_hidden_states class HunyuanVideoTransformer3DModelSparse(HunyuanVideoTransformer3DModel): def forward( self, hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, pooled_projections: torch.Tensor, guidance: torch.Tensor = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, numeral_timestep: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) batch_size, num_channels, num_frames, height, width = hidden_states.shape p, p_t = self.config.patch_size, self.config.patch_size_t post_patch_num_frames = num_frames // p_t post_patch_height = height // p post_patch_width = width // p first_frame_num_tokens = 1 * post_patch_height * post_patch_width # 1. RoPE image_rotary_emb = self.rope(hidden_states) # 2. Conditional embeddings temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance) hidden_states = self.x_embedder(hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) # 3. Attention mask preparation latent_sequence_length = hidden_states.shape[1] condition_sequence_length = encoder_hidden_states.shape[1] sequence_length = latent_sequence_length + condition_sequence_length attention_mask = torch.ones( batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool ) # [B, N] effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] effective_sequence_length = latent_sequence_length + effective_condition_sequence_length indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N] mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N] attention_mask = attention_mask.masked_fill(mask_indices, False) attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N] # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: for block in self.transformer_blocks: hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb, timestep, numeral_timestep, token_replace_emb, first_frame_num_tokens, ) for block in self.single_transformer_blocks: hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb, timestep, numeral_timestep, token_replace_emb, first_frame_num_tokens, ) else: for block in self.transformer_blocks: hidden_states, encoder_hidden_states = block( hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb, timestep, numeral_timestep, token_replace_emb, first_frame_num_tokens, ) for block in self.single_transformer_blocks: hidden_states, encoder_hidden_states = block( hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb, timestep, numeral_timestep, token_replace_emb, first_frame_num_tokens, ) # 5. Output projection hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p ) hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (hidden_states,) return Transformer2DModelOutput(sample=hidden_states) class HunyuanVideoPipelineSparse(HunyuanVideoPipeline): @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, negative_prompt_2: Union[str, List[str]] = None, height: int = 720, width: int = 1280, num_frames: int = 129, num_inference_steps: int = 50, sigmas: List[float] = None, true_cfg_scale: float = 1.0, guidance_scale: float = 6.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, max_sequence_length: int = 256, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is not greater than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. height (`int`, defaults to `720`): The height in pixels of the generated image. width (`int`, defaults to `1280`): The width in pixels of the generated image. num_frames (`int`, defaults to `129`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. true_cfg_scale (`float`, *optional*, defaults to 1.0): When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. guidance_scale (`float`, defaults to `6.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. Note that the only available HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and conditional latent is not applied. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~HunyuanVideoPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, prompt_embeds, callback_on_step_end_tensor_inputs, prompt_template, ) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False device = self._execution_device # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # 3. Encode input prompt transformer_dtype = self.transformer.dtype prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_template=prompt_template, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, prompt_attention_mask=prompt_attention_mask, device=device, max_sequence_length=max_sequence_length, ) prompt_embeds = prompt_embeds.to(transformer_dtype) prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) if do_true_cfg: negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( prompt=negative_prompt, prompt_2=negative_prompt_2, prompt_template=prompt_template, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=negative_pooled_prompt_embeds, prompt_attention_mask=negative_prompt_attention_mask, device=device, max_sequence_length=max_sequence_length, ) negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) # 4. Prepare timesteps sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, width, num_frames, torch.float32, device, generator, latents, ) # 6. Prepare guidance condition guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t latent_model_input = latents.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, pooled_projections=pooled_prompt_embeds, guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, numeral_timestep=i, )[0] if do_true_cfg: neg_noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, encoder_attention_mask=negative_prompt_attention_mask, pooled_projections=negative_pooled_prompt_embeds, guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, numeral_timestep=i, )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() self._current_timestep = None if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return HunyuanVideoPipelineOutput(frames=video) def replace_sparse_forward(): HunyuanVideoSingleTransformerBlock.forward = HunyuanVideoSingleTransformerBlockSparse.forward HunyuanVideoTransformerBlock.forward = HunyuanVideoTransformerBlockSparse.forward HunyuanVideoTransformer3DModel.forward = HunyuanVideoTransformer3DModelSparse.forward HunyuanVideoPipeline.__call__ = HunyuanVideoPipelineSparse.__call__