Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# @Time : 2025/1/2 | |
# @Author : wenshao | |
# @ProjectName: browser-use-webui | |
# @FileName: custom_agent.py | |
import asyncio | |
import base64 | |
import io | |
import json | |
import logging | |
import os | |
import pdb | |
import textwrap | |
import time | |
import uuid | |
from io import BytesIO | |
from pathlib import Path | |
from typing import Any, Optional, Type, TypeVar | |
from dotenv import load_dotenv | |
from langchain_core.language_models.chat_models import BaseChatModel | |
from langchain_core.messages import ( | |
BaseMessage, | |
SystemMessage, | |
) | |
from openai import RateLimitError | |
from PIL import Image, ImageDraw, ImageFont | |
from pydantic import BaseModel, ValidationError | |
from browser_use.agent.message_manager.service import MessageManager | |
from browser_use.agent.prompts import AgentMessagePrompt, SystemPrompt | |
from browser_use.agent.service import Agent | |
from browser_use.agent.views import ( | |
ActionResult, | |
AgentError, | |
AgentHistory, | |
AgentHistoryList, | |
AgentOutput, | |
AgentStepInfo, | |
) | |
from browser_use.browser.browser import Browser | |
from browser_use.browser.context import BrowserContext | |
from browser_use.browser.views import BrowserState, BrowserStateHistory | |
from browser_use.controller.registry.views import ActionModel | |
from browser_use.controller.service import Controller | |
from browser_use.dom.history_tree_processor.service import ( | |
DOMHistoryElement, | |
HistoryTreeProcessor, | |
) | |
from browser_use.telemetry.service import ProductTelemetry | |
from browser_use.telemetry.views import ( | |
AgentEndTelemetryEvent, | |
AgentRunTelemetryEvent, | |
AgentStepErrorTelemetryEvent, | |
) | |
from browser_use.utils import time_execution_async | |
from .custom_views import CustomAgentOutput, CustomAgentStepInfo | |
from .custom_massage_manager import CustomMassageManager | |
logger = logging.getLogger(__name__) | |
class CustomAgent(Agent): | |
def __init__( | |
self, | |
task: str, | |
llm: BaseChatModel, | |
add_infos: str = '', | |
browser: Browser | None = None, | |
browser_context: BrowserContext | None = None, | |
controller: Controller = Controller(), | |
use_vision: bool = True, | |
save_conversation_path: Optional[str] = None, | |
max_failures: int = 5, | |
retry_delay: int = 10, | |
system_prompt_class: Type[SystemPrompt] = SystemPrompt, | |
max_input_tokens: int = 128000, | |
validate_output: bool = False, | |
include_attributes: list[str] = [ | |
'title', | |
'type', | |
'name', | |
'role', | |
'tabindex', | |
'aria-label', | |
'placeholder', | |
'value', | |
'alt', | |
'aria-expanded', | |
], | |
max_error_length: int = 400, | |
max_actions_per_step: int = 10, | |
): | |
super().__init__(task, llm, browser, browser_context, controller, use_vision, save_conversation_path, | |
max_failures, retry_delay, system_prompt_class, max_input_tokens, validate_output, | |
include_attributes, max_error_length, max_actions_per_step) | |
self.add_infos = add_infos | |
self.message_manager = CustomMassageManager( | |
llm=self.llm, | |
task=self.task, | |
action_descriptions=self.controller.registry.get_prompt_description(), | |
system_prompt_class=self.system_prompt_class, | |
max_input_tokens=self.max_input_tokens, | |
include_attributes=self.include_attributes, | |
max_error_length=self.max_error_length, | |
max_actions_per_step=self.max_actions_per_step, | |
) | |
def _setup_action_models(self) -> None: | |
"""Setup dynamic action models from controller's registry""" | |
# Get the dynamic action model from controller's registry | |
self.ActionModel = self.controller.registry.create_action_model() | |
# Create output model with the dynamic actions | |
self.AgentOutput = CustomAgentOutput.type_with_custom_actions(self.ActionModel) | |
def _log_response(self, response: CustomAgentOutput) -> None: | |
"""Log the model's response""" | |
if 'Success' in response.current_state.prev_action_evaluation: | |
emoji = '✅' | |
elif 'Failed' in response.current_state.prev_action_evaluation: | |
emoji = '❌' | |
else: | |
emoji = '🤷' | |
logger.info(f'{emoji} Eval: {response.current_state.prev_action_evaluation}') | |
logger.info(f'🧠 New Memory: {response.current_state.important_contents}') | |
logger.info(f'⏳ Task Progress: {response.current_state.completed_contents}') | |
logger.info(f'🤔 Thought: {response.current_state.thought}') | |
logger.info(f'🎯 Summary: {response.current_state.summary}') | |
for i, action in enumerate(response.action): | |
logger.info( | |
f'🛠️ Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}' | |
) | |
def update_step_info(self, model_output: CustomAgentOutput, step_info: CustomAgentStepInfo = None): | |
""" | |
update step info | |
""" | |
if step_info is None: | |
return | |
step_info.step_number += 1 | |
important_contents = model_output.current_state.important_contents | |
if important_contents and 'None' not in important_contents and important_contents not in step_info.memory: | |
step_info.memory += important_contents + '\n' | |
completed_contents = model_output.current_state.completed_contents | |
if completed_contents and 'None' not in completed_contents: | |
step_info.task_progress = completed_contents | |
async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput: | |
"""Get next action from LLM based on current state""" | |
ret = self.llm.invoke(input_messages) | |
parsed_json = json.loads(ret.content.replace('```json', '').replace("```", "")) | |
parsed: AgentOutput = self.AgentOutput(**parsed_json) | |
# cut the number of actions to max_actions_per_step | |
parsed.action = parsed.action[: self.max_actions_per_step] | |
self._log_response(parsed) | |
self.n_steps += 1 | |
return parsed | |
async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None: | |
"""Execute one step of the task""" | |
logger.info(f'\n📍 Step {self.n_steps}') | |
state = None | |
model_output = None | |
result: list[ActionResult] = [] | |
try: | |
state = await self.browser_context.get_state(use_vision=self.use_vision) | |
self.message_manager.add_state_message(state, self._last_result, step_info) | |
input_messages = self.message_manager.get_messages() | |
model_output = await self.get_next_action(input_messages) | |
self.update_step_info(model_output, step_info) | |
logger.info(f'🧠 All Memory: {step_info.memory}') | |
self._save_conversation(input_messages, model_output) | |
self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history | |
self.message_manager.add_model_output(model_output) | |
result: list[ActionResult] = await self.controller.multi_act( | |
model_output.action, self.browser_context | |
) | |
self._last_result = result | |
if len(result) > 0 and result[-1].is_done: | |
logger.info(f'📄 Result: {result[-1].extracted_content}') | |
self.consecutive_failures = 0 | |
except Exception as e: | |
result = self._handle_step_error(e) | |
self._last_result = result | |
finally: | |
if not result: | |
return | |
for r in result: | |
if r.error: | |
self.telemetry.capture( | |
AgentStepErrorTelemetryEvent( | |
agent_id=self.agent_id, | |
error=r.error, | |
) | |
) | |
if state: | |
self._make_history_item(model_output, state, result) | |
async def run(self, max_steps: int = 100) -> AgentHistoryList: | |
"""Execute the task with maximum number of steps""" | |
try: | |
logger.info(f'🚀 Starting task: {self.task}') | |
self.telemetry.capture( | |
AgentRunTelemetryEvent( | |
agent_id=self.agent_id, | |
task=self.task, | |
) | |
) | |
step_info = CustomAgentStepInfo(task=self.task, | |
add_infos=self.add_infos, | |
step_number=1, | |
max_steps=max_steps, | |
memory='', | |
task_progress='' | |
) | |
for step in range(max_steps): | |
if self._too_many_failures(): | |
break | |
await self.step(step_info) | |
if self.history.is_done(): | |
if ( | |
self.validate_output and step < max_steps - 1 | |
): # if last step, we dont need to validate | |
if not await self._validate_output(): | |
continue | |
logger.info('✅ Task completed successfully') | |
break | |
else: | |
logger.info('❌ Failed to complete task in maximum steps') | |
return self.history | |
finally: | |
self.telemetry.capture( | |
AgentEndTelemetryEvent( | |
agent_id=self.agent_id, | |
task=self.task, | |
success=self.history.is_done(), | |
steps=len(self.history.history), | |
) | |
) | |
if not self.injected_browser_context: | |
await self.browser_context.close() | |
if not self.injected_browser and self.browser: | |
await self.browser.close() | |