Spaces:
Sleeping
Sleeping
File size: 5,648 Bytes
bc5e560 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# coding: utf-8
# Copyright (c) 2025 inclusionAI.
import abc
import time
from typing import AsyncGenerator
from aworld.core.common import TaskItem
from aworld.core.tool.base import Tool, AsyncTool
from aworld.core.event.base import Message, Constants, TopicType
from aworld.core.task import TaskResponse
from aworld.logs.util import logger
from aworld.output import Output
from aworld.runners.handler.base import DefaultHandler
from aworld.runners.hook.hook_factory import HookFactory
from aworld.runners.hook.hooks import HookPoint
class TaskHandler(DefaultHandler):
__metaclass__ = abc.ABCMeta
def __init__(self, runner: 'TaskEventRunner'):
self.runner = runner
self.retry_count = 0
self.hooks = {}
if runner.task.hooks:
for k, vals in runner.task.hooks.items():
self.hooks[k] = []
for v in vals:
cls = HookFactory.get_class(v)
if cls:
self.hooks[k].append(cls)
@classmethod
def name(cls):
return "_task_handler"
class DefaultTaskHandler(TaskHandler):
async def handle(self, message: Message) -> AsyncGenerator[Message, None]:
if message.category != Constants.TASK:
return
logger.info(f"task handler receive message: {message}")
headers = {"context": message.context}
topic = message.topic
task_item: TaskItem = message.payload
if topic == TopicType.SUBSCRIBE_TOOL:
new_tools = message.payload.data
for name, tool in new_tools.items():
if isinstance(tool, Tool) or isinstance(tool, AsyncTool):
await self.runner.event_mng.register(Constants.TOOL, name, tool.step)
logger.info(f"dynamic register {name} tool.")
else:
logger.warning(f"Unknown tool instance: {tool}")
return
elif topic == TopicType.SUBSCRIBE_AGENT:
return
elif topic == TopicType.ERROR:
async for event in self.run_hooks(message, HookPoint.ERROR):
yield event
if task_item.stop:
await self.runner.stop()
logger.warning(f"task {self.runner.task.id} stop, cause: {task_item.msg}")
self.runner._task_response = TaskResponse(msg=task_item.msg,
answer='',
success=False,
id=self.runner.task.id,
time_cost=(time.time() - self.runner.start_time),
usage=self.runner.context.token_usage)
return
# restart
logger.warning(f"The task {self.runner.task.id} will be restarted due to error: {task_item.msg}.")
if self.retry_count >= 3:
raise Exception(f"The task {self.runner.task.id} failed, due to error: {task_item.msg}.")
self.retry_count += 1
yield Message(
category=Constants.TASK,
payload='',
sender=self.name(),
session_id=self.runner.context.session_id,
topic=TopicType.START,
headers=headers
)
elif topic == TopicType.FINISHED:
async for event in self.run_hooks(message, HookPoint.FINISHED):
yield event
self.runner._task_response = TaskResponse(answer=str(message.payload),
success=True,
id=self.runner.task.id,
time_cost=(time.time() - self.runner.start_time),
usage=self.runner.context.token_usage)
await self.runner.stop()
logger.info(f"{self.runner.task.id} finished.")
elif topic == TopicType.START:
async for event in self.run_hooks(message, HookPoint.START):
yield event
logger.info(f"task start event: {message}, will send init message.")
if message.payload:
yield message
else:
yield self.runner.init_message
elif topic == TopicType.OUTPUT:
yield message
elif topic == TopicType.HUMAN_CONFIRM:
logger.warn("=============== Get human confirm, pause execution ===============")
if self.runner.task.outputs and message.payload:
await self.runner.task.outputs.add_output(Output(data=message.payload))
self.runner._task_response = TaskResponse(answer=str(message.payload),
success=True,
id=self.runner.task.id,
time_cost=(time.time() - self.runner.start_time),
usage=self.runner.context.token_usage)
await self.runner.stop()
async def run_hooks(self, message: Message, hook_point: str) -> AsyncGenerator[Message, None]:
hooks = self.hooks.get(hook_point, [])
for hook in hooks:
try:
msg = hook(message)
if msg:
yield msg
except:
logger.warning(f"{hook.point()} {hook.name()} execute fail.")
|