import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Any, Dict, Optional, Tuple, Union from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from diffusers.models.modeling_outputs import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @torch.no_grad() def decode_latents(pipe, latents): has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / pipe.vae.config.scaling_factor + latents_mean else: latents = latents / pipe.vae.config.scaling_factor video = pipe.vae.decode(latents, return_dict=False)[0] video = pipe.video_processor.postprocess_video(video, output_type='np') return video class RGBALoRAMochiAttnProcessor: """Attention processor used in Mochi.""" def __init__(self, device, dtype, lora_rank=128, lora_alpha=1.0, latent_dim=3072): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") # Initialize LoRA layers self.lora_alpha = lora_alpha self.lora_rank = lora_rank # Helper function to create LoRA layers def create_lora_layer(in_dim, mid_dim, out_dim, device=device, dtype=dtype): # Define the LoRA layers lora_a = nn.Linear(in_dim, mid_dim, bias=False, device=device, dtype=dtype) lora_b = nn.Linear(mid_dim, out_dim, bias=False, device=device, dtype=dtype) # Initialize lora_a with random parameters (default initialization) nn.init.kaiming_uniform_(lora_a.weight, a=math.sqrt(5)) # or another suitable initialization # Initialize lora_b with zero values nn.init.zeros_(lora_b.weight) lora_a.weight.requires_grad = True lora_b.weight.requires_grad = True # Combine the layers into a sequential module return nn.Sequential(lora_a, lora_b) self.to_q_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) self.to_k_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) self.to_v_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) self.to_out_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) def _apply_lora(self, hidden_states, seq_len, query, key, value, scaling): """Applies LoRA updates to query, key, and value tensors.""" query_delta = self.to_q_lora(hidden_states).to(query.device) query[:, -seq_len // 2:, :] += query_delta[:, -seq_len // 2:, :] * scaling key_delta = self.to_k_lora(hidden_states).to(key.device) key[:, -seq_len // 2:, :] += key_delta[:, -seq_len // 2:, :] * scaling value_delta = self.to_v_lora(hidden_states).to(value.device) value[:, -seq_len // 2:, :] += value_delta[:, -seq_len // 2:, :] * scaling return query, key, value def __call__( self, attn, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) scaling = self.lora_alpha / self.lora_rank sequence_length = query.size(1) query, key, value = self._apply_lora(hidden_states, sequence_length, query, key, value, scaling) query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) value = value.unflatten(2, (attn.heads, -1)) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) if attn.norm_added_q is not None: encoder_query = attn.norm_added_q(encoder_query) if attn.norm_added_k is not None: encoder_key = attn.norm_added_k(encoder_key) if image_rotary_emb is not None: def apply_rotary_emb(x, freqs_cos, freqs_sin): x_even = x[..., 0::2].float() x_odd = x[..., 1::2].float() cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) return torch.stack([cos, sin], dim=-1).flatten(-2) query[:,sequence_length//2:] = apply_rotary_emb(query[:,sequence_length//2:], *image_rotary_emb) query[:,:sequence_length//2] = apply_rotary_emb(query[:,:sequence_length//2], *image_rotary_emb) key[:,sequence_length//2:] = apply_rotary_emb(key[:,sequence_length//2:], *image_rotary_emb) key[:,:sequence_length//2] = apply_rotary_emb(key[:,:sequence_length//2], *image_rotary_emb) # query = apply_rotary_emb(query, *image_rotary_emb) # key = apply_rotary_emb(key, *image_rotary_emb) query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) encoder_query, encoder_key, encoder_value = ( encoder_query.transpose(1, 2), encoder_key.transpose(1, 2), encoder_value.transpose(1, 2), ) sequence_length = query.size(2) encoder_sequence_length = encoder_query.size(2) total_length = sequence_length + encoder_sequence_length batch_size, heads, _, dim = query.shape attn_outputs = [] prompt_attention_mask = attention_mask["prompt_attention_mask"] rect_attention_mask = attention_mask["rect_attention_mask"] for idx in range(batch_size): mask = prompt_attention_mask[idx][None, :] # two components: attention mask and prompt mask valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :] valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :] valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :] valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2) valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2) valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2) attn_output = F.scaled_dot_product_attention( valid_query, valid_key, valid_value, dropout_p=0.0, attn_mask=rect_attention_mask[idx], is_causal=False ) valid_sequence_length = attn_output.size(2) attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length)) attn_outputs.append(attn_output) hidden_states = torch.cat(attn_outputs, dim=0) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( (sequence_length, encoder_sequence_length), dim=1 ) # linear proj original_hidden_states = attn.to_out[0](hidden_states) hidden_states_delta = self.to_out_lora(hidden_states).to(hidden_states.device) original_hidden_states[:, -sequence_length // 2:, :] += hidden_states_delta[:, -sequence_length // 2:, :] * scaling # dropout hidden_states = attn.to_out[1](original_hidden_states) if hasattr(attn, "to_add_out"): encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states def prepare_for_rgba_inference( model, device: torch.device, dtype: torch.dtype, lora_rank: int = 128, lora_alpha: float = 1.0 ): def custom_forward(self): def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> torch.Tensor: if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) batch_size, num_channels, num_frames, height, width = hidden_states.shape p = self.config.patch_size post_patch_height = height // p post_patch_width = width // p temb, encoder_hidden_states = self.time_embed( timestep, encoder_hidden_states, encoder_attention_mask["prompt_attention_mask"], hidden_dtype=hidden_states.dtype, ) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.patch_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) image_rotary_emb = self.rope( self.pos_frequencies, num_frames // 2, # Identitical PE for RGB and Alpha post_patch_height, post_patch_width, device=hidden_states.device, dtype=torch.float32, ) for i, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, temb, encoder_attention_mask, image_rotary_emb, **ckpt_kwargs, ) else: hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, encoder_attention_mask=encoder_attention_mask, image_rotary_emb=image_rotary_emb, ) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) output = hidden_states.reshape(batch_size, -1, num_frames, height, width) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) return forward for _, block in enumerate(model.transformer_blocks): attn_processor = RGBALoRAMochiAttnProcessor( device=device, dtype=dtype, lora_rank=lora_rank, lora_alpha=lora_alpha ) # block.attn1.set_processor(attn_processor) block.attn1.processor = attn_processor model.forward = custom_forward(model) def get_processor_state_dict(model): """Save trainable parameters of processors to a checkpoint.""" processor_state_dict = {} for index, block in enumerate(model.transformer_blocks): if hasattr(block.attn1, "processor"): processor = block.attn1.processor for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]: if hasattr(processor, attr_name): lora_layer = getattr(processor, attr_name) for param_name, param in lora_layer.named_parameters(): key = f"block_{index}.{attr_name}.{param_name}" processor_state_dict[key] = param.data.clone() # torch.save({"processor_state_dict": processor_state_dict}, checkpoint_path) # print(f"Processor state_dict saved to {checkpoint_path}") return processor_state_dict def load_processor_state_dict(model, processor_state_dict): """Load trainable parameters of processors from a checkpoint.""" for index, block in enumerate(model.transformer_blocks): if hasattr(block.attn1, "processor"): processor = block.attn1.processor for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]: if hasattr(processor, attr_name): lora_layer = getattr(processor, attr_name) for param_name, param in lora_layer.named_parameters(): key = f"block_{index}.{attr_name}.{param_name}" if key in processor_state_dict: param.data.copy_(processor_state_dict[key]) else: raise KeyError(f"Missing key {key} in checkpoint.") # Prepare training parameters def get_processor_params(processor): params = [] for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]: if hasattr(processor, attr_name): lora_layer = getattr(processor, attr_name) params.extend(p for p in lora_layer.parameters() if p.requires_grad) return params def get_all_processor_params(transformer): all_params = [] for block in transformer.transformer_blocks: if hasattr(block.attn1, "processor"): processor = block.attn1.processor all_params.extend(get_processor_params(processor)) return all_params