Spaces:
Sleeping
Sleeping
File size: 5,930 Bytes
9c9d7c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import aworld.trace as trace
import aworld.trace.instrumentation.semconv as semconv
from aworld.trace.server import get_trace_server
from aworld.trace.server.util import build_trace_tree
from aworld.core.tool.base import Tool, AgentInput, ToolFactory
from examples.tools.tool_action import GetTraceAction
from aworld.tools.utils import build_observation
from aworld.config.conf import ToolConfig
from aworld.core.common import Observation, ActionModel, ActionResult
from typing import Tuple, Dict, Any, List
from aworld.logs.util import logger
@ToolFactory.register(name="trace",
desc="Get the trace of the current execution.",
supported_action=GetTraceAction,
conf_file_name=f'trace_tool.yaml')
class TraceTool(Tool):
def __init__(self,
conf: ToolConfig,
**kwargs) -> None:
"""
Initialize the TraceTool
Args:
conf: tool config
**kwargs: -
Return:
None
"""
super(TraceTool, self).__init__(conf, **kwargs)
self.type = "function"
self.get_trace_url = self.conf.get('get_trace_url')
def reset(self,
*,
seed: int | None = None,
options: Dict[str, str] | None = None) -> Tuple[AgentInput, dict[str, Any]]:
"""
Reset the executor
Args:
seed: -
options: -
Returns:
AgentInput, dict[str, Any]: -
"""
self._finished = False
return build_observation(observer=self.name(),
ability=GetTraceAction.GET_TRACE.value.name), {}
def close(self) -> None:
"""
Close the executor
Returns:
None
"""
self._finished = True
def do_step(self,
actions: List[ActionModel],
**kwargs) -> Tuple[Observation, float, bool, bool, dict[str, Any]]:
reward = 0
fail_error = ""
observation = build_observation(observer=self.name(),
ability=GetTraceAction.GET_TRACE.value.name)
results = []
try:
if not actions:
return (observation, reward,
kwargs.get("terminated",
False), kwargs.get("truncated", False), {
"exception": "actions is empty"
})
for action in actions:
trace_id = action.params.get("trace_id", "")
if not trace_id:
current_span = trace.get_current_span()
if current_span:
trace_id = current_span.get_trace_id()
if not trace_id:
logger.warning(f"{action} no trace_id to fetch.")
observation.action_result.append(
ActionResult(is_done=True,
success=False,
content="",
error="no trace_id to fetch",
keep=False))
continue
try:
trace_data = self.fetch_trace_data(trace_id)
logger.info(f"trace_data={trace_data}")
error = ""
except Exception as e:
error = str(e)
results.append(trace_data)
observation.action_result.append(
ActionResult(is_done=True,
success=False if error else True,
content=f"{trace_data}",
error=f"{error}",
keep=False))
observation.content = f"{results}"
reward = 1
except Exception as e:
fail_error = str(e)
finally:
self._finished = True
info = {"exception": fail_error}
info.update(kwargs)
return (observation, reward, kwargs.get("terminated", False),
kwargs.get("truncated", False), info)
def fetch_trace_data(self, trace_id=None):
'''
fetch trace data from trace server.
return trace data, like:
{
'trace_id': trace_id,
'root_span': [],
}
'''
try:
if trace_id:
trace_server = get_trace_server()
if not trace_server:
logger.error("No memory trace server has been set.")
else:
trace_storage = trace_server.get_storage()
spans = trace_storage.get_all_spans(trace_id)
if spans:
return self.proccess_trace(build_trace_tree(spans))
return {"trace_id": trace_id, "root_span": []}
except Exception as e:
logger.error(f"Error fetching trace data: {e}")
return {"trace_id": trace_id, "root_span": []}
def proccess_trace(self, trace_data):
root_spans = trace_data.get("root_span")
for span in root_spans:
self.choose_attribute(span)
return trace_data
def choose_attribute(self, span):
include_attr = [semconv.GEN_AI_USAGE_INPUT_TOKENS,
semconv.GEN_AI_USAGE_OUTPUT_TOKENS, semconv.GEN_AI_USAGE_TOTAL_TOKENS]
result_attributes = {}
origin_attributes = span.get("attributes") or {}
for key, value in origin_attributes.items():
if key in include_attr:
result_attributes[key] = value
span["attributes"] = result_attributes
if span.get("children"):
for child in span.get("children"):
self.choose_attribute(child)
|