Spaces:
Running
on
Zero
Running
on
Zero
import abc | |
import types | |
import torch | |
from diffusers.models.transformers.transformer_flux import ( | |
FluxSingleTransformerBlock, FluxTransformerBlock) | |
from .flux_transformer_forward import (joint_transformer_forward, | |
single_transformer_forward) | |
class FeatureCollector: | |
def __init__(self, transformer, controller, layer_list=[]): | |
self.transformer = transformer | |
self.controller = controller | |
self.layer_list = layer_list | |
def register_transformer_control(self): | |
index = 0 | |
for joint_transformer in self.transformer.transformer_blocks: | |
place_in_transformer = f'joint_{index}' | |
joint_transformer.forward = joint_transformer_forward(joint_transformer, self.controller, place_in_transformer) | |
index +=1 | |
for i, single_transformer in enumerate(self.transformer.single_transformer_blocks): | |
place_in_transformer = f'single_{index}' | |
single_transformer.forward = single_transformer_forward(single_transformer, self.controller, place_in_transformer) | |
index +=1 | |
self.controller.num_layers = index | |
def restore_orig_transformer(self): | |
place_in_transformer='' | |
for joint_transformer in self.transformer.transformer_blocks: | |
joint_transformer.forward = joint_transformer_forward(joint_transformer, None, place_in_transformer) | |
for i, single_transformer in enumerate(self.transformer.single_transformer_blocks): | |
single_transformer.forward = single_transformer_forward(single_transformer, None, place_in_transformer) | |
class FeatureControl(abc.ABC): | |
def __init__(self): | |
self.cur_step = 0 | |
self.num_layers = -1 | |
self.cur_layer = 0 | |
def step_callback(self, x_t): | |
return x_t | |
def between_steps(self): | |
return | |
def forward(self, attn, place_in_transformer: str): | |
raise NotImplementedError | |
def __call__(self, hidden_state, place_in_transformer: str): | |
hidden_state = self.forward(hidden_state, place_in_transformer) | |
self.cur_layer = self.cur_layer + 1 | |
if self.cur_layer == self.num_layers: | |
self.cur_layer = 0 | |
self.cur_step = self.cur_step + 1 | |
self.between_steps() | |
return hidden_state | |
def reset(self): | |
self.cur_step = 0 | |
self.cur_layer = 0 | |
class FeatureReplace(FeatureControl): | |
def __init__( | |
self, | |
layer_list=[], | |
feature_steps=7 | |
): | |
super(FeatureReplace, self).__init__() | |
self.layer_list = layer_list | |
self.feature_steps = feature_steps | |
def forward(self, hidden_states, place_in_transformer): | |
layer_index = int(place_in_transformer.split('_')[-1]) | |
if (layer_index not in self.layer_list) or (self.cur_step not in range(0, self.feature_steps)): | |
return hidden_states | |
hs_dim = hidden_states.shape[1] | |
t5_dim = 512 | |
latent_dim = 4096 | |
attn_dim = t5_dim + latent_dim | |
index_all = torch.arange(attn_dim) | |
t5_index, latent_index = index_all.split([t5_dim, latent_dim]) | |
if 'single' in place_in_transformer: | |
mask = torch.ones(hs_dim).to(device=hidden_states.device, dtype=hidden_states.dtype) | |
mask[t5_index] = 0 # Only use image latent | |
else: | |
mask = torch.ones(hs_dim).to(device=hidden_states.device, dtype=hidden_states.dtype) | |
mask = mask[None, :, None] | |
source_hs = hidden_states[:1] | |
target_hs = hidden_states[1:] | |
target_hs = source_hs * mask + target_hs * (1 - mask) | |
hidden_states[1:] = target_hs | |
return hidden_states | |