ReFlex / src /transformer_utils /flux_transformer_forward.py
SahilCarterr's picture
Upload 77 files
f056744 verified
import abc
import types
import torch
from diffusers.models.transformers.transformer_flux import (
FluxSingleTransformerBlock, FluxTransformerBlock)
def joint_transformer_forward(self, controller, place_in_transformer):
def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs=None
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
if controller is not None:
ff_output = controller(ff_output, place_in_transformer)
hidden_states = hidden_states + ff_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
encoder_hidden_states = encoder_hidden_states + context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
return forward
def single_transformer_forward(self, controller, place_in_transformer):
def forward(
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs=None
):
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_input = norm_hidden_states
mlp_hidden_states = self.act_mlp(self.proj_mlp(mlp_input))
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
# Change here
if controller is not None:
hidden_states = controller(hidden_states, place_in_transformer)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
return forward