SMDL-Attribution / models /submodular_vit_torch.py
RuoyuChen's picture
first commit
4dca37a
import math
import random
import numpy as np
from tqdm import tqdm
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
# import torchvision.transforms as transforms
from itertools import combinations
from collections import OrderedDict
class MultiModalSubModularExplanation(object):
def __init__(self,
model,
semantic_feature,
preproccessing_function,
k = 40,
lambda1 = 1.0,
lambda2 = 1.0,
lambda3 = 1.0,
lambda4 = 1.0,
device = "cuda"):
super(MultiModalSubModularExplanation, self).__init__()
# Parameters of the submodular
self.k = k
self.model = model
self.semantic_feature = semantic_feature
self.preproccessing_function = preproccessing_function
self.lambda1 = lambda1
self.lambda2 = lambda2
self.lambda3 = lambda3
self.lambda4 = lambda4
self.device = device
def partition_collection(self, image_set):
"""
Divide m image elements into n sets
"""
image_set_size = len(image_set)
sample_size_per_partition = image_set_size
image_set_clone = list(image_set)
random.shuffle(image_set_clone)
V_partition = [image_set_clone[i: i + sample_size_per_partition] for i in range(0, image_set_size, sample_size_per_partition)]
assert len(V_partition[0]) == sample_size_per_partition
self.s_size = sample_size_per_partition
# assert image_set_size > sample_size_per_partition * self.k # 其实就是 self.n > self.k ?
return V_partition
def merge_image(self, sub_index_set, partition_image_set):
"""
merge image
"""
sub_image_set_ = np.array(partition_image_set)[sub_index_set]
image = sub_image_set_.sum(0)
return image.astype(np.uint8)
# def compute_effectiveness_score(self, features):
# """
# Computes Eeffectiveness Score: The point should be distant from all the other elements in the subset.
# features: torch.Size(batch, d)
# """
# norm_feature = F.normalize(features, p=2, dim=1)
# # Consine Similarity
# cosine_similarity = torch.mm(norm_feature, norm_feature.t())
# cosine_similarity = torch.clamp(cosine_similarity, min=-1, max=1)
# # Normlize 0-1
# cosine_dist = torch.arccos(cosine_similarity) / math.pi
# if cosine_dist.shape[0] == 1:
# eye = 1 - torch.eye(norm_feature.shape[0], device=self.device)
# masked_dist = cosine_dist * eye
# e_score = torch.sum(torch.min(masked_dist, dim=1).values)
# else:
# eye = torch.eye(norm_feature.shape[0], device=self.device)
# adjusted_cosine_dist = cosine_dist + eye
# e_score = torch.sum(torch.min(adjusted_cosine_dist, dim=1).values)
# return e_score # tensor(0.0343, device='cuda:0')
# def proccess_compute_effectiveness_score_v1(self, components_image_feature, combination_list):
# """
# Compute each S's effectiveness score
# """
# e_scores = []
# for sub_index in combination_list:
# sub_feature_set = components_image_feature[sub_index]
# e_score = self.compute_effectiveness_score(sub_feature_set)
# e_scores.append(e_score)
# return torch.stack(e_scores)
def proccess_compute_confidence_score(self):
"""
Compute confidence score
"""
# visual_features = self.model(batch_input_images)
# predicted_scores = torch.softmax(visual_features @ self.semantic_feature.T, dim=-1)
entropy = - torch.sum(self.predicted_scores * torch.log(self.predicted_scores + 1e-7), dim=1)
max_entropy = torch.log(torch.tensor(self.predicted_scores.shape[1])).to(self.device)
confidence = 1 - (entropy / max_entropy)
return confidence
def proccess_compute_effectiveness_score(self, sub_index_sets):
"""
Compute each S's effectiveness score
"""
e_scores = []
for sub_index in sub_index_sets:
cosine_dist = self.effectiveness_dist[:, np.array(sub_index)] # [len(element) , len(main_set)]
cosine_dist = cosine_dist[np.array(sub_index), :]
eye = torch.eye(cosine_dist.shape[0], device=self.device)
adjusted_cosine_dist = cosine_dist + eye
e_score = torch.sum(torch.min(adjusted_cosine_dist, dim=1).values)
e_scores.append(e_score)
effectiveness_score = torch.stack(e_scores)
if len(sub_index_sets[0]) == 1:
effectiveness_score = effectiveness_score * 0
return effectiveness_score
def proccess_compute_consistency_score(self, batch_input_images):
"""
Compute each consistency score
"""
with torch.no_grad():
visual_features = self.model(batch_input_images)
self.predicted_scores = torch.softmax(visual_features @ self.semantic_feature.T, dim=-1)
consistency_scores = self.predicted_scores[:, self.target_label]
return consistency_scores
def evaluation_maximun_sample(self,
main_set,
candidate_set,
partition_image_set):
"""
Given a subset, return a best sample index
"""
sub_index_sets = []
for candidate_ in candidate_set:
sub_index_sets.append(
np.concatenate((main_set, np.array([candidate_]))).astype(int))
# merge images / 组合图像
sub_images = torch.stack([
self.preproccessing_function(
self.merge_image(sub_index_set, partition_image_set)
) for sub_index_set in sub_index_sets])
batch_input_images = sub_images.to(self.device)
with torch.no_grad():
# 2. Effectiveness Score
score_effectiveness = self.proccess_compute_effectiveness_score(sub_index_sets)
# 3. Consistency Score
score_consistency = self.proccess_compute_consistency_score(batch_input_images)
# 1. Confidence Score
score_confidence = self.proccess_compute_confidence_score()
# 4. Collaboration Score
sub_images_reverse = torch.stack([
self.preproccessing_function(
self.org_img - self.merge_image(sub_index_set, partition_image_set)
) for sub_index_set in sub_index_sets])
batch_input_images_reverse = sub_images_reverse.to(self.device)
score_collaboration = 1 - self.proccess_compute_consistency_score(batch_input_images_reverse)
# 1. Confidence Score
# score_confidence = self.proccess_compute_confidence_score()
# submodular score
smdl_score = self.lambda1 * score_confidence + self.lambda2 * score_effectiveness + self.lambda3 * score_consistency + self.lambda4 * score_collaboration
# smdl_score = self.lambda2 * score_effectiveness + self.lambda3 * score_consistency + self.lambda4 * score_collaboration
arg_max_index = smdl_score.argmax().cpu().item()
# if self.lambda1 != 0:
self.saved_json_file["confidence_score"].append(score_confidence[arg_max_index].cpu().item())
self.saved_json_file["effectiveness_score"].append(score_effectiveness[arg_max_index].cpu().item())
self.saved_json_file["consistency_score"].append(score_consistency[arg_max_index].cpu().item())
self.saved_json_file["collaboration_score"].append(score_collaboration[arg_max_index].cpu().item())
self.saved_json_file["smdl_score"].append(smdl_score[arg_max_index].cpu().item())
return sub_index_sets[arg_max_index]
def save_file_init(self):
self.saved_json_file = {}
self.saved_json_file["sub-k"] = self.k
self.saved_json_file["confidence_score"] = []
self.saved_json_file["effectiveness_score"] = []
self.saved_json_file["consistency_score"] = []
self.saved_json_file["collaboration_score"] = []
self.saved_json_file["smdl_score"] = []
self.saved_json_file["lambda1"] = self.lambda1
self.saved_json_file["lambda2"] = self.lambda2
self.saved_json_file["lambda3"] = self.lambda3
self.saved_json_file["lambda4"] = self.lambda4
def calculate_distance_of_each_element(self, partition_image_set):
"""
Calculate the similarity of each element, obtain a similarity matrix
"""
with torch.no_grad():
partition_images = torch.stack([
self.preproccessing_function(
partition_image
) for partition_image in partition_image_set]).to(self.device)
partition_image_features = self.model(partition_images)
norm_feature = F.normalize(partition_image_features, p=2, dim=1)
# Consine Similarity
cosine_similarity = torch.mm(norm_feature, norm_feature.t())
cosine_similarity = torch.clamp(cosine_similarity, min=-1, max=1)
# Normlize 0-1
self.effectiveness_dist = torch.arccos(cosine_similarity) / math.pi
def get_merge_set(self, partition):
"""
"""
Subset = np.array([])
indexes = np.arange(len(partition))
# First calculate the similarity of each element to facilitate calculation of effectiveness score.
self.calculate_distance_of_each_element(partition)
self.smdl_score_best = 0
for j in tqdm(range(self.k)):
diff = np.setdiff1d(indexes, np.array(Subset)) # in indexes but not in Subset
sub_candidate_indexes = diff
Subset = self.evaluation_maximun_sample(Subset, sub_candidate_indexes, partition)
return Subset
def __call__(self, image_set, id = None):
"""
Compute Source Face Submodular Score
@image_set: [mask_image 1, ..., mask_image m] (cv2 format)
"""
# V_partition = self.partition_collection(image_set) # [ [image1, image2, ...], [image1, image2, ...], ... ]
self.save_file_init()
self.org_img = np.array(image_set).sum(0).astype(np.uint8)
source_image = self.preproccessing_function(self.org_img)
self.source_feature = self.model(source_image.unsqueeze(0).to(self.device))
self.target_label = id
Subset_merge = np.array(image_set)
Submodular_Subset = self.get_merge_set(Subset_merge) # array([17, 42, 49, ...])
submodular_image_set = Subset_merge[Submodular_Subset] # sub_k x (112, 112, 3)
submodular_image = submodular_image_set.sum(0).astype(np.uint8)
self.saved_json_file["smdl_score_max"] = max(self.saved_json_file["smdl_score"])
self.saved_json_file["smdl_score_max_index"] = self.saved_json_file["smdl_score"].index(self.saved_json_file["smdl_score_max"])
return submodular_image, submodular_image_set, self.saved_json_file