File size: 7,763 Bytes
8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a9892f2 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d a72f911 8a17d1d 4d9c7c0 8a17d1d a72f911 8a17d1d a72f911 d03bad1 8a17d1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
"""
Script for joining .csv candidate data into a .duckdb results.
Launches a gradio app to review candidates
"""
import argparse
from pathlib import Path
import pandas as pd
from metrics import load_results
from utils import query_format_models, sha256_hash, get_completions, print_info, regex_compare
import numpy as np
import json
import ast
import gradio as gr
import re
from typing import List
SQL_QUERY = """
WITH AllResults AS (
SELECT
results.parent_dir AS model,
*
FROM
results.completions results
JOIN
challenges challenges
ON
results.prompt_id = challenges.ID
)
SELECT prompt_id, model, completion, answer as solution, prompt
FROM AllResults
WHERE
AllResults.model IN {models}
""".format(models=query_format_models(['r1','gemini2']))
def _parse(x):
if isinstance(x, str):
if len(x.strip()) == 0 or x.strip() in ["]","["]:
return [] # bad gen
else:
try:
return ast.literal_eval(x)
except:
raise ValueError(f"Bad gen: {x}")
elif np.isnan(x):
return []
else:
raise ValueError(f"Found unexpected type {type(x)}: {x}")
def _concat(series: pd.Series) -> np.array:
items = list(filter(lambda x: len(x) > 0, map(_parse, series)))
if len(items) > 0:
return np.unique(np.concatenate(items))
else:
return items
def check_candidates(candidates: pd.DataFrame, merged_df: pd.DataFrame):
"""
Perform a variety of sanity checks ie:
- all attempted answers are in the completion
"""
for _,row in merged_df.iterrows():
candidates = json.loads(row["candidates"])
comp = row["completion"].lower()
for c in candidates:
assert c.lower() in comp or regex_compare(c.lower(), comp), \
json.dumps({"candidate":c, "completion":row["completion"], "hash": row["_original_completion_hash"]}, indent=4)
def launch_app(df: pd.DataFrame, share_demo: bool = False):
# Define function to display table and toggle completion
def show_table(show_completion, example_idx):
# Extract the row based on the slider index
example = df.iloc[example_idx]
# Function to highlight words from the candidates list
def highlight_words(text, candidates, color="yellow"):
for word in candidates:
# Use word boundaries to ensure we only match whole words
text = re.sub(rf'\b({re.escape(word)})\b', r'<@>\1</@>', text, flags=re.IGNORECASE)
text = re.sub("<@>",f'<mark style="background-color:{color};">', text)
text = re.sub("</@>",'</mark>'.format(color=color), text)
return text
# Highlight words in the 'completion' column
candidates = json.loads(example['candidates'])
regex_candidates = json.loads(example['regex_candidates'])
highlighted_completion = highlight_words(example['completion'], candidates)
highlighted_regex_completion = highlight_words(example['completion'], regex_candidates, color="green")
# Create a table with the core columns
table_html = f"""
<table>
<tr><td><b>Completion hash</b></td><td>{example['_original_completion_hash']}</td></tr>
<tr><td><b>Model</b></td><td>{example['model']}</td></tr>
<tr><td><b>Prompt ID</b></td><td>{example['prompt_id']}</td></tr>
<tr><td><b>Solution</b></td><td>{example['solution']}</td></tr>
<tr><td><b>Prompt</b></td><td>{example['prompt']}</td></tr>
<tr><td><b>Candidates</b></td><td>{candidates}</td></tr>
<tr><td><b>Regex Candidates</b></td><td>{regex_candidates}</td></tr>
</table>
"""
# If the toggle is checked, show the 'completion' column with highlighted words
if "highlight_regex" in show_completion:
completion = highlighted_regex_completion
table_html += f"""
<br><b>Completion:</b><br>
<p>{completion}</p>
"""
if "highlight_candidates" in show_completion:
completion = highlighted_completion
table_html += f"""
<br><b>Completion:</b><br>
<p>{completion}</p>
"""
return table_html
# Create the Gradio interface
with gr.Blocks() as demo:
# Slider to navigate through examples
example_slider = gr.Slider(minimum=0, maximum=len(df)-1, step=1, label="Example", value=0)
# Toggle button for showing/hiding completion
toggle_button = gr.CheckboxGroup(["highlight_candidates", "highlight_regex"])
with gr.Row():
gr.HTML('<h1>Candidates Table</h1>')
# Table display
table_output = gr.HTML()
# Set interaction behavior: update the table when slider or checkbox changes
example_slider.change(show_table, inputs=[toggle_button, example_slider], outputs=[table_output])
toggle_button.input(show_table, inputs=[toggle_button, example_slider], outputs=[table_output])
# Launch the app
demo.launch(share=share_demo)
def _extract_candidates(row, do_regex: bool) -> str:
"""
Try to re-extract candidates assuming between quotes
"""
if do_regex:
# pattern = r'"(.+?)"|\*(.+?)\*'
pattern = r'"(.+?)"'
found_c = set([i.group(0)[1:-1] for i in re.finditer(pattern, row["completion"])])
return json.dumps(list(found_c))
elif np.isnan(candidates) or candidates == []:
candidates = re.findall(r'"(\w+)"', row["generated"])
return json.dumps(list(set(candidates)))
else:
return candidates
def main(candidates: Path, output_csv: Path, launch_gradio: bool, share_demo: bool, do_regex:bool):
if not output_csv.exists():
candidates = pd.read_csv(candidates.as_posix())
conn = load_results()
completions = conn.sql(SQL_QUERY).df()
candidates["candidates"] = candidates.apply(lambda x: _extract_candidates(x, False), axis=1)
candidates = candidates.groupby(["model","prompt_id","solution","prompt","_original_completion_hash"]).agg({
"candidates": "unique"
}).reset_index()
candidates["candidates"] = candidates["candidates"].apply(lambda x: json.dumps(list(_concat(x))))
completions["_original_completion_hash"] = completions["completion"].apply(sha256_hash)
print(completions["model"].value_counts())
print(candidates["model"].value_counts())
df = candidates.merge(completions, on=["model","prompt_id","prompt","solution","_original_completion_hash"])
print(df["model"].value_counts())
# check_candidates(candidates, df)
df.to_csv(output_csv)
else:
df = pd.read_csv(output_csv.as_posix())
df["regex_candidates"] = df.apply(lambda x: _extract_candidates(x, True), axis=1)
if launch_gradio:
df = df.sort_values(by="prompt_id")
launch_app(df, share_demo)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--candidates", type=Path, help="path to .csv data containing extracted candidates", default="data.csv")
parser.add_argument("--output_csv", type=Path, help="path to .csv output file; will reload from here if path exists", default="output.csv")
parser.add_argument("-gr","--launch_gradio", action="store_true")
parser.add_argument("-s", "--share_demo", action="store_true")
parser.add_argument("-r", "--do_regex", action="store_true")
args = parser.parse_args()
args.do_regex = True
args.launch_gradio = True
main(**vars(args)) |