File size: 9,479 Bytes
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
from .attribute import *
import numpy as np
import random
from src.utils import *
import time
from sklearn.linear_model import LinearRegression
from scipy.spatial.distance import cosine
class PerturbationBasedAttribution(Attribution):
    def __init__(self, llm,explanation_level = "segment",K=5, attr_type = "tracllm",score_funcs=['stc','loo','denoised_shapley'], sh_N=5,w=2,beta = 0.2,verbose =1):
        super().__init__(llm,explanation_level,K,verbose)
        self.K=K
        self.w = w
        self.sh_N = sh_N
        self.attr_type = attr_type
        self.score_funcs = score_funcs
        self.beta = beta
        if "gpt" not in self.llm.name:
            self.model = llm.model
            self.tokenizer = llm.tokenizer

        self.func_map = {
            "shapley": self.shapley_scores,
            "denoised_shapley": self.denoised_shapley_scores,
            "stc": self.stc_scores,
            "loo": self.loo_scores
        }          


    def marginal_contributions(self, question: str, contexts: list, answer: str) -> list:
        """
        Estimate the Shapley values using a Monte Carlo approximation method, handling duplicate contexts.
        
        Each occurrence of a context, even if duplicated, is treated separately.

        Parameters:
        - contexts: a list of contexts, possibly with duplicates.
        - v: a function that takes a list of contexts and returns the total value for that coalition.
        - N: the number of random permutations to consider for the approximation.

        Returns:
        - A list with every context's Shapley value.
        """

        k = len(contexts)
        
        # Initialize a list of Shapley values for each context occurrence
        shapley_values = [[] for _ in range(k)]
        count = 0

        for j in range(self.sh_N):

            # Generate a random permutation of the indices of the contexts (to handle duplicates properly)
            perm_indices = random.sample(range(k), k)
            
            # Calculate the coalition value for the empty set + cf
            coalition_value = self.context_value(question, [""], answer)
            
            for i, index in enumerate(perm_indices):
                count += 1

                # Create the coalition up to the current context (based on its index in the permutation)
                coalition = [contexts[idx] for idx in perm_indices[:i + 1]]
                coalition = sorted(coalition, key=lambda x: contexts.index(x))  # Sort based on original context order

                # Calculate the value for the current coalition
                context_value = self.context_value(question, coalition, answer)               
                marginal_contribution = context_value - coalition_value

                # Update the Shapley value for the specific context at this index
                shapley_values[index].append(marginal_contribution)
                
                # Update the coalition value for the next iteration
                coalition_value = context_value
        return shapley_values

    def shapley_scores(self, question:str, contexts:list, answer:str) -> list:
        """
        Estimate the Shapley values using a Monte Carlo approximation method.
        Parameters:
        - contexts: a list of contexts.
        - v: a function that takes a list of contexts and returns the total value for that coalition.
        - N: the number of random permutations to consider for the approximation.

        Returns:
        - A dictionary with contexts as keys and their approximate Shapley values as values.
        - A list with every context's shapley value.
        """ 
        marginal_values= self.marginal_contributions(question, contexts, answer)
        shapley_values = np.zeros(len(marginal_values))
        for i,value_list in enumerate(marginal_values):
            shapley_values[i] = np.mean(value_list)

        return shapley_values
 
    def denoised_shapley_scores(self, question:str, contexts:list, answer:str) -> list:
        marginal_values = self.marginal_contributions(question, contexts, answer)
        new_shapley_values = np.zeros(len(marginal_values))
        for i,value_list in enumerate(marginal_values):
            new_shapley_values[i] = mean_of_percent(value_list,self.beta)
        return new_shapley_values
    
    def stc_scores(self, question:str, contexts:list, answer:str) -> list:
        k = len(contexts)
        scores = np.zeros(k)
        goal_score = self.context_value(question,[''],answer)
        for i,text in enumerate(contexts):
            scores[i] = (self.context_value(question, [text], answer) - goal_score)
        return scores.tolist()

    def loo_scores(self, question:str, contexts:list, answer:str) -> list:
        k = len(contexts)
        scores = np.zeros(k)
        v_all = self.context_value(question, contexts, answer)
        for i,text in enumerate(contexts):
            rest_texts = contexts[:i] + contexts[i+1:]
            scores[i] = v_all - self.context_value(question, rest_texts, answer)
        return scores.tolist()
 
    def tracllm(self, question:str, contexts:list, answer:str, score_func):
        current_nodes =[manual_zip(contexts, list(range(len(contexts))))]
        current_nodes_scores = None
        def get_important_nodes(nodes,importance_values):
            combined = list(zip(nodes, importance_values))
            combined_sorted = sorted(combined, key=lambda x: x[1], reverse=True)
            # Determine the number of top nodes to keep
            k = min(self.K, len(combined))
            top_nodes = combined_sorted[:k]
            top_nodes_sorted = sorted(top_nodes, key=lambda x: combined.index(x))

            # Extract the top k important nodes and their scores in the original order
            important_nodes = [node for node, _ in top_nodes_sorted]
            important_nodes_scores = [score for _, score in top_nodes_sorted]
            
            return important_nodes, important_nodes_scores
        level = 0

        while len(current_nodes)>0 and any(len(node) > 1 for node in current_nodes):
            level+=1
            if self.verbose == 1:
                print(f"======= layer: {level}=======")
            new_nodes = []
            for node in current_nodes:
                if len(node)>1:
                    mid = len(node) // 2
                    node_left, node_right = node[:mid], node[mid:]
                    new_nodes.append(node_left)
                    new_nodes.append(node_right)
                else:
                    new_nodes.append(node)
            if len(new_nodes)<= self.K:
                current_nodes = new_nodes   
            else:
                importance_values= self.func_map[score_func](question, [" ".join(unzip_tuples(node)[0]) for node in new_nodes], answer)

                current_nodes,current_nodes_scores = get_important_nodes(new_nodes,importance_values)
        flattened_current_nodes = [item for sublist in current_nodes for item in sublist]
        return flattened_current_nodes, current_nodes_scores

    
    def vanilla_explanation(self, question:str, texts:list, answer:str,score_func):   
        texts_scores  = self.func_map[score_func](question, texts, answer)   
        return texts,texts_scores
    def attribute(self, question:str, contexts:list, answer:str):
        
        """
        Given question, contexts and answer, return attribution results
        """

        ensemble_list = dict()
        texts = split_context(self.explanation_level,contexts)
        start_time = time.time()
        importance_dict = {}
        max_score_func_dict = {}

        score_funcs = self.score_funcs
        for score_func in score_funcs:
            if self.verbose == 1:
                print(f"-Start {score_func}")
            if score_func == "loo":
                weight = self.w
            else:
                weight = 1
            
            if self.attr_type == "tracllm":
                important_nodes,importance_scores = self.tracllm(question, texts, answer,score_func)
                important_texts, important_ids = unzip_tuples(important_nodes)
            elif self.attr_type== "vanilla_perturb":
                important_texts,importance_scores = self.vanilla_explanation(question, texts, answer,score_func)
                texts = split_context(self.explanation_level,contexts)
                important_ids = [texts.index(text) for text in important_texts]
            else:
                raise ValueError("Unsupported attr_type.")      
            
            ensemble_list[score_func] = list(zip(important_ids,importance_scores))
            for idx, important_id in enumerate(important_ids):
                if important_id in importance_dict:
                    if importance_dict[important_id]<weight*importance_scores[idx]:
                        max_score_func_dict[important_id] = score_func
                    importance_dict[important_id] = max(importance_dict[important_id],weight*importance_scores[idx])
                else:
                    importance_dict[important_id] = weight*importance_scores[idx]
                    max_score_func_dict[important_id] = score_func
            
        end_time = time.time()

        important_ids = list(importance_dict.keys())  
        importance_scores = list(importance_dict.values())
        return texts,important_ids, importance_scores, end_time-start_time,ensemble_list