import abc import types import torch from diffusers.models.transformers.transformer_flux import ( FluxSingleTransformerBlock, FluxTransformerBlock) def joint_transformer_forward(self, controller, place_in_transformer): def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, joint_attention_kwargs=None ): norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) # Attention. attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, ) # Process attention outputs for the `hidden_states`. attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output if controller is not None: ff_output = controller(ff_output, place_in_transformer) hidden_states = hidden_states + ff_output # Process attention outputs for the `encoder_hidden_states`. context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] context_ff_output = self.ff_context(norm_encoder_hidden_states) context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output encoder_hidden_states = encoder_hidden_states + context_ff_output if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) return encoder_hidden_states, hidden_states return forward def single_transformer_forward(self, controller, place_in_transformer): def forward( hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, joint_attention_kwargs=None ): residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_input = norm_hidden_states mlp_hidden_states = self.act_mlp(self.proj_mlp(mlp_input)) attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, ) hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) gate = gate.unsqueeze(1) hidden_states = gate * self.proj_out(hidden_states) # Change here if controller is not None: hidden_states = controller(hidden_states, place_in_transformer) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) return hidden_states return forward