Other
English
minecraft
action prediction
File size: 4,263 Bytes
fe11772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bcfbaa
fe11772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import json
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from datasets import load_dataset
from tqdm import tqdm

device_map = "auto"
model = AutoModelForCausalLM.from_pretrained(
    "/path/to/llamipa/adapter",
    return_dict=True,
    torch_dtype=torch.float16,
    device_map=device_map)

tokenizer = AutoTokenizer.from_pretrained("/tmpdir/thompson/Meta-Llama-3-8B/",add_eos_token=True)

tokenizer.pad_token_id = tokenizer.eos_token_id + 1
tokenizer.padding_side = "right"

pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=100)
print("Device:",model.hf_device_map)

test_dataset = load_dataset("json", data_files={'test':'/path/to/parser_test_15_incremental.jsonl'})["test"]  

def check_endpoints(struct, head):
    """
    takes a struct string and a head int and returns only 
    the struct rels with sources that are >= head
    """
    new_rels_list = []
    new_rels = None
    if struct:
        rels = struct.split(' ')
        for rel in rels:
            if len(rel) > 0:
                source = int(rel.split('(')[1].split(',')[0].strip())
                if source >= head:
                    new_rels_list.append(rel)
        if len(new_rels_list) > 0:
            new_rels = ' '.join(new_rels_list)
    return new_rels

def add_previous(sample, previous, predictions):
    new_output = []
    keep_str = None
    #get head
    slist = sample.split('\n')
    head = int(slist[0].split('Context:')[1].split('<')[0].strip())
    # check current structure
    for s in slist:
        if s.startswith('Structure:'):
            new_structure = check_endpoints(previous, head)
            if new_structure:
                s = 'Structure: ' + new_structure + ' ' + predictions
                keep_str = new_structure + ' ' + predictions
            else:
                s = 'Structure: ' + predictions
                keep_str = predictions
        new_output.append(s)
    new_output_string = '\n'.join(new_output)
    return keep_str, new_output_string

def format_gen(preds):
    labels = ['COM','CONTR','CORR','QAP','ACK','ELAB','CLARIFQ','COND','CONTIN',
              'RES','EXPL','QELAB','ALT','NARR','CONFQ','SEQ']
    split_list = [st.strip() for st in preds.split(' ')]
    clean_list = []
    for a in split_list:
        s_tuple = None
        rel = None
        try:
            s = a.split('(')[1].split(')')[0].split(',')
            r = a.split('(')[0].strip()
        except IndexError:
            print('split error one')
        else:
            try:
                s_tuple = (int(s[0]), int(s[1]))
            except IndexError:
                print('split error two')
            except ValueError:
                print('value error three')
            if r in labels:
                #make sure the label is well-formed 
                rel = r
        if rel != None and s_tuple != None:
            clean_list.append(rel + '(' + str(s_tuple[0]) + ',' + str(s_tuple[1]) + ')')
    clean_preds = ' '.join(clean_list)
    return clean_preds


def formatting_prompts_func(example):
    output_text = '<|begin_of_text|>Identify the discourse structure (DS) for the new turn in the following excerpt :\n ' + example + '\n ### DS:'
    return output_text

f = open("/path/to/test-output-file.txt","w")

new_generations = None
previous_generations = None
new_dialogue = 0
for datum in tqdm(test_dataset['sample']):
    if datum.startswith('NEW DIALOGUE'):
        print("cond satisfied")
        new_dialogue = 1
        continue
    #figure out if it's a first example
    if new_dialogue:
        text = formatting_prompts_func(datum)
        previous_generations = None
        new_dialogue = 0
    else:
        #need to make sure head edu and relations match up
        update_prev, amended_text = add_previous(datum, previous_generations, new_generations)
        previous_generations = update_prev
        text = formatting_prompts_func(amended_text)
    #print(text)
    generated = pipe(text)[0]['generated_text']
    # generated = 'CORR(8,9) CLARIFQ(9,10) QAP(10,11) COM(10,12) ELAB(12,13) CORR(8,14) RES(11,14) CONFQ(14,15)'
    print(generated, file=f)
    new_generations = format_gen(generated.split('### DS:')[1])

    
f.close()