Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,642 Bytes
f214f36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from src.prompts import wrap_prompt_self_citation
from src.utils import *
import time
from src.models import create_model
from .attribute import *
import copy
class SelfCitationAttribution(Attribution):
def __init__(self, llm, explanation_level,K=5,self_citation_model = "self",verbose = 1):
super().__init__(llm,explanation_level,K,verbose)
if "gpt" not in llm.name:
self.model = llm.model
self.tokenizer = llm.tokenizer
else:
self.model = llm
if self_citation_model == "self":
self.explainer = self.llm
else:
self.explainer = create_model(f'model_configs/{self.self_citation_model}_config.json')
def attribute(self, question:str, contexts:list, answer:str):
def remove_numbered_patterns(input_string):
# Define the pattern to be removed, where \d+ matches one or more digits
pattern = r'\[\d+\]'
# Use re.sub() to replace all occurrences of the pattern with an empty string
result = re.sub(pattern, '', input_string)
result = result.replace('\n', '')
return result
def extract_numbers_in_order(input_string):
# Define the pattern to match numbers within square brackets
pattern = r'\[(\d+)\]'
# Use re.findall() to find all occurrences of the pattern and extract the numbers
numbers = re.findall(pattern, input_string)
# Convert the list of strings to a list of integers
numbers = [int(num) for num in numbers]
return numbers
"""
Given question, contexts and answer, return attribution results
"""
start_time = time.time()
texts = split_context(self.explanation_level,contexts)
citation_texts = copy.deepcopy(texts)
for i,sentence in enumerate(citation_texts):
#clean up existing numbered patterns
sentence = remove_numbered_patterns(sentence)
citation_texts[i]=f"[{str(i)}]: "+sentence
prompt = wrap_prompt_self_citation(question, citation_texts,answer)
start_time = time.time()
self_citation = self.explainer.query(prompt)
end_time = time.time()
print("Self Citation: ", self_citation)
important_ids = extract_numbers_in_order(self_citation)
important_ids = [i for i in important_ids if i < len(citation_texts)]
print("Important ids: ", important_ids)
importance_scores = list(range(len(important_ids), 0, -1))
return texts,important_ids, importance_scores, end_time-start_time,None |