franlucc's picture
fix html highlight
a9892f2
"""
Script for joining .csv candidate data into a .duckdb results.
Launches a gradio app to review candidates
"""
import argparse
from pathlib import Path
import pandas as pd
from metrics import load_results
from utils import query_format_models, sha256_hash, get_completions, print_info, regex_compare
import numpy as np
import json
import ast
import gradio as gr
import re
from typing import List
SQL_QUERY = """
WITH AllResults AS (
SELECT
results.parent_dir AS model,
*
FROM
results.completions results
JOIN
challenges challenges
ON
results.prompt_id = challenges.ID
)
SELECT prompt_id, model, completion, answer as solution, prompt
FROM AllResults
WHERE
AllResults.model IN {models}
""".format(models=query_format_models(['r1','gemini2']))
def _parse(x):
if isinstance(x, str):
if len(x.strip()) == 0 or x.strip() in ["]","["]:
return [] # bad gen
else:
try:
return ast.literal_eval(x)
except:
raise ValueError(f"Bad gen: {x}")
elif np.isnan(x):
return []
else:
raise ValueError(f"Found unexpected type {type(x)}: {x}")
def _concat(series: pd.Series) -> np.array:
items = list(filter(lambda x: len(x) > 0, map(_parse, series)))
if len(items) > 0:
return np.unique(np.concatenate(items))
else:
return items
def check_candidates(candidates: pd.DataFrame, merged_df: pd.DataFrame):
"""
Perform a variety of sanity checks ie:
- all attempted answers are in the completion
"""
for _,row in merged_df.iterrows():
candidates = json.loads(row["candidates"])
comp = row["completion"].lower()
for c in candidates:
assert c.lower() in comp or regex_compare(c.lower(), comp), \
json.dumps({"candidate":c, "completion":row["completion"], "hash": row["_original_completion_hash"]}, indent=4)
def launch_app(df: pd.DataFrame, share_demo: bool = False):
# Define function to display table and toggle completion
def show_table(show_completion, example_idx):
# Extract the row based on the slider index
example = df.iloc[example_idx]
# Function to highlight words from the candidates list
def highlight_words(text, candidates, color="yellow"):
for word in candidates:
# Use word boundaries to ensure we only match whole words
text = re.sub(rf'\b({re.escape(word)})\b', r'<@>\1</@>', text, flags=re.IGNORECASE)
text = re.sub("<@>",f'<mark style="background-color:{color};">', text)
text = re.sub("</@>",'</mark>'.format(color=color), text)
return text
# Highlight words in the 'completion' column
candidates = json.loads(example['candidates'])
regex_candidates = json.loads(example['regex_candidates'])
highlighted_completion = highlight_words(example['completion'], candidates)
highlighted_regex_completion = highlight_words(example['completion'], regex_candidates, color="green")
# Create a table with the core columns
table_html = f"""
<table>
<tr><td><b>Completion hash</b></td><td>{example['_original_completion_hash']}</td></tr>
<tr><td><b>Model</b></td><td>{example['model']}</td></tr>
<tr><td><b>Prompt ID</b></td><td>{example['prompt_id']}</td></tr>
<tr><td><b>Solution</b></td><td>{example['solution']}</td></tr>
<tr><td><b>Prompt</b></td><td>{example['prompt']}</td></tr>
<tr><td><b>Candidates</b></td><td>{candidates}</td></tr>
<tr><td><b>Regex Candidates</b></td><td>{regex_candidates}</td></tr>
</table>
"""
# If the toggle is checked, show the 'completion' column with highlighted words
if "highlight_regex" in show_completion:
completion = highlighted_regex_completion
table_html += f"""
<br><b>Completion:</b><br>
<p>{completion}</p>
"""
if "highlight_candidates" in show_completion:
completion = highlighted_completion
table_html += f"""
<br><b>Completion:</b><br>
<p>{completion}</p>
"""
return table_html
# Create the Gradio interface
with gr.Blocks() as demo:
# Slider to navigate through examples
example_slider = gr.Slider(minimum=0, maximum=len(df)-1, step=1, label="Example", value=0)
# Toggle button for showing/hiding completion
toggle_button = gr.CheckboxGroup(["highlight_candidates", "highlight_regex"])
with gr.Row():
gr.HTML('<h1>Candidates Table</h1>')
# Table display
table_output = gr.HTML()
# Set interaction behavior: update the table when slider or checkbox changes
example_slider.change(show_table, inputs=[toggle_button, example_slider], outputs=[table_output])
toggle_button.input(show_table, inputs=[toggle_button, example_slider], outputs=[table_output])
# Launch the app
demo.launch(share=share_demo)
def _extract_candidates(row, do_regex: bool) -> str:
"""
Try to re-extract candidates assuming between quotes
"""
if do_regex:
# pattern = r'"(.+?)"|\*(.+?)\*'
pattern = r'"(.+?)"'
found_c = set([i.group(0)[1:-1] for i in re.finditer(pattern, row["completion"])])
return json.dumps(list(found_c))
elif np.isnan(candidates) or candidates == []:
candidates = re.findall(r'"(\w+)"', row["generated"])
return json.dumps(list(set(candidates)))
else:
return candidates
def main(candidates: Path, output_csv: Path, launch_gradio: bool, share_demo: bool, do_regex:bool):
if not output_csv.exists():
candidates = pd.read_csv(candidates.as_posix())
conn = load_results()
completions = conn.sql(SQL_QUERY).df()
candidates["candidates"] = candidates.apply(lambda x: _extract_candidates(x, False), axis=1)
candidates = candidates.groupby(["model","prompt_id","solution","prompt","_original_completion_hash"]).agg({
"candidates": "unique"
}).reset_index()
candidates["candidates"] = candidates["candidates"].apply(lambda x: json.dumps(list(_concat(x))))
completions["_original_completion_hash"] = completions["completion"].apply(sha256_hash)
print(completions["model"].value_counts())
print(candidates["model"].value_counts())
df = candidates.merge(completions, on=["model","prompt_id","prompt","solution","_original_completion_hash"])
print(df["model"].value_counts())
# check_candidates(candidates, df)
df.to_csv(output_csv)
else:
df = pd.read_csv(output_csv.as_posix())
df["regex_candidates"] = df.apply(lambda x: _extract_candidates(x, True), axis=1)
if launch_gradio:
df = df.sort_values(by="prompt_id")
launch_app(df, share_demo)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--candidates", type=Path, help="path to .csv data containing extracted candidates", default="data.csv")
parser.add_argument("--output_csv", type=Path, help="path to .csv output file; will reload from here if path exists", default="output.csv")
parser.add_argument("-gr","--launch_gradio", action="store_true")
parser.add_argument("-s", "--share_demo", action="store_true")
parser.add_argument("-r", "--do_regex", action="store_true")
args = parser.parse_args()
args.do_regex = True
args.launch_gradio = True
main(**vars(args))