|
from citekit.cite_modules.LLM import LLM |
|
from citekit.cite_modules.augment_model import Retriever |
|
from citekit.pipeline.pipeline import Pipeline, PIPELINE_OUTPUT,PIPELINE_DOC_CACHE |
|
from citekit.prompt.prompt import Prompt, DocPrompt |
|
from citekit.Dataset.Dataset import PromptDataset |
|
from citekit.evaluator.evaluator import DefaultEvaluator |
|
from citekit.utils.utils import output_begin_with, make_as,output_end_with,one_paragraph,remove_citations |
|
import json |
|
import argparse |
|
import nltk |
|
import re |
|
|
|
def each_make_as(key): |
|
def function(output): |
|
sents = nltk.sent_tokenize(one_paragraph(output)) |
|
if len(sents)>3: |
|
sents = sents[:3] |
|
return [make_as(key)(sent) for sent in sents] |
|
return function |
|
|
|
def add_citation(ls): |
|
output = '' |
|
pattern = r'([.!?])\s*$' |
|
for i, answer in enumerate(ls): |
|
cite = f'[{i+1}]' |
|
answer = one_paragraph(answer) |
|
if not answer: |
|
return cite |
|
else: |
|
answer = re.sub(pattern, rf'{cite}\1 ', answer) |
|
if cite not in answer: |
|
answer += cite |
|
output += answer |
|
return output |
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--save_path", type=str, default='res.json', help="Path to the config file") |
|
parser.add_argument("--model", type=str, default='gpt-3.5-turbo', help="model name or path") |
|
parser.add_argument("--shots", type=int, default=2, help="number of shots") |
|
parser.add_argument("--ndoc", type=int, default=3, help="number of docs") |
|
parser.add_argument("--pr", action='store_true', help="use cite PR") |
|
parser.add_argument("--rouge", action='store_true', help="use rouge") |
|
parser.add_argument("--temp", type=float, default=0.5, help="temperature") |
|
parser.add_argument("--qa", action='store_true', help="eval qa") |
|
parser.add_argument("--mauve", action='store_true', help="eval mauve") |
|
parser.add_argument("--length", type=bool, default=True, help="eval length") |
|
parser.add_argument("--claims", action='store_true', help="eval length") |
|
parser.add_argument("--qampari", type=str, default=False, help="eval qampari") |
|
parser.add_argument("--dataset", type=str, default='data/asqa_eval_gtr_top100.json', help="dataset") |
|
parser.add_argument("--demo", type=str, default='prompts/asqa_default.json', help="demo") |
|
parser.add_argument("--add_cite", action='store_true', help="manuel add cite") |
|
parser.add_argument("--top_k", type=int, default=1, help="retrieve docs") |
|
args = parser.parse_args() |
|
|
|
|
|
file_path = args.dataset |
|
demo_path = args.demo |
|
with open(file_path,'r',encoding='utf-8') as file: |
|
dataset = json.load(file) |
|
with open(demo_path,'r',encoding='utf-8') as file: |
|
demo = json.load(file) |
|
|
|
|
|
documents = [DocPrompt().load_data(list(enumerate(data['docs'])),Title = lambda data: data[1]['title'], Passage = lambda data: data[1]['text']) for data in dataset] |
|
|
|
dataset =PromptDataset(dataset, 'question','answer','qa_pairs','answers','claims')[:200] |
|
|
|
llm_instruction = 'Instruction: Write an accurate, engaging, and concise answer for the given question. Use an unbiased and journalistic tone.' |
|
if args.add_cite: |
|
llm_instruction_after = 'Instruction: Revise and correct the answer to an accurate, engaging, and concise answer for the given question using only the provided document using only one sentence. Use an unbiased and journalistic tone. Your revised answer must contain only one short sentence.' |
|
else: |
|
llm_instruction_after = 'Instruction: Revise and correct the answer to an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents. Your revised answer must contain only one short sentence.' |
|
shots = '\n\n'.join(llm_instruction + '\n\nQuestion: '+ d['question']+'\n\nAnswer: '+remove_citations(d['answer']) for d in demo['demos'][:args.shots]) |
|
llm_prompt = Prompt(template='<shots><INST><question><docs><answer>\n\nAnswer: ',components={'INST':'{INST}\n\n', 'shots':'{shots}\n\n', 'question':'Question: {question}\n\n','docs':'{docs}', 'answer':'\nThis is the answer you should revise based on the provided document: \n{answer}'}) |
|
retriever_prompt = Prompt(template='<query>',components={'query':'{query}'}) |
|
|
|
|
|
llm = LLM(model=args.model, prompt_maker=llm_prompt, self_prompt={'INST':llm_instruction,'shots':shots},stop=['\n','\n\n']) |
|
eval = DefaultEvaluator(args) |
|
pipeline = Pipeline(llm = llm, head_prompt_maker=llm_prompt,evaluator = eval,dataset = dataset,save_path=args.save_path) |
|
retriever = Retriever(prompt_maker=retriever_prompt,pipeline=pipeline,retrieve_by='bm25',documents=documents,topk=args.top_k) |
|
llm.set_target(retriever,lambda self: self.turns == 1, post_processing=each_make_as('query')) |
|
if args.add_cite: |
|
llm.set_output(lambda self: self.turns > 1, post_processing = add_citation, end=True) |
|
else: |
|
llm.set_output(lambda self: self.turns > 1, post_processing = lambda ls: ''.join(map(one_paragraph,ls)), end=True) |
|
retriever.set_target(llm ,post_processing=lambda input, output: {'docs': output,'answer': input,'INST':llm_instruction_after,'shots':Prompt.UNABLE}) |
|
|
|
|
|
pipeline.run_on_dataset(datakeys=['question']) |
|
|
|
|
|
|