Spaces:
Running
on
Zero
Running
on
Zero
import abc | |
import types | |
import torch | |
from diffusers.models.transformers.transformer_flux import ( | |
FluxSingleTransformerBlock, FluxTransformerBlock) | |
def joint_transformer_forward(self, controller, place_in_transformer): | |
def forward( | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: torch.FloatTensor, | |
temb: torch.FloatTensor, | |
image_rotary_emb=None, | |
joint_attention_kwargs=None | |
): | |
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) | |
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( | |
encoder_hidden_states, emb=temb | |
) | |
# Attention. | |
attn_output, context_attn_output = self.attn( | |
hidden_states=norm_hidden_states, | |
encoder_hidden_states=norm_encoder_hidden_states, | |
image_rotary_emb=image_rotary_emb, | |
) | |
# Process attention outputs for the `hidden_states`. | |
attn_output = gate_msa.unsqueeze(1) * attn_output | |
hidden_states = hidden_states + attn_output | |
norm_hidden_states = self.norm2(hidden_states) | |
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
ff_output = self.ff(norm_hidden_states) | |
ff_output = gate_mlp.unsqueeze(1) * ff_output | |
if controller is not None: | |
ff_output = controller(ff_output, place_in_transformer) | |
hidden_states = hidden_states + ff_output | |
# Process attention outputs for the `encoder_hidden_states`. | |
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output | |
encoder_hidden_states = encoder_hidden_states + context_attn_output | |
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) | |
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] | |
context_ff_output = self.ff_context(norm_encoder_hidden_states) | |
context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output | |
encoder_hidden_states = encoder_hidden_states + context_ff_output | |
if encoder_hidden_states.dtype == torch.float16: | |
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) | |
return encoder_hidden_states, hidden_states | |
return forward | |
def single_transformer_forward(self, controller, place_in_transformer): | |
def forward( | |
hidden_states: torch.FloatTensor, | |
temb: torch.FloatTensor, | |
image_rotary_emb=None, | |
joint_attention_kwargs=None | |
): | |
residual = hidden_states | |
norm_hidden_states, gate = self.norm(hidden_states, emb=temb) | |
mlp_input = norm_hidden_states | |
mlp_hidden_states = self.act_mlp(self.proj_mlp(mlp_input)) | |
attn_output = self.attn( | |
hidden_states=norm_hidden_states, | |
image_rotary_emb=image_rotary_emb, | |
) | |
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) | |
gate = gate.unsqueeze(1) | |
hidden_states = gate * self.proj_out(hidden_states) | |
# Change here | |
if controller is not None: | |
hidden_states = controller(hidden_states, place_in_transformer) | |
hidden_states = residual + hidden_states | |
if hidden_states.dtype == torch.float16: | |
hidden_states = hidden_states.clip(-65504, 65504) | |
return hidden_states | |
return forward |