Spaces:
Sleeping
Sleeping

Rename AWorld-main/aworlddistributed/aworldspace/utils/utils.py to aworlddistributed/aworldspace/utils/utils.py
086fc6e
verified
import json | |
import re | |
import string | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional | |
from loguru import logger | |
from tabulate import tabulate | |
def normalize_str(input_str, remove_punct=True) -> str: | |
no_spaces = re.sub(r"\s", "", input_str) | |
if remove_punct: | |
translator = str.maketrans("", "", string.punctuation) | |
return no_spaces.lower().translate(translator) | |
else: | |
return no_spaces.lower() | |
def split_string(s: str, char_list: Optional[List[str]] = None) -> list[str]: | |
if char_list is None: | |
char_list = [",", ";"] | |
pattern = f"[{''.join(char_list)}]" | |
return re.split(pattern, s) | |
def normalize_number_str(number_str: str) -> float: | |
for char in ["$", "%", ","]: | |
number_str = number_str.replace(char, "") | |
try: | |
return float(number_str) | |
except ValueError: | |
logger.error(f"String {number_str} cannot be normalized to number str.") | |
return float("inf") | |
def question_scorer(model_answer: str, ground_truth: str) -> bool: | |
def is_float(element: Any) -> bool: | |
try: | |
float(element) | |
return True | |
except ValueError: | |
return False | |
try: | |
if is_float(ground_truth): | |
logger.info(f"Evaluating {model_answer} as a number.") | |
normalized_answer = normalize_number_str(model_answer) | |
return normalized_answer == float(ground_truth) | |
elif any(char in ground_truth for char in [",", ";"]): | |
logger.info(f"Evaluating {model_answer} as a comma separated list.") | |
gt_elems = split_string(ground_truth) | |
ma_elems = split_string(model_answer) | |
if len(gt_elems) != len(ma_elems): | |
logger.warning("Answer lists have different lengths, returning False.") | |
return False | |
comparisons = [] | |
for ma_elem, gt_elem in zip(ma_elems, gt_elems): | |
if is_float(gt_elem): | |
normalized_ma_elem = normalize_number_str(ma_elem) | |
comparisons.append(normalized_ma_elem == float(gt_elem)) | |
else: | |
ma_elem = normalize_str(ma_elem, remove_punct=False) | |
gt_elem = normalize_str(gt_elem, remove_punct=False) | |
comparisons.append(ma_elem == gt_elem) | |
return all(comparisons) | |
else: | |
logger.info(f"Evaluating {model_answer} as a string.") | |
ma_elem = normalize_str(model_answer) | |
gt_elem = normalize_str(ground_truth) | |
return ma_elem == gt_elem | |
except Exception as e: | |
logger.error(f"Error during evaluation: {e}") | |
return False | |
def load_dataset_meta(path: str, split: str = "validation"): | |
data_dir = Path(path) / split | |
dataset = [] | |
with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf: | |
lines = metaf.readlines() | |
for line in lines: | |
data = json.loads(line) | |
if data["task_id"] == "0-0-0-0-0": | |
continue | |
if data["file_name"]: | |
data["file_name"] = data_dir / data["file_name"] | |
dataset.append(data) | |
return dataset | |
def load_dataset_meta_dict(path: str, split: str = "validation"): | |
data_dir = Path(path) / split | |
dataset = {} | |
with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf: | |
lines = metaf.readlines() | |
for line in lines: | |
data = json.loads(line) | |
if data["task_id"] == "0-0-0-0-0": | |
continue | |
if data["file_name"]: | |
data["file_name"] = data_dir / data["file_name"] | |
dataset[data["task_id"]] = data | |
return dataset | |
def add_file_path( | |
task: Dict[str, Any], file_path: str = "./gaia_dataset", split: str = "validation" | |
): | |
if task["file_name"]: | |
file_path = Path(f"{file_path}/{split}") / task["file_name"] | |
if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]: | |
task["Question"] += f" Here are the necessary document files: {file_path}" | |
elif file_path.suffix in [".jpg", ".jpeg", ".png"]: | |
task["Question"] += f" Here are the necessary image files: {file_path}" | |
elif file_path.suffix in [".xlsx", "xls", ".csv"]: | |
task["Question"] += ( | |
f" Here are the necessary table files: {file_path}, for processing excel file," | |
" you can use the excel tool or write python code to process the file" | |
" step-by-step and get the information." | |
) | |
elif file_path.suffix in [".py"]: | |
task["Question"] += f" Here are the necessary python files: {file_path}" | |
else: | |
task["Question"] += f" Here are the necessary files: {file_path}" | |
return task | |
def report_results(entries): | |
# Initialize counters | |
total_entries = len(entries) | |
total_correct = 0 | |
# Initialize level statistics | |
level_stats = {} | |
# Process each entry | |
for entry in entries: | |
level = entry.get("level") | |
is_correct = entry.get("is_correct", False) | |
# Initialize level stats if not already present | |
if level not in level_stats: | |
level_stats[level] = {"total": 0, "correct": 0, "accuracy": 0} | |
# Update counters | |
level_stats[level]["total"] += 1 | |
if is_correct: | |
total_correct += 1 | |
level_stats[level]["correct"] += 1 | |
# Calculate accuracy for each level | |
for level, stats in level_stats.items(): | |
if stats["total"] > 0: | |
stats["accuracy"] = (stats["correct"] / stats["total"]) * 100 | |
# Print overall statistics with colorful logging | |
logger.info("Overall Statistics:") | |
overall_accuracy = (total_correct / total_entries) * 100 | |
# Create overall statistics table | |
overall_table = [ | |
["Total Entries", total_entries], | |
["Total Correct", total_correct], | |
["Overall Accuracy", f"{overall_accuracy:.2f}%"], | |
] | |
logger.success(tabulate(overall_table, tablefmt="grid")) | |
logger.info("") | |
# Create level statistics table | |
logger.info("Statistics by Level:") | |
level_table = [] | |
headers = ["Level", "Total Entries", "Correct Answers", "Accuracy"] | |
for level in sorted(level_stats.keys()): | |
stats = level_stats[level] | |
level_table.append( | |
[level, stats["total"], stats["correct"], f"{stats['accuracy']:.2f}%"] | |
) | |
logger.success(tabulate(level_table, headers=headers, tablefmt="grid")) | |
import uuid | |
import time | |
from typing import List | |
import inspect | |
from typing import get_type_hints, Tuple | |
def stream_message_template(model: str, message: str): | |
return { | |
"id": f"{model}-{str(uuid.uuid4())}", | |
"object": "chat.completion.chunk", | |
"created": int(time.time()), | |
"model": model, | |
"choices": [ | |
{ | |
"index": 0, | |
"delta": {"content": message}, | |
"logprobs": None, | |
"finish_reason": None, | |
} | |
], | |
} | |
def get_last_user_message(messages: List[dict]) -> str: | |
for message in reversed(messages): | |
if message["role"] == "user": | |
if isinstance(message["content"], list): | |
for item in message["content"]: | |
if item["type"] == "text": | |
return item["text"] | |
return message["content"] | |
return None | |
def get_last_assistant_message(messages: List[dict]) -> str: | |
for message in reversed(messages): | |
if message["role"] == "assistant": | |
if isinstance(message["content"], list): | |
for item in message["content"]: | |
if item["type"] == "text": | |
return item["text"] | |
return message["content"] | |
return None | |
def get_system_message(messages: List[dict]) -> dict: | |
for message in messages: | |
if message["role"] == "system": | |
return message | |
return None | |
def remove_system_message(messages: List[dict]) -> List[dict]: | |
return [message for message in messages if message["role"] != "system"] | |
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]: | |
return get_system_message(messages), remove_system_message(messages) | |
def add_or_update_system_message(content: str, messages: List[dict]) -> List[dict]: | |
""" | |
Adds a new system message at the beginning of the messages list | |
or updates the existing system message at the beginning. | |
:param msg: The message to be added or appended. | |
:param messages: The list of message dictionaries. | |
:return: The updated list of message dictionaries. | |
""" | |
if messages and messages[0].get("role") == "system": | |
messages[0]["content"] += f"{content}\n{messages[0]['content']}" | |
else: | |
# Insert at the beginning | |
messages.insert(0, {"role": "system", "content": content}) | |
return messages | |
def doc_to_dict(docstring): | |
lines = docstring.split("\n") | |
description = lines[1].strip() | |
param_dict = {} | |
for line in lines: | |
if ":param" in line: | |
line = line.replace(":param", "").strip() | |
param, desc = line.split(":", 1) | |
param_dict[param.strip()] = desc.strip() | |
ret_dict = {"description": description, "params": param_dict} | |
return ret_dict | |
def get_tools_specs(tools) -> List[dict]: | |
function_list = [ | |
{"name": func, "function": getattr(tools, func)} | |
for func in dir(tools) | |
if callable(getattr(tools, func)) and not func.startswith("__") | |
] | |
specs = [] | |
for function_item in function_list: | |
function_name = function_item["name"] | |
function = function_item["function"] | |
function_doc = doc_to_dict(function.__doc__ or function_name) | |
specs.append( | |
{ | |
"name": function_name, | |
# TODO: multi-line desc? | |
"description": function_doc.get("description", function_name), | |
"parameters": { | |
"type": "object", | |
"properties": { | |
param_name: { | |
"type": param_annotation.__name__.lower(), | |
**( | |
{ | |
"enum": ( | |
param_annotation.__args__ | |
if hasattr(param_annotation, "__args__") | |
else None | |
) | |
} | |
if hasattr(param_annotation, "__args__") | |
else {} | |
), | |
"description": function_doc.get("params", {}).get( | |
param_name, param_name | |
), | |
} | |
for param_name, param_annotation in get_type_hints( | |
function | |
).items() | |
if param_name != "return" | |
}, | |
"required": [ | |
name | |
for name, param in inspect.signature( | |
function | |
).parameters.items() | |
if param.default is param.empty | |
], | |
}, | |
} | |
) | |
return specs | |