Spaces:
Sleeping
Sleeping
import argparse | |
import json | |
import logging | |
import os | |
import re | |
import traceback | |
from typing import Any, Dict, List | |
from dotenv import load_dotenv | |
from aworld.config.conf import AgentConfig, TaskConfig | |
from aworld.agents.llm_agent import Agent | |
from aworld.core.task import Task | |
from aworld.runner import Runners | |
from examples.gaia.prompt import system_prompt | |
from examples.gaia.utils import ( | |
add_file_path, | |
load_dataset_meta, | |
question_scorer, | |
report_results, | |
) | |
# Create log directory if it doesn't exist | |
if not os.path.exists(os.getenv("LOG_FILE_PATH")): | |
os.makedirs(os.getenv("LOG_FILE_PATH")) | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--start", | |
type=int, | |
default=0, | |
help="Start index of the dataset", | |
) | |
parser.add_argument( | |
"--end", | |
type=int, | |
default=20, | |
help="End index of the dataset", | |
) | |
parser.add_argument( | |
"--q", | |
type=str, | |
help="Question Index, e.g., 0-0-0-0-0. Highest priority: override other arguments if provided.", | |
) | |
parser.add_argument( | |
"--skip", | |
action="store_true", | |
help="Skip the question if it has been processed before.", | |
) | |
parser.add_argument( | |
"--split", | |
type=str, | |
default="validation", | |
help="Split of the dataset, e.g., validation, test", | |
) | |
parser.add_argument( | |
"--blacklist_file_path", | |
type=str, | |
nargs="?", | |
help="Blacklist file path, e.g., blacklist.txt", | |
) | |
args = parser.parse_args() | |
def setup_logging(): | |
logging_logger = logging.getLogger() | |
logging_logger.setLevel(logging.INFO) | |
log_file_name = ( | |
f"/super_agent_{args.q}.log" | |
if args.q | |
else f"/super_agent_{args.start}_{args.end}.log" | |
) | |
file_handler = logging.FileHandler( | |
os.getenv( | |
"LOG_FILE_PATH", | |
"run_super_agent.log", | |
) | |
+ log_file_name, | |
mode="a", | |
encoding="utf-8", | |
) | |
file_handler.setLevel(logging.INFO) | |
formatter = logging.Formatter( | |
"%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
file_handler.setFormatter(formatter) | |
logging_logger.addHandler(file_handler) | |
if __name__ == "__main__": | |
load_dotenv() | |
setup_logging() | |
gaia_dataset_path = os.getenv("GAIA_DATASET_PATH", "./gaia_dataset") | |
full_dataset = load_dataset_meta(gaia_dataset_path, split=args.split) | |
logging.info(f"Total questions: {len(full_dataset)}") | |
agent_config = AgentConfig( | |
llm_provider="openai", | |
llm_model_name=os.getenv("LLM_MODEL_NAME", "gpt-4o"), | |
llm_api_key=os.getenv("LLM_API_KEY", "your_openai_api_key"), | |
llm_base_url=os.getenv("LLM_BASE_URL", "your_openai_base_url"), | |
) | |
super_agent = Agent( | |
conf=agent_config, | |
name="gaia_super_agent", | |
system_prompt=system_prompt, | |
mcp_servers=[ | |
"e2b-server", | |
# "filesystem", | |
"terminal-controller", | |
"excel", | |
"calculator", | |
"ms-playwright", | |
"audio_server", | |
"image_server", | |
"video_server", | |
"search_server", | |
"download_server", | |
"document_server", | |
# "browser_server", | |
"youtube_server", | |
"reasoning_server", | |
], | |
) | |
# load results from the checkpoint file | |
if os.path.exists(os.getenv("LOG_FILE_PATH") + "/results.json"): | |
with open( | |
os.getenv("LOG_FILE_PATH") + "/results.json", "r", encoding="utf-8" | |
) as results_f: | |
results: List[Dict[str, Any]] = json.load(results_f) | |
else: | |
results: List[Dict[str, Any]] = [] | |
# load blacklist `task_id` | |
if args.blacklist_file_path and os.path.exists(args.blacklist_file_path): | |
with open(args.blacklist_file_path, "r", encoding="utf-8") as f: | |
blacklist = set(f.read().splitlines()) | |
else: | |
blacklist = set() # Empty set if file doesn't exist | |
try: | |
# slice dataset by args.start and args.end, overrided by args.q (single `task_id`) | |
dataset_slice = ( | |
[ | |
dataset_record | |
for idx, dataset_record in enumerate(full_dataset) | |
if dataset_record["task_id"] in args.q | |
] | |
if args.q is not None | |
else full_dataset[args.start : args.end] | |
) | |
# main loop to execute questions | |
for i, dataset_i in enumerate(dataset_slice): | |
# specify `task_id` | |
if args.q and args.q != dataset_i["task_id"]: | |
continue | |
# only valid for args.q==None | |
if not args.q: | |
# blacklist | |
if dataset_i["task_id"] in blacklist: | |
continue | |
# pass | |
if any( | |
# Question Done and Correct | |
(result["task_id"] == dataset_i["task_id"] and result["is_correct"]) | |
for result in results | |
) or any( | |
# Question Done and Incorrect, but Level is 3 | |
( | |
result["task_id"] == dataset_i["task_id"] | |
and not result["is_correct"] | |
and dataset_i["Level"] == 3 | |
) | |
for result in results | |
): | |
continue | |
# skip | |
if args.skip and any( | |
# Question Done and Correct | |
(result["task_id"] == dataset_i["task_id"]) | |
for result in results | |
): | |
continue | |
# run | |
try: | |
logging.info(f"Start to process: {dataset_i['task_id']}") | |
logging.info(f"Detail: {dataset_i}") | |
logging.info(f"Question: {dataset_i['Question']}") | |
logging.info(f"Level: {dataset_i['Level']}") | |
logging.info(f"Tools: {dataset_i['Annotator Metadata']['Tools']}") | |
question = add_file_path( | |
dataset_i, file_path=gaia_dataset_path, split=args.split | |
)["Question"] | |
task = Task(input=question, agent=super_agent, conf=TaskConfig()) | |
result = Runners.sync_run_task(task=task) | |
match = re.search(r"<answer>(.*?)</answer>", result[task.id].get('answer')) | |
if match: | |
answer = match.group(1) | |
logging.info(f"Agent answer: {answer}") | |
logging.info(f"Correct answer: {dataset_i['Final answer']}") | |
if question_scorer(answer, dataset_i["Final answer"]): | |
logging.info(f"Question {i} Correct!") | |
else: | |
logging.info("Incorrect!") | |
# Create the new result record | |
new_result = { | |
"task_id": dataset_i["task_id"], | |
"level": dataset_i["Level"], | |
"question": question, | |
"answer": dataset_i["Final answer"], | |
"response": answer, | |
"is_correct": question_scorer(answer, dataset_i["Final answer"]), | |
} | |
# Check if this task_id already exists in results | |
existing_index = next( | |
( | |
i | |
for i, result in enumerate(results) | |
if result["task_id"] == dataset_i["task_id"] | |
), | |
None, | |
) | |
if existing_index is not None: | |
# Update existing record | |
results[existing_index] = new_result | |
logging.info( | |
f"Updated existing record for task_id: {dataset_i['task_id']}" | |
) | |
else: | |
# Append new record | |
results.append(new_result) | |
logging.info( | |
f"Added new record for task_id: {dataset_i['task_id']}" | |
) | |
except Exception as e: | |
logging.error(f"Error processing {i}: {traceback.format_exc()}") | |
continue | |
except KeyboardInterrupt: | |
pass | |
finally: | |
# report | |
report_results(results) | |
with open( | |
os.getenv("LOG_FILE_PATH") + "/results.json", "w", encoding="utf-8" | |
) as f: | |
json.dump(results, f, indent=4, ensure_ascii=False) | |