Duibonduil's picture
Upload 9 files
3a235a9 verified
import argparse
import json
import logging
import os
import re
import traceback
from typing import Any, Dict, List
from dotenv import load_dotenv
from aworld.config.conf import AgentConfig, TaskConfig
from aworld.agents.llm_agent import Agent
from aworld.core.task import Task
from aworld.runner import Runners
from examples.gaia.prompt import system_prompt
from examples.gaia.utils import (
add_file_path,
load_dataset_meta,
question_scorer,
report_results,
)
# Create log directory if it doesn't exist
if not os.path.exists(os.getenv("LOG_FILE_PATH")):
os.makedirs(os.getenv("LOG_FILE_PATH"))
parser = argparse.ArgumentParser()
parser.add_argument(
"--start",
type=int,
default=0,
help="Start index of the dataset",
)
parser.add_argument(
"--end",
type=int,
default=20,
help="End index of the dataset",
)
parser.add_argument(
"--q",
type=str,
help="Question Index, e.g., 0-0-0-0-0. Highest priority: override other arguments if provided.",
)
parser.add_argument(
"--skip",
action="store_true",
help="Skip the question if it has been processed before.",
)
parser.add_argument(
"--split",
type=str,
default="validation",
help="Split of the dataset, e.g., validation, test",
)
parser.add_argument(
"--blacklist_file_path",
type=str,
nargs="?",
help="Blacklist file path, e.g., blacklist.txt",
)
args = parser.parse_args()
def setup_logging():
logging_logger = logging.getLogger()
logging_logger.setLevel(logging.INFO)
log_file_name = (
f"/super_agent_{args.q}.log"
if args.q
else f"/super_agent_{args.start}_{args.end}.log"
)
file_handler = logging.FileHandler(
os.getenv(
"LOG_FILE_PATH",
"run_super_agent.log",
)
+ log_file_name,
mode="a",
encoding="utf-8",
)
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(formatter)
logging_logger.addHandler(file_handler)
if __name__ == "__main__":
load_dotenv()
setup_logging()
gaia_dataset_path = os.getenv("GAIA_DATASET_PATH", "./gaia_dataset")
full_dataset = load_dataset_meta(gaia_dataset_path, split=args.split)
logging.info(f"Total questions: {len(full_dataset)}")
agent_config = AgentConfig(
llm_provider="openai",
llm_model_name=os.getenv("LLM_MODEL_NAME", "gpt-4o"),
llm_api_key=os.getenv("LLM_API_KEY", "your_openai_api_key"),
llm_base_url=os.getenv("LLM_BASE_URL", "your_openai_base_url"),
)
super_agent = Agent(
conf=agent_config,
name="gaia_super_agent",
system_prompt=system_prompt,
mcp_servers=[
"e2b-server",
# "filesystem",
"terminal-controller",
"excel",
"calculator",
"ms-playwright",
"audio_server",
"image_server",
"video_server",
"search_server",
"download_server",
"document_server",
# "browser_server",
"youtube_server",
"reasoning_server",
],
)
# load results from the checkpoint file
if os.path.exists(os.getenv("LOG_FILE_PATH") + "/results.json"):
with open(
os.getenv("LOG_FILE_PATH") + "/results.json", "r", encoding="utf-8"
) as results_f:
results: List[Dict[str, Any]] = json.load(results_f)
else:
results: List[Dict[str, Any]] = []
# load blacklist `task_id`
if args.blacklist_file_path and os.path.exists(args.blacklist_file_path):
with open(args.blacklist_file_path, "r", encoding="utf-8") as f:
blacklist = set(f.read().splitlines())
else:
blacklist = set() # Empty set if file doesn't exist
try:
# slice dataset by args.start and args.end, overrided by args.q (single `task_id`)
dataset_slice = (
[
dataset_record
for idx, dataset_record in enumerate(full_dataset)
if dataset_record["task_id"] in args.q
]
if args.q is not None
else full_dataset[args.start : args.end]
)
# main loop to execute questions
for i, dataset_i in enumerate(dataset_slice):
# specify `task_id`
if args.q and args.q != dataset_i["task_id"]:
continue
# only valid for args.q==None
if not args.q:
# blacklist
if dataset_i["task_id"] in blacklist:
continue
# pass
if any(
# Question Done and Correct
(result["task_id"] == dataset_i["task_id"] and result["is_correct"])
for result in results
) or any(
# Question Done and Incorrect, but Level is 3
(
result["task_id"] == dataset_i["task_id"]
and not result["is_correct"]
and dataset_i["Level"] == 3
)
for result in results
):
continue
# skip
if args.skip and any(
# Question Done and Correct
(result["task_id"] == dataset_i["task_id"])
for result in results
):
continue
# run
try:
logging.info(f"Start to process: {dataset_i['task_id']}")
logging.info(f"Detail: {dataset_i}")
logging.info(f"Question: {dataset_i['Question']}")
logging.info(f"Level: {dataset_i['Level']}")
logging.info(f"Tools: {dataset_i['Annotator Metadata']['Tools']}")
question = add_file_path(
dataset_i, file_path=gaia_dataset_path, split=args.split
)["Question"]
task = Task(input=question, agent=super_agent, conf=TaskConfig())
result = Runners.sync_run_task(task=task)
match = re.search(r"<answer>(.*?)</answer>", result[task.id].get('answer'))
if match:
answer = match.group(1)
logging.info(f"Agent answer: {answer}")
logging.info(f"Correct answer: {dataset_i['Final answer']}")
if question_scorer(answer, dataset_i["Final answer"]):
logging.info(f"Question {i} Correct!")
else:
logging.info("Incorrect!")
# Create the new result record
new_result = {
"task_id": dataset_i["task_id"],
"level": dataset_i["Level"],
"question": question,
"answer": dataset_i["Final answer"],
"response": answer,
"is_correct": question_scorer(answer, dataset_i["Final answer"]),
}
# Check if this task_id already exists in results
existing_index = next(
(
i
for i, result in enumerate(results)
if result["task_id"] == dataset_i["task_id"]
),
None,
)
if existing_index is not None:
# Update existing record
results[existing_index] = new_result
logging.info(
f"Updated existing record for task_id: {dataset_i['task_id']}"
)
else:
# Append new record
results.append(new_result)
logging.info(
f"Added new record for task_id: {dataset_i['task_id']}"
)
except Exception as e:
logging.error(f"Error processing {i}: {traceback.format_exc()}")
continue
except KeyboardInterrupt:
pass
finally:
# report
report_results(results)
with open(
os.getenv("LOG_FILE_PATH") + "/results.json", "w", encoding="utf-8"
) as f:
json.dump(results, f, indent=4, ensure_ascii=False)