File size: 5,930 Bytes
9c9d7c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import aworld.trace as trace
import aworld.trace.instrumentation.semconv as semconv
from aworld.trace.server import get_trace_server
from aworld.trace.server.util import build_trace_tree
from aworld.core.tool.base import Tool, AgentInput, ToolFactory
from examples.tools.tool_action import GetTraceAction
from aworld.tools.utils import build_observation
from aworld.config.conf import ToolConfig
from aworld.core.common import Observation, ActionModel, ActionResult
from typing import Tuple, Dict, Any, List
from aworld.logs.util import logger


@ToolFactory.register(name="trace",
                      desc="Get the trace of the current execution.",
                      supported_action=GetTraceAction,
                      conf_file_name=f'trace_tool.yaml')
class TraceTool(Tool):
    def __init__(self,
                 conf: ToolConfig,
                 **kwargs) -> None:
        """
        Initialize the TraceTool
        Args:
            conf: tool config
            **kwargs: -
        Return:
            None
        """
        super(TraceTool, self).__init__(conf, **kwargs)
        self.type = "function"
        self.get_trace_url = self.conf.get('get_trace_url')

    def reset(self,
              *,
              seed: int | None = None,
              options: Dict[str, str] | None = None) -> Tuple[AgentInput, dict[str, Any]]:
        """
        Reset the executor
        Args:
            seed: -
            options: -
        Returns:
            AgentInput, dict[str, Any]: -
        """
        self._finished = False
        return build_observation(observer=self.name(),
                                 ability=GetTraceAction.GET_TRACE.value.name), {}

    def close(self) -> None:
        """
        Close the executor
        Returns:
            None
        """
        self._finished = True

    def do_step(self,
                actions: List[ActionModel],
                **kwargs) -> Tuple[Observation, float, bool, bool, dict[str, Any]]:
        reward = 0
        fail_error = ""
        observation = build_observation(observer=self.name(),
                                        ability=GetTraceAction.GET_TRACE.value.name)
        results = []
        try:
            if not actions:
                return (observation, reward,
                        kwargs.get("terminated",
                                   False), kwargs.get("truncated", False), {
                            "exception": "actions is empty"
                        })
            for action in actions:
                trace_id = action.params.get("trace_id", "")
                if not trace_id:
                    current_span = trace.get_current_span()
                    if current_span:
                        trace_id = current_span.get_trace_id()
                if not trace_id:
                    logger.warning(f"{action} no trace_id to fetch.")
                    observation.action_result.append(
                        ActionResult(is_done=True,
                                     success=False,
                                     content="",
                                     error="no trace_id to fetch",
                                     keep=False))
                    continue
                try:
                    trace_data = self.fetch_trace_data(trace_id)
                    logger.info(f"trace_data={trace_data}")
                    error = ""
                except Exception as e:
                    error = str(e)
                results.append(trace_data)
                observation.action_result.append(
                    ActionResult(is_done=True,
                                 success=False if error else True,
                                 content=f"{trace_data}",
                                 error=f"{error}",
                                 keep=False))

            observation.content = f"{results}"
            reward = 1
        except Exception as e:
            fail_error = str(e)
        finally:
            self._finished = True

        info = {"exception": fail_error}
        info.update(kwargs)
        return (observation, reward, kwargs.get("terminated", False),
                kwargs.get("truncated", False), info)

    def fetch_trace_data(self, trace_id=None):
        '''
            fetch trace data from trace server.
            return trace data, like:
            {
                'trace_id': trace_id,
                'root_span': [],
            }
        '''
        try:
            if trace_id:
                trace_server = get_trace_server()
                if not trace_server:
                    logger.error("No memory trace server has been set.")
                else:
                    trace_storage = trace_server.get_storage()
                    spans = trace_storage.get_all_spans(trace_id)
                    if spans:
                        return self.proccess_trace(build_trace_tree(spans))
            return {"trace_id": trace_id, "root_span": []}
        except Exception as e:
            logger.error(f"Error fetching trace data: {e}")
            return {"trace_id": trace_id, "root_span": []}

    def proccess_trace(self, trace_data):
        root_spans = trace_data.get("root_span")
        for span in root_spans:
            self.choose_attribute(span)
        return trace_data

    def choose_attribute(self, span):
        include_attr = [semconv.GEN_AI_USAGE_INPUT_TOKENS,
                        semconv.GEN_AI_USAGE_OUTPUT_TOKENS, semconv.GEN_AI_USAGE_TOTAL_TOKENS]
        result_attributes = {}
        origin_attributes = span.get("attributes") or {}
        for key, value in origin_attributes.items():
            if key in include_attr:
                result_attributes[key] = value
        span["attributes"] = result_attributes
        if span.get("children"):
            for child in span.get("children"):
                self.choose_attribute(child)