|
import torch |
|
import abc |
|
from distutils.version import LooseVersion |
|
import torch.nn.functional as F |
|
LOW_RESOURCE = False |
|
|
|
def cross_entropy2d(input, target, weight=None, size_average=True): |
|
n, c, h, w = input.size() |
|
if LooseVersion(torch.__version__) < LooseVersion('0.3'): |
|
log_p = F.log_softmax(input) |
|
else: |
|
log_p = F.log_softmax(input, dim=1) |
|
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous() |
|
log_p = log_p[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0] |
|
log_p = log_p.view(-1, c) |
|
mask = target >= 0 |
|
target = target[mask] |
|
loss = F.nll_loss(log_p, target, weight=weight, reduction='sum') |
|
if size_average: |
|
loss /= mask.data.sum() |
|
return loss |
|
|
|
class AttentionControl(abc.ABC): |
|
def step_callback(self, x_t): |
|
return x_t |
|
|
|
def between_steps(self): |
|
return |
|
|
|
@property |
|
def num_uncond_att_layers(self): |
|
return self.num_att_layers if LOW_RESOURCE else 0 |
|
|
|
@abc.abstractmethod |
|
def forward (self, attn, is_cross: bool, place_in_unet: str): |
|
raise NotImplementedError |
|
|
|
def __call__(self, attn, is_cross: bool, place_in_unet: str): |
|
|
|
if self.cur_att_layer >= self.num_uncond_att_layers: |
|
if LOW_RESOURCE: |
|
attn = self.forward(attn, is_cross, place_in_unet) |
|
else: |
|
h = attn.shape[0] |
|
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) |
|
self.cur_att_layer += 1 |
|
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: |
|
self.cur_att_layer = 0 |
|
if self.activate: |
|
self.cur_step += 1 |
|
self.between_steps() |
|
return attn |
|
|
|
def reset(self): |
|
self.cur_step = 0 |
|
self.cur_att_layer = 0 |
|
|
|
def __init__(self): |
|
self.cur_step = 0 |
|
self.num_att_layers = -1 |
|
self.cur_att_layer = 0 |
|
|
|
class AttentionStore(AttentionControl): |
|
|
|
@staticmethod |
|
def get_empty_store(): |
|
return {"down_cross": [], "mid_cross": [], "up_cross": [], |
|
"down_self": [], "mid_self": [], "up_self": []} |
|
|
|
def forward(self, attn, is_cross: bool, place_in_unet: str): |
|
if self.activate: |
|
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" |
|
self.step_store[key].append(attn) |
|
return attn |
|
|
|
def between_steps(self): |
|
if self.activate: |
|
if len(self.attention_store) == 0: |
|
self.attention_store = self.step_store |
|
else: |
|
for key in self.attention_store: |
|
for i in range(len(self.attention_store[key])): |
|
self.attention_store[key][i] += self.step_store[key][i] |
|
self.step_store = self.get_empty_store() |
|
|
|
def get_average_attention(self): |
|
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} |
|
return average_attention |
|
|
|
def reset(self): |
|
super(AttentionStore, self).reset() |
|
self.step_store = self.get_empty_store() |
|
self.attention_store = {} |
|
|
|
def __init__(self): |
|
super(AttentionStore, self).__init__() |
|
self.step_store = self.get_empty_store() |
|
self.attention_store = {} |
|
self.activate = True |
|
|
|
|