File size: 3,131 Bytes
52f4d1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import uuid
from typing import Any
from typing import List, Optional
from datetime import datetime

from pydantic import BaseModel, ConfigDict
from pydantic import Field


class AworldTask(BaseModel):
    task_id: str = Field(default=None, description="task id")
    agent_id: str = Field(default=None, description="agent id")
    agent_input: str = Field(default=None, description="agent input")
    session_id: Optional[str] = Field(default=None, description="session id")
    user_id: Optional[str] = Field(default=None, description="user id")
    llm_provider: Optional[str] = Field(default=None, description="llm provider")
    llm_model_name: Optional[str] = Field(default=None, description="llm model name")
    llm_api_key: Optional[str] = Field(default=None, description="llm api key")
    llm_base_url: Optional[str] = Field(default=None, description="llm base url")
    llm_custom_input: Optional[str] = Field(default=None, description="custom_input")
    task_system_prompt: Optional[str] = Field(default=None, description="task_system_prompt")
    mcp_servers: Optional[list[str]] = Field(default=None, description="mcp_servers")
    node_id: Optional[str] = Field(default=None, description="execute task node_id")
    client_id: Optional[str] = Field(default=None, description="submit client ip")
    status: Optional[str] = Field(default="INIT", description="submitted/running/execute_failed/execute_success")
    history_messages: Optional[int] = Field(default=100, description="history_message")
    max_steps: Optional[int] = Field(default=100, description="max_steps")
    max_retries: Optional[int] = Field(default=5, description="max_retries use Exponential backoff with jitter")
    ext_info: Optional[dict] = Field(default_factory=dict, description="custom")
    created_at: Optional[datetime] = Field(default=None, description="created time")
    updated_at: Optional[datetime] = Field(default=None, description="updated time")

    def mark_running(self):
        self.status = 'RUNNING'

    def mark_failed(self):
        self.status = 'FAILED'

    def mark_success(self):
        self.status = 'SUCCESS'

class AworldTaskResult(BaseModel):
    task: AworldTask = Field(default=None, description="task")
    server_host: Optional[str] = Field(default=None, description="aworld server id")
    data: Any = Field(default=None, description="result data")

class AworldTaskForm(BaseModel):
    batch_id: str = Field(default=str(uuid.uuid4()), description="batch_id")
    task: Optional[AworldTask] = Field(default=None, description="task")
    user_id: Optional[str] = Field(default=None, description="user id")
    client_id: Optional[str] = Field(default=None, description="submit client ip")


class OpenAIChatMessage(BaseModel):
    role: str
    content: str | List

    model_config = ConfigDict(extra="allow")


class OpenAIChatCompletionForm(BaseModel):
    stream: bool = True
    model: str
    messages: List[OpenAIChatMessage]

    model_config = ConfigDict(extra="allow")


class FilterForm(BaseModel):
    body: dict
    user: Optional[dict] = None
    model_config = ConfigDict(extra="allow")