Spaces:
Running
on
Zero
Running
on
Zero
from src.prompts import wrap_prompt | |
import torch | |
import math | |
from src.utils import * | |
from nltk.translate.bleu_score import sentence_bleu | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
class Attribution: | |
def __init__(self,llm,explanation_level,K,verbose): | |
self.llm = llm | |
self.explanation_level = explanation_level | |
self.verbose = verbose | |
self.K = K | |
def attribute(self): | |
pass | |
def context_value(self, question:str, contexts:list, answer:str) -> float: | |
if "gpt" in self.llm.name: # use BLEU score for black-box models | |
prompt = wrap_prompt(question, contexts) | |
new_answer =self.llm.query(prompt) | |
reference_tokens = answer.split() | |
candidate_tokens = new_answer.split() | |
# Calculate BLEU score | |
similarity = sentence_bleu([reference_tokens], candidate_tokens) | |
return similarity | |
else: | |
# First, encode the prompt and answer separately | |
prompt = wrap_prompt(question, contexts) | |
#print("prompt:", prompt) | |
prompt_ids = self.tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=True).to(self.model.device) | |
answer_ids = self.tokenizer.encode(answer, return_tensors='pt', add_special_tokens=False).to(self.model.device) | |
# Aggregate token_ids by concatenating prompt_ids and answer_ids | |
combined_ids = torch.cat([prompt_ids, answer_ids], dim=1) | |
# Compute the start position of the answer | |
response_start_pos = prompt_ids.shape[1]-1 | |
#print("Response start position: ", response_start_pos) | |
# Run the model with the combined input IDs | |
with torch.no_grad(): | |
outputs = self.model(combined_ids) | |
logits = outputs.logits | |
# Shift logits and labels to align them | |
shift_logits = logits[:, :-1, :] | |
shift_labels = combined_ids[:, 1:] | |
# Compute probabilities using softmax | |
probs = torch.softmax(shift_logits, dim=-1) | |
# Extract the probabilities corresponding to the correct next tokens | |
response_probs = torch.gather(probs, 2, shift_labels.unsqueeze(-1)).squeeze(-1) | |
response_log_probs = torch.log(response_probs[0, response_start_pos:]) | |
# Compute the total log probability (value) | |
value = torch.sum(response_log_probs).item() | |
# Handle infinity values | |
if math.isinf(value): | |
value = -1000.0 | |
return value | |
def visualize_results(self,texts,question,answer, important_ids,importance_scores, width = 200): | |
#Only visualize top-K | |
topk_ids,topk_scores = get_top_k(important_ids, importance_scores, self.K) | |
plot_sentence_importance(question, texts, topk_ids, topk_scores, answer, width = width) | |
def visualize_score_func_contribution(self,important_ids,importance_scores,ensemble_list): | |
important_ids,importance_scores = get_top_k(important_ids, importance_scores, self.K) | |
# Calculate the contribution of each score function | |
score_func_contributions = {func: 0 for func in ensemble_list.keys()} | |
for important_id in important_ids: | |
max_score = 0 | |
for score_func in ensemble_list.keys(): | |
for id, score in ensemble_list[score_func]: | |
if id == important_id: | |
if score > max_score: | |
max_score = score | |
max_score_func = score_func | |
break # Exit the loop once the id is found | |
score_func_contributions[max_score_func] += 1 | |
plt.figure(figsize=(10, 6)) | |
bar_width = 0.3 # Set the bar width to be thinner | |
plt.bar(score_func_contributions.keys(), score_func_contributions.values(), width=bar_width, color='skyblue') | |
plt.xlabel('Score Function', fontsize=14) # Increase font size | |
plt.ylabel('Number of Important Texts', fontsize=14) # Increase font size | |
plt.title('Contribution of Each Score Function', fontsize=16) # Increase font size | |
plt.xticks(rotation=45, fontsize=13) # Increase font size for x-ticks | |
plt.yticks(fontsize=13) # Increase font size for y-ticks | |
plt.tight_layout() | |
plt.show() | |
def get_data_frame(self,texts,important_ids,importance_scores): | |
important_ids,importance_scores = get_top_k(important_ids, importance_scores, self.K) | |
data = { | |
'Important Texts': [texts[id] for id in important_ids], | |
'Important IDs': important_ids, | |
'Importance Score': importance_scores | |
} | |
df = pd.DataFrame(data) | |
df.style | |
return df | |