File size: 7,763 Bytes
8a17d1d
 
 
 
 
 
 
 
a72f911
8a17d1d
 
 
 
 
a72f911
8a17d1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a72f911
8a17d1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a72f911
8a17d1d
 
 
 
 
 
 
 
 
 
a72f911
8a17d1d
 
a9892f2
 
 
8a17d1d
 
 
 
a72f911
8a17d1d
a72f911
8a17d1d
 
 
 
 
 
 
 
 
a72f911
8a17d1d
 
 
a72f911
 
 
8a17d1d
 
a72f911
8a17d1d
a72f911
 
 
 
 
 
8a17d1d
 
 
 
 
 
 
 
 
a72f911
8a17d1d
 
 
 
 
 
 
 
 
a72f911
8a17d1d
 
 
 
a72f911
 
 
 
 
 
 
 
 
 
 
 
 
 
8a17d1d
a72f911
8a17d1d
 
 
 
 
a72f911
8a17d1d
 
 
a72f911
8a17d1d
a72f911
 
 
8a17d1d
a72f911
8a17d1d
 
 
 
a72f911
 
8a17d1d
a72f911
8a17d1d
 
 
 
4d9c7c0
 
8a17d1d
 
a72f911
8a17d1d
a72f911
d03bad1
8a17d1d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
Script for joining .csv candidate data into a .duckdb results.
Launches a gradio app to review candidates
"""
import argparse
from pathlib import Path
import pandas as pd
from metrics import load_results
from utils import query_format_models, sha256_hash, get_completions, print_info, regex_compare
import numpy as np
import json
import ast
import gradio as gr
import re
from typing import List

SQL_QUERY = """
WITH AllResults AS (
    SELECT 
        results.parent_dir AS model,
        *
    FROM 
        results.completions results
    JOIN 
        challenges challenges
    ON 
        results.prompt_id = challenges.ID
)
SELECT prompt_id, model, completion, answer as solution, prompt
FROM AllResults                    
WHERE
    AllResults.model IN {models}
""".format(models=query_format_models(['r1','gemini2']))

def _parse(x):
    if isinstance(x, str):
        if len(x.strip()) == 0 or x.strip() in ["]","["]:
            return [] # bad gen
        else:
            try:
                return ast.literal_eval(x)
            except:
                raise ValueError(f"Bad gen: {x}")
    elif np.isnan(x):
        return []
    else:
        raise ValueError(f"Found unexpected type {type(x)}: {x}")

def _concat(series: pd.Series) -> np.array:
    items = list(filter(lambda x: len(x) > 0, map(_parse, series)))
    if len(items) > 0:
        return np.unique(np.concatenate(items))
    else:
        return items

def check_candidates(candidates: pd.DataFrame, merged_df: pd.DataFrame):
    """
    Perform a variety of sanity checks ie:
    - all attempted answers are in the completion
    """
    for _,row in merged_df.iterrows():
        candidates = json.loads(row["candidates"])
        comp = row["completion"].lower()
        for c in candidates:
            assert c.lower() in comp or regex_compare(c.lower(), comp), \
                json.dumps({"candidate":c, "completion":row["completion"], "hash": row["_original_completion_hash"]}, indent=4)
    
def launch_app(df: pd.DataFrame, share_demo: bool = False):
    
    # Define function to display table and toggle completion
    def show_table(show_completion, example_idx):
        # Extract the row based on the slider index
        example = df.iloc[example_idx]

        # Function to highlight words from the candidates list
        def highlight_words(text, candidates, color="yellow"):
            for word in candidates:
                # Use word boundaries to ensure we only match whole words
                text = re.sub(rf'\b({re.escape(word)})\b', r'<@>\1</@>', text, flags=re.IGNORECASE)
            text = re.sub("<@>",f'<mark style="background-color:{color};">', text)
            text = re.sub("</@>",'</mark>'.format(color=color), text)
            return text

        # Highlight words in the 'completion' column
        candidates = json.loads(example['candidates'])
        regex_candidates = json.loads(example['regex_candidates'])
        highlighted_completion = highlight_words(example['completion'], candidates)
        highlighted_regex_completion = highlight_words(example['completion'], regex_candidates, color="green")
        # Create a table with the core columns
        table_html = f"""
        <table>
            <tr><td><b>Completion hash</b></td><td>{example['_original_completion_hash']}</td></tr>
            <tr><td><b>Model</b></td><td>{example['model']}</td></tr>
            <tr><td><b>Prompt ID</b></td><td>{example['prompt_id']}</td></tr>
            <tr><td><b>Solution</b></td><td>{example['solution']}</td></tr>
            <tr><td><b>Prompt</b></td><td>{example['prompt']}</td></tr>
            <tr><td><b>Candidates</b></td><td>{candidates}</td></tr>
            <tr><td><b>Regex Candidates</b></td><td>{regex_candidates}</td></tr>
        </table>
        """

        # If the toggle is checked, show the 'completion' column with highlighted words   
        if "highlight_regex" in show_completion:
            completion = highlighted_regex_completion
            table_html += f"""
            <br><b>Completion:</b><br>
            <p>{completion}</p>
            """
        if "highlight_candidates" in show_completion:
            completion = highlighted_completion
            table_html += f"""
                <br><b>Completion:</b><br>
                <p>{completion}</p>
                """
        
        return table_html

    # Create the Gradio interface
    with gr.Blocks() as demo:
        # Slider to navigate through examples
        example_slider = gr.Slider(minimum=0, maximum=len(df)-1, step=1, label="Example", value=0)

        # Toggle button for showing/hiding completion
        toggle_button = gr.CheckboxGroup(["highlight_candidates", "highlight_regex"])

        with gr.Row():
            gr.HTML('<h1>Candidates Table</h1>')

        # Table display
        table_output = gr.HTML()

        # Set interaction behavior: update the table when slider or checkbox changes
        example_slider.change(show_table, inputs=[toggle_button, example_slider], outputs=[table_output])
        toggle_button.input(show_table, inputs=[toggle_button, example_slider], outputs=[table_output])

    # Launch the app
    demo.launch(share=share_demo)

def _extract_candidates(row, do_regex: bool) -> str:
    """
    Try to re-extract candidates assuming between quotes
    """
    if do_regex:
        # pattern = r'"(.+?)"|\*(.+?)\*'
        pattern = r'"(.+?)"'
        found_c = set([i.group(0)[1:-1] for i in re.finditer(pattern, row["completion"])])
        return json.dumps(list(found_c))
    elif np.isnan(candidates) or candidates == []:
        candidates = re.findall(r'"(\w+)"', row["generated"])
        return json.dumps(list(set(candidates)))
    else:
        return candidates

def main(candidates: Path, output_csv: Path, launch_gradio: bool, share_demo: bool, do_regex:bool):
    if not output_csv.exists():
        candidates = pd.read_csv(candidates.as_posix())
        conn = load_results()
        completions = conn.sql(SQL_QUERY).df()
        
        candidates["candidates"] = candidates.apply(lambda x: _extract_candidates(x, False), axis=1)
        candidates = candidates.groupby(["model","prompt_id","solution","prompt","_original_completion_hash"]).agg({
            "candidates": "unique"
        }).reset_index()
        
        candidates["candidates"] = candidates["candidates"].apply(lambda x: json.dumps(list(_concat(x))))
        completions["_original_completion_hash"] = completions["completion"].apply(sha256_hash)
        print(completions["model"].value_counts())
        print(candidates["model"].value_counts())
        df = candidates.merge(completions, on=["model","prompt_id","prompt","solution","_original_completion_hash"])
        print(df["model"].value_counts())
        # check_candidates(candidates, df)
        df.to_csv(output_csv)
    else:
        df = pd.read_csv(output_csv.as_posix())
        df["regex_candidates"] = df.apply(lambda x: _extract_candidates(x, True), axis=1)
    
    if launch_gradio:
        df = df.sort_values(by="prompt_id")
        launch_app(df, share_demo)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--candidates", type=Path, help="path to .csv data containing extracted candidates", default="data.csv")
    parser.add_argument("--output_csv", type=Path, help="path to .csv output file; will reload from here if path exists", default="output.csv")
    parser.add_argument("-gr","--launch_gradio", action="store_true")
    parser.add_argument("-s", "--share_demo", action="store_true")
    parser.add_argument("-r", "--do_regex", action="store_true")
    args = parser.parse_args()
    args.do_regex = True
    args.launch_gradio = True
    main(**vars(args))