# coding: utf-8 # Copyright (c) 2025 inclusionAI. import abc import time from typing import AsyncGenerator from aworld.core.common import TaskItem from aworld.core.tool.base import Tool, AsyncTool from aworld.core.event.base import Message, Constants, TopicType from aworld.core.task import TaskResponse from aworld.logs.util import logger from aworld.output import Output from aworld.runners.handler.base import DefaultHandler from aworld.runners.hook.hook_factory import HookFactory from aworld.runners.hook.hooks import HookPoint class TaskHandler(DefaultHandler): __metaclass__ = abc.ABCMeta def __init__(self, runner: 'TaskEventRunner'): self.runner = runner self.retry_count = 0 self.hooks = {} if runner.task.hooks: for k, vals in runner.task.hooks.items(): self.hooks[k] = [] for v in vals: cls = HookFactory.get_class(v) if cls: self.hooks[k].append(cls) @classmethod def name(cls): return "_task_handler" class DefaultTaskHandler(TaskHandler): async def handle(self, message: Message) -> AsyncGenerator[Message, None]: if message.category != Constants.TASK: return logger.info(f"task handler receive message: {message}") headers = {"context": message.context} topic = message.topic task_item: TaskItem = message.payload if topic == TopicType.SUBSCRIBE_TOOL: new_tools = message.payload.data for name, tool in new_tools.items(): if isinstance(tool, Tool) or isinstance(tool, AsyncTool): await self.runner.event_mng.register(Constants.TOOL, name, tool.step) logger.info(f"dynamic register {name} tool.") else: logger.warning(f"Unknown tool instance: {tool}") return elif topic == TopicType.SUBSCRIBE_AGENT: return elif topic == TopicType.ERROR: async for event in self.run_hooks(message, HookPoint.ERROR): yield event if task_item.stop: await self.runner.stop() logger.warning(f"task {self.runner.task.id} stop, cause: {task_item.msg}") self.runner._task_response = TaskResponse(msg=task_item.msg, answer='', success=False, id=self.runner.task.id, time_cost=(time.time() - self.runner.start_time), usage=self.runner.context.token_usage) return # restart logger.warning(f"The task {self.runner.task.id} will be restarted due to error: {task_item.msg}.") if self.retry_count >= 3: raise Exception(f"The task {self.runner.task.id} failed, due to error: {task_item.msg}.") self.retry_count += 1 yield Message( category=Constants.TASK, payload='', sender=self.name(), session_id=self.runner.context.session_id, topic=TopicType.START, headers=headers ) elif topic == TopicType.FINISHED: async for event in self.run_hooks(message, HookPoint.FINISHED): yield event self.runner._task_response = TaskResponse(answer=str(message.payload), success=True, id=self.runner.task.id, time_cost=(time.time() - self.runner.start_time), usage=self.runner.context.token_usage) await self.runner.stop() logger.info(f"{self.runner.task.id} finished.") elif topic == TopicType.START: async for event in self.run_hooks(message, HookPoint.START): yield event logger.info(f"task start event: {message}, will send init message.") if message.payload: yield message else: yield self.runner.init_message elif topic == TopicType.OUTPUT: yield message elif topic == TopicType.HUMAN_CONFIRM: logger.warn("=============== Get human confirm, pause execution ===============") if self.runner.task.outputs and message.payload: await self.runner.task.outputs.add_output(Output(data=message.payload)) self.runner._task_response = TaskResponse(answer=str(message.payload), success=True, id=self.runner.task.id, time_cost=(time.time() - self.runner.start_time), usage=self.runner.context.token_usage) await self.runner.stop() async def run_hooks(self, message: Message, hook_point: str) -> AsyncGenerator[Message, None]: hooks = self.hooks.get(hook_point, []) for hook in hooks: try: msg = hook(message) if msg: yield msg except: logger.warning(f"{hook.point()} {hook.name()} execute fail.")