import json import re import string from pathlib import Path from typing import Any, Dict, List, Optional import logging from tabulate import tabulate logger = logging.getLogger(__name__) def normalize_str(input_str, remove_punct=True) -> str: no_spaces = re.sub(r"\s", "", input_str) if remove_punct: translator = str.maketrans("", "", string.punctuation) return no_spaces.lower().translate(translator) else: return no_spaces.lower() def split_string(s: str, char_list: Optional[List[str]] = None) -> list[str]: if char_list is None: char_list = [",", ";"] pattern = f"[{''.join(char_list)}]" return re.split(pattern, s) def normalize_number_str(number_str: str) -> float: for char in ["$", "%", ","]: number_str = number_str.replace(char, "") try: return float(number_str) except ValueError: logger.error(f"String {number_str} cannot be normalized to number str.") return float("inf") def question_scorer(model_answer: str, ground_truth: str) -> bool: def is_float(element: Any) -> bool: try: float(element) return True except ValueError: return False try: if is_float(ground_truth): logger.info(f"Evaluating {model_answer} as a number.") normalized_answer = normalize_number_str(model_answer) return normalized_answer == float(ground_truth) elif any(char in ground_truth for char in [",", ";"]): logger.info(f"Evaluating {model_answer} as a comma separated list.") gt_elems = split_string(ground_truth) ma_elems = split_string(model_answer) if len(gt_elems) != len(ma_elems): logger.warning("Answer lists have different lengths, returning False.") return False comparisons = [] for ma_elem, gt_elem in zip(ma_elems, gt_elems): if is_float(gt_elem): normalized_ma_elem = normalize_number_str(ma_elem) comparisons.append(normalized_ma_elem == float(gt_elem)) else: ma_elem = normalize_str(ma_elem, remove_punct=False) gt_elem = normalize_str(gt_elem, remove_punct=False) comparisons.append(ma_elem == gt_elem) return all(comparisons) else: logger.info(f"Evaluating {model_answer} as a string.") ma_elem = normalize_str(model_answer) gt_elem = normalize_str(ground_truth) return ma_elem == gt_elem except Exception as e: logger.error(f"Error during evaluation: {e}") return False def load_dataset_meta(path: str, split: str = "validation"): data_dir = Path(path) / split dataset = [] with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf: lines = metaf.readlines() for line in lines: data = json.loads(line) if data["task_id"] == "0-0-0-0-0": continue if data["file_name"]: data["file_name"] = data_dir / data["file_name"] dataset.append(data) return dataset def load_dataset_meta_dict(path: str, split: str = "validation"): data_dir = Path(path) / split dataset = {} with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf: lines = metaf.readlines() for line in lines: data = json.loads(line) if data["task_id"] == "0-0-0-0-0": continue if data["file_name"]: data["file_name"] = data_dir / data["file_name"] dataset[data["task_id"]] = data return dataset def add_file_path( task: Dict[str, Any], file_path: str = "./gaia_dataset", split: str = "validation" ): if task["file_name"]: file_path = Path(f"{file_path}/{split}") / task["file_name"] if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]: task["Question"] += f" Here are the necessary document files: {file_path}" elif file_path.suffix in [".jpg", ".jpeg", ".png"]: task["Question"] += f" Here are the necessary image files: {file_path}" elif file_path.suffix in [".xlsx", "xls", ".csv"]: task["Question"] += ( f" Here are the necessary table files: {file_path}, for processing excel file," " you can use the excel tool or write python code to process the file" " step-by-step and get the information." ) elif file_path.suffix in [".py"]: task["Question"] += f" Here are the necessary python files: {file_path}" else: task["Question"] += f" Here are the necessary files: {file_path}" return task def report_results(entries): # Initialize counters total_entries = len(entries) total_correct = 0 # Initialize level statistics level_stats = {} # Process each entry for entry in entries: level = entry.get("level") is_correct = entry.get("is_correct", False) # Initialize level stats if not already present if level not in level_stats: level_stats[level] = {"total": 0, "correct": 0, "accuracy": 0} # Update counters level_stats[level]["total"] += 1 if is_correct: total_correct += 1 level_stats[level]["correct"] += 1 # Calculate accuracy for each level for level, stats in level_stats.items(): if stats["total"] > 0: stats["accuracy"] = (stats["correct"] / stats["total"]) * 100 # Print overall statistics with colorful logging logger.info("Overall Statistics:") overall_accuracy = (total_correct / total_entries) * 100 # Create overall statistics table overall_table = [ ["Total Entries", total_entries], ["Total Correct", total_correct], ["Overall Accuracy", f"{overall_accuracy:.2f}%"], ] logger.success(tabulate(overall_table, tablefmt="grid")) logger.info("") # Create level statistics table logger.info("Statistics by Level:") level_table = [] headers = ["Level", "Total Entries", "Correct Answers", "Accuracy"] for level in sorted(level_stats.keys()): stats = level_stats[level] level_table.append( [level, stats["total"], stats["correct"], f"{stats['accuracy']:.2f}%"] ) logger.success(tabulate(level_table, headers=headers, tablefmt="grid"))