Spaces:
Sleeping
Sleeping
# coding: utf-8 | |
# Copyright (c) 2025 inclusionAI. | |
import re | |
import time | |
import traceback | |
import json | |
from typing import Dict, Any, Optional, List, Union, Tuple | |
from dataclasses import dataclass, field | |
from langchain_core.messages import HumanMessage, BaseMessage, AIMessage, ToolMessage | |
from pydantic import ValidationError | |
from aworld.core.agent.base import AgentFactory, AgentResult | |
from aworld.agents.llm_agent import Agent | |
from examples.browsers.prompts import SystemPrompt | |
from examples.browsers.utils import convert_input_messages, extract_json_from_model_output, estimate_messages_tokens | |
from examples.browsers.common import AgentState, AgentStepInfo, AgentHistory, PolicyMetadata, AgentBrain | |
from aworld.config.conf import AgentConfig, ConfigDict | |
from aworld.core.common import Observation, ActionModel, ToolActionInfo, ActionResult | |
from aworld.logs.util import logger | |
from examples.browsers.prompts import AgentMessagePrompt | |
from examples.tools.tool_action import BrowserAction | |
class Trajectory: | |
"""A class to store agent history records, including all observations, info and AgentResult""" | |
history: List[tuple[List[BaseMessage], Observation, Dict[str, Any], AIMessage, AgentResult]] = field( | |
default_factory=list) | |
def add_step(self, input_messages: List[BaseMessage], observation: Observation, info: Dict[str, Any], | |
output_message: AIMessage, agent_result: AgentResult): | |
"""Add a step to the history""" | |
self.history.append((input_messages, observation, info, output_message, agent_result)) | |
def get_history(self) -> List[tuple[List[BaseMessage], Observation, Dict[str, Any], AIMessage, AgentResult]]: | |
"""Get the complete history""" | |
return self.history | |
def save_history(self, file_path: str): | |
his_li = [] | |
for input_messages, observation, info, output_message, agent_result in self.get_history(): | |
llm_input = [{"type": input_message.type, "content": input_message.content} for input_message in | |
input_messages] | |
llm_output = output_message.content | |
his_li.append({"llm_input": llm_input, "llm_output": llm_output}) | |
with open(file_path, 'w', encoding='utf-8') as f: | |
json.dump(his_li, f, ensure_ascii=False, indent=4) | |
class BrowserAgent(Agent): | |
def __init__(self, conf: Union[Dict[str, Any], ConfigDict, AgentConfig], **kwargs): | |
super(BrowserAgent, self).__init__(conf, **kwargs) | |
self.state = AgentState() | |
self.settings = self.conf | |
provider = self.conf.llm_config.llm_provider if self.conf.llm_config.llm_provider else self.conf.llm_provider | |
if self.conf.llm_config.llm_provider: | |
self.conf.llm_config.llm_provider = "chat" + provider | |
else: | |
self.conf.llm_provider = "chat" + provider | |
self.save_file_path = self.conf.save_file_path | |
self.available_actions = self._build_action_prompt() | |
# Note: Removed _message_manager initialization as it's no longer used | |
# Initialize trajectory | |
self.trajectory = Trajectory() | |
self._init = False | |
def reset(self, options: Dict[str, Any]): | |
super(BrowserAgent, self).reset(options) | |
# Reset trajectory | |
self.trajectory = Trajectory() | |
# Note: Removed _message_manager initialization as it's no longer used | |
# _estimate_tokens_for_messages method now directly uses functions from utils.py | |
self._init = True | |
def _build_action_prompt(self) -> str: | |
def _prompt(info: ToolActionInfo) -> str: | |
s = f'{info.desc}: \n' | |
s += '{' + str(info.name) + ': ' | |
if info.input_params: | |
s += str({k: {"title": k, "type": v.type} for k, v in info.input_params.items()}) | |
s += '}' | |
return s | |
val = "\n".join([_prompt(v.value) for k, v in BrowserAction.__members__.items()]) | |
return val | |
def _log_message_sequence(self, input_messages: List[BaseMessage]) -> None: | |
"""Log the sequence of messages for debugging purposes""" | |
logger.info(f"[agent] 🔍 Invoking LLM with {len(input_messages)} messages") | |
logger.info("[agent] 📝 Messages sequence:") | |
for i, msg in enumerate(input_messages): | |
prefix = msg.type | |
logger.info(f"[agent] Message {i + 1}: {prefix} ===================================") | |
if isinstance(msg.content, list): | |
for item in msg.content: | |
if item.get('type') == 'text': | |
logger.info(f"[agent] Text content: {item.get('text')}") | |
elif item.get('type') == 'image_url': | |
# Only print the first 30 characters of image URL to avoid printing entire base64 | |
image_url = item.get('image_url', {}).get('url', '') | |
if image_url.startswith('data:image'): | |
logger.info(f"[agent] Image: [Base64 image data]") | |
else: | |
logger.info(f"[agent] Image URL: {image_url[:30]}...") | |
else: | |
content = str(msg.content) | |
chunk_size = 500 | |
for j in range(0, len(content), chunk_size): | |
chunk = content[j:j + chunk_size] | |
if j == 0: | |
logger.info(f"[agent] Content: {chunk}") | |
else: | |
logger.info(f"[agent] Content (continued): {chunk}") | |
if isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls: | |
for tool_call in msg.tool_calls: | |
logger.info(f"[agent] Tool call: {tool_call.get('name')} - ID: {tool_call.get('id')}") | |
args = str(tool_call.get('args', {}))[:1000] | |
logger.info(f"[agent] Tool args: {args}...") | |
def save_process(self, file_path: str): | |
self.trajectory.save_history(file_path) | |
def policy(self, | |
observation: Observation, | |
info: Dict[str, Any] = None, **kwargs) -> Union[List[ActionModel], None]: | |
start_time = time.time() | |
if self._init is False: | |
self.reset({"task": observation.content}) | |
self._finished = False | |
# Save current observation to state for message construction | |
self.state.last_result = observation.action_result | |
if self.conf.max_steps <= self.state.n_steps: | |
logger.info('Last step finishing up') | |
logger.info(f'[agent] step {self.state.n_steps}') | |
# Use the new method to build messages, passing the current observation | |
input_messages = self.build_messages_from_trajectory_and_observation(observation=observation) | |
# Note: Special message addition has been moved to build_messages_from_trajectory_and_observation | |
# Estimate token count | |
tokens = self._estimate_tokens_for_messages(input_messages) | |
llm_result = None | |
output_message = None | |
try: | |
# Log the message sequence | |
self._log_message_sequence(input_messages) | |
output_message, llm_result = self._do_policy(input_messages) | |
if not llm_result: | |
logger.error("[agent] ❌ Failed to parse LLM response") | |
return [ActionModel(tool_name=Tools.BROWSER.value, action_name="stop")] | |
self.state.n_steps += 1 | |
# No longer need to remove the last state message | |
# self._message_manager._remove_last_state_message() | |
if self.state.stopped or self.state.paused: | |
logger.info('Browser gent paused after getting state') | |
return [ActionModel(tool_name=Tools.BROWSER.value, action_name="stop")] | |
tool_action = llm_result.actions | |
# Add the current step to the trajectory | |
self.trajectory.add_step(input_messages, observation, info, output_message, llm_result) | |
except Exception as e: | |
logger.warning(traceback.format_exc()) | |
# No longer need to remove the last state message | |
# self._message_manager._remove_last_state_message() | |
logger.error(f"[agent] ❌ Error parsing LLM response: {str(e)}") | |
# Create an AgentResult object with an empty actions list | |
error_result = AgentResult( | |
current_state=AgentBrain( | |
evaluation_previous_goal="Failed due to error", | |
memory=f"Error occurred: {str(e)}", | |
thought="Recover from error", | |
next_goal="Recover from error" | |
), | |
actions=[] # Empty actions list | |
) | |
# Add the error state to the trajectory | |
self.trajectory.add_step(input_messages, observation, info, output_message, error_result) | |
raise RuntimeError("Browser agent encountered exception while making the policy.", e) | |
finally: | |
if llm_result: | |
# Only keep the history_item creation part | |
metadata = PolicyMetadata( | |
number=self.state.n_steps, | |
start_time=start_time, | |
end_time=time.time(), | |
input_tokens=tokens, | |
) | |
self._make_history_item(llm_result, observation, observation.action_result, metadata) | |
else: | |
logger.warning("no result to record!") | |
return tool_action | |
def _do_policy(self, input_messages: list[BaseMessage]) -> Tuple[AIMessage, AgentResult]: | |
THINK_TAGS = re.compile(r'<think>.*?</think>', re.DOTALL) | |
def _remove_think_tags(text: str) -> str: | |
"""Remove think tags from text""" | |
return re.sub(THINK_TAGS, '', text) | |
input_messages = self._convert_input_messages(input_messages) | |
output_message = None | |
try: | |
output_message = self.llm.invoke(input_messages) | |
if not output_message or not output_message.content: | |
logger.warning("[agent] LLM returned empty response") | |
return output_message, AgentResult( | |
current_state=AgentBrain(evaluation_previous_goal="", memory="", thought="", next_goal=""), | |
actions=[ActionModel(agent_name=self.id(), tool_name='browser', action_name="stop")]) | |
except: | |
logger.error(f"[agent] Response content: {output_message}") | |
raise RuntimeError('call llm fail, please check llm conf and network.') | |
if self.model_name == 'deepseek-reasoner': | |
output_message.content = _remove_think_tags(output_message.content) | |
try: | |
# Get max retries from config | |
max_retries = self.settings.get('max_llm_json_retries', 3) | |
retry_count = 0 | |
json_parse_error = None | |
while retry_count < max_retries: | |
try: | |
parsed_json = extract_json_from_model_output(output_message.content) | |
# If parsing succeeds, break out of the retry loop | |
json_parse_error = None | |
break | |
except ValueError as e: | |
# Store the error and retry | |
json_parse_error = e | |
retry_count += 1 | |
logger.warning(f"[agent] Failed to parse JSON (attempt {retry_count}/{max_retries}): {str(e)}") | |
if retry_count < max_retries: | |
# Add a reminder message about JSON format with specific structure guidance | |
format_reminder = HumanMessage( | |
content="Your responses must be always JSON with the specified format. Make sure your response includes a 'current_state' object with 'evaluation_previous_goal', 'memory', and 'next_goal' fields, and an 'action' array with the actions to perform. Do not include any explanatory text, only return the raw JSON.") | |
retry_messages = input_messages.copy() | |
retry_messages.append(format_reminder) | |
# Retry with the updated messages | |
logger.info( | |
f"[agent] Retrying LLM invocation ({retry_count}/{max_retries}) with format reminder") | |
output_message = self.llm.invoke(retry_messages) | |
# Check for empty response during retry | |
if not output_message or not output_message.content: | |
logger.warning( | |
f"[agent] LLM returned empty response on retry attempt {retry_count}/{max_retries}") | |
# Continue to next retry instead of immediately returning | |
continue | |
if self.model_name == 'deepseek-reasoner': | |
output_message.content = _remove_think_tags(output_message.content) | |
# If all retries failed, raise the last error | |
if json_parse_error: | |
logger.error(f"[agent] ❌ All {max_retries} attempts to parse JSON failed") | |
raise json_parse_error | |
logger.info((f"llm response: {parsed_json}")) | |
try: | |
agent_brain = AgentBrain(**parsed_json['current_state']) | |
except: | |
agent_brain = None | |
actions = parsed_json.get('action') | |
result = [] | |
if not actions: | |
actions = parsed_json.get("actions") | |
if not actions: | |
logger.warning("agent not policy an action.") | |
self._finished = True | |
return output_message, AgentResult(current_state=agent_brain, | |
actions=[ActionModel(tool_name='browser', | |
agent_name=self.id(), | |
action_name="done")]) | |
for action in actions: | |
if "action_name" in action: | |
action_name = action['action_name'] | |
browser_action = BrowserAction.get_value_by_name(action_name) | |
if not browser_action: | |
logger.warning(f"Unsupported action: {action_name}") | |
if action_name == "done": | |
self._finished = True | |
action_model = ActionModel(agent_name=self.id(), | |
tool_name='browser', | |
action_name=action_name, | |
params=action.get('params', {})) | |
result.append(action_model) | |
else: | |
for k, v in action.items(): | |
browser_action = BrowserAction.get_value_by_name(k) | |
if not browser_action: | |
logger.warning(f"Unsupported action: {k}") | |
action_model = ActionModel(agent_name=self.id(), tool_name='browser', action_name=k, params=v) | |
result.append(action_model) | |
if k == "done": | |
self._finished = True | |
return output_message, AgentResult(current_state=agent_brain, actions=result) | |
except (ValueError, ValidationError) as e: | |
logger.warning(f'Failed to parse model output: {output_message} {str(e)}') | |
raise ValueError('Could not parse response.') | |
def _convert_input_messages(self, input_messages: list[BaseMessage]) -> list[BaseMessage]: | |
"""Convert input messages to the correct format""" | |
if self.model_name == 'deepseek-reasoner' or self.model_name.startswith('deepseek-r1'): | |
return convert_input_messages(input_messages, self.model_name) | |
else: | |
return input_messages | |
def _make_history_item(self, | |
model_output: AgentResult | None, | |
state: Observation, | |
result: list[ActionResult], | |
metadata: Optional[PolicyMetadata] = None) -> None: | |
content = "" | |
if hasattr(state, 'dom_tree') and state.dom_tree is not None: | |
if hasattr(state.dom_tree, 'element_tree'): | |
content = state.dom_tree.element_tree.__repr__() | |
else: | |
content = str(state.dom_tree) | |
history_item = AgentHistory(model_output=model_output, | |
result=state.action_result, | |
metadata=metadata, | |
content=content, | |
base64_img=state.image if hasattr(state, 'image') else None) | |
self.state.history.history.append(history_item) | |
def _process_action_result(self, action_result, messages, tool_call=None): | |
"""Helper method to process an action result and add appropriate messages""" | |
if action_result.content is not None: | |
messages.append(HumanMessage(content='Action result: ' + action_result.content)) | |
elif action_result.error is not None: | |
# Assemble error message when error information exists | |
messages.append(HumanMessage(content='Action result: ' + action_result.error)) | |
if tool_call is not None: | |
logger.warning(f"Action {tool_call} failed: {action_result.error}") | |
else: | |
logger.warning(f"Action failed: {action_result.error}") | |
# If there is an error but success is true, log the error and terminate the program as the result is invalid | |
if action_result.success is True: | |
error_msg = f"Invalid result: success=True but error message exists: {action_result.error}" | |
logger.error(error_msg) | |
raise ValueError(error_msg) | |
return action_result.error is not None | |
def build_messages_from_trajectory_and_observation(self, observation: Optional[Observation] = None) -> List[ | |
BaseMessage]: | |
""" | |
Build complete message history from trajectory and current observation | |
Args: | |
observation: Current observation object, if None current observation won't be added | |
""" | |
messages = [] | |
# Add system message | |
system_message = SystemPrompt( | |
max_actions_per_step=self.settings.get('max_actions_per_step') | |
).get_system_message() | |
if isinstance(system_message, tuple): | |
system_message = system_message[0] | |
messages.append(system_message) | |
tool_calling_method = self.settings.get("tool_calling_method") | |
llm_provider = self.conf.llm_provider if self.conf.llm_provider else self.conf.llm_config.llm_provider | |
if tool_calling_method == 'raw' or (tool_calling_method == 'auto' and ( | |
llm_provider == 'deepseek-reasoner' or llm_provider.startswith('deepseek-r1'))): | |
message_context = f'\n\nAvailable actions: {self.available_actions}' | |
else: | |
message_context = None | |
# Add task context (if any) | |
if message_context: | |
context_message = HumanMessage(content='Context for the task' + message_context) | |
messages.append(context_message) | |
# Add task message | |
task_message = HumanMessage( | |
content=f'Your ultimate task is: """{self.task}""". If you achieved your ultimate task, stop everything and use the done action in the next step to complete the task. If not, continue as usual.' | |
) | |
messages.append(task_message) | |
# Add example output | |
placeholder_message = HumanMessage(content='Example output:') | |
messages.append(placeholder_message) | |
# Add example tool call | |
tool_calls = [ | |
{ | |
'name': 'AgentOutput', | |
'args': { | |
'current_state': { | |
'evaluation_previous_goal': 'Success - I opend the first page', | |
'memory': 'Starting with the new task. I have completed 1/10 steps', | |
'thought': 'From the current page I can get information about all the companies.', | |
'next_goal': 'Click on company a', | |
}, | |
'action': [{'click_element': {'index': 0}}], | |
}, | |
'id': '1', | |
'type': 'tool_call', | |
} | |
] | |
example_tool_call = AIMessage( | |
content='', | |
tool_calls=tool_calls, | |
) | |
messages.append(example_tool_call) | |
# Add first tool message with "Browser started" content | |
messages.append(ToolMessage(content='Browser started', tool_call_id='1')) | |
# Add task history marker | |
messages.append(HumanMessage(content='[Your task history memory starts here]')) | |
# Add available file paths (if any) | |
if self.settings.get('available_file_paths'): | |
filepaths_msg = HumanMessage( | |
content=f'Here are file paths you can use: {self.settings.get("available_file_paths")}') | |
messages.append(filepaths_msg) | |
previous_action_entries = [] | |
# Add messages from the history trajectory | |
for input_msgs, obs, info, output_msg, llm_result in self.trajectory.get_history(): | |
# Check the previous step's actionResult | |
has_error = False | |
if obs.action_result is not None: | |
# The previous action entries should match with action results | |
if len(previous_action_entries) == 0: | |
# if previous_action_entries is empty,process action_result directly | |
logger.info( | |
f"History item with action_result count ({len(obs.action_result)}) with empty previous actions - skipping count check") | |
elif len(previous_action_entries) == len(obs.action_result): | |
for i, one_action_result in enumerate(obs.action_result): | |
has_error = self._process_action_result(one_action_result, messages, | |
previous_action_entries[i]) or has_error | |
else: | |
# If sizes don't match, this is a critical error | |
error_msg = f"Action results count ({len(obs.action_result)}) doesn't match action entries count ({len(previous_action_entries)})" | |
logger.error(error_msg) | |
has_error = True | |
# raise ValueError(error_msg) | |
# Add agent response | |
if llm_result: | |
# Create AI message | |
output_data = llm_result.model_dump(mode='json', exclude_unset=True) | |
action_entries = [{action.action_name: action.params} for action in llm_result.actions] | |
output_data["action"] = action_entries | |
if "actions" in output_data: | |
del output_data["actions"] | |
# Calculate tool_id based on trajectory history. If no actions yet, start with ID 1 | |
tool_id = 1 if len(self.trajectory.get_history()) == 0 else len(self.trajectory.get_history()) + 1 | |
tool_calls = [ | |
{ | |
'name': 'AgentOutput', | |
'args': output_data, | |
'id': str(tool_id), | |
'type': 'tool_call', | |
} | |
] | |
previous_action_entries = action_entries | |
ai_message = AIMessage( | |
content='', | |
tool_calls=tool_calls, | |
) | |
messages.append(ai_message) | |
# Add empty tool message after each AIMessage | |
messages.append(ToolMessage(content='', tool_call_id=str(tool_id))) | |
# Add current observation - using the passed observation parameter instead of self.state.current_observation | |
if observation: | |
# Check if the current observation has an action_result with error | |
has_error = False | |
if hasattr(observation, 'action_result') and observation.action_result is not None: | |
# Match action results with previous actions | |
if len(previous_action_entries) == 0: | |
# if previous_action_entries is empty,process action_result directly | |
logger.info( | |
f"Current observation with action_result count ({len(observation.action_result)}) with empty previous actions - skipping count check") | |
elif len(previous_action_entries) == len(observation.action_result): | |
for i, one_action_result in enumerate(observation.action_result): | |
has_error = self._process_action_result(one_action_result, messages, | |
previous_action_entries[i]) or has_error | |
else: | |
# If sizes don't match, this is a critical error | |
error_msg = f"Action results count ({len(observation.action_result)}) doesn't match action entries count ({len(previous_action_entries)})" | |
logger.error(error_msg) | |
has_error = True | |
# If there's an error, append observation content outside the loop | |
if has_error and observation.content: | |
messages.append(HumanMessage(content=observation.content)) | |
# If no error, process the observation normally | |
elif not has_error: | |
step_info = AgentStepInfo(number=self.state.n_steps, max_steps=self.conf.max_steps) | |
if hasattr(observation, 'dom_tree') and observation.dom_tree: | |
state_message = AgentMessagePrompt( | |
observation, | |
self.state.last_result, | |
include_attributes=self.settings.get('include_attributes'), | |
step_info=step_info, | |
).get_user_message(self.settings.get('use_vision')) | |
messages.append(state_message) | |
elif observation.content: | |
messages.append(HumanMessage(content=observation.content)) | |
# Add special message for the last step | |
# Note: Moved here from policy method to centralize all message building logic | |
if self.conf.max_steps <= self.state.n_steps: | |
last_step_message = f""" | |
Now comes your last step. Use only the "done" action now. No other actions - so here your action sequence must have length 1. | |
\nIf the task is not yet fully finished as requested by the user, set success in "done" to false! E.g. if not all steps are fully completed. | |
\nIf the task is fully finished, set success in "done" to true. | |
\nInclude everything you found out for the ultimate task in the done text. | |
""" | |
messages.append(HumanMessage(content=[{'type': 'text', 'text': last_step_message}])) | |
return messages | |
def _estimate_tokens_for_messages(self, messages: List[BaseMessage]) -> int: | |
"""Roughly estimate token count for message list""" | |
# Note: Using estimate_messages_tokens function from utils.py instead of calling _message_manager | |
# This decouples the dependency on MessageManager | |
return estimate_messages_tokens( | |
messages, | |
image_tokens=self.settings.get('image_tokens', 800), | |
estimated_characters_per_token=self.settings.get('estimated_characters_per_token', 3) | |
) | |