Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,652 Bytes
f214f36 ee19553 f214f36 3a7a5c6 f214f36 3a7a5c6 f214f36 2fae289 e311fe1 f214f36 3a7a5c6 e311fe1 3a7a5c6 f214f36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from .attribute import *
import numpy as np
from src.utils import *
import time
import torch.nn.functional as F
import gc
from src.prompts import wrap_prompt_attention
from .attention_utils import *
import spaces
class AttnTraceAttribution(Attribution):
def __init__(self, llm,explanation_level = "segment",K=5, avg_k=5, q=0.4, B=30, verbose =1):
super().__init__(llm,explanation_level,K,verbose)
self.llm = llm # Use float16 for the model
self.model = None
self.model_type = llm.provider
self.tokenizer = llm.tokenizer
self.avg_k = avg_k
self.q = q
self.B = B
self.explanation_level = explanation_level
def loss_to_importance(self,losses, sentences_id_list):
importances = np.zeros(len(sentences_id_list))
for i in range(1,len(losses)):
group = np.array(losses[i][0])
last_group = np.array(losses[i-1][0])
group_loss=np.array(losses[i][1])
last_group_loss=np.array(losses[i-1][1])
if len(group)-len(last_group) == 1:
feature_index = [item for item in group if item not in last_group]
#print(feature_index)
#print(last_group,group, last_group_label,group_label)
importances[feature_index[0]]+=(last_group_loss-group_loss)
return importances
@spaces.GPU
def attribute(self, question: str, contexts: list, answer: str,explained_answer: str, customized_template: str = None):
start_time = time.time()
if self.llm.model!=None:
self.model = self.llm.model
else:
self.model = self.llm._load_model_if_needed().to("cuda")
self.layers = range(len(self.model.model.layers))
model = self.model
tokenizer = self.tokenizer
model.eval() # Set model to evaluation mode
contexts = split_context(self.explanation_level, contexts)
previous_answer = get_previous_answer(answer, explained_answer)
#print("contexts: ", contexts)
# Get prompt and target token ids
prompt_part1, prompt_part2 = wrap_prompt_attention(question,customized_template)
prompt_part1_ids = tokenizer(prompt_part1, return_tensors="pt").input_ids.to(model.device)[0]
context_ids_list = [tokenizer(context, return_tensors="pt").input_ids.to(model.device)[0][1:] for context in contexts]
prompt_part2_ids = tokenizer(prompt_part2, return_tensors="pt").input_ids.to(model.device)[0][1:]
print("previous_answer: ", previous_answer)
print("explained_answer: ", explained_answer)
previous_answer_ids = tokenizer(previous_answer, return_tensors="pt").input_ids.to(model.device)[0][1:]
target_ids = tokenizer(explained_answer, return_tensors="pt").input_ids.to(model.device)[0][1:]
avg_importance_values = np.zeros(len(context_ids_list))
idx_frequency = {idx: 0 for idx in range(len(context_ids_list))}
for t in range(self.B):
# Combine prompt and target tokens
# Randomly subsample half of the context_ids_list
num_samples = int(len(context_ids_list)*self.q)
sampled_indices = np.sort(np.random.permutation(len(context_ids_list))[:num_samples])
sampled_context_ids = [context_ids_list[idx] for idx in sampled_indices]
input_ids = torch.cat([prompt_part1_ids] + sampled_context_ids + [prompt_part2_ids,previous_answer_ids, target_ids], dim=-1).unsqueeze(0)
self.context_length = sum(len(context_ids) for context_ids in sampled_context_ids)
self.prompt_length = len(prompt_part1_ids) + self.context_length + len(prompt_part2_ids)+len(previous_answer_ids)
# Directly calculate the average attention of each answer token to the context tokens to save memory
with torch.no_grad():
outputs = model(input_ids, output_hidden_states=True) # Choose the specific layer you want to use
hidden_states = outputs.hidden_states
with torch.no_grad():
avg_attentions = None # Initialize to None for accumulative average
for i in self.layers:
attentions = get_attention_weights_one_layer(model, hidden_states, i, attribution_start=self.prompt_length,model_type=self.model_type)
batch_mean = attentions
if avg_attentions is None:
avg_attentions = batch_mean[:, :, :, len(prompt_part1_ids):len(prompt_part1_ids) + self.context_length]
else:
avg_attentions += batch_mean[:, :, :, len(prompt_part1_ids):len(prompt_part1_ids) + self.context_length]
avg_attentions = (avg_attentions / (len(self.layers))).mean(dim=0).mean(dim=(0, 1)).to(torch.float16)
importance_values = avg_attentions.to(torch.float32).cpu().numpy()
# Decode tokens to readable format
# Calculate cumulative sums of context lengths
context_lengths = [len(context_ids) for context_ids in sampled_context_ids[:-1]]
start_positions = np.cumsum([0] + context_lengths)
# Calculate mean importance values for each context group
group_importance_values = []
for start, context_ids in zip(start_positions, sampled_context_ids):
end = start + len(context_ids)
values = np.sort(importance_values[start:end])
k = min(self.avg_k, end-start) # Take min of 5 and actual length
group_mean = np.mean(values[-k:]) # Take top k values
group_importance_values.append(group_mean)
group_importance_values = np.array(group_importance_values)
for idx in sampled_indices:
idx_frequency[idx] += 1
for i, idx in enumerate(sampled_indices):
avg_importance_values[idx] += group_importance_values[i]
for i, idx in enumerate(context_ids_list):
if idx_frequency[i] != 0:
avg_importance_values[i] /= idx_frequency[i]
# Plot sentence importance
top_k_indices = np.argsort(avg_importance_values)[::-1][:self.K]
# Get the corresponding importance scores
top_k_scores = [avg_importance_values[i] for i in top_k_indices]
end_time = time.time()
gc.collect()
torch.cuda.empty_cache()
return contexts, top_k_indices, top_k_scores, end_time - start_time, None
|