Spaces:
Sleeping
Sleeping
import json | |
import logging | |
import os | |
import re | |
import traceback | |
from typing import AsyncGenerator | |
import uuid | |
from aworld.config.conf import AgentConfig, TaskConfig | |
from aworld.agents.llm_agent import Agent | |
from aworld.core.task import Task | |
from aworld.runner import Runners | |
from aworld.output.ui.base import AworldUI | |
from aworld.output.ui.markdown_aworld_ui import MarkdownAworldUI | |
from aworld.output.base import Output | |
from .utils import ( | |
add_file_path, | |
load_dataset_meta_dict, | |
question_scorer, | |
) | |
from .prompt import system_prompt | |
logger = logging.getLogger(__name__) | |
class GaiaAgentRunner: | |
""" | |
Gaia Agent Runner | |
""" | |
def __init__( | |
self, | |
llm_provider: str, | |
llm_model_name: str, | |
llm_base_url: str, | |
llm_api_key: str, | |
llm_temperature: float = 0.0, | |
mcp_config: dict = {}, | |
): | |
self.agent_config = AgentConfig( | |
llm_provider=llm_provider, | |
llm_model_name=llm_model_name, | |
llm_api_key=llm_api_key, | |
llm_base_url=llm_base_url, | |
llm_temperature=llm_temperature, | |
) | |
self.super_agent = Agent( | |
conf=self.agent_config, | |
name="gaia_super_agent", | |
system_prompt=system_prompt, | |
mcp_config=mcp_config, | |
mcp_servers=mcp_config.get("mcpServers", {}).keys(), | |
) | |
self.gaia_dataset_path = os.path.abspath( | |
os.getenv( | |
"GAIA_DATASET_PATH", | |
os.path.join(os.path.dirname(os.path.abspath(__file__)), "GAIA", "2023"), | |
) | |
) | |
self.full_dataset = load_dataset_meta_dict(self.gaia_dataset_path) | |
logger.info( | |
f"Gaia Agent Runner initialized: super_agent={self.super_agent}, agent_config={self.agent_config}, gaia_dataset_path={self.gaia_dataset_path}, full_dataset={len(self.full_dataset)}" | |
) | |
async def run(self, prompt: str): | |
yield (f"\n### GAIA Agent Start!") | |
mcp_servers = "\n- ".join(self.super_agent.mcp_servers) | |
yield (f"\n```gaia_agent_status\n- {mcp_servers}\n```\n") | |
question = None | |
data_item = None | |
task_id = None | |
try: | |
json_data = json.loads(prompt) | |
task_id = json_data["task_id"] | |
data_item = self.full_dataset[task_id] | |
question = add_file_path(data_item, file_path=self.gaia_dataset_path)[ | |
"Question" | |
] | |
yield (f"\n```gaia_question\n{json.dumps(data_item, indent=2)}\n```\n") | |
except Exception as e: | |
pass | |
if not question: | |
logger.warning( | |
"Could not find GAIA question for prompt, chat using prompt directly!" | |
) | |
yield (f"\n{prompt}\n") | |
question = prompt | |
try: | |
task = Task( | |
id=task_id + "." + uuid.uuid1().hex if task_id else uuid.uuid1().hex, | |
input=question, | |
agent=self.super_agent, | |
event_driven=False, | |
conf=TaskConfig(max_steps=20), | |
) | |
last_output: Output = None | |
rich_ui = MarkdownAworldUI() | |
async for output in Runners.streamed_run_task(task).stream_events(): | |
logger.info(f"Gaia Agent Ouput: {output}") | |
res = await AworldUI.parse_output(output, rich_ui) | |
for item in res if isinstance(res, list) else [res]: | |
if isinstance(item, AsyncGenerator): | |
async for sub_item in item: | |
yield sub_item | |
else: | |
yield item | |
last_output = item | |
logger.info(f"Gaia Agent Last Output: {last_output}") | |
if data_item and last_output: | |
final_response = self._judge_answer(data_item, last_output) | |
yield final_response | |
except Exception as e: | |
logger.error(f"Error processing {prompt}, error: {traceback.format_exc()}") | |
def _judge_answer(self, data_item: dict, result: Output): | |
answer = result | |
match = re.search(r"<answer>(.*?)</answer>", answer) | |
if match: | |
answer = match.group(1) | |
logger.info(f"Agent answer: {answer}") | |
logger.info(f"Correct answer: {data_item['Final answer']}") | |
if question_scorer(answer, data_item["Final answer"]): | |
logger.info(f"Question {data_item['task_id']} Correct!") | |
else: | |
logger.info(f"Question {data_item['task_id']} Incorrect!") | |
# Create the new result record | |
correct = question_scorer(answer, data_item["Final answer"]) | |
new_result = { | |
"task_id": data_item["task_id"], | |
"level": data_item["Level"], | |
"question": data_item["Question"], | |
"answer": data_item["Final answer"], | |
"response": answer, | |
"is_correct": correct, | |
} | |
return f"\n## Final Result: {'✅' if correct else '❌'}\n \n```gaia_result\n{json.dumps(new_result, indent=2)}\n```" | |
else: | |
new_result = answer | |
return f"\n## Final Result:\n \n```gaia_result\n{json.dumps(new_result, indent=2)}\n```" | |
if __name__ == "__main__": | |
import asyncio | |
import argparse | |
from datetime import datetime | |
logger = logging.getLogger(__name__) | |
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output") | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
output_file = os.path.join( | |
output_dir, f"output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md" | |
) | |
async def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--prompt", type=str, default="") | |
args = parser.parse_args() | |
try: | |
prompt = args.prompt | |
llm_provider = os.getenv("LLM_PROVIDER") | |
llm_model_name = os.getenv("LLM_MODEL_NAME") | |
llm_api_key = os.getenv("LLM_API_KEY") | |
llm_base_url = os.getenv("LLM_BASE_URL") | |
llm_temperature = os.getenv("LLM_TEMPERATURE", 0.0) | |
def send_output(output): | |
with open(output_file, "a") as f: | |
f.write(f"{output}\n") | |
async for i in GaiaAgentRunner( | |
llm_provider=llm_provider, | |
llm_model_name=llm_model_name, | |
llm_base_url=llm_base_url, | |
llm_api_key=llm_api_key, | |
llm_temperature=llm_temperature, | |
).run(prompt): | |
send_output(i) | |
except Exception as e: | |
logger.error( | |
f"Error processing {args.prompt}, error: {traceback.format_exc()}" | |
) | |
asyncio.run(main()) | |