|
""" |
|
Script for joining .csv candidate data into a .duckdb results. |
|
Launches a gradio app to review candidates |
|
""" |
|
import argparse |
|
from pathlib import Path |
|
import pandas as pd |
|
from metrics import load_results |
|
from utils import query_format_models, sha256_hash, get_completions, print_info, regex_compare |
|
import numpy as np |
|
import json |
|
import ast |
|
import gradio as gr |
|
import re |
|
from typing import List |
|
|
|
SQL_QUERY = """ |
|
WITH AllResults AS ( |
|
SELECT |
|
results.parent_dir AS model, |
|
* |
|
FROM |
|
results.completions results |
|
JOIN |
|
challenges challenges |
|
ON |
|
results.prompt_id = challenges.ID |
|
) |
|
SELECT prompt_id, model, completion, answer as solution, prompt |
|
FROM AllResults |
|
WHERE |
|
AllResults.model IN {models} |
|
""".format(models=query_format_models(['r1','gemini2'])) |
|
|
|
def _parse(x): |
|
if isinstance(x, str): |
|
if len(x.strip()) == 0 or x.strip() in ["]","["]: |
|
return [] |
|
else: |
|
try: |
|
return ast.literal_eval(x) |
|
except: |
|
raise ValueError(f"Bad gen: {x}") |
|
elif np.isnan(x): |
|
return [] |
|
else: |
|
raise ValueError(f"Found unexpected type {type(x)}: {x}") |
|
|
|
def _concat(series: pd.Series) -> np.array: |
|
items = list(filter(lambda x: len(x) > 0, map(_parse, series))) |
|
if len(items) > 0: |
|
return np.unique(np.concatenate(items)) |
|
else: |
|
return items |
|
|
|
def check_candidates(candidates: pd.DataFrame, merged_df: pd.DataFrame): |
|
""" |
|
Perform a variety of sanity checks ie: |
|
- all attempted answers are in the completion |
|
""" |
|
for _,row in merged_df.iterrows(): |
|
candidates = json.loads(row["candidates"]) |
|
comp = row["completion"].lower() |
|
for c in candidates: |
|
assert c.lower() in comp or regex_compare(c.lower(), comp), \ |
|
json.dumps({"candidate":c, "completion":row["completion"], "hash": row["_original_completion_hash"]}, indent=4) |
|
|
|
def launch_app(df: pd.DataFrame, share_demo: bool = False): |
|
|
|
|
|
def show_table(show_completion, example_idx): |
|
|
|
example = df.iloc[example_idx] |
|
|
|
|
|
def highlight_words(text, candidates, color="yellow"): |
|
for word in candidates: |
|
|
|
text = re.sub(rf'\b({re.escape(word)})\b', r'<@>\1</@>', text, flags=re.IGNORECASE) |
|
text = re.sub("<@>",f'<mark style="background-color:{color};">', text) |
|
text = re.sub("</@>",'</mark>'.format(color=color), text) |
|
return text |
|
|
|
|
|
candidates = json.loads(example['candidates']) |
|
regex_candidates = json.loads(example['regex_candidates']) |
|
highlighted_completion = highlight_words(example['completion'], candidates) |
|
highlighted_regex_completion = highlight_words(example['completion'], regex_candidates, color="green") |
|
|
|
table_html = f""" |
|
<table> |
|
<tr><td><b>Completion hash</b></td><td>{example['_original_completion_hash']}</td></tr> |
|
<tr><td><b>Model</b></td><td>{example['model']}</td></tr> |
|
<tr><td><b>Prompt ID</b></td><td>{example['prompt_id']}</td></tr> |
|
<tr><td><b>Solution</b></td><td>{example['solution']}</td></tr> |
|
<tr><td><b>Prompt</b></td><td>{example['prompt']}</td></tr> |
|
<tr><td><b>Candidates</b></td><td>{candidates}</td></tr> |
|
<tr><td><b>Regex Candidates</b></td><td>{regex_candidates}</td></tr> |
|
</table> |
|
""" |
|
|
|
|
|
if "highlight_regex" in show_completion: |
|
completion = highlighted_regex_completion |
|
table_html += f""" |
|
<br><b>Completion:</b><br> |
|
<p>{completion}</p> |
|
""" |
|
if "highlight_candidates" in show_completion: |
|
completion = highlighted_completion |
|
table_html += f""" |
|
<br><b>Completion:</b><br> |
|
<p>{completion}</p> |
|
""" |
|
|
|
return table_html |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
example_slider = gr.Slider(minimum=0, maximum=len(df)-1, step=1, label="Example", value=0) |
|
|
|
|
|
toggle_button = gr.CheckboxGroup(["highlight_candidates", "highlight_regex"]) |
|
|
|
with gr.Row(): |
|
gr.HTML('<h1>Candidates Table</h1>') |
|
|
|
|
|
table_output = gr.HTML() |
|
|
|
|
|
example_slider.change(show_table, inputs=[toggle_button, example_slider], outputs=[table_output]) |
|
toggle_button.input(show_table, inputs=[toggle_button, example_slider], outputs=[table_output]) |
|
|
|
|
|
demo.launch(share=share_demo) |
|
|
|
def _extract_candidates(row, do_regex: bool) -> str: |
|
""" |
|
Try to re-extract candidates assuming between quotes |
|
""" |
|
if do_regex: |
|
|
|
pattern = r'"(.+?)"' |
|
found_c = set([i.group(0)[1:-1] for i in re.finditer(pattern, row["completion"])]) |
|
return json.dumps(list(found_c)) |
|
elif np.isnan(candidates) or candidates == []: |
|
candidates = re.findall(r'"(\w+)"', row["generated"]) |
|
return json.dumps(list(set(candidates))) |
|
else: |
|
return candidates |
|
|
|
def main(candidates: Path, output_csv: Path, launch_gradio: bool, share_demo: bool, do_regex:bool): |
|
if not output_csv.exists(): |
|
candidates = pd.read_csv(candidates.as_posix()) |
|
conn = load_results() |
|
completions = conn.sql(SQL_QUERY).df() |
|
|
|
candidates["candidates"] = candidates.apply(lambda x: _extract_candidates(x, False), axis=1) |
|
candidates = candidates.groupby(["model","prompt_id","solution","prompt","_original_completion_hash"]).agg({ |
|
"candidates": "unique" |
|
}).reset_index() |
|
|
|
candidates["candidates"] = candidates["candidates"].apply(lambda x: json.dumps(list(_concat(x)))) |
|
completions["_original_completion_hash"] = completions["completion"].apply(sha256_hash) |
|
print(completions["model"].value_counts()) |
|
print(candidates["model"].value_counts()) |
|
df = candidates.merge(completions, on=["model","prompt_id","prompt","solution","_original_completion_hash"]) |
|
print(df["model"].value_counts()) |
|
|
|
df.to_csv(output_csv) |
|
else: |
|
df = pd.read_csv(output_csv.as_posix()) |
|
df["regex_candidates"] = df.apply(lambda x: _extract_candidates(x, True), axis=1) |
|
|
|
if launch_gradio: |
|
df = df.sort_values(by="prompt_id") |
|
launch_app(df, share_demo) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--candidates", type=Path, help="path to .csv data containing extracted candidates", default="data.csv") |
|
parser.add_argument("--output_csv", type=Path, help="path to .csv output file; will reload from here if path exists", default="output.csv") |
|
parser.add_argument("-gr","--launch_gradio", action="store_true") |
|
parser.add_argument("-s", "--share_demo", action="store_true") |
|
parser.add_argument("-r", "--do_regex", action="store_true") |
|
args = parser.parse_args() |
|
args.do_regex = True |
|
args.launch_gradio = True |
|
main(**vars(args)) |