Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| import math | |
| from PIL import Image, ImageDraw, ImageFont | |
| import logging | |
| import os | |
| import pandas as pd | |
| import csv | |
| import pickle | |
| import numpy as np | |
| from torch.nn import BCELoss | |
| from torch.nn import functional as F | |
| import math | |
| import numbers | |
| from typing import List | |
| def get_all_attention_64(attn_maps_down, attn_maps_mid , attn_maps_up, res = 16): | |
| result = [] | |
| for attn_map_integrated in attn_maps_up: | |
| if attn_map_integrated == []: continue | |
| attn_map = attn_map_integrated.squeeze(0) | |
| # print(attn_map.shape) | |
| b, i, j = attn_map.shape | |
| H = W = int(math.sqrt(i)) | |
| # print(H) | |
| if H == res: | |
| item = attn_map.reshape(-1, res, res, attn_map.shape[-1] ) | |
| item = item.permute(0, 3, 1, 2) | |
| item = F.interpolate(item, 64, mode='bilinear').permute(0, 2, 3, 1) | |
| result.append(item) | |
| for attn_map_integrated in attn_maps_mid: | |
| attn_map = attn_map_integrated.squeeze(0) | |
| b, i, j = attn_map.shape | |
| H = W = int(math.sqrt(i)) | |
| # print(H) | |
| if (H==8): | |
| item = attn_map.reshape(-1, 8, 8, attn_map.shape[-1] ) | |
| item = item.permute(0, 3, 1, 2) | |
| item = F.interpolate(item, 64, mode='bilinear').permute(0, 2, 3, 1) | |
| result.append(item) | |
| for attn_map_integrated in attn_maps_down: | |
| if attn_map_integrated == []: continue | |
| attn_map = attn_map_integrated.squeeze(0) | |
| if attn_map == []: continue | |
| b, i, j = attn_map.shape | |
| H = W = int(math.sqrt(i)) | |
| if H == res: | |
| item = attn_map.reshape(-1, res, res, attn_map.shape[-1] ) | |
| item = item.permute(0, 3, 1, 2) | |
| item = F.interpolate(item, 64, mode='bilinear').permute(0, 2, 3, 1) | |
| result.append(item) | |
| # print('RES LENGTH', len(result)) | |
| # for maps in result: | |
| # print(maps.shape) | |
| result = torch.cat(result, dim=0) | |
| result = result.sum(0) / result.shape[0] | |
| return result | |
| def compute_loco_v2(attn_maps_down, attn_maps_mid, attn_maps_up, bboxes, object_positions, smooth_attn=True, topk = 0.8): | |
| loss = 0. | |
| pad_loss = 0. | |
| total_fg_map = torch.zeros(size=(64, 64)).cuda() | |
| alpha = 0.2 | |
| beta = 0.8 | |
| object_number = len(bboxes) | |
| if object_number == 0: | |
| return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float() | |
| attn16 = get_all_attention_64(attn_maps_down[-1]+ attn_maps_down[-2], attn_maps_mid, attn_maps_up[0]+attn_maps_up[1], 16) | |
| all_attn = [attn16] | |
| max_loss = 0 | |
| for attn_map in all_attn: | |
| sum_in = 0. | |
| sum_out = 0. | |
| i, j, k = attn_map.shape | |
| H = W = i | |
| for obj_idx in range(object_number): | |
| obj_loss = 0 | |
| mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W)) | |
| for obj_box in bboxes[obj_idx]: | |
| x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ | |
| int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) | |
| mask[y_min: y_max, x_min: x_max] = 1 | |
| total_fg_map[y_min: y_max, x_min: x_max] = 1 | |
| for obj_position in [object_positions[obj_idx]]: | |
| ca_map_obj = attn_map[:, :, obj_position].sum(-1) | |
| ca_map_obj = ca_map_obj.reshape(H, W) | |
| norm_ca_map_obj = ca_map_obj / ca_map_obj.max() | |
| norm_ca_map_obj = norm_ca_map_obj.reshape(H, W) | |
| sum_in += (norm_ca_map_obj * mask).sum() | |
| sum_out += (norm_ca_map_obj * (1 - mask)).sum() | |
| loss += (obj_loss/len(object_positions[obj_idx])) | |
| sot_map = attn_map[:, :, 0].reshape(H, W) | |
| eot_map = attn_map[:, :, -1].reshape(H, W) | |
| norm_sot_map = (1 - sot_map) / (1 - sot_map).max() | |
| norm_eot_map = eot_map / eot_map.max() | |
| pad_map = beta * norm_sot_map + (1 - beta) * norm_eot_map | |
| total_fg_mask = total_fg_map | |
| fg_map = pad_map * total_fg_mask | |
| bce_loss = F.binary_cross_entropy(torch.sigmoid(pad_map.to(torch.float16).reshape(-1)), fg_map.to(torch.float16).reshape(-1)) | |
| pad_loss += bce_loss | |
| loss += (1 - sum_in / (sum_in + sum_out)) ** 2 | |
| return loss + alpha * pad_loss | |
| def compute_ca_loss(attn_maps_mid, attn_maps_up, bboxes, object_positions): | |
| loss = 0 | |
| object_number = len(bboxes) | |
| if object_number == 0: | |
| return torch.tensor(0).float().cuda() | |
| for attn_map_integrated in attn_maps_mid: | |
| attn_map = attn_map_integrated.chunk(2)[1] | |
| # | |
| b, i, j = attn_map.shape | |
| H = W = int(math.sqrt(i)) | |
| for obj_idx in range(object_number): | |
| obj_loss = 0 | |
| mask = torch.zeros(size=(H, W)).cuda() | |
| for obj_box in bboxes[obj_idx]: | |
| x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ | |
| int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) | |
| mask[y_min: y_max, x_min: x_max] = 1 | |
| for obj_position in object_positions[obj_idx]: | |
| ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) | |
| activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1) | |
| obj_loss += torch.mean((1 - activation_value) ** 2) | |
| loss += (obj_loss/len(object_positions[obj_idx])) | |
| # compute loss on padding tokens | |
| # activation_value = torch.zeros(size=(b, )).cuda() | |
| # for obj_idx in range(object_number): | |
| # bbox = bboxes[obj_idx] | |
| # ca_map_obj = attn_map[:, :, padding_start:].reshape(b, H, W, -1) | |
| # activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H), | |
| # int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1) | |
| # | |
| # loss += torch.mean((1 - activation_value) ** 2) | |
| for attn_map_integrated in attn_maps_up[0]: | |
| attn_map = attn_map_integrated.chunk(2)[1] | |
| # | |
| b, i, j = attn_map.shape | |
| H = W = int(math.sqrt(i)) | |
| for obj_idx in range(object_number): | |
| obj_loss = 0 | |
| mask = torch.zeros(size=(H, W)).cuda() | |
| for obj_box in bboxes[obj_idx]: | |
| x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ | |
| int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) | |
| mask[y_min: y_max, x_min: x_max] = 1 | |
| for obj_position in object_positions[obj_idx]: | |
| ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) | |
| # ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W) | |
| activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum( | |
| dim=-1) | |
| obj_loss += torch.mean((1 - activation_value) ** 2) | |
| loss += (obj_loss / len(object_positions[obj_idx])) | |
| # compute loss on padding tokens | |
| # activation_value = torch.zeros(size=(b, )).cuda() | |
| # for obj_idx in range(object_number): | |
| # bbox = bboxes[obj_idx] | |
| # ca_map_obj = attn_map[:, :,padding_start:].reshape(b, H, W, -1) | |
| # activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H), | |
| # int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1) | |
| # | |
| # loss += torch.mean((1 - activation_value) ** 2) | |
| loss = loss / (object_number * (len(attn_maps_up[0]) + len(attn_maps_mid))) | |
| return loss |