# coding: utf-8 import json import traceback import uuid from dataclasses import dataclass, field from pathlib import Path from typing import Any, Optional, Dict, List from langchain_core.load import dumpd, load from langchain_core.messages import BaseMessage, AIMessage, ToolMessage, SystemMessage, HumanMessage from openai import RateLimitError from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator from aworld.core.agent.base import AgentResult from aworld.core.common import ActionResult, Observation class MessageMetadata(BaseModel): """Metadata for a message""" tokens: int = 0 class ManagedMessage(BaseModel): """A message with its metadata""" message: BaseMessage metadata: MessageMetadata = Field(default_factory=MessageMetadata) model_config = ConfigDict(arbitrary_types_allowed=True) # https://github.com/pydantic/pydantic/discussions/7558 @model_serializer(mode='wrap') def to_json(self, original_dump): """ Returns the JSON representation of the model. It uses langchain's `dumps` function to serialize the `message` property before encoding the overall dict with json.dumps. """ data = original_dump(self) # NOTE: We override the message field to use langchain JSON serialization. data['message'] = dumpd(self.message) return data @model_validator(mode='before') @classmethod def validate( cls, value: Any, *, strict: bool | None = None, from_attributes: bool | None = None, context: Any | None = None, ) -> Any: """ Custom validator that uses langchain's `loads` function to parse the message if it is provided as a JSON string. """ if isinstance(value, dict) and 'message' in value: # NOTE: We use langchain's load to convert the JSON string back into a BaseMessage object. value['message'] = load(value['message']) return value class MessageHistory(BaseModel): """History of messages with metadata""" messages: list[ManagedMessage] = Field(default_factory=list) current_tokens: int = 0 model_config = ConfigDict(arbitrary_types_allowed=True) def add_message(self, message: BaseMessage, metadata: MessageMetadata, position: int | None = None) -> None: """Add message with metadata to history""" if position is None: self.messages.append(ManagedMessage(message=message, metadata=metadata)) else: self.messages.insert(position, ManagedMessage(message=message, metadata=metadata)) self.current_tokens += metadata.tokens def add_model_output(self, output) -> None: """Add model output as AI message""" tool_calls = [ { 'name': 'AgentOutput', 'args': output.model_dump(mode='json', exclude_unset=True), 'id': '1', 'type': 'tool_call', } ] msg = AIMessage( content='', tool_calls=tool_calls, ) self.add_message(msg, MessageMetadata(tokens=100)) # Estimate tokens for tool calls # Empty tool response tool_message = ToolMessage(content='', tool_call_id='1') self.add_message(tool_message, MessageMetadata(tokens=10)) # Estimate tokens for empty response def get_messages(self) -> list[BaseMessage]: """Get all messages""" return [m.message for m in self.messages] def get_total_tokens(self) -> int: """Get total tokens in history""" return self.current_tokens def remove_oldest_message(self) -> None: """Remove oldest non-system message""" for i, msg in enumerate(self.messages): if not isinstance(msg.message, SystemMessage): self.current_tokens -= msg.metadata.tokens self.messages.pop(i) break def remove_last_state_message(self) -> None: """Remove last state message from history""" if len(self.messages) > 2 and isinstance(self.messages[-1].message, HumanMessage): self.current_tokens -= self.messages[-1].metadata.tokens self.messages.pop() class MessageManagerState(BaseModel): """Holds the state for MessageManager""" history: MessageHistory = Field(default_factory=MessageHistory) tool_id: int = 1 model_config = ConfigDict(arbitrary_types_allowed=True) class AgentSettings(BaseModel): """Options for the agent""" max_failures: int = 3 retry_delay: int = 10 save_history: bool = True history_path: Optional[str] = None max_actions_per_step: int = 10 validate_output: bool = False message_context: Optional[str] = None class PolicyMetadata(BaseModel): """Metadata for a single step including timing information""" start_time: float end_time: float number: int input_tokens: int @property def duration_seconds(self) -> float: """Calculate step duration in seconds""" return self.end_time - self.start_time class AgentBrain(BaseModel): """Current state of the agent""" evaluation_previous_goal: str memory: str next_goal: str class AgentHistory(BaseModel): """History item for agent actions""" model_output: Optional[BaseModel] = None result: List[ActionResult] metadata: Optional[PolicyMetadata] = None content: Optional[str] = None base64_img: Optional[str] = None model_config = ConfigDict(arbitrary_types_allowed=True) def model_dump(self, **kwargs) -> Dict[str, Any]: """Custom serialization handling""" return { 'model_output': self.model_output.model_dump() if self.model_output else None, 'result': [r.model_dump(exclude_none=True) for r in self.result], 'metadata': self.metadata.model_dump() if self.metadata else None, 'content': self.xml_content, 'base64_img': self.base64_img } class AgentHistoryList(BaseModel): """List of agent history items""" history: List[AgentHistory] def total_duration_seconds(self) -> float: """Get total duration of all steps in seconds""" total = 0.0 for h in self.history: if h.metadata: total += h.metadata.duration_seconds return total def save_to_file(self, filepath: str | Path) -> None: """Save history to JSON file with proper serialization""" try: Path(filepath).parent.mkdir(parents=True, exist_ok=True) data = self.model_dump() with open(filepath, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2) except Exception as e: raise e def model_dump(self, **kwargs) -> Dict[str, Any]: """Custom serialization that properly uses AgentHistory's model_dump""" return { 'history': [h.model_dump(**kwargs) for h in self.history], } @classmethod def load_from_file(cls, filepath: str | Path) -> 'AgentHistoryList': """Load history from JSON file""" with open(filepath, 'r', encoding='utf-8') as f: data = json.load(f) return cls.model_validate(data) class AgentError: """Container for agent error handling""" VALIDATION_ERROR = 'Invalid model output format. Please follow the correct schema.' RATE_LIMIT_ERROR = 'Rate limit reached. Waiting before retry.' NO_VALID_ACTION = 'No valid action found' @staticmethod def format_error(error: Exception, include_trace: bool = False) -> str: """Format error message based on error type and optionally include trace""" if isinstance(error, RateLimitError): return AgentError.RATE_LIMIT_ERROR if include_trace: return f'{str(error)}\nStacktrace:\n{traceback.format_exc()}' return f'{str(error)}' class AgentState(BaseModel): """Holds all state information for an Agent""" agent_id: str = Field(default_factory=lambda: str(uuid.uuid4())) n_steps: int = 1 consecutive_failures: int = 0 last_result: Optional[List['ActionResult']] = None history: AgentHistoryList = Field(default_factory=lambda: AgentHistoryList(history=[])) last_plan: Optional[str] = None paused: bool = False stopped: bool = False message_manager_state: MessageManagerState = Field(default_factory=MessageManagerState) @dataclass class AgentStepInfo: number: int max_steps: int def is_last_step(self) -> bool: """Check if this is the last step""" return self.number >= self.max_steps - 1 @dataclass class Trajectory: """Stores the agent's history, including all observations, info, and AgentResults.""" history: List[tuple[Observation, Dict[str, Any], AgentResult]] = field(default_factory=list) def add_step(self, observation: Observation, info: Dict[str, Any], agent_result: AgentResult): """Add a step to the history""" self.history.append((observation, info, agent_result)) def get_history(self) -> List[tuple[Observation, Dict[str, Any], AgentResult]]: """Retrieve the complete history""" return self.history