franlucc's picture
add missing utils
1a9dcdb
import pandas as pd
from hashlib import sha256
from typing import List, Tuple, Dict, Any
import math
import re
EXTRACTION_PROMPT = "All attempted answers, correct and incorrect"
def regex_compare(a: str, b: str) -> bool:
"""
Compare all alphanum chars in a and b
"""
a_chars = "".join(re.findall(r'\w', a))
b_chars = "".join(re.findall(r'\w', b))
return a_chars == b_chars or a_chars in b_chars
def print_info(db_connection):
tables = db_connection.execute("SHOW TABLES").fetchall()
# Iterate over each table and print its name and columns
for table in tables:
table_name = table[0]
print(f"Table: {table_name}")
# Get the columns for this table
columns = db_connection.execute(f"DESCRIBE {table_name}").fetchall()
# Print the column details
for column in columns:
print(f" - {column[0]} ({column[1]})") # column[0] is the column name, column[1] is the data type
print() # Add a blank line between tables for readability
def query_format_models(models: List[str]) -> str:
"""
Format model names for the SQL query `WHERE <this_model> IN <models>
"""
return "('" + "','".join(["completions-"+m for m in models]) + "')"
def get_completions(db_connector, query: str, **query_kwargs) -> pd.DataFrame:
"""
If model has multiple completions, use only first.
"""
df = db_connector.sql(query.format(**query_kwargs)).df()
df = df.groupby(["prompt_id", "model", "solution", "prompt"]).agg({"completion":"first"}).reset_index()
return df
def sha256_hash(text: str) -> str:
return sha256(bytes(text, "utf-8")).hexdigest()