Final_Assignment_AWorld / examples /gaia /gaia_agent_runner.py
Duibonduil's picture
Upload 9 files
3a235a9 verified
import json
import logging
import os
import re
import traceback
from typing import AsyncGenerator
import uuid
from aworld.config.conf import AgentConfig, TaskConfig
from aworld.agents.llm_agent import Agent
from aworld.core.task import Task
from aworld.runner import Runners
from aworld.output.ui.base import AworldUI
from aworld.output.ui.markdown_aworld_ui import MarkdownAworldUI
from aworld.output.base import Output
from .utils import (
add_file_path,
load_dataset_meta_dict,
question_scorer,
)
from .prompt import system_prompt
logger = logging.getLogger(__name__)
class GaiaAgentRunner:
"""
Gaia Agent Runner
"""
def __init__(
self,
llm_provider: str,
llm_model_name: str,
llm_base_url: str,
llm_api_key: str,
llm_temperature: float = 0.0,
mcp_config: dict = {},
):
self.agent_config = AgentConfig(
llm_provider=llm_provider,
llm_model_name=llm_model_name,
llm_api_key=llm_api_key,
llm_base_url=llm_base_url,
llm_temperature=llm_temperature,
)
self.super_agent = Agent(
conf=self.agent_config,
name="gaia_super_agent",
system_prompt=system_prompt,
mcp_config=mcp_config,
mcp_servers=mcp_config.get("mcpServers", {}).keys(),
)
self.gaia_dataset_path = os.path.abspath(
os.getenv(
"GAIA_DATASET_PATH",
os.path.join(os.path.dirname(os.path.abspath(__file__)), "GAIA", "2023"),
)
)
self.full_dataset = load_dataset_meta_dict(self.gaia_dataset_path)
logger.info(
f"Gaia Agent Runner initialized: super_agent={self.super_agent}, agent_config={self.agent_config}, gaia_dataset_path={self.gaia_dataset_path}, full_dataset={len(self.full_dataset)}"
)
async def run(self, prompt: str):
yield (f"\n### GAIA Agent Start!")
mcp_servers = "\n- ".join(self.super_agent.mcp_servers)
yield (f"\n```gaia_agent_status\n- {mcp_servers}\n```\n")
question = None
data_item = None
task_id = None
try:
json_data = json.loads(prompt)
task_id = json_data["task_id"]
data_item = self.full_dataset[task_id]
question = add_file_path(data_item, file_path=self.gaia_dataset_path)[
"Question"
]
yield (f"\n```gaia_question\n{json.dumps(data_item, indent=2)}\n```\n")
except Exception as e:
pass
if not question:
logger.warning(
"Could not find GAIA question for prompt, chat using prompt directly!"
)
yield (f"\n{prompt}\n")
question = prompt
try:
task = Task(
id=task_id + "." + uuid.uuid1().hex if task_id else uuid.uuid1().hex,
input=question,
agent=self.super_agent,
event_driven=False,
conf=TaskConfig(max_steps=20),
)
last_output: Output = None
rich_ui = MarkdownAworldUI()
async for output in Runners.streamed_run_task(task).stream_events():
logger.info(f"Gaia Agent Ouput: {output}")
res = await AworldUI.parse_output(output, rich_ui)
for item in res if isinstance(res, list) else [res]:
if isinstance(item, AsyncGenerator):
async for sub_item in item:
yield sub_item
else:
yield item
last_output = item
logger.info(f"Gaia Agent Last Output: {last_output}")
if data_item and last_output:
final_response = self._judge_answer(data_item, last_output)
yield final_response
except Exception as e:
logger.error(f"Error processing {prompt}, error: {traceback.format_exc()}")
def _judge_answer(self, data_item: dict, result: Output):
answer = result
match = re.search(r"<answer>(.*?)</answer>", answer)
if match:
answer = match.group(1)
logger.info(f"Agent answer: {answer}")
logger.info(f"Correct answer: {data_item['Final answer']}")
if question_scorer(answer, data_item["Final answer"]):
logger.info(f"Question {data_item['task_id']} Correct!")
else:
logger.info(f"Question {data_item['task_id']} Incorrect!")
# Create the new result record
correct = question_scorer(answer, data_item["Final answer"])
new_result = {
"task_id": data_item["task_id"],
"level": data_item["Level"],
"question": data_item["Question"],
"answer": data_item["Final answer"],
"response": answer,
"is_correct": correct,
}
return f"\n## Final Result: {'✅' if correct else '❌'}\n \n```gaia_result\n{json.dumps(new_result, indent=2)}\n```"
else:
new_result = answer
return f"\n## Final Result:\n \n```gaia_result\n{json.dumps(new_result, indent=2)}\n```"
if __name__ == "__main__":
import asyncio
import argparse
from datetime import datetime
logger = logging.getLogger(__name__)
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_file = os.path.join(
output_dir, f"output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
)
async def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="")
args = parser.parse_args()
try:
prompt = args.prompt
llm_provider = os.getenv("LLM_PROVIDER")
llm_model_name = os.getenv("LLM_MODEL_NAME")
llm_api_key = os.getenv("LLM_API_KEY")
llm_base_url = os.getenv("LLM_BASE_URL")
llm_temperature = os.getenv("LLM_TEMPERATURE", 0.0)
def send_output(output):
with open(output_file, "a") as f:
f.write(f"{output}\n")
async for i in GaiaAgentRunner(
llm_provider=llm_provider,
llm_model_name=llm_model_name,
llm_base_url=llm_base_url,
llm_api_key=llm_api_key,
llm_temperature=llm_temperature,
).run(prompt):
send_output(i)
except Exception as e:
logger.error(
f"Error processing {args.prompt}, error: {traceback.format_exc()}"
)
asyncio.run(main())