File size: 4,820 Bytes
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from src.prompts import wrap_prompt
import torch
import math
from src.utils import *
from nltk.translate.bleu_score import sentence_bleu
import pandas as pd
import matplotlib.pyplot as plt
class Attribution:
    def __init__(self,llm,explanation_level,K,verbose):
        self.llm = llm
        self.explanation_level = explanation_level
        self.verbose = verbose
        self.K = K
    def attribute(self):
        pass
            
    def context_value(self, question:str, contexts:list, answer:str) -> float:
        if "gpt" in self.llm.name: # use BLEU score for black-box models
            prompt = wrap_prompt(question, contexts)
            new_answer =self.llm.query(prompt)
            reference_tokens = answer.split()
            candidate_tokens = new_answer.split()

            # Calculate BLEU score
            similarity = sentence_bleu([reference_tokens], candidate_tokens)
            return similarity
        else:
            # First, encode the prompt and answer separately
            prompt = wrap_prompt(question, contexts)
            #print("prompt:", prompt)
            prompt_ids = self.tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=True).to(self.model.device)
            answer_ids = self.tokenizer.encode(answer, return_tensors='pt', add_special_tokens=False).to(self.model.device)

            # Aggregate token_ids by concatenating prompt_ids and answer_ids
            combined_ids = torch.cat([prompt_ids, answer_ids], dim=1)

            # Compute the start position of the answer
            response_start_pos = prompt_ids.shape[1]-1
            #print("Response start position: ", response_start_pos)

            # Run the model with the combined input IDs
            with torch.no_grad():
                outputs = self.model(combined_ids)
                logits = outputs.logits

            # Shift logits and labels to align them
            shift_logits = logits[:, :-1, :]
            shift_labels = combined_ids[:, 1:]

            # Compute probabilities using softmax
            probs = torch.softmax(shift_logits, dim=-1)
            
            # Extract the probabilities corresponding to the correct next tokens
            response_probs = torch.gather(probs, 2, shift_labels.unsqueeze(-1)).squeeze(-1)
            response_log_probs = torch.log(response_probs[0, response_start_pos:])

            # Compute the total log probability (value)
            value = torch.sum(response_log_probs).item()

            # Handle infinity values
            if math.isinf(value):
                value = -1000.0
            return value
    def visualize_results(self,texts,question,answer, important_ids,importance_scores, width = 200):
        #Only visualize top-K
        topk_ids,topk_scores = get_top_k(important_ids, importance_scores, self.K)
        plot_sentence_importance(question, texts, topk_ids, topk_scores, answer, width = width)

    def visualize_score_func_contribution(self,important_ids,importance_scores,ensemble_list):
        important_ids,importance_scores = get_top_k(important_ids, importance_scores, self.K)
    # Calculate the contribution of each score function
        score_func_contributions = {func: 0 for func in ensemble_list.keys()}
        for important_id in important_ids:
            max_score = 0
            for score_func in ensemble_list.keys():
                for id, score in ensemble_list[score_func]:
                    if id == important_id:
                        if score > max_score:
                            max_score = score
                            max_score_func = score_func
                        break  # Exit the loop once the id is found
            score_func_contributions[max_score_func] += 1

        plt.figure(figsize=(10, 6))
        bar_width = 0.3  # Set the bar width to be thinner
        plt.bar(score_func_contributions.keys(), score_func_contributions.values(), width=bar_width, color='skyblue')
        plt.xlabel('Score Function', fontsize=14)  # Increase font size
        plt.ylabel('Number of Important Texts', fontsize=14)  # Increase font size
        plt.title('Contribution of Each Score Function', fontsize=16)  # Increase font size
        plt.xticks(rotation=45, fontsize=13)  # Increase font size for x-ticks
        plt.yticks(fontsize=13)  # Increase font size for y-ticks
        plt.tight_layout()
        plt.show()

    def get_data_frame(self,texts,important_ids,importance_scores):
        important_ids,importance_scores = get_top_k(important_ids, importance_scores, self.K)
        data = {
            'Important Texts': [texts[id] for id in important_ids],
            'Important IDs': important_ids,
            'Importance Score': importance_scores
        }
        df = pd.DataFrame(data)
        df.style
        return df