Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,479 Bytes
f214f36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
from .attribute import *
import numpy as np
import random
from src.utils import *
import time
from sklearn.linear_model import LinearRegression
from scipy.spatial.distance import cosine
class PerturbationBasedAttribution(Attribution):
def __init__(self, llm,explanation_level = "segment",K=5, attr_type = "tracllm",score_funcs=['stc','loo','denoised_shapley'], sh_N=5,w=2,beta = 0.2,verbose =1):
super().__init__(llm,explanation_level,K,verbose)
self.K=K
self.w = w
self.sh_N = sh_N
self.attr_type = attr_type
self.score_funcs = score_funcs
self.beta = beta
if "gpt" not in self.llm.name:
self.model = llm.model
self.tokenizer = llm.tokenizer
self.func_map = {
"shapley": self.shapley_scores,
"denoised_shapley": self.denoised_shapley_scores,
"stc": self.stc_scores,
"loo": self.loo_scores
}
def marginal_contributions(self, question: str, contexts: list, answer: str) -> list:
"""
Estimate the Shapley values using a Monte Carlo approximation method, handling duplicate contexts.
Each occurrence of a context, even if duplicated, is treated separately.
Parameters:
- contexts: a list of contexts, possibly with duplicates.
- v: a function that takes a list of contexts and returns the total value for that coalition.
- N: the number of random permutations to consider for the approximation.
Returns:
- A list with every context's Shapley value.
"""
k = len(contexts)
# Initialize a list of Shapley values for each context occurrence
shapley_values = [[] for _ in range(k)]
count = 0
for j in range(self.sh_N):
# Generate a random permutation of the indices of the contexts (to handle duplicates properly)
perm_indices = random.sample(range(k), k)
# Calculate the coalition value for the empty set + cf
coalition_value = self.context_value(question, [""], answer)
for i, index in enumerate(perm_indices):
count += 1
# Create the coalition up to the current context (based on its index in the permutation)
coalition = [contexts[idx] for idx in perm_indices[:i + 1]]
coalition = sorted(coalition, key=lambda x: contexts.index(x)) # Sort based on original context order
# Calculate the value for the current coalition
context_value = self.context_value(question, coalition, answer)
marginal_contribution = context_value - coalition_value
# Update the Shapley value for the specific context at this index
shapley_values[index].append(marginal_contribution)
# Update the coalition value for the next iteration
coalition_value = context_value
return shapley_values
def shapley_scores(self, question:str, contexts:list, answer:str) -> list:
"""
Estimate the Shapley values using a Monte Carlo approximation method.
Parameters:
- contexts: a list of contexts.
- v: a function that takes a list of contexts and returns the total value for that coalition.
- N: the number of random permutations to consider for the approximation.
Returns:
- A dictionary with contexts as keys and their approximate Shapley values as values.
- A list with every context's shapley value.
"""
marginal_values= self.marginal_contributions(question, contexts, answer)
shapley_values = np.zeros(len(marginal_values))
for i,value_list in enumerate(marginal_values):
shapley_values[i] = np.mean(value_list)
return shapley_values
def denoised_shapley_scores(self, question:str, contexts:list, answer:str) -> list:
marginal_values = self.marginal_contributions(question, contexts, answer)
new_shapley_values = np.zeros(len(marginal_values))
for i,value_list in enumerate(marginal_values):
new_shapley_values[i] = mean_of_percent(value_list,self.beta)
return new_shapley_values
def stc_scores(self, question:str, contexts:list, answer:str) -> list:
k = len(contexts)
scores = np.zeros(k)
goal_score = self.context_value(question,[''],answer)
for i,text in enumerate(contexts):
scores[i] = (self.context_value(question, [text], answer) - goal_score)
return scores.tolist()
def loo_scores(self, question:str, contexts:list, answer:str) -> list:
k = len(contexts)
scores = np.zeros(k)
v_all = self.context_value(question, contexts, answer)
for i,text in enumerate(contexts):
rest_texts = contexts[:i] + contexts[i+1:]
scores[i] = v_all - self.context_value(question, rest_texts, answer)
return scores.tolist()
def tracllm(self, question:str, contexts:list, answer:str, score_func):
current_nodes =[manual_zip(contexts, list(range(len(contexts))))]
current_nodes_scores = None
def get_important_nodes(nodes,importance_values):
combined = list(zip(nodes, importance_values))
combined_sorted = sorted(combined, key=lambda x: x[1], reverse=True)
# Determine the number of top nodes to keep
k = min(self.K, len(combined))
top_nodes = combined_sorted[:k]
top_nodes_sorted = sorted(top_nodes, key=lambda x: combined.index(x))
# Extract the top k important nodes and their scores in the original order
important_nodes = [node for node, _ in top_nodes_sorted]
important_nodes_scores = [score for _, score in top_nodes_sorted]
return important_nodes, important_nodes_scores
level = 0
while len(current_nodes)>0 and any(len(node) > 1 for node in current_nodes):
level+=1
if self.verbose == 1:
print(f"======= layer: {level}=======")
new_nodes = []
for node in current_nodes:
if len(node)>1:
mid = len(node) // 2
node_left, node_right = node[:mid], node[mid:]
new_nodes.append(node_left)
new_nodes.append(node_right)
else:
new_nodes.append(node)
if len(new_nodes)<= self.K:
current_nodes = new_nodes
else:
importance_values= self.func_map[score_func](question, [" ".join(unzip_tuples(node)[0]) for node in new_nodes], answer)
current_nodes,current_nodes_scores = get_important_nodes(new_nodes,importance_values)
flattened_current_nodes = [item for sublist in current_nodes for item in sublist]
return flattened_current_nodes, current_nodes_scores
def vanilla_explanation(self, question:str, texts:list, answer:str,score_func):
texts_scores = self.func_map[score_func](question, texts, answer)
return texts,texts_scores
def attribute(self, question:str, contexts:list, answer:str):
"""
Given question, contexts and answer, return attribution results
"""
ensemble_list = dict()
texts = split_context(self.explanation_level,contexts)
start_time = time.time()
importance_dict = {}
max_score_func_dict = {}
score_funcs = self.score_funcs
for score_func in score_funcs:
if self.verbose == 1:
print(f"-Start {score_func}")
if score_func == "loo":
weight = self.w
else:
weight = 1
if self.attr_type == "tracllm":
important_nodes,importance_scores = self.tracllm(question, texts, answer,score_func)
important_texts, important_ids = unzip_tuples(important_nodes)
elif self.attr_type== "vanilla_perturb":
important_texts,importance_scores = self.vanilla_explanation(question, texts, answer,score_func)
texts = split_context(self.explanation_level,contexts)
important_ids = [texts.index(text) for text in important_texts]
else:
raise ValueError("Unsupported attr_type.")
ensemble_list[score_func] = list(zip(important_ids,importance_scores))
for idx, important_id in enumerate(important_ids):
if important_id in importance_dict:
if importance_dict[important_id]<weight*importance_scores[idx]:
max_score_func_dict[important_id] = score_func
importance_dict[important_id] = max(importance_dict[important_id],weight*importance_scores[idx])
else:
importance_dict[important_id] = weight*importance_scores[idx]
max_score_func_dict[important_id] = score_func
end_time = time.time()
important_ids = list(importance_dict.keys())
importance_scores = list(importance_dict.values())
return texts,important_ids, importance_scores, end_time-start_time,ensemble_list
|