Spaces:
Sleeping
Sleeping
# coding: utf-8 | |
# Copyright (c) 2025 inclusionAI. | |
import asyncio | |
import time | |
import traceback | |
import aworld.trace as trace | |
from typing import List, Callable, Any | |
from aworld.core.common import TaskItem | |
from aworld.core.context.base import Context | |
from aworld.agents.llm_agent import Agent | |
from aworld.core.event.base import Message, Constants, TopicType, ToolMessage, AgentMessage | |
from aworld.core.task import Task, TaskResponse | |
from aworld.events.manager import EventManager | |
from aworld.logs.util import logger | |
from aworld.runners.handler.agent import DefaultAgentHandler, AgentHandler | |
from aworld.runners.handler.base import DefaultHandler | |
from aworld.runners.handler.output import DefaultOutputHandler | |
from aworld.runners.handler.task import DefaultTaskHandler, TaskHandler | |
from aworld.runners.handler.tool import DefaultToolHandler, ToolHandler | |
from aworld.runners.task_runner import TaskRunner | |
from aworld.utils.common import override_in_subclass, new_instance | |
from aworld.runners.state_manager import EventRuntimeStateManager | |
class TaskEventRunner(TaskRunner): | |
"""Event driven task runner.""" | |
def __init__(self, task: Task, *args, **kwargs): | |
super().__init__(task, *args, **kwargs) | |
self._task_response = None | |
self.event_mng = EventManager(self.context) | |
self.hooks = {} | |
self.background_tasks = set() | |
self.state_manager = EventRuntimeStateManager.instance() | |
async def pre_run(self): | |
await super().pre_run() | |
if self.swarm and not self.swarm.max_steps: | |
self.swarm.max_steps = self.task.conf.get('max_steps', 10) | |
observation = self.observation | |
if not observation: | |
raise RuntimeError("no observation, check run process") | |
self._build_first_message() | |
if self.swarm: | |
# register agent handler | |
for _, agent in self.swarm.agents.items(): | |
agent.set_tools_instances(self.tools, self.tools_conf) | |
if agent.handler: | |
await self.event_mng.register(Constants.AGENT, agent.id(), agent.handler) | |
else: | |
if override_in_subclass('async_policy', agent.__class__, Agent): | |
await self.event_mng.register(Constants.AGENT, agent.id(), agent.async_run) | |
else: | |
await self.event_mng.register(Constants.AGENT, agent.id(), agent.run) | |
# register tool handler | |
for key, tool in self.tools.items(): | |
if tool.handler: | |
await self.event_mng.register(Constants.TOOL, tool.name(), tool.handler) | |
else: | |
await self.event_mng.register(Constants.TOOL, tool.name(), tool.step) | |
handlers = self.event_mng.event_bus.get_topic_handlers( | |
Constants.TOOL, tool.name()) | |
if not handlers: | |
await self.event_mng.register(Constants.TOOL, Constants.TOOL, tool.step) | |
self._stopped = asyncio.Event() | |
# handler of process in framework | |
handler_list = self.conf.get("handlers") | |
if handler_list: | |
handlers = [] | |
for hand in handler_list: | |
handlers.append(new_instance(hand, self)) | |
has_task_handler = False | |
has_tool_handler = False | |
has_agent_handler = False | |
for hand in handlers: | |
if isinstance(hand, TaskHandler): | |
has_task_handler = True | |
elif isinstance(hand, ToolHandler): | |
has_tool_handler = True | |
elif isinstance(hand, AgentHandler): | |
has_agent_handler = True | |
if not has_agent_handler: | |
self.handlers.append(DefaultAgentHandler(runner=self)) | |
if not has_tool_handler: | |
self.handlers.append(DefaultToolHandler(runner=self)) | |
if not has_task_handler: | |
self.handlers.append(DefaultTaskHandler(runner=self)) | |
self.handlers = handlers | |
else: | |
self.handlers = [DefaultAgentHandler(runner=self), | |
DefaultToolHandler(runner=self), | |
DefaultTaskHandler(runner=self), | |
DefaultOutputHandler(runner=self)] | |
def _build_first_message(self): | |
# build the first message | |
if self.agent_oriented: | |
self.init_message = AgentMessage(payload=self.observation, | |
sender='runner', | |
receiver=self.swarm.communicate_agent.id(), | |
session_id=self.context.session_id, | |
headers={'context': self.context}) | |
else: | |
actions = self.observation.content | |
receiver = actions[0].tool_name | |
self.init_message = ToolMessage(payload=self.observation.content, | |
sender='runner', | |
receiver=receiver, | |
session_id=self.context.session_id, | |
headers={'context': self.context}) | |
async def _common_process(self, message: Message) -> List[Message]: | |
event_bus = self.event_mng.event_bus | |
key = message.category | |
transformer = event_bus.get_transform_handlers(key) | |
if transformer: | |
message = await event_bus.transform(message, handler=transformer) | |
results = [] | |
handlers = event_bus.get_handlers(key) | |
async with trace.message_span(message=message): | |
self.state_manager.start_message_node(message) | |
if handlers: | |
if message.topic: | |
handlers = {message.topic: handlers.get(message.topic, [])} | |
elif message.receiver: | |
handlers = {message.receiver: handlers.get( | |
message.receiver, [])} | |
for topic, handler_list in handlers.items(): | |
if not handler_list: | |
logger.warning(f"{topic} no handler, ignore.") | |
continue | |
for handler in handler_list: | |
t = asyncio.create_task( | |
self._handle_task(message, handler)) | |
self.background_tasks.add(t) | |
t.add_done_callback(self.background_tasks.discard) | |
else: | |
# not handler, return raw message | |
results.append(message) | |
t = asyncio.create_task(self._raw_task(results)) | |
self.background_tasks.add(t) | |
t.add_done_callback(self.background_tasks.discard) | |
# wait until it is complete | |
await t | |
self.state_manager.end_message_node(message) | |
return results | |
async def _handle_task(self, message: Message, handler: Callable[..., Any]): | |
con = message | |
async with trace.span(handler.__name__): | |
try: | |
logger.info( | |
f"event_runner _handle_task start, message: {message.id}") | |
if asyncio.iscoroutinefunction(handler): | |
con = await handler(con) | |
else: | |
con = handler(con) | |
logger.info(f"event_runner _handle_task message= {message.id}") | |
if isinstance(con, Message): | |
# process in framework | |
self.state_manager.save_message_handle_result(name=handler.__name__, | |
message=message, | |
result=con) | |
async for event in self._inner_handler_process( | |
results=[con], | |
handlers=self.handlers | |
): | |
await self.event_mng.emit_message(event) | |
else: | |
self.state_manager.save_message_handle_result(name=handler.__name__, | |
message=message) | |
except Exception as e: | |
logger.warning( | |
f"{handler} process fail. {traceback.format_exc()}") | |
error_msg = Message( | |
category=Constants.TASK, | |
payload=TaskItem(msg=str(e), data=message), | |
sender=self.name, | |
session_id=Context.instance().session_id, | |
topic=TopicType.ERROR | |
) | |
self.state_manager.save_message_handle_result(name=handler.__name__, | |
message=message, | |
result=error_msg) | |
await self.event_mng.event_bus.publish(error_msg) | |
async def _raw_task(self, messages: List[Message]): | |
# process in framework | |
async for event in self._inner_handler_process( | |
results=messages, | |
handlers=self.handlers | |
): | |
await self.event_mng.emit_message(event) | |
async def _inner_handler_process(self, results: List[Message], handlers: List[DefaultHandler]): | |
# can use runtime backend to parallel | |
for handler in handlers: | |
for result in results: | |
async for event in handler.handle(result): | |
yield event | |
async def _do_run(self): | |
"""Task execution process in real.""" | |
start = time.time() | |
msg = None | |
answer = None | |
try: | |
while True: | |
if await self.is_stopped(): | |
await self.event_mng.done() | |
logger.info("stop task...") | |
if self._task_response is None: | |
# send msg to output | |
self._task_response = TaskResponse(msg=msg, | |
answer=answer, | |
success=True if not msg else False, | |
id=self.task.id, | |
time_cost=( | |
time.time() - start), | |
usage=self.context.token_usage) | |
break | |
# consume message | |
message: Message = await self.event_mng.consume() | |
# use registered handler to process message | |
await self._common_process(message) | |
except Exception as e: | |
logger.error(f"consume message fail. {traceback.format_exc()}") | |
finally: | |
if await self.is_stopped(): | |
await self.task.outputs.mark_completed() | |
# todo sandbox cleanup | |
if self.swarm and hasattr(self.swarm, 'agents') and self.swarm.agents: | |
for agent_name, agent in self.swarm.agents.items(): | |
try: | |
if hasattr(agent, 'sandbox') and agent.sandbox: | |
await agent.sandbox.cleanup() | |
except Exception as e: | |
logger.warning( | |
f"event_runner Failed to cleanup sandbox for agent {agent_name}: {e}") | |
async def do_run(self, context: Context = None): | |
if self.swarm and not self.swarm.initialized: | |
raise RuntimeError("swarm needs to use `reset` to init first.") | |
async with trace.span("Task_" + self.init_message.session_id): | |
await self.event_mng.emit_message(self.init_message) | |
await self._do_run() | |
return self._task_response | |
async def stop(self): | |
self._stopped.set() | |
async def is_stopped(self): | |
return self._stopped.is_set() | |
def response(self): | |
return self._task_response | |