Spaces:
Sleeping
Sleeping
import traceback | |
from aworld.agents.llm_agent import Agent | |
from aworld.config.conf import AgentConfig, ConfigDict | |
from aworld.core.common import Observation, ActionModel | |
from typing import Dict, Any, List, Union, Callable | |
from aworld.core.tool.base import ToolFactory | |
from aworld.models.llm import call_llm_model, acall_llm_model | |
from aworld.utils.common import sync_exec | |
from aworld.logs.util import logger | |
from examples.tools.common import Tools | |
from examples.tools.tool_action import GetTraceAction | |
from aworld.core.agent.swarm import Swarm | |
from aworld.runner import Runners | |
from aworld.trace.server import get_trace_server | |
from aworld.runners.state_manager import RuntimeStateManager, RunNode | |
import aworld.trace as trace | |
trace.configure() | |
class TraceAgent(Agent): | |
def __init__(self, | |
conf: Union[Dict[str, Any], ConfigDict, AgentConfig], | |
resp_parse_func: Callable[..., Any] = None, | |
**kwargs): | |
super().__init__(conf, **kwargs) | |
def policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]: | |
"""use trace tool to get trace data, and call llm to summary | |
Args: | |
observation: The state observed from tools in the environment. | |
info: Extended information is used to assist the agent to decide a policy. | |
Returns: | |
ActionModel sequence from agent policy | |
""" | |
self._finished = False | |
self.desc_transform() | |
tool_name = "trace" | |
tool = ToolFactory(tool_name, asyn=False) | |
tool.reset() | |
tool_params = {} | |
action = ActionModel(tool_name=tool_name, | |
action_name=GetTraceAction.GET_TRACE.name, | |
agent_name=self.id(), | |
params=tool_params) | |
message = tool.step(action) | |
observation, _, _, _, _ = message.payload | |
llm_response = None | |
messages = self.messages_transform(content=observation.content, | |
sys_prompt=self.system_prompt, | |
agent_prompt=self.agent_prompt) | |
try: | |
llm_response = call_llm_model( | |
self.llm, | |
messages=messages, | |
model=self.model_name, | |
temperature=self.conf.llm_config.llm_temperature | |
) | |
logger.info(f"Execute response: {llm_response.message}") | |
except Exception as e: | |
logger.warn(traceback.format_exc()) | |
raise e | |
finally: | |
if llm_response: | |
if llm_response.error: | |
logger.info( | |
f"{self.id()} llm result error: {llm_response.error}") | |
else: | |
logger.error(f"{self.id()} failed to get LLM response") | |
raise RuntimeError( | |
f"{self.id()} failed to get LLM response") | |
agent_result = sync_exec(self.resp_parse_func, llm_response) | |
if not agent_result.is_call_tool: | |
self._finished = True | |
return agent_result.actions | |
async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]: | |
self._finished = False | |
self.desc_transform() | |
tool_name = "trace" | |
tool = ToolFactory(tool_name, asyn=False) | |
tool.reset() | |
tool_params = {} | |
action = ActionModel(tool_name=tool_name, | |
action_name=GetTraceAction.GET_TRACE.name, | |
agent_name=self.id(), | |
params=tool_params) | |
message = tool.step([action]) | |
observation, _, _, _, _ = message.payload | |
llm_response = None | |
messages = self.messages_transform(content=observation.content, | |
sys_prompt=self.system_prompt, | |
agent_prompt=self.agent_prompt) | |
try: | |
llm_response = await acall_llm_model( | |
self.llm, | |
messages=messages, | |
model=self.model_name, | |
temperature=self.conf.llm_config.llm_temperature | |
) | |
logger.info(f"Execute response: {llm_response.message}") | |
except Exception as e: | |
logger.warn(traceback.format_exc()) | |
raise e | |
finally: | |
if llm_response: | |
if llm_response.error: | |
logger.info( | |
f"{self.id()} llm result error: {llm_response.error}") | |
else: | |
logger.error(f"{self.id()} failed to get LLM response") | |
raise RuntimeError( | |
f"{self.id()} failed to get LLM response") | |
agent_result = sync_exec(self.resp_parse_func, llm_response) | |
if not agent_result.is_call_tool: | |
self._finished = True | |
return agent_result.actions | |
search_sys_prompt = "You are a helpful search agent." | |
search_prompt = """ | |
Please act as a search agent, constructing appropriate keywords and searach terms, using search toolkit to collect relevant information, including urls, webpage snapshots, etc. | |
Here are the question: {task} | |
pleas only use one action complete this task, at least results 6 pages. | |
""" | |
summary_sys_prompt = "You are a helpful general summary agent." | |
summary_prompt = """ | |
Summarize the following text in one clear and concise paragraph, capturing the key ideas without missing critical points. | |
Ensure the summary is easy to understand and avoids excessive detail. | |
Here are the content: | |
{task} | |
""" | |
trace_sys_prompt = "You are a helpful trace agent." | |
trace_prompt = """ | |
Please act as a trace agent, Using the provided trace data, summarize the token usage of each agent, | |
whether the runotype attribute of span is an agent or a large model call: | |
run_type=AGNET represents the agent, | |
run_type=LLM represents the large model call. | |
The LLM call of a certain agent is represented as LLM span, which is a child span of that agent span | |
Here are the content: {task} | |
""" | |
def build_run_flow(nodes: List[RunNode]): | |
graph = {} | |
start_nodes = [] | |
for node in nodes: | |
if hasattr(node, 'parent_node_id') and node.parent_node_id: | |
if node.parent_node_id not in graph: | |
graph[node.parent_node_id] = [] | |
graph[node.parent_node_id].append(node.node_id) | |
else: | |
start_nodes.append(node.node_id) | |
for start in start_nodes: | |
print("-----------------------------------") | |
_print_tree(graph, start, "", True) | |
print("-----------------------------------") | |
def _print_tree(graph, node_id, prefix, is_last): | |
print(prefix + ("└── " if is_last else "├── ") + node_id) | |
if node_id in graph: | |
children = graph[node_id] | |
for i, child in enumerate(children): | |
_print_tree(graph, child, prefix + | |
(" " if is_last else "│ "), i == len(children)-1) | |
if __name__ == "__main__": | |
agent_config = AgentConfig( | |
llm_provider="openai", | |
llm_model_name="gpt-4o", | |
llm_temperature=0.3, | |
llm_base_url="http://localhost:34567", | |
llm_api_key="dummy-key", | |
) | |
search = Agent( | |
conf=agent_config, | |
name="search_agent", | |
system_prompt=search_sys_prompt, | |
agent_prompt=search_prompt, | |
tool_names=[Tools.SEARCH_API.value] | |
) | |
summary = Agent( | |
conf=agent_config, | |
name="summary_agent", | |
system_prompt=summary_sys_prompt, | |
agent_prompt=summary_prompt | |
) | |
trace = TraceAgent( | |
conf=agent_config, | |
name="trace_agent", | |
system_prompt=trace_sys_prompt, | |
agent_prompt=trace_prompt | |
) | |
# default is sequence swarm mode | |
swarm = Swarm(search, summary, trace, max_steps=1, event_driven=True) | |
prefix = "search baidu:" | |
# can special search google, wiki, duck go, or baidu. such as: | |
# prefix = "search wiki: " | |
try: | |
res = Runners.sync_run( | |
input=prefix + """What is an agent.""", | |
swarm=swarm, | |
session_id="123" | |
) | |
print(res.answer) | |
except Exception as e: | |
logger.error(traceback.format_exc()) | |
state_manager = RuntimeStateManager.instance() | |
nodes = state_manager.get_nodes("123") | |
logger.info(f"session 123 nodes: {nodes}") | |
build_run_flow(nodes) | |
get_trace_server().join() | |