File size: 2,643 Bytes
bc5e560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# aworld/runners/handler/output.py
import json
from typing import AsyncGenerator
from aworld.core.task import TaskResponse
from aworld.models.model_response import ModelResponse
from aworld.runners.handler.base import DefaultHandler
from aworld.output.base import StepOutput, MessageOutput, ToolResultOutput, Output
from aworld.core.common import TaskItem
from aworld.core.context.base import Context
from aworld.core.event.base import Message, Constants, TopicType
from aworld.logs.util import logger


class DefaultOutputHandler(DefaultHandler):
    def __init__(self, runner):
        self.runner = runner

    async def handle(self, message):
        if message.category != Constants.OUTPUT:
            return
        # 1. get outputs
        outputs = self.runner.task.outputs
        if not outputs:
            yield Message(
                category=Constants.TASK,
                payload=TaskItem(msg="Cannot get outputs.", data=message, stop=True),
                sender=self.name(),
                session_id=Context.instance().session_id,
                topic=TopicType.ERROR,
                headers={"context": message.context}
            )
            return
        # 2. build Output
        payload = message.payload
        mark_complete = False
        output = None
        try:
            if isinstance(payload, Output):
                output = payload
            elif isinstance(payload, TaskResponse):
                logger.info(f"output get task_response with usage: {json.dumps(payload.usage)}")
                if message.topic == TopicType.FINISHED or message.topic == TopicType.ERROR:
                    mark_complete = True
            elif isinstance(payload, ModelResponse) or isinstance(payload, AsyncGenerator):
                output = MessageOutput(source=payload)
        except Exception as e:
            logger.warning(f"Failed to parse output: {e}")
            yield Message(
                category=Constants.TASK,
                payload=TaskItem(msg="Failed to parse output.", data=payload, stop=True),
                sender=self.name(),
                session_id=Context.instance().session_id,
                topic=TopicType.ERROR,
                headers={"context": message.context}
            )
        finally:
            if output:
                if not output.metadata:
                    output.metadata = {}
                output.metadata['sender'] = message.sender
                output.metadata['receiver'] = message.receiver
                await outputs.add_output(output)
            if mark_complete:
                await outputs.mark_completed()

        return