File size: 9,295 Bytes
9afc9c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
# coding: utf-8
# Copyright (c) 2025 inclusionAI.

import json
import time
import traceback
from typing import Dict, Any, Optional, List, Union

from langchain_core.messages import HumanMessage, BaseMessage, SystemMessage

from examples.android.prompts import SYSTEM_PROMPT, LAST_STEP_PROMPT
from examples.android.utils import (
    AgentState,
    AgentHistory,
    AgentHistoryList,
    ActionResult,
    PolicyMetadata,
    AgentBrain,
    Trajectory
)
from examples.browsers.common import AgentStepInfo
from aworld.config.conf import AgentConfig, ConfigDict
from aworld.core.agent.base import AgentResult
from aworld.agents.llm_agent import Agent
from aworld.core.common import Observation, ActionModel, ToolActionInfo
from aworld.logs.util import logger
from examples.tools.tool_action import AndroidAction


class AndroidAgent(Agent):
    def __init__(self, conf: Union[Dict[str, Any], ConfigDict, AgentConfig], **kwargs):
        super(AndroidAgent, self).__init__(conf, **kwargs)
        provider = self.conf.llm_config.llm_provider if self.conf.llm_config.llm_provider else self.conf.llm_provider
        if self.conf.llm_config.llm_provider:
            self.conf.llm_config.llm_provider = "chat" + provider
        else:
            self.conf.llm_provider = "chat" + provider
        self.available_actions_desc = self._build_action_prompt()
        # Settings
        self.settings = self.conf

    def reset(self, options: Dict[str, Any]):
        super(AndroidAgent, self).__init__(options)
        # State
        self.state = AgentState()
        # History
        self.history = AgentHistoryList(history=[])
        self.trajectory = Trajectory(history=[])

    def _build_action_prompt(self) -> str:
        def _prompt(info: ToolActionInfo) -> str:
            s = f'{info.desc}:\n'
            s += '{' + str(info.name) + ': '
            if info.input_params:
                s += str({k: {"title": k, "type": v} for k, v in info.input_params.items()})
            s += '}'
            return s

        # Iterate over all android actions
        val = "\n".join([_prompt(v.value) for k, v in AndroidAction.__members__.items()])
        return val

    def policy(self,
               observation: Observation,
               info: Dict[str, Any] = None,
               **kwargs) -> Union[List[ActionModel], None]:
        self._finished = False
        step_info = AgentStepInfo(number=self.state.n_steps, max_steps=self.conf.max_steps)
        last_step_msg = None
        if step_info and step_info.is_last_step():
            # Add last step warning if needed
            last_step_msg = HumanMessage(
                content=LAST_STEP_PROMPT)
            logger.info('Last step finishing up')

        logger.info(f'[agent] 📍 Step {self.state.n_steps}')
        step_start_time = time.time()

        try:

            xml_content, base64_img = observation.dom_tree, observation.image

            if xml_content is None:
                logger.error("[agent] ⚠ Failed to get UI state, stopping task")
                self.stop()
                return None

            self.state.last_result = (xml_content, base64_img if base64_img else "")

            logger.info("[agent] 🤖 Analyzing current state with LLM...")
            a_step_msg = HumanMessage(content=[
                {
                    "type": "text",
                    "text": f"""
                        Task: {self.task}
                        Current Step: {self.state.n_steps}
                        
                        Please analyze the current interface and decide the next action. Please directly return the response in JSON format without any other text or code block markers.
                    """
                },
                {
                    "type": "image_url",
                    "image_url": f"data:image/jpeg;base64,{self.state.image}"
                }
            ])

            messages = [SystemMessage(content=SYSTEM_PROMPT)]
            if last_step_msg:
                messages.append(last_step_msg)
            messages.append(a_step_msg)

            logger.info(f"[agent] VLM Input last message: {messages[-1]}")
            llm_result = None
            try:
                llm_result = self._do_policy(messages)

                if self.state.stopped or self.state.paused:
                    logger.info('Android agent paused after getting state')
                    return [ActionModel(tool_name='android', action_name="stop")]

                tool_action = llm_result.actions

                step_metadata = PolicyMetadata(
                    start_time=step_start_time,
                    end_time=time.time(),
                    number=self.state.n_steps,
                    input_tokens=1
                )

                history_item = AgentHistory(
                    result=[ActionResult(success=True)],
                    metadata=step_metadata,
                    content=xml_content,
                    base64_img=base64_img
                )
                self.history.history.append(history_item)

                if self.settings.save_history and self.settings.history_path:
                    self.history.save_to_file(self.settings.history_path)

                logger.info(f'📍 Step {self.state.n_steps} starts to execute')

                self.state.n_steps += 1
                self.state.consecutive_failures = 0
                return tool_action

            except Exception as e:
                logger.warning(traceback.format_exc())
                raise RuntimeError("Android agent encountered exception while making the policy.", e)
            finally:
                if llm_result:
                    self.trajectory.add_step(observation, info, llm_result)
                    metadata = PolicyMetadata(
                        number=self.state.n_steps,
                        start_time=step_start_time,
                        end_time=time.time(),
                        input_tokens=1
                    )
                    self._make_history_item(llm_result, observation, metadata)
                else:
                    logger.warning("no result to record!")

        except json.JSONDecodeError as e:
            logger.error("[agent] ❌ JSON parsing error")
            raise
        except Exception as e:
            logger.error(f"[agent] ❌ Action execution error: {str(e)}")
            raise

    def _do_policy(self, input_messages: list[BaseMessage]) -> AgentResult:
        response = self.llm.invoke(input_messages)
        content = response.content

        if content.startswith("```json"):
            content = content[7:]
        if content.startswith("```"):
            content = content[3:]
        if content.endswith("```"):
            content = content[:-3]
        content = content.strip()

        action_data = json.loads(content)
        brain_state = AgentBrain(**action_data["current_state"])

        logger.info(f"[agent] ⚠ Eval: {brain_state.evaluation_previous_goal}")
        logger.info(f"[agent] 🧠 Memory: {brain_state.memory}")
        logger.info(f"[agent] 🎯 Next goal: {brain_state.next_goal}")

        actions = action_data.get('action')
        result = []
        if not actions:
            actions = action_data.get("actions")

        # print actions
        logger.info(f"[agent] VLM Output actions: {actions}")
        for action in actions:
            action_type = action.get('type')
            if not action_type:
                logger.warning(f"Action missing type: {action}")
                continue

            params = {}
            if 'type' == action_type:
                action_type = 'input_text'
            if 'params' in action:
                params = action['params']
            if 'index' in action:
                params['index'] = action['index']
            if 'type' in action:
                params['type'] = action['type']
            if 'text' in action:
                params['text'] = action['text']

            action_model = ActionModel(
                tool_name='android',
                action_name=action_type,
                params=params
            )
            result.append(action_model)

        return AgentResult(current_state=brain_state, actions=result)

    def _make_history_item(self,
                           model_output: AgentResult | None,
                           state: Observation,
                           metadata: Optional[PolicyMetadata] = None) -> None:
        if isinstance(state, dict):
            state = Observation(**state)

        history_item = AgentHistory(
            model_output=model_output,
            result=state.action_result,
            metadata=metadata,
            content=state.dom_tree,
            base64_img=state.image
        )
        self.state.history.history.append(history_item)

    def pause(self) -> None:
        """Pause the agent"""
        logger.info('🔄 Pausing Agent')
        self.state.paused = True

    def resume(self) -> None:
        """Resume the agent"""
        logger.info('▶️ Agent resuming')
        self.state.paused = False

    def stop(self) -> None:
        """Stop the agent"""
        logger.info('⏹️ Agent stopping')
        self.state.stopped = True