import torch import abc from distutils.version import LooseVersion import torch.nn.functional as F LOW_RESOURCE = False def cross_entropy2d(input, target, weight=None, size_average=True): n, c, h, w = input.size() if LooseVersion(torch.__version__) < LooseVersion('0.3'): log_p = F.log_softmax(input) else: log_p = F.log_softmax(input, dim=1) log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous() log_p = log_p[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0] log_p = log_p.view(-1, c) mask = target >= 0 target = target[mask] loss = F.nll_loss(log_p, target, weight=weight, reduction='sum') if size_average: loss /= mask.data.sum() return loss class AttentionControl(abc.ABC): def step_callback(self, x_t): return x_t def between_steps(self): return @property def num_uncond_att_layers(self): return self.num_att_layers if LOW_RESOURCE else 0 @abc.abstractmethod def forward (self, attn, is_cross: bool, place_in_unet: str): raise NotImplementedError def __call__(self, attn, is_cross: bool, place_in_unet: str): if self.cur_att_layer >= self.num_uncond_att_layers: if LOW_RESOURCE: attn = self.forward(attn, is_cross, place_in_unet) else: h = attn.shape[0] attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: self.cur_att_layer = 0 if self.activate: self.cur_step += 1 self.between_steps() return attn def reset(self): self.cur_step = 0 self.cur_att_layer = 0 def __init__(self): self.cur_step = 0 self.num_att_layers = -1 self.cur_att_layer = 0 class AttentionStore(AttentionControl): @staticmethod def get_empty_store(): return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} def forward(self, attn, is_cross: bool, place_in_unet: str): if self.activate: key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" self.step_store[key].append(attn) return attn def between_steps(self): if self.activate: if len(self.attention_store) == 0: self.attention_store = self.step_store else: for key in self.attention_store: for i in range(len(self.attention_store[key])): self.attention_store[key][i] += self.step_store[key][i] self.step_store = self.get_empty_store() def get_average_attention(self): average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} return average_attention def reset(self): super(AttentionStore, self).reset() self.step_store = self.get_empty_store() self.attention_store = {} def __init__(self): super(AttentionStore, self).__init__() self.step_store = self.get_empty_store() self.attention_store = {} self.activate = True