File size: 8,582 Bytes
fb49ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import traceback
from aworld.agents.llm_agent import Agent
from aworld.config.conf import AgentConfig, ConfigDict
from aworld.core.common import Observation, ActionModel
from typing import Dict, Any, List, Union, Callable
from aworld.core.tool.base import ToolFactory
from aworld.models.llm import call_llm_model, acall_llm_model
from aworld.utils.common import sync_exec
from aworld.logs.util import logger
from examples.tools.common import Tools
from examples.tools.tool_action import GetTraceAction
from aworld.core.agent.swarm import Swarm
from aworld.runner import Runners
from aworld.trace.server import get_trace_server
from aworld.runners.state_manager import RuntimeStateManager, RunNode
import aworld.trace as trace


trace.configure()


class TraceAgent(Agent):

    def __init__(self,
                 conf: Union[Dict[str, Any], ConfigDict, AgentConfig],
                 resp_parse_func: Callable[..., Any] = None,
                 **kwargs):
        super().__init__(conf, **kwargs)

    def policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]:
        """use trace tool to get trace data, and call llm to summary

        Args:
            observation: The state observed from tools in the environment.
            info: Extended information is used to assist the agent to decide a policy.

        Returns:
            ActionModel sequence from agent policy
        """

        self._finished = False
        self.desc_transform()

        tool_name = "trace"
        tool = ToolFactory(tool_name, asyn=False)
        tool.reset()
        tool_params = {}
        action = ActionModel(tool_name=tool_name,
                             action_name=GetTraceAction.GET_TRACE.name,
                             agent_name=self.id(),
                             params=tool_params)
        message = tool.step(action)

        observation, _, _, _, _ = message.payload

        llm_response = None

        messages = self.messages_transform(content=observation.content,
                                           sys_prompt=self.system_prompt,
                                           agent_prompt=self.agent_prompt)
        try:
            llm_response = call_llm_model(
                self.llm,
                messages=messages,
                model=self.model_name,
                temperature=self.conf.llm_config.llm_temperature
            )

            logger.info(f"Execute response: {llm_response.message}")
        except Exception as e:
            logger.warn(traceback.format_exc())
            raise e
        finally:
            if llm_response:
                if llm_response.error:
                    logger.info(
                        f"{self.id()} llm result error: {llm_response.error}")
            else:
                logger.error(f"{self.id()} failed to get LLM response")
                raise RuntimeError(
                    f"{self.id()} failed to get LLM response")

        agent_result = sync_exec(self.resp_parse_func, llm_response)
        if not agent_result.is_call_tool:
            self._finished = True
        return agent_result.actions

    async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]:

        self._finished = False
        self.desc_transform()

        tool_name = "trace"
        tool = ToolFactory(tool_name, asyn=False)
        tool.reset()
        tool_params = {}
        action = ActionModel(tool_name=tool_name,
                             action_name=GetTraceAction.GET_TRACE.name,
                             agent_name=self.id(),
                             params=tool_params)
        message = tool.step([action])

        observation, _, _, _, _ = message.payload

        llm_response = None

        messages = self.messages_transform(content=observation.content,
                                           sys_prompt=self.system_prompt,
                                           agent_prompt=self.agent_prompt)
        try:
            llm_response = await acall_llm_model(
                self.llm,
                messages=messages,
                model=self.model_name,
                temperature=self.conf.llm_config.llm_temperature
            )

            logger.info(f"Execute response: {llm_response.message}")
        except Exception as e:
            logger.warn(traceback.format_exc())
            raise e
        finally:
            if llm_response:
                if llm_response.error:
                    logger.info(
                        f"{self.id()} llm result error: {llm_response.error}")
            else:
                logger.error(f"{self.id()} failed to get LLM response")
                raise RuntimeError(
                    f"{self.id()} failed to get LLM response")

        agent_result = sync_exec(self.resp_parse_func, llm_response)
        if not agent_result.is_call_tool:
            self._finished = True
        return agent_result.actions


search_sys_prompt = "You are a helpful search agent."
search_prompt = """
    Please act as a search agent, constructing appropriate keywords and searach terms, using search toolkit to collect relevant information, including urls, webpage snapshots, etc.

    Here are the question: {task}

    pleas only use one action complete this task, at least results 6 pages.
    """

summary_sys_prompt = "You are a helpful general summary agent."

summary_prompt = """
Summarize the following text in one clear and concise paragraph, capturing the key ideas without missing critical points. 
Ensure the summary is easy to understand and avoids excessive detail.

Here are the content: 
{task}
"""

trace_sys_prompt = "You are a helpful trace agent."

trace_prompt = """
    Please act as a trace agent, Using the provided trace data, summarize the token usage of each agent,
    whether the runotype attribute of span is an agent or a large model call: 
        run_type=AGNET represents the agent, 
        run_type=LLM represents the large model call.
    The LLM call of a certain agent is represented as LLM span, which is a child span of that agent span

    Here are the content: {task}
    """


def build_run_flow(nodes: List[RunNode]):
    graph = {}
    start_nodes = []

    for node in nodes:
        if hasattr(node, 'parent_node_id') and node.parent_node_id:
            if node.parent_node_id not in graph:
                graph[node.parent_node_id] = []
            graph[node.parent_node_id].append(node.node_id)
        else:
            start_nodes.append(node.node_id)

    for start in start_nodes:
        print("-----------------------------------")
        _print_tree(graph, start, "", True)
        print("-----------------------------------")


def _print_tree(graph, node_id, prefix, is_last):
    print(prefix + ("└── " if is_last else "├── ") + node_id)
    if node_id in graph:
        children = graph[node_id]
        for i, child in enumerate(children):
            _print_tree(graph, child, prefix +
                        ("    " if is_last else "│   "), i == len(children)-1)


if __name__ == "__main__":
    agent_config = AgentConfig(
        llm_provider="openai",
        llm_model_name="gpt-4o",
        llm_temperature=0.3,

        llm_base_url="http://localhost:34567",
        llm_api_key="dummy-key",
    )

    search = Agent(
        conf=agent_config,
        name="search_agent",
        system_prompt=search_sys_prompt,
        agent_prompt=search_prompt,
        tool_names=[Tools.SEARCH_API.value]
    )

    summary = Agent(
        conf=agent_config,
        name="summary_agent",
        system_prompt=summary_sys_prompt,
        agent_prompt=summary_prompt
    )

    trace = TraceAgent(
        conf=agent_config,
        name="trace_agent",
        system_prompt=trace_sys_prompt,
        agent_prompt=trace_prompt
    )

    # default is sequence swarm mode
    swarm = Swarm(search, summary, trace, max_steps=1, event_driven=True)

    prefix = "search baidu:"
    # can special search google, wiki, duck go, or baidu. such as:
    # prefix = "search wiki: "
    try:
        res = Runners.sync_run(
            input=prefix + """What is an agent.""",
            swarm=swarm,
            session_id="123"
        )
        print(res.answer)
    except Exception as e:
        logger.error(traceback.format_exc())

    state_manager = RuntimeStateManager.instance()
    nodes = state_manager.get_nodes("123")
    logger.info(f"session 123 nodes: {nodes}")
    build_run_flow(nodes)
    get_trace_server().join()