File size: 5,582 Bytes
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from .attribute import *
import numpy as np
from src.utils import *
import time
import torch.nn.functional as F
import gc
from src.prompts import wrap_prompt_attention
from .attention_utils import *

class AvgAttentionAttribution(Attribution):
    def __init__(self,  llm,explanation_level = "segment",K=5, verbose =1):
        super().__init__(llm,explanation_level,K,verbose)
        self.model = llm.model # Use float16 for the model
        self.tokenizer = llm.tokenizer
        self.layers = range(len(llm.model.model.layers))
        self.variant = "default"
        self.explanation_level = explanation_level
 
    def loss_to_importance(self,losses, sentences_id_list):

        importances = np.zeros(len(sentences_id_list))

        for i in range(1,len(losses)):
            group = np.array(losses[i][0])
            last_group = np.array(losses[i-1][0])
            
            group_loss=np.array(losses[i][1])
            last_group_loss=np.array(losses[i-1][1])
            if len(group)-len(last_group) == 1:
                feature_index = [item for item in group if item not in last_group]
                #print(feature_index)
                #print(last_group,group, last_group_label,group_label)
                importances[feature_index[0]]+=(last_group_loss-group_loss)
        print("importances: ",importances)
        return importances
    def attribute(self, question: str, contexts: list, answer: str, customized_template: str = None):
        start_time = time.time()
        model = self.model
        tokenizer = self.tokenizer
        model.eval()  # Set model to evaluation mode
        contexts = split_context(self.explanation_level, contexts)
        #print("contexts: ", contexts)
        # Get prompt and target token ids
        prompt_part1, prompt_part2 = wrap_prompt_attention(question,customized_template)
        prompt_part1_ids = tokenizer(prompt_part1, return_tensors="pt").input_ids.to(model.device)[0]
        context_ids_list = [tokenizer(context, return_tensors="pt").input_ids.to(model.device)[0][1:] for context in contexts]
        prompt_part2_ids = tokenizer(prompt_part2, return_tensors="pt").input_ids.to(model.device)[0]
        target_ids = tokenizer(answer, return_tensors="pt").input_ids.to(model.device)[0]
        avg_importance_values = np.zeros(len(context_ids_list))

        # Combine prompt and target tokens
        
        sampled_context_ids = context_ids_list
        input_ids = torch.cat([prompt_part1_ids] + sampled_context_ids + [prompt_part2_ids, target_ids], dim=-1).unsqueeze(0)
        self.context_length = sum(len(context_ids) for context_ids in sampled_context_ids)
        self.prompt_length = len(prompt_part1_ids) + self.context_length + len(prompt_part2_ids)

        print("input_ids_shape: ", input_ids.shape)
        with torch.no_grad():
            outputs = model(input_ids, output_hidden_states=True)  # Choose the specific layer you want to use
            #torch.cuda.empty_cache()
        hidden_states = outputs.hidden_states
        with torch.no_grad():
            batch_size = 1  # Process 4 layers at a time
            avg_attentions = None  # Initialize to None for accumulative average
            for i in self.layers:
                attentions = get_attention_weights_one_layer(model, hidden_states, i, attribution_start=self.prompt_length)
                batch_mean = attentions
                print(batch_mean.shape)
                if avg_attentions is None:
                    avg_attentions = batch_mean[:, :, :, len(prompt_part1_ids):len(prompt_part1_ids) + self.context_length]
                else:
                    avg_attentions += batch_mean[:, :, :, len(prompt_part1_ids):len(prompt_part1_ids) + self.context_length]
            avg_attentions = (avg_attentions / (len(self.layers) / batch_size)).mean(dim=0).mean(dim=(0, 1)).to(torch.float16)
            
        gc.collect()
        torch.cuda.empty_cache()
        # Convert attention scores to importance values
        importance_values = avg_attentions.to(torch.float32).cpu().numpy()
        print("importance_values_shape", importance_values.shape)

        # Decode tokens to readable format

        # Calculate cumulative sums of context lengths
        context_lengths = [len(context_ids) for context_ids in sampled_context_ids[:-1]]
        start_positions = np.cumsum([0] + context_lengths)
        
        # Calculate mean importance values for each context group
        group_importance_values = []
        for start, context_ids in zip(start_positions, sampled_context_ids):
            end = start + len(context_ids)
            values = np.sort(importance_values[start:end])
            group_mean = np.mean(values) # Take top k values
            group_importance_values.append(group_mean)
        
        group_importance_values = np.array(group_importance_values)

        avg_importance_values = group_importance_values
        print(len(group_importance_values))

        # Plot sentence importance
        top_k_indices = np.argsort(avg_importance_values)[::-1][:self.K]
        # Get the corresponding importance scores
        top_k_scores = [avg_importance_values[i] for i in top_k_indices]
        
        end_time = time.time()
        print(f"Topk_indices: {top_k_indices}")
        print(f"Topk_contexts: {[contexts[i] for i in top_k_indices]}")
        print(f"Topk_scores: {top_k_scores}")

        end_time = time.time()
        gc.collect()
        torch.cuda.empty_cache()
        return contexts, top_k_indices, top_k_scores, end_time - start_time, None