Spaces:
Sleeping
Sleeping
File size: 9,341 Bytes
9afc9c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 |
# coding: utf-8
import json
import traceback
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional, Dict, List
from langchain_core.load import dumpd, load
from langchain_core.messages import BaseMessage, AIMessage, ToolMessage, SystemMessage, HumanMessage
from openai import RateLimitError
from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator
from aworld.core.agent.base import AgentResult
from aworld.core.common import ActionResult, Observation
class MessageMetadata(BaseModel):
"""Metadata for a message"""
tokens: int = 0
class ManagedMessage(BaseModel):
"""A message with its metadata"""
message: BaseMessage
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
model_config = ConfigDict(arbitrary_types_allowed=True)
# https://github.com/pydantic/pydantic/discussions/7558
@model_serializer(mode='wrap')
def to_json(self, original_dump):
"""
Returns the JSON representation of the model.
It uses langchain's `dumps` function to serialize the `message`
property before encoding the overall dict with json.dumps.
"""
data = original_dump(self)
# NOTE: We override the message field to use langchain JSON serialization.
data['message'] = dumpd(self.message)
return data
@model_validator(mode='before')
@classmethod
def validate(
cls,
value: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: Any | None = None,
) -> Any:
"""
Custom validator that uses langchain's `loads` function
to parse the message if it is provided as a JSON string.
"""
if isinstance(value, dict) and 'message' in value:
# NOTE: We use langchain's load to convert the JSON string back into a BaseMessage object.
value['message'] = load(value['message'])
return value
class MessageHistory(BaseModel):
"""History of messages with metadata"""
messages: list[ManagedMessage] = Field(default_factory=list)
current_tokens: int = 0
model_config = ConfigDict(arbitrary_types_allowed=True)
def add_message(self, message: BaseMessage, metadata: MessageMetadata, position: int | None = None) -> None:
"""Add message with metadata to history"""
if position is None:
self.messages.append(ManagedMessage(message=message, metadata=metadata))
else:
self.messages.insert(position, ManagedMessage(message=message, metadata=metadata))
self.current_tokens += metadata.tokens
def add_model_output(self, output) -> None:
"""Add model output as AI message"""
tool_calls = [
{
'name': 'AgentOutput',
'args': output.model_dump(mode='json', exclude_unset=True),
'id': '1',
'type': 'tool_call',
}
]
msg = AIMessage(
content='',
tool_calls=tool_calls,
)
self.add_message(msg, MessageMetadata(tokens=100)) # Estimate tokens for tool calls
# Empty tool response
tool_message = ToolMessage(content='', tool_call_id='1')
self.add_message(tool_message, MessageMetadata(tokens=10)) # Estimate tokens for empty response
def get_messages(self) -> list[BaseMessage]:
"""Get all messages"""
return [m.message for m in self.messages]
def get_total_tokens(self) -> int:
"""Get total tokens in history"""
return self.current_tokens
def remove_oldest_message(self) -> None:
"""Remove oldest non-system message"""
for i, msg in enumerate(self.messages):
if not isinstance(msg.message, SystemMessage):
self.current_tokens -= msg.metadata.tokens
self.messages.pop(i)
break
def remove_last_state_message(self) -> None:
"""Remove last state message from history"""
if len(self.messages) > 2 and isinstance(self.messages[-1].message, HumanMessage):
self.current_tokens -= self.messages[-1].metadata.tokens
self.messages.pop()
class MessageManagerState(BaseModel):
"""Holds the state for MessageManager"""
history: MessageHistory = Field(default_factory=MessageHistory)
tool_id: int = 1
model_config = ConfigDict(arbitrary_types_allowed=True)
class AgentSettings(BaseModel):
"""Options for the agent"""
max_failures: int = 3
retry_delay: int = 10
save_history: bool = True
history_path: Optional[str] = None
max_actions_per_step: int = 10
validate_output: bool = False
message_context: Optional[str] = None
class PolicyMetadata(BaseModel):
"""Metadata for a single step including timing information"""
start_time: float
end_time: float
number: int
input_tokens: int
@property
def duration_seconds(self) -> float:
"""Calculate step duration in seconds"""
return self.end_time - self.start_time
class AgentBrain(BaseModel):
"""Current state of the agent"""
evaluation_previous_goal: str
memory: str
next_goal: str
class AgentHistory(BaseModel):
"""History item for agent actions"""
model_output: Optional[BaseModel] = None
result: List[ActionResult]
metadata: Optional[PolicyMetadata] = None
content: Optional[str] = None
base64_img: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def model_dump(self, **kwargs) -> Dict[str, Any]:
"""Custom serialization handling"""
return {
'model_output': self.model_output.model_dump() if self.model_output else None,
'result': [r.model_dump(exclude_none=True) for r in self.result],
'metadata': self.metadata.model_dump() if self.metadata else None,
'content': self.xml_content,
'base64_img': self.base64_img
}
class AgentHistoryList(BaseModel):
"""List of agent history items"""
history: List[AgentHistory]
def total_duration_seconds(self) -> float:
"""Get total duration of all steps in seconds"""
total = 0.0
for h in self.history:
if h.metadata:
total += h.metadata.duration_seconds
return total
def save_to_file(self, filepath: str | Path) -> None:
"""Save history to JSON file with proper serialization"""
try:
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
data = self.model_dump()
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2)
except Exception as e:
raise e
def model_dump(self, **kwargs) -> Dict[str, Any]:
"""Custom serialization that properly uses AgentHistory's model_dump"""
return {
'history': [h.model_dump(**kwargs) for h in self.history],
}
@classmethod
def load_from_file(cls, filepath: str | Path) -> 'AgentHistoryList':
"""Load history from JSON file"""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
return cls.model_validate(data)
class AgentError:
"""Container for agent error handling"""
VALIDATION_ERROR = 'Invalid model output format. Please follow the correct schema.'
RATE_LIMIT_ERROR = 'Rate limit reached. Waiting before retry.'
NO_VALID_ACTION = 'No valid action found'
@staticmethod
def format_error(error: Exception, include_trace: bool = False) -> str:
"""Format error message based on error type and optionally include trace"""
if isinstance(error, RateLimitError):
return AgentError.RATE_LIMIT_ERROR
if include_trace:
return f'{str(error)}\nStacktrace:\n{traceback.format_exc()}'
return f'{str(error)}'
class AgentState(BaseModel):
"""Holds all state information for an Agent"""
agent_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
n_steps: int = 1
consecutive_failures: int = 0
last_result: Optional[List['ActionResult']] = None
history: AgentHistoryList = Field(default_factory=lambda: AgentHistoryList(history=[]))
last_plan: Optional[str] = None
paused: bool = False
stopped: bool = False
message_manager_state: MessageManagerState = Field(default_factory=MessageManagerState)
@dataclass
class AgentStepInfo:
number: int
max_steps: int
def is_last_step(self) -> bool:
"""Check if this is the last step"""
return self.number >= self.max_steps - 1
@dataclass
class Trajectory:
"""Stores the agent's history, including all observations, info, and AgentResults."""
history: List[tuple[Observation, Dict[str, Any], AgentResult]] = field(default_factory=list)
def add_step(self, observation: Observation, info: Dict[str, Any], agent_result: AgentResult):
"""Add a step to the history"""
self.history.append((observation, info, agent_result))
def get_history(self) -> List[tuple[Observation, Dict[str, Any], AgentResult]]:
"""Retrieve the complete history"""
return self.history
|