|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import random |
|
|
|
from typing import Any, Dict, Optional, Tuple |
|
|
|
from diffusers.models.attention import BasicTransformerBlock |
|
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D |
|
|
|
def torch_dfs(model: torch.nn.Module): |
|
result = [model] |
|
for child in model.children(): |
|
result += torch_dfs(child) |
|
return result |
|
|
|
class ReferenceAttentionControl(): |
|
def __init__(self, |
|
unet=None, |
|
mode="write", |
|
do_classifier_free_guidance=False, |
|
attention_auto_machine_weight = float('inf'), |
|
gn_auto_machine_weight = 1.0, |
|
style_fidelity = 1.0, |
|
reference_attn=True, |
|
reference_adain=False, |
|
fusion_blocks="full", |
|
batch_size=1, |
|
is_train=False, |
|
is_second_stage=False, |
|
use_jointcond=False, |
|
) -> None: |
|
|
|
self.unet = unet |
|
assert mode in ["read", "write"] |
|
assert fusion_blocks in ["midup", "full"] |
|
self.reference_attn = reference_attn |
|
self.reference_adain = reference_adain |
|
self.fusion_blocks = fusion_blocks |
|
self.batch_size = batch_size |
|
self.is_train = is_train |
|
self.is_second_stage=is_second_stage |
|
self.add_clothing_text = getattr(unet, "add_clothing_text", False) |
|
self.do_classifier_free_guidance = do_classifier_free_guidance |
|
self.use_jointcond = use_jointcond |
|
|
|
self.register_reference_hooks( |
|
mode, |
|
do_classifier_free_guidance, |
|
attention_auto_machine_weight, |
|
gn_auto_machine_weight, |
|
style_fidelity, |
|
reference_attn, |
|
reference_adain, |
|
fusion_blocks, |
|
batch_size=batch_size, |
|
is_train=is_train, |
|
is_second_stage=is_second_stage, |
|
add_clothing_text=self.add_clothing_text, |
|
use_jointcond=self.use_jointcond |
|
) |
|
|
|
|
|
|
|
def register_reference_hooks( |
|
self, |
|
mode, |
|
do_classifier_free_guidance, |
|
attention_auto_machine_weight, |
|
gn_auto_machine_weight, |
|
style_fidelity, |
|
reference_attn, |
|
reference_adain, |
|
dtype=torch.float16, |
|
batch_size=1, |
|
num_images_per_prompt=1, |
|
device=torch.device("cpu"), |
|
fusion_blocks='full', |
|
is_train=False, |
|
is_second_stage=False, |
|
add_clothing_text=False, |
|
use_jointcond=False, |
|
): |
|
MODE = mode |
|
do_classifier_free_guidance = do_classifier_free_guidance |
|
attention_auto_machine_weight = attention_auto_machine_weight |
|
gn_auto_machine_weight = gn_auto_machine_weight |
|
style_fidelity = style_fidelity |
|
reference_attn = reference_attn |
|
reference_adain = reference_adain |
|
fusion_blocks = fusion_blocks |
|
num_images_per_prompt = num_images_per_prompt |
|
dtype=dtype |
|
batch_size=batch_size |
|
is_train=is_train |
|
is_second_stage=is_second_stage |
|
add_clothing_text=add_clothing_text |
|
use_jointcond=use_jointcond |
|
|
|
if do_classifier_free_guidance: |
|
uc_mask = ( |
|
torch.Tensor([1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16) |
|
.to(device) |
|
.bool() |
|
) |
|
else: |
|
uc_mask = ( |
|
torch.Tensor([0] * batch_size * num_images_per_prompt * 2) |
|
.to(device) |
|
.bool() |
|
) |
|
|
|
def hacked_basic_transformer_inner_forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
class_labels: Optional[torch.LongTensor] = None, |
|
): |
|
if self.use_ada_layer_norm: |
|
norm_hidden_states = self.norm1(hidden_states, timestep) |
|
elif self.use_ada_layer_norm_zero: |
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( |
|
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype |
|
) |
|
else: |
|
norm_hidden_states = self.norm1(hidden_states) |
|
|
|
|
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} |
|
if self.only_cross_attention: |
|
attn_output = self.attn1( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, |
|
attention_mask=attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
else: |
|
if MODE == "write": |
|
self.bank.append(norm_hidden_states.clone()) |
|
|
|
|
|
if getattr(self, "c_text_proj_layer", None): |
|
self.bank.append(encoder_hidden_states.clone()) |
|
|
|
|
|
attn_output = self.attn1( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, |
|
attention_mask=attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
if MODE == "read": |
|
if getattr(self, "c_text_proj_layer", None): |
|
c_text = self.bank[-1] |
|
if c_text.shape[-1] == 2048: |
|
self.bank[-1] = self.c_text_proj_layer(c_text) |
|
|
|
if getattr(self, "c_attn1", None): |
|
hidden_states_uc_p = self.attn1(norm_hidden_states, |
|
encoder_hidden_states=torch.cat([norm_hidden_states], dim=1), |
|
attention_mask=attention_mask) |
|
hidden_states_uc_c = self.c_attn1(norm_hidden_states, |
|
encoder_hidden_states=torch.cat(self.bank, dim=1), |
|
attention_mask=attention_mask) |
|
|
|
hidden_states_uc = hidden_states_uc_p + hidden_states_uc_c * self.gate_val.to(dtype=hidden_states_uc_c.dtype).tanh() + hidden_states |
|
|
|
else: |
|
hidden_states_uc = self.attn1(norm_hidden_states, |
|
encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1), |
|
attention_mask=attention_mask, |
|
) + hidden_states |
|
|
|
hidden_states_c = hidden_states_uc.clone() |
|
if use_jointcond: |
|
hidden_states = hidden_states_uc |
|
else: |
|
if is_train and not is_second_stage: |
|
_uc_mask = self.cfg_uc_mask.clone() |
|
assert hidden_states.shape[0] == _uc_mask.shape[0], f"in training, cfg_uc_mask is used to drop the reference images so that batch_size must be equal : {hidden_states.shape[0]} vs {_uc_mask.shape[0]}" |
|
else: |
|
_uc_mask = uc_mask.clone() |
|
|
|
|
|
if do_classifier_free_guidance and torch.any(_uc_mask) and not is_second_stage: |
|
if hidden_states.shape[0] != _uc_mask.shape[0]: |
|
_uc_mask = ( |
|
torch.Tensor([1] * (hidden_states.shape[0]//2) + [0] * (hidden_states.shape[0]//2)) |
|
.to(device) |
|
.bool() |
|
) |
|
|
|
hidden_states_c[_uc_mask] = self.attn1( |
|
norm_hidden_states[_uc_mask], |
|
encoder_hidden_states=norm_hidden_states[_uc_mask], |
|
attention_mask=attention_mask |
|
) + hidden_states[_uc_mask] |
|
hidden_states = hidden_states_c.clone() |
|
|
|
|
|
if self.attn2 is not None: |
|
|
|
norm_hidden_states = ( |
|
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) |
|
) |
|
|
|
hidden_states = ( |
|
self.attn2( |
|
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask |
|
) |
|
+ hidden_states |
|
) |
|
|
|
|
|
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states |
|
|
|
return hidden_states |
|
|
|
if self.use_ada_layer_norm_zero: |
|
attn_output = gate_msa.unsqueeze(1) * attn_output |
|
hidden_states = attn_output + hidden_states |
|
|
|
if self.attn2 is not None: |
|
norm_hidden_states = ( |
|
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) |
|
) |
|
|
|
|
|
attn_output = self.attn2( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
hidden_states = attn_output + hidden_states |
|
|
|
|
|
norm_hidden_states = self.norm3(hidden_states) |
|
|
|
if self.use_ada_layer_norm_zero: |
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
if self.use_ada_layer_norm_zero: |
|
ff_output = gate_mlp.unsqueeze(1) * ff_output |
|
|
|
hidden_states = ff_output + hidden_states |
|
|
|
return hidden_states |
|
|
|
def hacked_mid_forward(self, *args, **kwargs): |
|
eps = 1e-6 |
|
x = self.original_forward(*args, **kwargs) |
|
if MODE == "write": |
|
if gn_auto_machine_weight >= self.gn_weight: |
|
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) |
|
self.mean_bank.append(mean) |
|
self.var_bank.append(var) |
|
if MODE == "read": |
|
if len(self.mean_bank) > 0 and len(self.var_bank) > 0: |
|
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) |
|
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 |
|
mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) |
|
var_acc = sum(self.var_bank) / float(len(self.var_bank)) |
|
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 |
|
x_uc = (((x - mean) / std) * std_acc) + mean_acc |
|
x_c = x_uc.clone() |
|
if do_classifier_free_guidance and style_fidelity > 0: |
|
x_c[uc_mask] = x[uc_mask] |
|
x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc |
|
self.mean_bank = [] |
|
self.var_bank = [] |
|
return x |
|
|
|
def hack_CrossAttnDownBlock2D_forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
temb: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
): |
|
eps = 1e-6 |
|
|
|
|
|
output_states = () |
|
|
|
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): |
|
hidden_states = resnet(hidden_states, temb) |
|
hidden_states = attn( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
if MODE == "write": |
|
if gn_auto_machine_weight >= self.gn_weight: |
|
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) |
|
self.mean_bank.append([mean]) |
|
self.var_bank.append([var]) |
|
if MODE == "read": |
|
if len(self.mean_bank) > 0 and len(self.var_bank) > 0: |
|
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) |
|
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 |
|
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) |
|
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) |
|
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 |
|
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc |
|
hidden_states_c = hidden_states_uc.clone() |
|
if do_classifier_free_guidance and style_fidelity > 0: |
|
hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) |
|
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc |
|
|
|
output_states = output_states + (hidden_states,) |
|
|
|
if MODE == "read": |
|
self.mean_bank = [] |
|
self.var_bank = [] |
|
|
|
if self.downsamplers is not None: |
|
for downsampler in self.downsamplers: |
|
hidden_states = downsampler(hidden_states) |
|
|
|
output_states = output_states + (hidden_states,) |
|
|
|
return hidden_states, output_states |
|
|
|
def hacked_DownBlock2D_forward(self, hidden_states, temb=None): |
|
eps = 1e-6 |
|
|
|
output_states = () |
|
|
|
for i, resnet in enumerate(self.resnets): |
|
hidden_states = resnet(hidden_states, temb) |
|
|
|
if MODE == "write": |
|
if gn_auto_machine_weight >= self.gn_weight: |
|
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) |
|
self.mean_bank.append([mean]) |
|
self.var_bank.append([var]) |
|
if MODE == "read": |
|
if len(self.mean_bank) > 0 and len(self.var_bank) > 0: |
|
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) |
|
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 |
|
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) |
|
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) |
|
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 |
|
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc |
|
hidden_states_c = hidden_states_uc.clone() |
|
if do_classifier_free_guidance and style_fidelity > 0: |
|
hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) |
|
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc |
|
|
|
output_states = output_states + (hidden_states,) |
|
|
|
if MODE == "read": |
|
self.mean_bank = [] |
|
self.var_bank = [] |
|
|
|
if self.downsamplers is not None: |
|
for downsampler in self.downsamplers: |
|
hidden_states = downsampler(hidden_states) |
|
|
|
output_states = output_states + (hidden_states,) |
|
|
|
return hidden_states, output_states |
|
|
|
def hacked_CrossAttnUpBlock2D_forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], |
|
temb: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
upsample_size: Optional[int] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
): |
|
eps = 1e-6 |
|
|
|
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): |
|
|
|
res_hidden_states = res_hidden_states_tuple[-1] |
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
hidden_states = resnet(hidden_states, temb) |
|
hidden_states = attn( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
|
|
if MODE == "write": |
|
if gn_auto_machine_weight >= self.gn_weight: |
|
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) |
|
self.mean_bank.append([mean]) |
|
self.var_bank.append([var]) |
|
if MODE == "read": |
|
if len(self.mean_bank) > 0 and len(self.var_bank) > 0: |
|
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) |
|
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 |
|
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) |
|
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) |
|
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 |
|
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc |
|
hidden_states_c = hidden_states_uc.clone() |
|
if do_classifier_free_guidance and style_fidelity > 0: |
|
hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) |
|
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc |
|
|
|
if MODE == "read": |
|
self.mean_bank = [] |
|
self.var_bank = [] |
|
|
|
if self.upsamplers is not None: |
|
for upsampler in self.upsamplers: |
|
hidden_states = upsampler(hidden_states, upsample_size) |
|
|
|
return hidden_states |
|
|
|
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): |
|
eps = 1e-6 |
|
for i, resnet in enumerate(self.resnets): |
|
|
|
res_hidden_states = res_hidden_states_tuple[-1] |
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
hidden_states = resnet(hidden_states, temb) |
|
|
|
if MODE == "write": |
|
if gn_auto_machine_weight >= self.gn_weight: |
|
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) |
|
self.mean_bank.append([mean]) |
|
self.var_bank.append([var]) |
|
if MODE == "read": |
|
if len(self.mean_bank) > 0 and len(self.var_bank) > 0: |
|
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) |
|
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 |
|
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) |
|
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) |
|
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 |
|
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc |
|
hidden_states_c = hidden_states_uc.clone() |
|
if do_classifier_free_guidance and style_fidelity > 0: |
|
hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) |
|
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc |
|
|
|
if MODE == "read": |
|
self.mean_bank = [] |
|
self.var_bank = [] |
|
|
|
if self.upsamplers is not None: |
|
for upsampler in self.upsamplers: |
|
hidden_states = upsampler(hidden_states, upsample_size) |
|
|
|
return hidden_states |
|
|
|
if self.reference_attn: |
|
if self.fusion_blocks == "midup": |
|
attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)] |
|
elif self.fusion_blocks == "full": |
|
attn_modules = [module for module in (torch_dfs(self.unet.down_blocks)+torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)] |
|
attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) |
|
|
|
if self.is_train: |
|
cfg_uc_mask = torch.BoolTensor([ |
|
True if random.random() < 0.1 else False for _ in range(self.batch_size) |
|
]) |
|
for i, module in enumerate(attn_modules): |
|
module._original_inner_forward = module.forward |
|
module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) |
|
module.bank = [] |
|
module.attn_weight = float(i) / float(len(attn_modules)) |
|
|
|
if self.is_train: |
|
module.cfg_uc_mask = cfg_uc_mask.clone() |
|
|
|
|
|
if self.reference_adain: |
|
gn_modules = [self.unet.mid_block] |
|
self.unet.mid_block.gn_weight = 0 |
|
|
|
down_blocks = self.unet.down_blocks |
|
for w, module in enumerate(down_blocks): |
|
module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) |
|
gn_modules.append(module) |
|
|
|
up_blocks = self.unet.up_blocks |
|
for w, module in enumerate(up_blocks): |
|
module.gn_weight = float(w) / float(len(up_blocks)) |
|
gn_modules.append(module) |
|
|
|
for i, module in enumerate(gn_modules): |
|
if getattr(module, "original_forward", None) is None: |
|
module.original_forward = module.forward |
|
if i == 0: |
|
|
|
module.forward = hacked_mid_forward.__get__(module, torch.nn.Module) |
|
elif isinstance(module, CrossAttnDownBlock2D): |
|
module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D) |
|
elif isinstance(module, DownBlock2D): |
|
module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D) |
|
elif isinstance(module, CrossAttnUpBlock2D): |
|
module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) |
|
elif isinstance(module, UpBlock2D): |
|
module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D) |
|
module.mean_bank = [] |
|
module.var_bank = [] |
|
module.gn_weight *= 2 |
|
|
|
def update(self, writer, dtype=torch.float16): |
|
if self.reference_attn: |
|
if self.fusion_blocks == "midup": |
|
reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)] |
|
writer_attn_modules = [module for module in (torch_dfs(writer.unet.mid_block)+torch_dfs(writer.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)] |
|
elif self.fusion_blocks == "full": |
|
reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)] |
|
writer_attn_modules = [module for module in torch_dfs(writer.unet) if isinstance(module, BasicTransformerBlock)] |
|
reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) |
|
writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) |
|
|
|
if self.is_train: |
|
cfg_uc_mask = torch.BoolTensor([ |
|
True if random.random() < 0.1 else False for _ in range(self.batch_size) |
|
]) |
|
for r, w in zip(reader_attn_modules, writer_attn_modules): |
|
r.bank = [v.clone().to(dtype) for v in w.bank] |
|
|
|
if self.is_train: |
|
r.cfg_uc_mask = cfg_uc_mask.clone() |
|
|
|
|
|
if self.reference_adain: |
|
reader_gn_modules = [self.unet.mid_block] |
|
|
|
down_blocks = self.unet.down_blocks |
|
for w, module in enumerate(down_blocks): |
|
reader_gn_modules.append(module) |
|
|
|
up_blocks = self.unet.up_blocks |
|
for w, module in enumerate(up_blocks): |
|
reader_gn_modules.append(module) |
|
|
|
writer_gn_modules = [writer.unet.mid_block] |
|
|
|
down_blocks = writer.unet.down_blocks |
|
for w, module in enumerate(down_blocks): |
|
writer_gn_modules.append(module) |
|
|
|
up_blocks = writer.unet.up_blocks |
|
for w, module in enumerate(up_blocks): |
|
writer_gn_modules.append(module) |
|
|
|
for r, w in zip(reader_gn_modules, writer_gn_modules): |
|
if len(w.mean_bank) > 0 and isinstance(w.mean_bank[0], list): |
|
r.mean_bank = [[v.clone().to(dtype) for v in vl] for vl in w.mean_bank] |
|
r.var_bank = [[v.clone().to(dtype) for v in vl] for vl in w.var_bank] |
|
else: |
|
r.mean_bank = [v.clone().to(dtype) for v in w.mean_bank] |
|
r.var_bank = [v.clone().to(dtype) for v in w.var_bank] |
|
|
|
def clear(self): |
|
if self.reference_attn: |
|
if self.fusion_blocks == "midup": |
|
reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)] |
|
elif self.fusion_blocks == "full": |
|
reader_attn_modules = [module for module in (torch_dfs(self.unet.down_blocks) + torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)] |
|
reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) |
|
|
|
for r in reader_attn_modules: |
|
r.bank.clear() |
|
if self.reference_adain: |
|
reader_gn_modules = [self.unet.mid_block] |
|
|
|
down_blocks = self.unet.down_blocks |
|
for w, module in enumerate(down_blocks): |
|
reader_gn_modules.append(module) |
|
|
|
up_blocks = self.unet.up_blocks |
|
for w, module in enumerate(up_blocks): |
|
reader_gn_modules.append(module) |
|
|
|
for r in reader_gn_modules: |
|
r.mean_bank.clear() |
|
r.var_bank.clear() |
|
|