Spaces:
Sleeping
Sleeping
File size: 8,582 Bytes
fb49ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
import traceback
from aworld.agents.llm_agent import Agent
from aworld.config.conf import AgentConfig, ConfigDict
from aworld.core.common import Observation, ActionModel
from typing import Dict, Any, List, Union, Callable
from aworld.core.tool.base import ToolFactory
from aworld.models.llm import call_llm_model, acall_llm_model
from aworld.utils.common import sync_exec
from aworld.logs.util import logger
from examples.tools.common import Tools
from examples.tools.tool_action import GetTraceAction
from aworld.core.agent.swarm import Swarm
from aworld.runner import Runners
from aworld.trace.server import get_trace_server
from aworld.runners.state_manager import RuntimeStateManager, RunNode
import aworld.trace as trace
trace.configure()
class TraceAgent(Agent):
def __init__(self,
conf: Union[Dict[str, Any], ConfigDict, AgentConfig],
resp_parse_func: Callable[..., Any] = None,
**kwargs):
super().__init__(conf, **kwargs)
def policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]:
"""use trace tool to get trace data, and call llm to summary
Args:
observation: The state observed from tools in the environment.
info: Extended information is used to assist the agent to decide a policy.
Returns:
ActionModel sequence from agent policy
"""
self._finished = False
self.desc_transform()
tool_name = "trace"
tool = ToolFactory(tool_name, asyn=False)
tool.reset()
tool_params = {}
action = ActionModel(tool_name=tool_name,
action_name=GetTraceAction.GET_TRACE.name,
agent_name=self.id(),
params=tool_params)
message = tool.step(action)
observation, _, _, _, _ = message.payload
llm_response = None
messages = self.messages_transform(content=observation.content,
sys_prompt=self.system_prompt,
agent_prompt=self.agent_prompt)
try:
llm_response = call_llm_model(
self.llm,
messages=messages,
model=self.model_name,
temperature=self.conf.llm_config.llm_temperature
)
logger.info(f"Execute response: {llm_response.message}")
except Exception as e:
logger.warn(traceback.format_exc())
raise e
finally:
if llm_response:
if llm_response.error:
logger.info(
f"{self.id()} llm result error: {llm_response.error}")
else:
logger.error(f"{self.id()} failed to get LLM response")
raise RuntimeError(
f"{self.id()} failed to get LLM response")
agent_result = sync_exec(self.resp_parse_func, llm_response)
if not agent_result.is_call_tool:
self._finished = True
return agent_result.actions
async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]:
self._finished = False
self.desc_transform()
tool_name = "trace"
tool = ToolFactory(tool_name, asyn=False)
tool.reset()
tool_params = {}
action = ActionModel(tool_name=tool_name,
action_name=GetTraceAction.GET_TRACE.name,
agent_name=self.id(),
params=tool_params)
message = tool.step([action])
observation, _, _, _, _ = message.payload
llm_response = None
messages = self.messages_transform(content=observation.content,
sys_prompt=self.system_prompt,
agent_prompt=self.agent_prompt)
try:
llm_response = await acall_llm_model(
self.llm,
messages=messages,
model=self.model_name,
temperature=self.conf.llm_config.llm_temperature
)
logger.info(f"Execute response: {llm_response.message}")
except Exception as e:
logger.warn(traceback.format_exc())
raise e
finally:
if llm_response:
if llm_response.error:
logger.info(
f"{self.id()} llm result error: {llm_response.error}")
else:
logger.error(f"{self.id()} failed to get LLM response")
raise RuntimeError(
f"{self.id()} failed to get LLM response")
agent_result = sync_exec(self.resp_parse_func, llm_response)
if not agent_result.is_call_tool:
self._finished = True
return agent_result.actions
search_sys_prompt = "You are a helpful search agent."
search_prompt = """
Please act as a search agent, constructing appropriate keywords and searach terms, using search toolkit to collect relevant information, including urls, webpage snapshots, etc.
Here are the question: {task}
pleas only use one action complete this task, at least results 6 pages.
"""
summary_sys_prompt = "You are a helpful general summary agent."
summary_prompt = """
Summarize the following text in one clear and concise paragraph, capturing the key ideas without missing critical points.
Ensure the summary is easy to understand and avoids excessive detail.
Here are the content:
{task}
"""
trace_sys_prompt = "You are a helpful trace agent."
trace_prompt = """
Please act as a trace agent, Using the provided trace data, summarize the token usage of each agent,
whether the runotype attribute of span is an agent or a large model call:
run_type=AGNET represents the agent,
run_type=LLM represents the large model call.
The LLM call of a certain agent is represented as LLM span, which is a child span of that agent span
Here are the content: {task}
"""
def build_run_flow(nodes: List[RunNode]):
graph = {}
start_nodes = []
for node in nodes:
if hasattr(node, 'parent_node_id') and node.parent_node_id:
if node.parent_node_id not in graph:
graph[node.parent_node_id] = []
graph[node.parent_node_id].append(node.node_id)
else:
start_nodes.append(node.node_id)
for start in start_nodes:
print("-----------------------------------")
_print_tree(graph, start, "", True)
print("-----------------------------------")
def _print_tree(graph, node_id, prefix, is_last):
print(prefix + ("└── " if is_last else "├── ") + node_id)
if node_id in graph:
children = graph[node_id]
for i, child in enumerate(children):
_print_tree(graph, child, prefix +
(" " if is_last else "│ "), i == len(children)-1)
if __name__ == "__main__":
agent_config = AgentConfig(
llm_provider="openai",
llm_model_name="gpt-4o",
llm_temperature=0.3,
llm_base_url="http://localhost:34567",
llm_api_key="dummy-key",
)
search = Agent(
conf=agent_config,
name="search_agent",
system_prompt=search_sys_prompt,
agent_prompt=search_prompt,
tool_names=[Tools.SEARCH_API.value]
)
summary = Agent(
conf=agent_config,
name="summary_agent",
system_prompt=summary_sys_prompt,
agent_prompt=summary_prompt
)
trace = TraceAgent(
conf=agent_config,
name="trace_agent",
system_prompt=trace_sys_prompt,
agent_prompt=trace_prompt
)
# default is sequence swarm mode
swarm = Swarm(search, summary, trace, max_steps=1, event_driven=True)
prefix = "search baidu:"
# can special search google, wiki, duck go, or baidu. such as:
# prefix = "search wiki: "
try:
res = Runners.sync_run(
input=prefix + """What is an agent.""",
swarm=swarm,
session_id="123"
)
print(res.answer)
except Exception as e:
logger.error(traceback.format_exc())
state_manager = RuntimeStateManager.instance()
nodes = state_manager.get_nodes("123")
logger.info(f"session 123 nodes: {nodes}")
build_run_flow(nodes)
get_trace_server().join()
|