Spaces:
Sleeping
Sleeping
import json | |
import re | |
import string | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional | |
import logging | |
from tabulate import tabulate | |
logger = logging.getLogger(__name__) | |
def normalize_str(input_str, remove_punct=True) -> str: | |
no_spaces = re.sub(r"\s", "", input_str) | |
if remove_punct: | |
translator = str.maketrans("", "", string.punctuation) | |
return no_spaces.lower().translate(translator) | |
else: | |
return no_spaces.lower() | |
def split_string(s: str, char_list: Optional[List[str]] = None) -> list[str]: | |
if char_list is None: | |
char_list = [",", ";"] | |
pattern = f"[{''.join(char_list)}]" | |
return re.split(pattern, s) | |
def normalize_number_str(number_str: str) -> float: | |
for char in ["$", "%", ","]: | |
number_str = number_str.replace(char, "") | |
try: | |
return float(number_str) | |
except ValueError: | |
logger.error(f"String {number_str} cannot be normalized to number str.") | |
return float("inf") | |
def question_scorer(model_answer: str, ground_truth: str) -> bool: | |
def is_float(element: Any) -> bool: | |
try: | |
float(element) | |
return True | |
except ValueError: | |
return False | |
try: | |
if is_float(ground_truth): | |
logger.info(f"Evaluating {model_answer} as a number.") | |
normalized_answer = normalize_number_str(model_answer) | |
return normalized_answer == float(ground_truth) | |
elif any(char in ground_truth for char in [",", ";"]): | |
logger.info(f"Evaluating {model_answer} as a comma separated list.") | |
gt_elems = split_string(ground_truth) | |
ma_elems = split_string(model_answer) | |
if len(gt_elems) != len(ma_elems): | |
logger.warning("Answer lists have different lengths, returning False.") | |
return False | |
comparisons = [] | |
for ma_elem, gt_elem in zip(ma_elems, gt_elems): | |
if is_float(gt_elem): | |
normalized_ma_elem = normalize_number_str(ma_elem) | |
comparisons.append(normalized_ma_elem == float(gt_elem)) | |
else: | |
ma_elem = normalize_str(ma_elem, remove_punct=False) | |
gt_elem = normalize_str(gt_elem, remove_punct=False) | |
comparisons.append(ma_elem == gt_elem) | |
return all(comparisons) | |
else: | |
logger.info(f"Evaluating {model_answer} as a string.") | |
ma_elem = normalize_str(model_answer) | |
gt_elem = normalize_str(ground_truth) | |
return ma_elem == gt_elem | |
except Exception as e: | |
logger.error(f"Error during evaluation: {e}") | |
return False | |
def load_dataset_meta(path: str, split: str = "validation"): | |
data_dir = Path(path) / split | |
dataset = [] | |
with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf: | |
lines = metaf.readlines() | |
for line in lines: | |
data = json.loads(line) | |
if data["task_id"] == "0-0-0-0-0": | |
continue | |
if data["file_name"]: | |
data["file_name"] = data_dir / data["file_name"] | |
dataset.append(data) | |
return dataset | |
def load_dataset_meta_dict(path: str, split: str = "validation"): | |
data_dir = Path(path) / split | |
dataset = {} | |
with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf: | |
lines = metaf.readlines() | |
for line in lines: | |
data = json.loads(line) | |
if data["task_id"] == "0-0-0-0-0": | |
continue | |
if data["file_name"]: | |
data["file_name"] = data_dir / data["file_name"] | |
dataset[data["task_id"]] = data | |
return dataset | |
def add_file_path( | |
task: Dict[str, Any], file_path: str = "./gaia_dataset", split: str = "validation" | |
): | |
if task["file_name"]: | |
file_path = Path(f"{file_path}/{split}") / task["file_name"] | |
if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]: | |
task["Question"] += f" Here are the necessary document files: {file_path}" | |
elif file_path.suffix in [".jpg", ".jpeg", ".png"]: | |
task["Question"] += f" Here are the necessary image files: {file_path}" | |
elif file_path.suffix in [".xlsx", "xls", ".csv"]: | |
task["Question"] += ( | |
f" Here are the necessary table files: {file_path}, for processing excel file," | |
" you can use the excel tool or write python code to process the file" | |
" step-by-step and get the information." | |
) | |
elif file_path.suffix in [".py"]: | |
task["Question"] += f" Here are the necessary python files: {file_path}" | |
else: | |
task["Question"] += f" Here are the necessary files: {file_path}" | |
return task | |
def report_results(entries): | |
# Initialize counters | |
total_entries = len(entries) | |
total_correct = 0 | |
# Initialize level statistics | |
level_stats = {} | |
# Process each entry | |
for entry in entries: | |
level = entry.get("level") | |
is_correct = entry.get("is_correct", False) | |
# Initialize level stats if not already present | |
if level not in level_stats: | |
level_stats[level] = {"total": 0, "correct": 0, "accuracy": 0} | |
# Update counters | |
level_stats[level]["total"] += 1 | |
if is_correct: | |
total_correct += 1 | |
level_stats[level]["correct"] += 1 | |
# Calculate accuracy for each level | |
for level, stats in level_stats.items(): | |
if stats["total"] > 0: | |
stats["accuracy"] = (stats["correct"] / stats["total"]) * 100 | |
# Print overall statistics with colorful logging | |
logger.info("Overall Statistics:") | |
overall_accuracy = (total_correct / total_entries) * 100 | |
# Create overall statistics table | |
overall_table = [ | |
["Total Entries", total_entries], | |
["Total Correct", total_correct], | |
["Overall Accuracy", f"{overall_accuracy:.2f}%"], | |
] | |
logger.success(tabulate(overall_table, tablefmt="grid")) | |
logger.info("") | |
# Create level statistics table | |
logger.info("Statistics by Level:") | |
level_table = [] | |
headers = ["Level", "Total Entries", "Correct Answers", "Accuracy"] | |
for level in sorted(level_stats.keys()): | |
stats = level_stats[level] | |
level_table.append( | |
[level, stats["total"], stats["correct"], f"{stats['accuracy']:.2f}%"] | |
) | |
logger.success(tabulate(level_table, headers=headers, tablefmt="grid")) | |