from .attribute import * import numpy as np from src.utils import * import time import torch.nn.functional as F import gc from src.prompts import wrap_prompt_attention from .attention_utils import * class AvgAttentionAttribution(Attribution): def __init__(self, llm,explanation_level = "segment",K=5, verbose =1): super().__init__(llm,explanation_level,K,verbose) self.model = llm.model # Use float16 for the model self.tokenizer = llm.tokenizer self.layers = range(len(llm.model.model.layers)) self.variant = "default" self.explanation_level = explanation_level def loss_to_importance(self,losses, sentences_id_list): importances = np.zeros(len(sentences_id_list)) for i in range(1,len(losses)): group = np.array(losses[i][0]) last_group = np.array(losses[i-1][0]) group_loss=np.array(losses[i][1]) last_group_loss=np.array(losses[i-1][1]) if len(group)-len(last_group) == 1: feature_index = [item for item in group if item not in last_group] #print(feature_index) #print(last_group,group, last_group_label,group_label) importances[feature_index[0]]+=(last_group_loss-group_loss) print("importances: ",importances) return importances def attribute(self, question: str, contexts: list, answer: str, customized_template: str = None): start_time = time.time() model = self.model tokenizer = self.tokenizer model.eval() # Set model to evaluation mode contexts = split_context(self.explanation_level, contexts) #print("contexts: ", contexts) # Get prompt and target token ids prompt_part1, prompt_part2 = wrap_prompt_attention(question,customized_template) prompt_part1_ids = tokenizer(prompt_part1, return_tensors="pt").input_ids.to(model.device)[0] context_ids_list = [tokenizer(context, return_tensors="pt").input_ids.to(model.device)[0][1:] for context in contexts] prompt_part2_ids = tokenizer(prompt_part2, return_tensors="pt").input_ids.to(model.device)[0] target_ids = tokenizer(answer, return_tensors="pt").input_ids.to(model.device)[0] avg_importance_values = np.zeros(len(context_ids_list)) # Combine prompt and target tokens sampled_context_ids = context_ids_list input_ids = torch.cat([prompt_part1_ids] + sampled_context_ids + [prompt_part2_ids, target_ids], dim=-1).unsqueeze(0) self.context_length = sum(len(context_ids) for context_ids in sampled_context_ids) self.prompt_length = len(prompt_part1_ids) + self.context_length + len(prompt_part2_ids) print("input_ids_shape: ", input_ids.shape) with torch.no_grad(): outputs = model(input_ids, output_hidden_states=True) # Choose the specific layer you want to use #torch.cuda.empty_cache() hidden_states = outputs.hidden_states with torch.no_grad(): batch_size = 1 # Process 4 layers at a time avg_attentions = None # Initialize to None for accumulative average for i in self.layers: attentions = get_attention_weights_one_layer(model, hidden_states, i, attribution_start=self.prompt_length) batch_mean = attentions print(batch_mean.shape) if avg_attentions is None: avg_attentions = batch_mean[:, :, :, len(prompt_part1_ids):len(prompt_part1_ids) + self.context_length] else: avg_attentions += batch_mean[:, :, :, len(prompt_part1_ids):len(prompt_part1_ids) + self.context_length] avg_attentions = (avg_attentions / (len(self.layers) / batch_size)).mean(dim=0).mean(dim=(0, 1)).to(torch.float16) gc.collect() torch.cuda.empty_cache() # Convert attention scores to importance values importance_values = avg_attentions.to(torch.float32).cpu().numpy() print("importance_values_shape", importance_values.shape) # Decode tokens to readable format # Calculate cumulative sums of context lengths context_lengths = [len(context_ids) for context_ids in sampled_context_ids[:-1]] start_positions = np.cumsum([0] + context_lengths) # Calculate mean importance values for each context group group_importance_values = [] for start, context_ids in zip(start_positions, sampled_context_ids): end = start + len(context_ids) values = np.sort(importance_values[start:end]) group_mean = np.mean(values) # Take top k values group_importance_values.append(group_mean) group_importance_values = np.array(group_importance_values) avg_importance_values = group_importance_values print(len(group_importance_values)) # Plot sentence importance top_k_indices = np.argsort(avg_importance_values)[::-1][:self.K] # Get the corresponding importance scores top_k_scores = [avg_importance_values[i] for i in top_k_indices] end_time = time.time() print(f"Topk_indices: {top_k_indices}") print(f"Topk_contexts: {[contexts[i] for i in top_k_indices]}") print(f"Topk_scores: {top_k_scores}") end_time = time.time() gc.collect() torch.cuda.empty_cache() return contexts, top_k_indices, top_k_scores, end_time - start_time, None