Spaces:
Running
on
Zero
Running
on
Zero
import abc | |
import gc | |
import math | |
import numbers | |
from collections import defaultdict | |
from difflib import SequenceMatcher | |
from typing import Dict, List, Optional, Tuple, Union | |
import cv2 | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from diffusers.models.attention_processor import FluxAttnProcessor2_0 | |
from PIL import Image | |
from scipy.ndimage import binary_dilation | |
from skimage.filters import threshold_otsu | |
class AttentionControl(abc.ABC): | |
def __init__(self,): | |
self.cur_step = 0 | |
self.num_att_layers = -1 | |
self.cur_att_layer = 0 | |
self.get_model_info() | |
def get_model_info(self): | |
t5_dim = 512 | |
latent_dim = 4096 | |
attn_dim = t5_dim + latent_dim | |
index_all = torch.arange(attn_dim) | |
t5_index, latent_index = index_all.split([t5_dim, latent_dim]) | |
patch_order = ['t5', 'latent'] | |
self.model_info = { | |
't5_dim': t5_dim, | |
'latent_dim': latent_dim, | |
'attn_dim': attn_dim, | |
't5_index': t5_index, | |
'latent_index': latent_index, | |
'patch_order': patch_order | |
} | |
def step_callback(self, x_t): | |
return x_t | |
def between_steps(self): | |
return | |
def forward(self, q, k, v, place_in_transformer: str): | |
raise NotImplementedError | |
def __call__(self, q, k, v, place_in_transformer: str): | |
hs = self.forward(q, k, v, place_in_transformer) | |
self.cur_att_layer += 1 | |
if self.cur_att_layer == self.num_att_layers: | |
self.cur_att_layer = 0 | |
self.cur_step += 1 | |
self.between_steps() | |
return hs | |
def reset(self): | |
self.cur_step = 0 | |
self.cur_att_layer = 0 | |
def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: | |
L, S = query.size(-2), key.size(-2) | |
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale | |
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) | |
if is_causal: | |
assert attn_mask is None | |
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) | |
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | |
attn_bias.to(query.dtype) | |
if attn_mask is not None: | |
if attn_mask.dtype == torch.bool: | |
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) | |
else: | |
attn_bias += attn_mask | |
attn_weight = query @ key.transpose(-2, -1) * scale_factor | |
attn_weight += attn_bias | |
attn_weight = torch.softmax(attn_weight, dim=-1) | |
attn_weight = torch.dropout(attn_weight, dropout_p, train=True) | |
return attn_weight | |
def split_attn(self, attn, q='latent', k='latent'): | |
patch_order = self.model_info['patch_order'] | |
t5_dim = self.model_info['t5_dim'] | |
latent_dim = self.model_info['latent_dim'] | |
clip_dim = self.model_info.get('clip_dim', None) | |
idx_q = patch_order.index(q) | |
idx_k = patch_order.index(k) | |
split = [t5_dim, latent_dim] | |
return attn.split(split, dim=-2)[idx_q].split(split, dim=-1)[idx_k].clone() | |
class AttentionAdapter(AttentionControl): | |
def __init__( | |
self, | |
ca_layer_list=list(range(13,45)), | |
sa_layer_list=list(range(22,45)), | |
method='replace_topk', | |
topk=1, | |
text_scale=1, | |
mappers=None, | |
alphas=None, | |
ca_steps=10, | |
sa_steps=7, | |
save_source_ca=False, | |
save_target_ca=False, | |
use_sa_replace=True, | |
attn_adj_from=0, | |
): | |
super(AttentionAdapter, self).__init__() | |
self.ca_layer_list = ca_layer_list | |
self.sa_layer_list = sa_layer_list | |
self.method = method | |
self.topk = topk | |
self.text_scale = text_scale | |
self.use_sa_replace = use_sa_replace | |
self.ca_steps = ca_steps | |
self.sa_steps = sa_steps | |
self.mappers = mappers | |
self.alphas = alphas | |
self.save_source_ca = save_source_ca | |
self.save_target_ca = save_target_ca | |
self.attn_adj_from = attn_adj_from | |
self.source_ca = None | |
self.source_attn = {} | |
def get_empty_store(): | |
return defaultdict(list) | |
def refine_ca(self, source_ca, target_ca): | |
source_ca_replace = source_ca[:, :, self.mappers].permute(2, 0, 1, 3) | |
new_ca = source_ca_replace * self.alphas + target_ca * (1 - self.alphas) * self.text_scale | |
return new_ca | |
def replace_ca(self, source_ca, target_ca): | |
new_ca = torch.einsum('hpw,bwn->bhpn', source_ca, self.mappers) | |
return new_ca | |
def get_index_from_source(self, attn, topk): | |
if self.method == 'replace_topk': | |
sa_max = torch.topk(attn, k=topk, dim=-1)[0][..., [-1]] | |
idx_from_source = (attn > sa_max) | |
elif self.method == 'replace_z': | |
log_attn = torch.log(attn) | |
idx_from_source = log_attn > (log_attn.mean(-1, keepdim=True) + self.z_value * log_attn.std(-1, keepdim=True)) | |
else: | |
print("No method") | |
return idx_from_source | |
def forward(self, q, k, v, place_in_transformer: str): | |
layer_index = int(place_in_transformer.split('_')[-1]) | |
use_ca_replace = False | |
use_sa_replace = False | |
if (layer_index in self.ca_layer_list) and (self.cur_step in range(0, self.ca_steps)): | |
if self.mappers is not None: | |
use_ca_replace = True | |
if (layer_index in self.sa_layer_list) and (self.cur_step in range(0, self.sa_steps)): | |
use_sa_replace = True | |
if not (use_sa_replace or use_ca_replace): | |
return F.scaled_dot_product_attention(q, k, v) | |
latent_index = self.model_info['latent_index'] | |
t5_index = self.model_info['t5_index'] | |
clip_index = self.model_info.get('clip_index', None) | |
# Get attention map first | |
attn = self.scaled_dot_product_attention(q, k, v) | |
source_attn = attn[:1] | |
target_attn = attn[1:] | |
source_hs = source_attn @ v[:1] | |
source_ca = self.split_attn(source_attn, 'latent', 't5') | |
target_ca = self.split_attn(target_attn, 'latent', 't5') | |
if use_ca_replace: | |
if self.save_source_ca: | |
if layer_index == self.ca_layer_list[0]: | |
self.source_ca = source_ca / source_ca.sum(dim=-1, keepdim=True) | |
else: | |
self.source_ca += source_ca / source_ca.sum(dim=-1, keepdim=True) | |
if self.save_target_ca: | |
if layer_index == self.ca_layer_list[0]: | |
self.target_ca = target_ca / target_ca.sum(dim=-1, keepdim=True) | |
else: | |
self.target_ca += target_ca / target_ca.sum(dim=-1, keepdim=True) | |
if self.alphas is None: | |
target_ca = self.replace_ca(source_ca[0], target_ca) | |
else: | |
target_ca = self.refine_ca(source_ca[0], target_ca) | |
target_sa = self.split_attn(target_attn, 'latent', 'latent') | |
if use_sa_replace: | |
if self.cur_step < self.attn_adj_from: | |
topk = 1 | |
else: | |
topk = self.topk | |
if self.method == 'base': | |
new_sa = self.split_attn(target_attn, 'latent', 'latent') | |
else: | |
source_sa = self.split_attn(source_attn, 'latent', 'latent') | |
if topk <= 1: | |
new_sa = source_sa.clone().repeat(len(target_attn), 1, 1, 1) | |
else: | |
idx_from_source = self.get_index_from_source(source_sa, topk) | |
# Get top-k attention values from target SA | |
new_sa = target_sa.clone() | |
new_sa.mul_(idx_from_source) | |
# Normalize | |
new_sa.div_(new_sa.sum(-1,keepdim=True)) | |
new_sa.nan_to_num_(0) | |
new_sa.mul_((source_sa * idx_from_source).sum(-1, keepdim=True)) | |
# Get rest attention vlaues from source SA | |
new_sa.add_(source_sa * idx_from_source.logical_not()) | |
# Additional normalize (Optional) | |
# new_sa.mul_((target_sa.sum(dim=(-1), keepdim=True) / new_sa.sum(dim=(-1), keepdim=True))) | |
target_sa = new_sa | |
target_l_to_l = target_sa @ v[1:, :, latent_index] | |
target_l_to_t = target_ca @ v[1:, :, t5_index] | |
if self.alphas is None: | |
target_latent_hs = target_l_to_l + target_l_to_t * self.text_scale | |
else: | |
# text scaling is already performed in self.refine_ca() | |
target_latent_hs = target_l_to_l + target_l_to_t | |
target_text_hs = target_attn[:,:, t5_index,:] @ v[1:] | |
target_hs = torch.cat([target_text_hs, target_latent_hs], dim=-2) | |
hs = torch.cat([source_hs, target_hs]) | |
return hs | |
def reset(self): | |
super(AttentionAdapter, self).reset() | |
del self.source_attn | |
gc.collect() | |
torch.cuda.empty_cache() | |
self.source_attn = {} | |
class AttnCollector: | |
def __init__(self, transformer, controller, attn_processor_class, layer_list=[]): | |
self.transformer = transformer | |
self.controller = controller | |
self.attn_processor_class = attn_processor_class | |
def restore_orig_attention(self): | |
attn_procs = {} | |
place='' | |
for i, (name, attn_proc) in enumerate(self.transformer.attn_processors.items()): | |
attn_proc = self.attn_processor_class( | |
controller=None, place_in_transformer=place, | |
) | |
attn_procs[name] = attn_proc | |
self.transformer.set_attn_processor(attn_procs) | |
self.controller.num_att_layers = 0 | |
def register_attention_control(self): | |
attn_procs = {} | |
count = 0 | |
for i, (name, attn_proc) in enumerate(self.transformer.attn_processors.items()): | |
if 'single' in name: | |
place = f'single_{i}' | |
else: | |
place = f'joint_{i}' | |
count += 1 | |
attn_proc = self.attn_processor_class( | |
controller=self.controller, place_in_transformer=place, | |
) | |
attn_procs[name] = attn_proc | |
self.transformer.set_attn_processor(attn_procs) | |
self.controller.num_att_layers = count | |