""" Script for joining .csv candidate data into a .duckdb results. Launches a gradio app to review candidates """ import argparse from pathlib import Path import pandas as pd from metrics import load_results from utils import query_format_models, sha256_hash, get_completions, print_info, regex_compare import numpy as np import json import ast import gradio as gr import re from typing import List SQL_QUERY = """ WITH AllResults AS ( SELECT results.parent_dir AS model, * FROM results.completions results JOIN challenges challenges ON results.prompt_id = challenges.ID ) SELECT prompt_id, model, completion, answer as solution, prompt FROM AllResults WHERE AllResults.model IN {models} """.format(models=query_format_models(['r1','gemini2'])) def _parse(x): if isinstance(x, str): if len(x.strip()) == 0 or x.strip() in ["]","["]: return [] # bad gen else: try: return ast.literal_eval(x) except: raise ValueError(f"Bad gen: {x}") elif np.isnan(x): return [] else: raise ValueError(f"Found unexpected type {type(x)}: {x}") def _concat(series: pd.Series) -> np.array: items = list(filter(lambda x: len(x) > 0, map(_parse, series))) if len(items) > 0: return np.unique(np.concatenate(items)) else: return items def check_candidates(candidates: pd.DataFrame, merged_df: pd.DataFrame): """ Perform a variety of sanity checks ie: - all attempted answers are in the completion """ for _,row in merged_df.iterrows(): candidates = json.loads(row["candidates"]) comp = row["completion"].lower() for c in candidates: assert c.lower() in comp or regex_compare(c.lower(), comp), \ json.dumps({"candidate":c, "completion":row["completion"], "hash": row["_original_completion_hash"]}, indent=4) def launch_app(df: pd.DataFrame, share_demo: bool = False): # Define function to display table and toggle completion def show_table(show_completion, example_idx): # Extract the row based on the slider index example = df.iloc[example_idx] # Function to highlight words from the candidates list def highlight_words(text, candidates, color="yellow"): for word in candidates: # Use word boundaries to ensure we only match whole words text = re.sub(rf'\b({re.escape(word)})\b', r'<@>\1', text, flags=re.IGNORECASE) text = re.sub("<@>",f'', text) text = re.sub("",''.format(color=color), text) return text # Highlight words in the 'completion' column candidates = json.loads(example['candidates']) regex_candidates = json.loads(example['regex_candidates']) highlighted_completion = highlight_words(example['completion'], candidates) highlighted_regex_completion = highlight_words(example['completion'], regex_candidates, color="green") # Create a table with the core columns table_html = f"""
Completion hash{example['_original_completion_hash']}
Model{example['model']}
Prompt ID{example['prompt_id']}
Solution{example['solution']}
Prompt{example['prompt']}
Candidates{candidates}
Regex Candidates{regex_candidates}
""" # If the toggle is checked, show the 'completion' column with highlighted words if "highlight_regex" in show_completion: completion = highlighted_regex_completion table_html += f"""
Completion:

{completion}

""" if "highlight_candidates" in show_completion: completion = highlighted_completion table_html += f"""
Completion:

{completion}

""" return table_html # Create the Gradio interface with gr.Blocks() as demo: # Slider to navigate through examples example_slider = gr.Slider(minimum=0, maximum=len(df)-1, step=1, label="Example", value=0) # Toggle button for showing/hiding completion toggle_button = gr.CheckboxGroup(["highlight_candidates", "highlight_regex"]) with gr.Row(): gr.HTML('

Candidates Table

') # Table display table_output = gr.HTML() # Set interaction behavior: update the table when slider or checkbox changes example_slider.change(show_table, inputs=[toggle_button, example_slider], outputs=[table_output]) toggle_button.input(show_table, inputs=[toggle_button, example_slider], outputs=[table_output]) # Launch the app demo.launch(share=share_demo) def _extract_candidates(row, do_regex: bool) -> str: """ Try to re-extract candidates assuming between quotes """ if do_regex: # pattern = r'"(.+?)"|\*(.+?)\*' pattern = r'"(.+?)"' found_c = set([i.group(0)[1:-1] for i in re.finditer(pattern, row["completion"])]) return json.dumps(list(found_c)) elif np.isnan(candidates) or candidates == []: candidates = re.findall(r'"(\w+)"', row["generated"]) return json.dumps(list(set(candidates))) else: return candidates def main(candidates: Path, output_csv: Path, launch_gradio: bool, share_demo: bool, do_regex:bool): if not output_csv.exists(): candidates = pd.read_csv(candidates.as_posix()) conn = load_results() completions = conn.sql(SQL_QUERY).df() candidates["candidates"] = candidates.apply(lambda x: _extract_candidates(x, False), axis=1) candidates = candidates.groupby(["model","prompt_id","solution","prompt","_original_completion_hash"]).agg({ "candidates": "unique" }).reset_index() candidates["candidates"] = candidates["candidates"].apply(lambda x: json.dumps(list(_concat(x)))) completions["_original_completion_hash"] = completions["completion"].apply(sha256_hash) print(completions["model"].value_counts()) print(candidates["model"].value_counts()) df = candidates.merge(completions, on=["model","prompt_id","prompt","solution","_original_completion_hash"]) print(df["model"].value_counts()) # check_candidates(candidates, df) df.to_csv(output_csv) else: df = pd.read_csv(output_csv.as_posix()) df["regex_candidates"] = df.apply(lambda x: _extract_candidates(x, True), axis=1) if launch_gradio: df = df.sort_values(by="prompt_id") launch_app(df, share_demo) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--candidates", type=Path, help="path to .csv data containing extracted candidates", default="data.csv") parser.add_argument("--output_csv", type=Path, help="path to .csv output file; will reload from here if path exists", default="output.csv") parser.add_argument("-gr","--launch_gradio", action="store_true") parser.add_argument("-s", "--share_demo", action="store_true") parser.add_argument("-r", "--do_regex", action="store_true") args = parser.parse_args() args.do_regex = True args.launch_gradio = True main(**vars(args))