Duibonduil's picture
Upload 9 files
fb49ac2 verified
import traceback
from aworld.agents.llm_agent import Agent
from aworld.config.conf import AgentConfig, ConfigDict
from aworld.core.common import Observation, ActionModel
from typing import Dict, Any, List, Union, Callable
from aworld.core.tool.base import ToolFactory
from aworld.models.llm import call_llm_model, acall_llm_model
from aworld.utils.common import sync_exec
from aworld.logs.util import logger
from examples.tools.common import Tools
from examples.tools.tool_action import GetTraceAction
from aworld.core.agent.swarm import Swarm
from aworld.runner import Runners
from aworld.trace.server import get_trace_server
from aworld.runners.state_manager import RuntimeStateManager, RunNode
import aworld.trace as trace
trace.configure()
class TraceAgent(Agent):
def __init__(self,
conf: Union[Dict[str, Any], ConfigDict, AgentConfig],
resp_parse_func: Callable[..., Any] = None,
**kwargs):
super().__init__(conf, **kwargs)
def policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]:
"""use trace tool to get trace data, and call llm to summary
Args:
observation: The state observed from tools in the environment.
info: Extended information is used to assist the agent to decide a policy.
Returns:
ActionModel sequence from agent policy
"""
self._finished = False
self.desc_transform()
tool_name = "trace"
tool = ToolFactory(tool_name, asyn=False)
tool.reset()
tool_params = {}
action = ActionModel(tool_name=tool_name,
action_name=GetTraceAction.GET_TRACE.name,
agent_name=self.id(),
params=tool_params)
message = tool.step(action)
observation, _, _, _, _ = message.payload
llm_response = None
messages = self.messages_transform(content=observation.content,
sys_prompt=self.system_prompt,
agent_prompt=self.agent_prompt)
try:
llm_response = call_llm_model(
self.llm,
messages=messages,
model=self.model_name,
temperature=self.conf.llm_config.llm_temperature
)
logger.info(f"Execute response: {llm_response.message}")
except Exception as e:
logger.warn(traceback.format_exc())
raise e
finally:
if llm_response:
if llm_response.error:
logger.info(
f"{self.id()} llm result error: {llm_response.error}")
else:
logger.error(f"{self.id()} failed to get LLM response")
raise RuntimeError(
f"{self.id()} failed to get LLM response")
agent_result = sync_exec(self.resp_parse_func, llm_response)
if not agent_result.is_call_tool:
self._finished = True
return agent_result.actions
async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]:
self._finished = False
self.desc_transform()
tool_name = "trace"
tool = ToolFactory(tool_name, asyn=False)
tool.reset()
tool_params = {}
action = ActionModel(tool_name=tool_name,
action_name=GetTraceAction.GET_TRACE.name,
agent_name=self.id(),
params=tool_params)
message = tool.step([action])
observation, _, _, _, _ = message.payload
llm_response = None
messages = self.messages_transform(content=observation.content,
sys_prompt=self.system_prompt,
agent_prompt=self.agent_prompt)
try:
llm_response = await acall_llm_model(
self.llm,
messages=messages,
model=self.model_name,
temperature=self.conf.llm_config.llm_temperature
)
logger.info(f"Execute response: {llm_response.message}")
except Exception as e:
logger.warn(traceback.format_exc())
raise e
finally:
if llm_response:
if llm_response.error:
logger.info(
f"{self.id()} llm result error: {llm_response.error}")
else:
logger.error(f"{self.id()} failed to get LLM response")
raise RuntimeError(
f"{self.id()} failed to get LLM response")
agent_result = sync_exec(self.resp_parse_func, llm_response)
if not agent_result.is_call_tool:
self._finished = True
return agent_result.actions
search_sys_prompt = "You are a helpful search agent."
search_prompt = """
Please act as a search agent, constructing appropriate keywords and searach terms, using search toolkit to collect relevant information, including urls, webpage snapshots, etc.
Here are the question: {task}
pleas only use one action complete this task, at least results 6 pages.
"""
summary_sys_prompt = "You are a helpful general summary agent."
summary_prompt = """
Summarize the following text in one clear and concise paragraph, capturing the key ideas without missing critical points.
Ensure the summary is easy to understand and avoids excessive detail.
Here are the content:
{task}
"""
trace_sys_prompt = "You are a helpful trace agent."
trace_prompt = """
Please act as a trace agent, Using the provided trace data, summarize the token usage of each agent,
whether the runotype attribute of span is an agent or a large model call:
run_type=AGNET represents the agent,
run_type=LLM represents the large model call.
The LLM call of a certain agent is represented as LLM span, which is a child span of that agent span
Here are the content: {task}
"""
def build_run_flow(nodes: List[RunNode]):
graph = {}
start_nodes = []
for node in nodes:
if hasattr(node, 'parent_node_id') and node.parent_node_id:
if node.parent_node_id not in graph:
graph[node.parent_node_id] = []
graph[node.parent_node_id].append(node.node_id)
else:
start_nodes.append(node.node_id)
for start in start_nodes:
print("-----------------------------------")
_print_tree(graph, start, "", True)
print("-----------------------------------")
def _print_tree(graph, node_id, prefix, is_last):
print(prefix + ("└── " if is_last else "├── ") + node_id)
if node_id in graph:
children = graph[node_id]
for i, child in enumerate(children):
_print_tree(graph, child, prefix +
(" " if is_last else "│ "), i == len(children)-1)
if __name__ == "__main__":
agent_config = AgentConfig(
llm_provider="openai",
llm_model_name="gpt-4o",
llm_temperature=0.3,
llm_base_url="http://localhost:34567",
llm_api_key="dummy-key",
)
search = Agent(
conf=agent_config,
name="search_agent",
system_prompt=search_sys_prompt,
agent_prompt=search_prompt,
tool_names=[Tools.SEARCH_API.value]
)
summary = Agent(
conf=agent_config,
name="summary_agent",
system_prompt=summary_sys_prompt,
agent_prompt=summary_prompt
)
trace = TraceAgent(
conf=agent_config,
name="trace_agent",
system_prompt=trace_sys_prompt,
agent_prompt=trace_prompt
)
# default is sequence swarm mode
swarm = Swarm(search, summary, trace, max_steps=1, event_driven=True)
prefix = "search baidu:"
# can special search google, wiki, duck go, or baidu. such as:
# prefix = "search wiki: "
try:
res = Runners.sync_run(
input=prefix + """What is an agent.""",
swarm=swarm,
session_id="123"
)
print(res.answer)
except Exception as e:
logger.error(traceback.format_exc())
state_manager = RuntimeStateManager.instance()
nodes = state_manager.get_nodes("123")
logger.info(f"session 123 nodes: {nodes}")
build_run_flow(nodes)
get_trace_server().join()