AttnTrace / src /attribution /perturbation_based.py
SecureLLMSys's picture
init
f214f36
from .attribute import *
import numpy as np
import random
from src.utils import *
import time
from sklearn.linear_model import LinearRegression
from scipy.spatial.distance import cosine
class PerturbationBasedAttribution(Attribution):
def __init__(self, llm,explanation_level = "segment",K=5, attr_type = "tracllm",score_funcs=['stc','loo','denoised_shapley'], sh_N=5,w=2,beta = 0.2,verbose =1):
super().__init__(llm,explanation_level,K,verbose)
self.K=K
self.w = w
self.sh_N = sh_N
self.attr_type = attr_type
self.score_funcs = score_funcs
self.beta = beta
if "gpt" not in self.llm.name:
self.model = llm.model
self.tokenizer = llm.tokenizer
self.func_map = {
"shapley": self.shapley_scores,
"denoised_shapley": self.denoised_shapley_scores,
"stc": self.stc_scores,
"loo": self.loo_scores
}
def marginal_contributions(self, question: str, contexts: list, answer: str) -> list:
"""
Estimate the Shapley values using a Monte Carlo approximation method, handling duplicate contexts.
Each occurrence of a context, even if duplicated, is treated separately.
Parameters:
- contexts: a list of contexts, possibly with duplicates.
- v: a function that takes a list of contexts and returns the total value for that coalition.
- N: the number of random permutations to consider for the approximation.
Returns:
- A list with every context's Shapley value.
"""
k = len(contexts)
# Initialize a list of Shapley values for each context occurrence
shapley_values = [[] for _ in range(k)]
count = 0
for j in range(self.sh_N):
# Generate a random permutation of the indices of the contexts (to handle duplicates properly)
perm_indices = random.sample(range(k), k)
# Calculate the coalition value for the empty set + cf
coalition_value = self.context_value(question, [""], answer)
for i, index in enumerate(perm_indices):
count += 1
# Create the coalition up to the current context (based on its index in the permutation)
coalition = [contexts[idx] for idx in perm_indices[:i + 1]]
coalition = sorted(coalition, key=lambda x: contexts.index(x)) # Sort based on original context order
# Calculate the value for the current coalition
context_value = self.context_value(question, coalition, answer)
marginal_contribution = context_value - coalition_value
# Update the Shapley value for the specific context at this index
shapley_values[index].append(marginal_contribution)
# Update the coalition value for the next iteration
coalition_value = context_value
return shapley_values
def shapley_scores(self, question:str, contexts:list, answer:str) -> list:
"""
Estimate the Shapley values using a Monte Carlo approximation method.
Parameters:
- contexts: a list of contexts.
- v: a function that takes a list of contexts and returns the total value for that coalition.
- N: the number of random permutations to consider for the approximation.
Returns:
- A dictionary with contexts as keys and their approximate Shapley values as values.
- A list with every context's shapley value.
"""
marginal_values= self.marginal_contributions(question, contexts, answer)
shapley_values = np.zeros(len(marginal_values))
for i,value_list in enumerate(marginal_values):
shapley_values[i] = np.mean(value_list)
return shapley_values
def denoised_shapley_scores(self, question:str, contexts:list, answer:str) -> list:
marginal_values = self.marginal_contributions(question, contexts, answer)
new_shapley_values = np.zeros(len(marginal_values))
for i,value_list in enumerate(marginal_values):
new_shapley_values[i] = mean_of_percent(value_list,self.beta)
return new_shapley_values
def stc_scores(self, question:str, contexts:list, answer:str) -> list:
k = len(contexts)
scores = np.zeros(k)
goal_score = self.context_value(question,[''],answer)
for i,text in enumerate(contexts):
scores[i] = (self.context_value(question, [text], answer) - goal_score)
return scores.tolist()
def loo_scores(self, question:str, contexts:list, answer:str) -> list:
k = len(contexts)
scores = np.zeros(k)
v_all = self.context_value(question, contexts, answer)
for i,text in enumerate(contexts):
rest_texts = contexts[:i] + contexts[i+1:]
scores[i] = v_all - self.context_value(question, rest_texts, answer)
return scores.tolist()
def tracllm(self, question:str, contexts:list, answer:str, score_func):
current_nodes =[manual_zip(contexts, list(range(len(contexts))))]
current_nodes_scores = None
def get_important_nodes(nodes,importance_values):
combined = list(zip(nodes, importance_values))
combined_sorted = sorted(combined, key=lambda x: x[1], reverse=True)
# Determine the number of top nodes to keep
k = min(self.K, len(combined))
top_nodes = combined_sorted[:k]
top_nodes_sorted = sorted(top_nodes, key=lambda x: combined.index(x))
# Extract the top k important nodes and their scores in the original order
important_nodes = [node for node, _ in top_nodes_sorted]
important_nodes_scores = [score for _, score in top_nodes_sorted]
return important_nodes, important_nodes_scores
level = 0
while len(current_nodes)>0 and any(len(node) > 1 for node in current_nodes):
level+=1
if self.verbose == 1:
print(f"======= layer: {level}=======")
new_nodes = []
for node in current_nodes:
if len(node)>1:
mid = len(node) // 2
node_left, node_right = node[:mid], node[mid:]
new_nodes.append(node_left)
new_nodes.append(node_right)
else:
new_nodes.append(node)
if len(new_nodes)<= self.K:
current_nodes = new_nodes
else:
importance_values= self.func_map[score_func](question, [" ".join(unzip_tuples(node)[0]) for node in new_nodes], answer)
current_nodes,current_nodes_scores = get_important_nodes(new_nodes,importance_values)
flattened_current_nodes = [item for sublist in current_nodes for item in sublist]
return flattened_current_nodes, current_nodes_scores
def vanilla_explanation(self, question:str, texts:list, answer:str,score_func):
texts_scores = self.func_map[score_func](question, texts, answer)
return texts,texts_scores
def attribute(self, question:str, contexts:list, answer:str):
"""
Given question, contexts and answer, return attribution results
"""
ensemble_list = dict()
texts = split_context(self.explanation_level,contexts)
start_time = time.time()
importance_dict = {}
max_score_func_dict = {}
score_funcs = self.score_funcs
for score_func in score_funcs:
if self.verbose == 1:
print(f"-Start {score_func}")
if score_func == "loo":
weight = self.w
else:
weight = 1
if self.attr_type == "tracllm":
important_nodes,importance_scores = self.tracllm(question, texts, answer,score_func)
important_texts, important_ids = unzip_tuples(important_nodes)
elif self.attr_type== "vanilla_perturb":
important_texts,importance_scores = self.vanilla_explanation(question, texts, answer,score_func)
texts = split_context(self.explanation_level,contexts)
important_ids = [texts.index(text) for text in important_texts]
else:
raise ValueError("Unsupported attr_type.")
ensemble_list[score_func] = list(zip(important_ids,importance_scores))
for idx, important_id in enumerate(important_ids):
if important_id in importance_dict:
if importance_dict[important_id]<weight*importance_scores[idx]:
max_score_func_dict[important_id] = score_func
importance_dict[important_id] = max(importance_dict[important_id],weight*importance_scores[idx])
else:
importance_dict[important_id] = weight*importance_scores[idx]
max_score_func_dict[important_id] = score_func
end_time = time.time()
important_ids = list(importance_dict.keys())
importance_scores = list(importance_dict.values())
return texts,important_ids, importance_scores, end_time-start_time,ensemble_list