File size: 6,585 Bytes
3a235a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import json
import re
import string
from pathlib import Path
from typing import Any, Dict, List, Optional
import logging
from tabulate import tabulate

logger = logging.getLogger(__name__)


def normalize_str(input_str, remove_punct=True) -> str:
    no_spaces = re.sub(r"\s", "", input_str)
    if remove_punct:
        translator = str.maketrans("", "", string.punctuation)
        return no_spaces.lower().translate(translator)
    else:
        return no_spaces.lower()


def split_string(s: str, char_list: Optional[List[str]] = None) -> list[str]:
    if char_list is None:
        char_list = [",", ";"]
    pattern = f"[{''.join(char_list)}]"
    return re.split(pattern, s)


def normalize_number_str(number_str: str) -> float:
    for char in ["$", "%", ","]:
        number_str = number_str.replace(char, "")
    try:
        return float(number_str)
    except ValueError:
        logger.error(f"String {number_str} cannot be normalized to number str.")
        return float("inf")


def question_scorer(model_answer: str, ground_truth: str) -> bool:
    def is_float(element: Any) -> bool:
        try:
            float(element)
            return True
        except ValueError:
            return False

    try:
        if is_float(ground_truth):
            logger.info(f"Evaluating {model_answer} as a number.")
            normalized_answer = normalize_number_str(model_answer)
            return normalized_answer == float(ground_truth)

        elif any(char in ground_truth for char in [",", ";"]):
            logger.info(f"Evaluating {model_answer} as a comma separated list.")
            gt_elems = split_string(ground_truth)
            ma_elems = split_string(model_answer)

            if len(gt_elems) != len(ma_elems):
                logger.warning("Answer lists have different lengths, returning False.")
                return False

            comparisons = []
            for ma_elem, gt_elem in zip(ma_elems, gt_elems):
                if is_float(gt_elem):
                    normalized_ma_elem = normalize_number_str(ma_elem)
                    comparisons.append(normalized_ma_elem == float(gt_elem))
                else:
                    ma_elem = normalize_str(ma_elem, remove_punct=False)
                    gt_elem = normalize_str(gt_elem, remove_punct=False)
                    comparisons.append(ma_elem == gt_elem)
            return all(comparisons)
        else:
            logger.info(f"Evaluating {model_answer} as a string.")
            ma_elem = normalize_str(model_answer)
            gt_elem = normalize_str(ground_truth)
            return ma_elem == gt_elem
    except Exception as e:
        logger.error(f"Error during evaluation: {e}")
        return False


def load_dataset_meta(path: str, split: str = "validation"):
    data_dir = Path(path) / split

    dataset = []
    with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf:
        lines = metaf.readlines()
        for line in lines:
            data = json.loads(line)
            if data["task_id"] == "0-0-0-0-0":
                continue
            if data["file_name"]:
                data["file_name"] = data_dir / data["file_name"]
            dataset.append(data)
    return dataset


def load_dataset_meta_dict(path: str, split: str = "validation"):
    data_dir = Path(path) / split

    dataset = {}
    with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf:
        lines = metaf.readlines()
        for line in lines:
            data = json.loads(line)
            if data["task_id"] == "0-0-0-0-0":
                continue
            if data["file_name"]:
                data["file_name"] = data_dir / data["file_name"]
            dataset[data["task_id"]] = data
    return dataset


def add_file_path(
    task: Dict[str, Any], file_path: str = "./gaia_dataset", split: str = "validation"
):
    if task["file_name"]:
        file_path = Path(f"{file_path}/{split}") / task["file_name"]
        if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
            task["Question"] += f" Here are the necessary document files: {file_path}"

        elif file_path.suffix in [".jpg", ".jpeg", ".png"]:
            task["Question"] += f" Here are the necessary image files: {file_path}"

        elif file_path.suffix in [".xlsx", "xls", ".csv"]:
            task["Question"] += (
                f" Here are the necessary table files: {file_path}, for processing excel file,"
                " you can use the excel tool or write python code to process the file"
                " step-by-step and get the information."
            )
        elif file_path.suffix in [".py"]:
            task["Question"] += f" Here are the necessary python files: {file_path}"

        else:
            task["Question"] += f" Here are the necessary files: {file_path}"

    return task


def report_results(entries):
    # Initialize counters
    total_entries = len(entries)
    total_correct = 0

    # Initialize level statistics
    level_stats = {}

    # Process each entry
    for entry in entries:
        level = entry.get("level")
        is_correct = entry.get("is_correct", False)

        # Initialize level stats if not already present
        if level not in level_stats:
            level_stats[level] = {"total": 0, "correct": 0, "accuracy": 0}

        # Update counters
        level_stats[level]["total"] += 1
        if is_correct:
            total_correct += 1
            level_stats[level]["correct"] += 1

    # Calculate accuracy for each level
    for level, stats in level_stats.items():
        if stats["total"] > 0:
            stats["accuracy"] = (stats["correct"] / stats["total"]) * 100

    # Print overall statistics with colorful logging
    logger.info("Overall Statistics:")
    overall_accuracy = (total_correct / total_entries) * 100

    # Create overall statistics table
    overall_table = [
        ["Total Entries", total_entries],
        ["Total Correct", total_correct],
        ["Overall Accuracy", f"{overall_accuracy:.2f}%"],
    ]
    logger.success(tabulate(overall_table, tablefmt="grid"))
    logger.info("")

    # Create level statistics table
    logger.info("Statistics by Level:")
    level_table = []
    headers = ["Level", "Total Entries", "Correct Answers", "Accuracy"]

    for level in sorted(level_stats.keys()):
        stats = level_stats[level]
        level_table.append(
            [level, stats["total"], stats["correct"], f"{stats['accuracy']:.2f}%"]
        )

    logger.success(tabulate(level_table, headers=headers, tablefmt="grid"))