File size: 4,162 Bytes
a27d8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# coding: utf-8

import json
import traceback
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, Dict, List

from openai import RateLimitError
from pydantic import BaseModel, ConfigDict, Field

from aworld.core.common import ActionResult


class PolicyMetadata(BaseModel):
    """Metadata for a single step including timing information"""
    start_time: float
    end_time: float
    number: int
    input_tokens: int

    @property
    def duration_seconds(self) -> float:
        """Calculate step duration in seconds"""
        return self.end_time - self.start_time


class AgentBrain(BaseModel):
    """Current state of the agent"""
    evaluation_previous_goal: str = None
    memory: str = None
    thought: str = None
    next_goal: str = None


class AgentHistory(BaseModel):
    """History item for agent actions"""
    model_output: Optional[BaseModel] = None
    result: List[ActionResult]
    metadata: Optional[PolicyMetadata] = None
    content: Optional[str] = None
    base64_img: Optional[str] = None

    model_config = ConfigDict(arbitrary_types_allowed=True)

    def model_dump(self, **kwargs) -> Dict[str, Any]:
        """Custom serialization handling"""
        return {
            'model_output': self.model_output.model_dump() if self.model_output else None,
            'result': [r.model_dump(exclude_none=True) for r in self.result],
            'metadata': self.metadata.model_dump() if self.metadata else None,
            'content': self.xml_content,
            'base64_img': self.base64_img
        }


class AgentHistoryList(BaseModel):
    """List of agent history items"""
    history: List[AgentHistory]

    def total_duration_seconds(self) -> float:
        """Get total duration of all steps in seconds"""
        total = 0.0
        for h in self.history:
            if h.metadata:
                total += h.metadata.duration_seconds
        return total

    def save_to_file(self, filepath: str | Path) -> None:
        """Save history to JSON file with proper serialization"""
        try:
            Path(filepath).parent.mkdir(parents=True, exist_ok=True)
            data = self.model_dump()
            with open(filepath, 'w', encoding='utf-8') as f:
                json.dump(data, f, indent=2)
        except Exception as e:
            raise e

    def model_dump(self, **kwargs) -> Dict[str, Any]:
        """Custom serialization that properly uses AgentHistory's model_dump"""
        return {
            'history': [h.model_dump(**kwargs) for h in self.history],
        }

    @classmethod
    def load_from_file(cls, filepath: str | Path) -> 'AgentHistoryList':
        """Load history from JSON file"""
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return cls.model_validate(data)


class AgentError:
    """Container for agent error handling"""
    VALIDATION_ERROR = 'Invalid model output format. Please follow the correct schema.'
    RATE_LIMIT_ERROR = 'Rate limit reached. Waiting before retry.'
    NO_VALID_ACTION = 'No valid action found'

    @staticmethod
    def format_error(error: Exception, include_trace: bool = False) -> str:
        """Format error message based on error type and optionally include trace"""
        if isinstance(error, RateLimitError):
            return AgentError.RATE_LIMIT_ERROR
        if include_trace:
            return f'{str(error)}\nStacktrace:\n{traceback.format_exc()}'
        return f'{str(error)}'


class AgentState(BaseModel):
    """Holds all state information for an Agent"""

    agent_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    n_steps: int = 1
    consecutive_failures: int = 0
    last_result: Optional[List['ActionResult']] = None
    history: AgentHistoryList = Field(default_factory=lambda: AgentHistoryList(history=[]))
    last_plan: Optional[str] = None
    paused: bool = False
    stopped: bool = False


@dataclass
class AgentStepInfo:
    number: int
    max_steps: int

    def is_last_step(self) -> bool:
        """Check if this is the last step"""
        return self.number >= self.max_steps - 1