AttnTrace / src /attribution /attntrace.py
SecureLLMSys's picture
update
e311fe1
from .attribute import *
import numpy as np
from src.utils import *
import time
import torch.nn.functional as F
import gc
from src.prompts import wrap_prompt_attention
from .attention_utils import *
import spaces
class AttnTraceAttribution(Attribution):
def __init__(self, llm,explanation_level = "segment",K=5, avg_k=5, q=0.4, B=30, verbose =1):
super().__init__(llm,explanation_level,K,verbose)
self.llm = llm # Use float16 for the model
self.model = None
self.model_type = llm.provider
self.tokenizer = llm.tokenizer
self.avg_k = avg_k
self.q = q
self.B = B
self.explanation_level = explanation_level
def loss_to_importance(self,losses, sentences_id_list):
importances = np.zeros(len(sentences_id_list))
for i in range(1,len(losses)):
group = np.array(losses[i][0])
last_group = np.array(losses[i-1][0])
group_loss=np.array(losses[i][1])
last_group_loss=np.array(losses[i-1][1])
if len(group)-len(last_group) == 1:
feature_index = [item for item in group if item not in last_group]
#print(feature_index)
#print(last_group,group, last_group_label,group_label)
importances[feature_index[0]]+=(last_group_loss-group_loss)
return importances
@spaces.GPU
def attribute(self, question: str, contexts: list, answer: str,explained_answer: str, customized_template: str = None):
start_time = time.time()
if self.llm.model!=None:
self.model = self.llm.model
else:
self.model = self.llm._load_model_if_needed().to("cuda")
self.layers = range(len(self.model.model.layers))
model = self.model
tokenizer = self.tokenizer
model.eval() # Set model to evaluation mode
contexts = split_context(self.explanation_level, contexts)
previous_answer = get_previous_answer(answer, explained_answer)
#print("contexts: ", contexts)
# Get prompt and target token ids
prompt_part1, prompt_part2 = wrap_prompt_attention(question,customized_template)
prompt_part1_ids = tokenizer(prompt_part1, return_tensors="pt").input_ids.to(model.device)[0]
context_ids_list = [tokenizer(context, return_tensors="pt").input_ids.to(model.device)[0][1:] for context in contexts]
prompt_part2_ids = tokenizer(prompt_part2, return_tensors="pt").input_ids.to(model.device)[0][1:]
print("previous_answer: ", previous_answer)
print("explained_answer: ", explained_answer)
previous_answer_ids = tokenizer(previous_answer, return_tensors="pt").input_ids.to(model.device)[0][1:]
target_ids = tokenizer(explained_answer, return_tensors="pt").input_ids.to(model.device)[0][1:]
avg_importance_values = np.zeros(len(context_ids_list))
idx_frequency = {idx: 0 for idx in range(len(context_ids_list))}
for t in range(self.B):
# Combine prompt and target tokens
# Randomly subsample half of the context_ids_list
num_samples = int(len(context_ids_list)*self.q)
sampled_indices = np.sort(np.random.permutation(len(context_ids_list))[:num_samples])
sampled_context_ids = [context_ids_list[idx] for idx in sampled_indices]
input_ids = torch.cat([prompt_part1_ids] + sampled_context_ids + [prompt_part2_ids,previous_answer_ids, target_ids], dim=-1).unsqueeze(0)
self.context_length = sum(len(context_ids) for context_ids in sampled_context_ids)
self.prompt_length = len(prompt_part1_ids) + self.context_length + len(prompt_part2_ids)+len(previous_answer_ids)
# Directly calculate the average attention of each answer token to the context tokens to save memory
with torch.no_grad():
outputs = model(input_ids, output_hidden_states=True) # Choose the specific layer you want to use
hidden_states = outputs.hidden_states
with torch.no_grad():
avg_attentions = None # Initialize to None for accumulative average
for i in self.layers:
attentions = get_attention_weights_one_layer(model, hidden_states, i, attribution_start=self.prompt_length,model_type=self.model_type)
batch_mean = attentions
if avg_attentions is None:
avg_attentions = batch_mean[:, :, :, len(prompt_part1_ids):len(prompt_part1_ids) + self.context_length]
else:
avg_attentions += batch_mean[:, :, :, len(prompt_part1_ids):len(prompt_part1_ids) + self.context_length]
avg_attentions = (avg_attentions / (len(self.layers))).mean(dim=0).mean(dim=(0, 1)).to(torch.float16)
importance_values = avg_attentions.to(torch.float32).cpu().numpy()
# Decode tokens to readable format
# Calculate cumulative sums of context lengths
context_lengths = [len(context_ids) for context_ids in sampled_context_ids[:-1]]
start_positions = np.cumsum([0] + context_lengths)
# Calculate mean importance values for each context group
group_importance_values = []
for start, context_ids in zip(start_positions, sampled_context_ids):
end = start + len(context_ids)
values = np.sort(importance_values[start:end])
k = min(self.avg_k, end-start) # Take min of 5 and actual length
group_mean = np.mean(values[-k:]) # Take top k values
group_importance_values.append(group_mean)
group_importance_values = np.array(group_importance_values)
for idx in sampled_indices:
idx_frequency[idx] += 1
for i, idx in enumerate(sampled_indices):
avg_importance_values[idx] += group_importance_values[i]
for i, idx in enumerate(context_ids_list):
if idx_frequency[i] != 0:
avg_importance_values[i] /= idx_frequency[i]
# Plot sentence importance
top_k_indices = np.argsort(avg_importance_values)[::-1][:self.K]
# Get the corresponding importance scores
top_k_scores = [avg_importance_values[i] for i in top_k_indices]
end_time = time.time()
gc.collect()
torch.cuda.empty_cache()
return contexts, top_k_indices, top_k_scores, end_time - start_time, None