Spaces:
Sleeping
Sleeping
import aworld.trace as trace | |
import aworld.trace.instrumentation.semconv as semconv | |
from aworld.trace.server import get_trace_server | |
from aworld.trace.server.util import build_trace_tree | |
from aworld.core.tool.base import Tool, AgentInput, ToolFactory | |
from examples.tools.tool_action import GetTraceAction | |
from aworld.tools.utils import build_observation | |
from aworld.config.conf import ToolConfig | |
from aworld.core.common import Observation, ActionModel, ActionResult | |
from typing import Tuple, Dict, Any, List | |
from aworld.logs.util import logger | |
class TraceTool(Tool): | |
def __init__(self, | |
conf: ToolConfig, | |
**kwargs) -> None: | |
""" | |
Initialize the TraceTool | |
Args: | |
conf: tool config | |
**kwargs: - | |
Return: | |
None | |
""" | |
super(TraceTool, self).__init__(conf, **kwargs) | |
self.type = "function" | |
self.get_trace_url = self.conf.get('get_trace_url') | |
def reset(self, | |
*, | |
seed: int | None = None, | |
options: Dict[str, str] | None = None) -> Tuple[AgentInput, dict[str, Any]]: | |
""" | |
Reset the executor | |
Args: | |
seed: - | |
options: - | |
Returns: | |
AgentInput, dict[str, Any]: - | |
""" | |
self._finished = False | |
return build_observation(observer=self.name(), | |
ability=GetTraceAction.GET_TRACE.value.name), {} | |
def close(self) -> None: | |
""" | |
Close the executor | |
Returns: | |
None | |
""" | |
self._finished = True | |
def do_step(self, | |
actions: List[ActionModel], | |
**kwargs) -> Tuple[Observation, float, bool, bool, dict[str, Any]]: | |
reward = 0 | |
fail_error = "" | |
observation = build_observation(observer=self.name(), | |
ability=GetTraceAction.GET_TRACE.value.name) | |
results = [] | |
try: | |
if not actions: | |
return (observation, reward, | |
kwargs.get("terminated", | |
False), kwargs.get("truncated", False), { | |
"exception": "actions is empty" | |
}) | |
for action in actions: | |
trace_id = action.params.get("trace_id", "") | |
if not trace_id: | |
current_span = trace.get_current_span() | |
if current_span: | |
trace_id = current_span.get_trace_id() | |
if not trace_id: | |
logger.warning(f"{action} no trace_id to fetch.") | |
observation.action_result.append( | |
ActionResult(is_done=True, | |
success=False, | |
content="", | |
error="no trace_id to fetch", | |
keep=False)) | |
continue | |
try: | |
trace_data = self.fetch_trace_data(trace_id) | |
logger.info(f"trace_data={trace_data}") | |
error = "" | |
except Exception as e: | |
error = str(e) | |
results.append(trace_data) | |
observation.action_result.append( | |
ActionResult(is_done=True, | |
success=False if error else True, | |
content=f"{trace_data}", | |
error=f"{error}", | |
keep=False)) | |
observation.content = f"{results}" | |
reward = 1 | |
except Exception as e: | |
fail_error = str(e) | |
finally: | |
self._finished = True | |
info = {"exception": fail_error} | |
info.update(kwargs) | |
return (observation, reward, kwargs.get("terminated", False), | |
kwargs.get("truncated", False), info) | |
def fetch_trace_data(self, trace_id=None): | |
''' | |
fetch trace data from trace server. | |
return trace data, like: | |
{ | |
'trace_id': trace_id, | |
'root_span': [], | |
} | |
''' | |
try: | |
if trace_id: | |
trace_server = get_trace_server() | |
if not trace_server: | |
logger.error("No memory trace server has been set.") | |
else: | |
trace_storage = trace_server.get_storage() | |
spans = trace_storage.get_all_spans(trace_id) | |
if spans: | |
return self.proccess_trace(build_trace_tree(spans)) | |
return {"trace_id": trace_id, "root_span": []} | |
except Exception as e: | |
logger.error(f"Error fetching trace data: {e}") | |
return {"trace_id": trace_id, "root_span": []} | |
def proccess_trace(self, trace_data): | |
root_spans = trace_data.get("root_span") | |
for span in root_spans: | |
self.choose_attribute(span) | |
return trace_data | |
def choose_attribute(self, span): | |
include_attr = [semconv.GEN_AI_USAGE_INPUT_TOKENS, | |
semconv.GEN_AI_USAGE_OUTPUT_TOKENS, semconv.GEN_AI_USAGE_TOTAL_TOKENS] | |
result_attributes = {} | |
origin_attributes = span.get("attributes") or {} | |
for key, value in origin_attributes.items(): | |
if key in include_attr: | |
result_attributes[key] = value | |
span["attributes"] = result_attributes | |
if span.get("children"): | |
for child in span.get("children"): | |
self.choose_attribute(child) | |