from .attribute import * import numpy as np import random from src.utils import * import time from sklearn.linear_model import LinearRegression from scipy.spatial.distance import cosine class PerturbationBasedAttribution(Attribution): def __init__(self, llm,explanation_level = "segment",K=5, attr_type = "tracllm",score_funcs=['stc','loo','denoised_shapley'], sh_N=5,w=2,beta = 0.2,verbose =1): super().__init__(llm,explanation_level,K,verbose) self.K=K self.w = w self.sh_N = sh_N self.attr_type = attr_type self.score_funcs = score_funcs self.beta = beta if "gpt" not in self.llm.name: self.model = llm.model self.tokenizer = llm.tokenizer self.func_map = { "shapley": self.shapley_scores, "denoised_shapley": self.denoised_shapley_scores, "stc": self.stc_scores, "loo": self.loo_scores } def marginal_contributions(self, question: str, contexts: list, answer: str) -> list: """ Estimate the Shapley values using a Monte Carlo approximation method, handling duplicate contexts. Each occurrence of a context, even if duplicated, is treated separately. Parameters: - contexts: a list of contexts, possibly with duplicates. - v: a function that takes a list of contexts and returns the total value for that coalition. - N: the number of random permutations to consider for the approximation. Returns: - A list with every context's Shapley value. """ k = len(contexts) # Initialize a list of Shapley values for each context occurrence shapley_values = [[] for _ in range(k)] count = 0 for j in range(self.sh_N): # Generate a random permutation of the indices of the contexts (to handle duplicates properly) perm_indices = random.sample(range(k), k) # Calculate the coalition value for the empty set + cf coalition_value = self.context_value(question, [""], answer) for i, index in enumerate(perm_indices): count += 1 # Create the coalition up to the current context (based on its index in the permutation) coalition = [contexts[idx] for idx in perm_indices[:i + 1]] coalition = sorted(coalition, key=lambda x: contexts.index(x)) # Sort based on original context order # Calculate the value for the current coalition context_value = self.context_value(question, coalition, answer) marginal_contribution = context_value - coalition_value # Update the Shapley value for the specific context at this index shapley_values[index].append(marginal_contribution) # Update the coalition value for the next iteration coalition_value = context_value return shapley_values def shapley_scores(self, question:str, contexts:list, answer:str) -> list: """ Estimate the Shapley values using a Monte Carlo approximation method. Parameters: - contexts: a list of contexts. - v: a function that takes a list of contexts and returns the total value for that coalition. - N: the number of random permutations to consider for the approximation. Returns: - A dictionary with contexts as keys and their approximate Shapley values as values. - A list with every context's shapley value. """ marginal_values= self.marginal_contributions(question, contexts, answer) shapley_values = np.zeros(len(marginal_values)) for i,value_list in enumerate(marginal_values): shapley_values[i] = np.mean(value_list) return shapley_values def denoised_shapley_scores(self, question:str, contexts:list, answer:str) -> list: marginal_values = self.marginal_contributions(question, contexts, answer) new_shapley_values = np.zeros(len(marginal_values)) for i,value_list in enumerate(marginal_values): new_shapley_values[i] = mean_of_percent(value_list,self.beta) return new_shapley_values def stc_scores(self, question:str, contexts:list, answer:str) -> list: k = len(contexts) scores = np.zeros(k) goal_score = self.context_value(question,[''],answer) for i,text in enumerate(contexts): scores[i] = (self.context_value(question, [text], answer) - goal_score) return scores.tolist() def loo_scores(self, question:str, contexts:list, answer:str) -> list: k = len(contexts) scores = np.zeros(k) v_all = self.context_value(question, contexts, answer) for i,text in enumerate(contexts): rest_texts = contexts[:i] + contexts[i+1:] scores[i] = v_all - self.context_value(question, rest_texts, answer) return scores.tolist() def tracllm(self, question:str, contexts:list, answer:str, score_func): current_nodes =[manual_zip(contexts, list(range(len(contexts))))] current_nodes_scores = None def get_important_nodes(nodes,importance_values): combined = list(zip(nodes, importance_values)) combined_sorted = sorted(combined, key=lambda x: x[1], reverse=True) # Determine the number of top nodes to keep k = min(self.K, len(combined)) top_nodes = combined_sorted[:k] top_nodes_sorted = sorted(top_nodes, key=lambda x: combined.index(x)) # Extract the top k important nodes and their scores in the original order important_nodes = [node for node, _ in top_nodes_sorted] important_nodes_scores = [score for _, score in top_nodes_sorted] return important_nodes, important_nodes_scores level = 0 while len(current_nodes)>0 and any(len(node) > 1 for node in current_nodes): level+=1 if self.verbose == 1: print(f"======= layer: {level}=======") new_nodes = [] for node in current_nodes: if len(node)>1: mid = len(node) // 2 node_left, node_right = node[:mid], node[mid:] new_nodes.append(node_left) new_nodes.append(node_right) else: new_nodes.append(node) if len(new_nodes)<= self.K: current_nodes = new_nodes else: importance_values= self.func_map[score_func](question, [" ".join(unzip_tuples(node)[0]) for node in new_nodes], answer) current_nodes,current_nodes_scores = get_important_nodes(new_nodes,importance_values) flattened_current_nodes = [item for sublist in current_nodes for item in sublist] return flattened_current_nodes, current_nodes_scores def vanilla_explanation(self, question:str, texts:list, answer:str,score_func): texts_scores = self.func_map[score_func](question, texts, answer) return texts,texts_scores def attribute(self, question:str, contexts:list, answer:str): """ Given question, contexts and answer, return attribution results """ ensemble_list = dict() texts = split_context(self.explanation_level,contexts) start_time = time.time() importance_dict = {} max_score_func_dict = {} score_funcs = self.score_funcs for score_func in score_funcs: if self.verbose == 1: print(f"-Start {score_func}") if score_func == "loo": weight = self.w else: weight = 1 if self.attr_type == "tracllm": important_nodes,importance_scores = self.tracllm(question, texts, answer,score_func) important_texts, important_ids = unzip_tuples(important_nodes) elif self.attr_type== "vanilla_perturb": important_texts,importance_scores = self.vanilla_explanation(question, texts, answer,score_func) texts = split_context(self.explanation_level,contexts) important_ids = [texts.index(text) for text in important_texts] else: raise ValueError("Unsupported attr_type.") ensemble_list[score_func] = list(zip(important_ids,importance_scores)) for idx, important_id in enumerate(important_ids): if important_id in importance_dict: if importance_dict[important_id]