import argparse import json import logging import os import re import traceback from typing import Any, Dict, List from dotenv import load_dotenv from aworld.config.conf import AgentConfig, TaskConfig from aworld.agents.llm_agent import Agent from aworld.core.task import Task from aworld.runner import Runners from examples.gaia.prompt import system_prompt from examples.gaia.utils import ( add_file_path, load_dataset_meta, question_scorer, report_results, ) # Create log directory if it doesn't exist if not os.path.exists(os.getenv("LOG_FILE_PATH")): os.makedirs(os.getenv("LOG_FILE_PATH")) parser = argparse.ArgumentParser() parser.add_argument( "--start", type=int, default=0, help="Start index of the dataset", ) parser.add_argument( "--end", type=int, default=20, help="End index of the dataset", ) parser.add_argument( "--q", type=str, help="Question Index, e.g., 0-0-0-0-0. Highest priority: override other arguments if provided.", ) parser.add_argument( "--skip", action="store_true", help="Skip the question if it has been processed before.", ) parser.add_argument( "--split", type=str, default="validation", help="Split of the dataset, e.g., validation, test", ) parser.add_argument( "--blacklist_file_path", type=str, nargs="?", help="Blacklist file path, e.g., blacklist.txt", ) args = parser.parse_args() def setup_logging(): logging_logger = logging.getLogger() logging_logger.setLevel(logging.INFO) log_file_name = ( f"/super_agent_{args.q}.log" if args.q else f"/super_agent_{args.start}_{args.end}.log" ) file_handler = logging.FileHandler( os.getenv( "LOG_FILE_PATH", "run_super_agent.log", ) + log_file_name, mode="a", encoding="utf-8", ) file_handler.setLevel(logging.INFO) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) file_handler.setFormatter(formatter) logging_logger.addHandler(file_handler) if __name__ == "__main__": load_dotenv() setup_logging() gaia_dataset_path = os.getenv("GAIA_DATASET_PATH", "./gaia_dataset") full_dataset = load_dataset_meta(gaia_dataset_path, split=args.split) logging.info(f"Total questions: {len(full_dataset)}") agent_config = AgentConfig( llm_provider="openai", llm_model_name=os.getenv("LLM_MODEL_NAME", "gpt-4o"), llm_api_key=os.getenv("LLM_API_KEY", "your_openai_api_key"), llm_base_url=os.getenv("LLM_BASE_URL", "your_openai_base_url"), ) super_agent = Agent( conf=agent_config, name="gaia_super_agent", system_prompt=system_prompt, mcp_servers=[ "e2b-server", # "filesystem", "terminal-controller", "excel", "calculator", "ms-playwright", "audio_server", "image_server", "video_server", "search_server", "download_server", "document_server", # "browser_server", "youtube_server", "reasoning_server", ], ) # load results from the checkpoint file if os.path.exists(os.getenv("LOG_FILE_PATH") + "/results.json"): with open( os.getenv("LOG_FILE_PATH") + "/results.json", "r", encoding="utf-8" ) as results_f: results: List[Dict[str, Any]] = json.load(results_f) else: results: List[Dict[str, Any]] = [] # load blacklist `task_id` if args.blacklist_file_path and os.path.exists(args.blacklist_file_path): with open(args.blacklist_file_path, "r", encoding="utf-8") as f: blacklist = set(f.read().splitlines()) else: blacklist = set() # Empty set if file doesn't exist try: # slice dataset by args.start and args.end, overrided by args.q (single `task_id`) dataset_slice = ( [ dataset_record for idx, dataset_record in enumerate(full_dataset) if dataset_record["task_id"] in args.q ] if args.q is not None else full_dataset[args.start : args.end] ) # main loop to execute questions for i, dataset_i in enumerate(dataset_slice): # specify `task_id` if args.q and args.q != dataset_i["task_id"]: continue # only valid for args.q==None if not args.q: # blacklist if dataset_i["task_id"] in blacklist: continue # pass if any( # Question Done and Correct (result["task_id"] == dataset_i["task_id"] and result["is_correct"]) for result in results ) or any( # Question Done and Incorrect, but Level is 3 ( result["task_id"] == dataset_i["task_id"] and not result["is_correct"] and dataset_i["Level"] == 3 ) for result in results ): continue # skip if args.skip and any( # Question Done and Correct (result["task_id"] == dataset_i["task_id"]) for result in results ): continue # run try: logging.info(f"Start to process: {dataset_i['task_id']}") logging.info(f"Detail: {dataset_i}") logging.info(f"Question: {dataset_i['Question']}") logging.info(f"Level: {dataset_i['Level']}") logging.info(f"Tools: {dataset_i['Annotator Metadata']['Tools']}") question = add_file_path( dataset_i, file_path=gaia_dataset_path, split=args.split )["Question"] task = Task(input=question, agent=super_agent, conf=TaskConfig()) result = Runners.sync_run_task(task=task) match = re.search(r"(.*?)", result[task.id].get('answer')) if match: answer = match.group(1) logging.info(f"Agent answer: {answer}") logging.info(f"Correct answer: {dataset_i['Final answer']}") if question_scorer(answer, dataset_i["Final answer"]): logging.info(f"Question {i} Correct!") else: logging.info("Incorrect!") # Create the new result record new_result = { "task_id": dataset_i["task_id"], "level": dataset_i["Level"], "question": question, "answer": dataset_i["Final answer"], "response": answer, "is_correct": question_scorer(answer, dataset_i["Final answer"]), } # Check if this task_id already exists in results existing_index = next( ( i for i, result in enumerate(results) if result["task_id"] == dataset_i["task_id"] ), None, ) if existing_index is not None: # Update existing record results[existing_index] = new_result logging.info( f"Updated existing record for task_id: {dataset_i['task_id']}" ) else: # Append new record results.append(new_result) logging.info( f"Added new record for task_id: {dataset_i['task_id']}" ) except Exception as e: logging.error(f"Error processing {i}: {traceback.format_exc()}") continue except KeyboardInterrupt: pass finally: # report report_results(results) with open( os.getenv("LOG_FILE_PATH") + "/results.json", "w", encoding="utf-8" ) as f: json.dump(results, f, indent=4, ensure_ascii=False)