File size: 5,692 Bytes
96b6673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from citekit.cite_modules.LLM import LLM
from citekit.cite_modules.augment_model import Retriever
from citekit.pipeline.pipeline import Pipeline, PIPELINE_OUTPUT,PIPELINE_DOC_CACHE
from citekit.prompt.prompt import Prompt, DocPrompt,ALCEDocPrompt,NewALCEVanillaPrompt
from citekit.Dataset.Dataset import PromptDataset
from citekit.evaluator.evaluator import DefaultEvaluator
from citekit.utils.utils import output_begin_with,output_end_with
import json
from parser import *
import argparse

def one_paragraph(text):
    paras = text.lstrip('\n').split('\n')
    if not paras:
        return ''
    else:
        return paras[0].rstrip('\n')
    
def cut_and_make_as(datakey):
    def f(passage):
        return {datakey:one_paragraph(passage)}
    return f

def if_output(x):
    return not (output_begin_with('check')(x))

def drop_end_and_output(x):
    x = one_paragraph(x)
    if x[-len('End.'):] == 'End.':
        x = x[:-len('End.')]
    if x[:len('output')] =='output' :
        x = x[len('output'):]
    if x[-len('End'):] == 'End':
        x = x[:-len('End')]
    return x


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_path", type=str, default='result.json', help="Path to the config file")
    parser.add_argument("--model", type=str, default='gpt-3.5-turbo-0301', help="model name or path")
    parser.add_argument("--shots", type=int, default=2, help="number of shots")
    parser.add_argument("--ndoc", type=int, default=5, help="number of docs")
    parser.add_argument("--pr", action='store_true', help="use cite PR")
    parser.add_argument("--rouge", action='store_true', help="use rouge")
    parser.add_argument("--mauve",  action='store_true', help="eval mauve")
    parser.add_argument("--qa", action='store_true', help="eval qa")
    parser.add_argument("--length", type=bool, default=True, help="eval length")
    parser.add_argument("--temp", type=float, default=0.5, help="temperature")
    parser.add_argument("--claims", action='store_true', help="eval claims")
    parser.add_argument("--qampari", type=str, default=False, help="eval qampari")
    parser.add_argument("--turns", type=int, default=6, help="turns")
    parser.add_argument("--dataset", type=str, default='data/asqa_eval_gtr_top100.json', help="dataset")
    parser.add_argument("--demo", type=str, default='prompts/asqa_default.json', help="demo")
    args = parser.parse_args()

    with open('data/asqa_eval_gtr_top100.json','r',encoding='utf-8') as file:
        dataset = json.load(file)
    with open('prompts/asqa_interact_doc_id.json','r',encoding='utf-8') as file:
        demo = json.load(file)
    documents = [DocPrompt().load_data(list(enumerate(data['docs'][:10])),Title = lambda data: data[1]['title'], Passage = lambda data: data[1]['text']) for data in dataset]

    llm_instruction = demo['instruction']
    dataset = PromptDataset(dataset,'question','answer','qa_pairs', extract = lambda data: ''.join(ALCEDocPrompt().default_load_data_extraction(data['docs'][:10])), docs = lambda data: ALCEDocPrompt().default_load_data(data['docs'][:args.ndoc]))[:100]
    shots = '\n'.join(NewALCEVanillaPrompt().load_data(demo['demos'][:args.shots], INST = lambda _:llm_instruction,
                                                  question = lambda data: data['question'], 
                                                  docs = lambda data:''.join(ALCEDocPrompt().default_load_data_extraction(demo['demos'][0]['docs'][:args.ndoc])),
                                                  answer = lambda data: '\n'.join(data['answer'])))
    
    # llm
    llm_prompt = Prompt(template='<shots><INST><question><extract><record><docs><forceAnswer>',components={'INST':'{INST}\n\n', 
                                                                                             'shots':'{shots}\n',
                                                                                        'question':'Question:{question}\n\n',
                                                                                        'extract':'{extract}\n',
                                                                                        'docs':'{docs}',
                                                                                        'record':'Answer:\n{record}',
                                                                                        'forceAnswer':'\n{forceAnswer}'})
    retriever_prompt = Prompt(template='<IDs>',components={'IDs':'{IDs}'})

    llm = LLM(model=args.model, prompt_maker=llm_prompt, self_prompt={'INST':llm_instruction, 'shots': shots+'\n','forceAnswer': 'Answer: \n'},stop=['\n\n'],max_turn=args.turns)

    eval = DefaultEvaluator(args)


    pipeline = Pipeline(llm = llm, head_prompt_maker=llm_prompt,evaluator = eval,dataset = dataset,save_path=args.save_path)
    retriever = Retriever(prompt_maker=retriever_prompt,pipeline=pipeline,topk=3, documents=documents)

    llm.set_target(retriever, output_begin_with('check'), post_processing=cut_and_make_as('IDs'))
    llm.set_target(llm,if_output, post_processing=lambda x: {'forceAnswer': Prompt.UNABLE})
    llm.add_to_head('record')
    llm.set_output(if_output,post_processing= drop_end_and_output, end=False)
    llm.set_output(output_end_with('End.'), post_processing=drop_end_and_output , end = True)
    llm.set_output(output_end_with('End'), post_processing=drop_end_and_output , end = True)

    retriever.set_target(llm ,post_processing=lambda input, output: {'docs': output, 'forceAnswer': 'Output:'})

    graph = PipelineGraph(pipeline)
    graph.visualize()
    #pipeline.run_on_dataset(datakeys=['question','extract'],init_docs='docs')