ReFlex / src /transformer_utils /transformer_utils.py
SahilCarterr's picture
Upload 77 files
f056744 verified
import abc
import types
import torch
from diffusers.models.transformers.transformer_flux import (
FluxSingleTransformerBlock, FluxTransformerBlock)
from .flux_transformer_forward import (joint_transformer_forward,
single_transformer_forward)
class FeatureCollector:
def __init__(self, transformer, controller, layer_list=[]):
self.transformer = transformer
self.controller = controller
self.layer_list = layer_list
def register_transformer_control(self):
index = 0
for joint_transformer in self.transformer.transformer_blocks:
place_in_transformer = f'joint_{index}'
joint_transformer.forward = joint_transformer_forward(joint_transformer, self.controller, place_in_transformer)
index +=1
for i, single_transformer in enumerate(self.transformer.single_transformer_blocks):
place_in_transformer = f'single_{index}'
single_transformer.forward = single_transformer_forward(single_transformer, self.controller, place_in_transformer)
index +=1
self.controller.num_layers = index
def restore_orig_transformer(self):
place_in_transformer=''
for joint_transformer in self.transformer.transformer_blocks:
joint_transformer.forward = joint_transformer_forward(joint_transformer, None, place_in_transformer)
for i, single_transformer in enumerate(self.transformer.single_transformer_blocks):
single_transformer.forward = single_transformer_forward(single_transformer, None, place_in_transformer)
class FeatureControl(abc.ABC):
def __init__(self):
self.cur_step = 0
self.num_layers = -1
self.cur_layer = 0
def step_callback(self, x_t):
return x_t
def between_steps(self):
return
@abc.abstractmethod
def forward(self, attn, place_in_transformer: str):
raise NotImplementedError
@torch.no_grad()
def __call__(self, hidden_state, place_in_transformer: str):
hidden_state = self.forward(hidden_state, place_in_transformer)
self.cur_layer = self.cur_layer + 1
if self.cur_layer == self.num_layers:
self.cur_layer = 0
self.cur_step = self.cur_step + 1
self.between_steps()
return hidden_state
def reset(self):
self.cur_step = 0
self.cur_layer = 0
class FeatureReplace(FeatureControl):
def __init__(
self,
layer_list=[],
feature_steps=7
):
super(FeatureReplace, self).__init__()
self.layer_list = layer_list
self.feature_steps = feature_steps
def forward(self, hidden_states, place_in_transformer):
layer_index = int(place_in_transformer.split('_')[-1])
if (layer_index not in self.layer_list) or (self.cur_step not in range(0, self.feature_steps)):
return hidden_states
hs_dim = hidden_states.shape[1]
t5_dim = 512
latent_dim = 4096
attn_dim = t5_dim + latent_dim
index_all = torch.arange(attn_dim)
t5_index, latent_index = index_all.split([t5_dim, latent_dim])
if 'single' in place_in_transformer:
mask = torch.ones(hs_dim).to(device=hidden_states.device, dtype=hidden_states.dtype)
mask[t5_index] = 0 # Only use image latent
else:
mask = torch.ones(hs_dim).to(device=hidden_states.device, dtype=hidden_states.dtype)
mask = mask[None, :, None]
source_hs = hidden_states[:1]
target_hs = hidden_states[1:]
target_hs = source_hs * mask + target_hs * (1 - mask)
hidden_states[1:] = target_hs
return hidden_states