File size: 6,652 Bytes
f214f36
 
 
 
 
 
 
 
ee19553
f214f36
 
 
 
3a7a5c6
 
f214f36
 
 
 
 
3a7a5c6
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fae289
e311fe1
f214f36
 
3a7a5c6
 
 
e311fe1
3a7a5c6
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from .attribute import *
import numpy as np
from src.utils import *
import time
import torch.nn.functional as F
import gc
from src.prompts import wrap_prompt_attention
from .attention_utils import *
import spaces

class AttnTraceAttribution(Attribution):
    def __init__(self,  llm,explanation_level = "segment",K=5, avg_k=5, q=0.4, B=30, verbose =1):
        super().__init__(llm,explanation_level,K,verbose)
        self.llm = llm # Use float16 for the model
        self.model = None
        self.model_type = llm.provider
        self.tokenizer = llm.tokenizer
        self.avg_k = avg_k
        self.q = q
        self.B = B
        
        self.explanation_level = explanation_level
 
    def loss_to_importance(self,losses, sentences_id_list):

        importances = np.zeros(len(sentences_id_list))

        for i in range(1,len(losses)):
            group = np.array(losses[i][0])
            last_group = np.array(losses[i-1][0])
            
            group_loss=np.array(losses[i][1])
            last_group_loss=np.array(losses[i-1][1])
            if len(group)-len(last_group) == 1:
                feature_index = [item for item in group if item not in last_group]
                #print(feature_index)
                #print(last_group,group, last_group_label,group_label)
                importances[feature_index[0]]+=(last_group_loss-group_loss)
        return importances
    @spaces.GPU

    def attribute(self, question: str, contexts: list, answer: str,explained_answer: str, customized_template: str = None):
        start_time = time.time()
        if self.llm.model!=None:
            self.model = self.llm.model
        else:
            self.model = self.llm._load_model_if_needed().to("cuda")
        self.layers = range(len(self.model.model.layers))
        model = self.model
        tokenizer = self.tokenizer
        model.eval()  # Set model to evaluation mode
        contexts = split_context(self.explanation_level, contexts)
        previous_answer = get_previous_answer(answer, explained_answer)
        #print("contexts: ", contexts)
        # Get prompt and target token ids
        prompt_part1, prompt_part2 = wrap_prompt_attention(question,customized_template)
        prompt_part1_ids = tokenizer(prompt_part1, return_tensors="pt").input_ids.to(model.device)[0]
        context_ids_list = [tokenizer(context, return_tensors="pt").input_ids.to(model.device)[0][1:] for context in contexts]

        prompt_part2_ids = tokenizer(prompt_part2, return_tensors="pt").input_ids.to(model.device)[0][1:]
        print("previous_answer: ", previous_answer)
        print("explained_answer: ", explained_answer)
        previous_answer_ids = tokenizer(previous_answer, return_tensors="pt").input_ids.to(model.device)[0][1:]
        target_ids = tokenizer(explained_answer, return_tensors="pt").input_ids.to(model.device)[0][1:]
        avg_importance_values = np.zeros(len(context_ids_list))
        idx_frequency = {idx: 0 for idx in range(len(context_ids_list))}
        for t in range(self.B):
            # Combine prompt and target tokens
            
            # Randomly subsample half of the context_ids_list
            num_samples = int(len(context_ids_list)*self.q)
            sampled_indices = np.sort(np.random.permutation(len(context_ids_list))[:num_samples])
            
            sampled_context_ids = [context_ids_list[idx] for idx in sampled_indices]
            input_ids = torch.cat([prompt_part1_ids] + sampled_context_ids + [prompt_part2_ids,previous_answer_ids, target_ids], dim=-1).unsqueeze(0)
            self.context_length = sum(len(context_ids) for context_ids in sampled_context_ids)
            self.prompt_length = len(prompt_part1_ids) + self.context_length + len(prompt_part2_ids)+len(previous_answer_ids)
            # Directly calculate the average attention of each answer token to the context tokens to save memory

            with torch.no_grad():
                outputs = model(input_ids, output_hidden_states=True)  # Choose the specific layer you want to use
            hidden_states = outputs.hidden_states
            with torch.no_grad():

                avg_attentions = None  # Initialize to None for accumulative average
                for i in self.layers:
                    attentions = get_attention_weights_one_layer(model, hidden_states, i, attribution_start=self.prompt_length,model_type=self.model_type)
                    batch_mean = attentions
                    if avg_attentions is None:
                        avg_attentions = batch_mean[:, :, :, len(prompt_part1_ids):len(prompt_part1_ids) + self.context_length]
                    else:
                        avg_attentions += batch_mean[:, :, :, len(prompt_part1_ids):len(prompt_part1_ids) + self.context_length]
                avg_attentions = (avg_attentions / (len(self.layers))).mean(dim=0).mean(dim=(0, 1)).to(torch.float16)
            
            importance_values = avg_attentions.to(torch.float32).cpu().numpy()

            # Decode tokens to readable format

            # Calculate cumulative sums of context lengths
            context_lengths = [len(context_ids) for context_ids in sampled_context_ids[:-1]]
            start_positions = np.cumsum([0] + context_lengths)
            
            # Calculate mean importance values for each context group
            group_importance_values = []
            for start, context_ids in zip(start_positions, sampled_context_ids):
                end = start + len(context_ids)
                values = np.sort(importance_values[start:end])
                k = min(self.avg_k, end-start) # Take min of 5 and actual length

                group_mean = np.mean(values[-k:]) # Take top k values
                group_importance_values.append(group_mean)
            
            group_importance_values = np.array(group_importance_values)

            
            for idx in sampled_indices:
                idx_frequency[idx] += 1

            for i, idx in enumerate(sampled_indices):
                avg_importance_values[idx] += group_importance_values[i]

        for i, idx in enumerate(context_ids_list):
            if idx_frequency[i] != 0:
                avg_importance_values[i] /= idx_frequency[i]

        # Plot sentence importance
        top_k_indices = np.argsort(avg_importance_values)[::-1][:self.K]
        # Get the corresponding importance scores
        top_k_scores = [avg_importance_values[i] for i in top_k_indices]
        
        end_time = time.time()

        gc.collect()
        torch.cuda.empty_cache()
        return contexts, top_k_indices, top_k_scores, end_time - start_time, None