File size: 7,630 Bytes
a27d8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# coding: utf-8
import requests
import json
from io import BytesIO
import os
from typing import Any, Optional, Type
import base64

from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    HumanMessage,
    SystemMessage,
    ToolMessage,
)

from aworld.logs.util import logger


def extract_json_from_model_output(content: str) -> dict:
    """Extract JSON from model output, handling both plain JSON and code-block-wrapped JSON."""
    try:
        # If content is wrapped in code blocks, extract just the JSON part
        if '```' in content:
            # Find the JSON content between code blocks
            content = content.split('```')[1]
            # Remove language identifier if present (e.g., 'json\n')
            if '\n' in content:
                content = content.split('\n', 1)[1]
        # Parse the cleaned content
        return json.loads(content)
    except json.JSONDecodeError as e:
        logger.warning(f'Failed to parse model output: {content} {str(e)}')
        raise ValueError('Could not parse response.')


def convert_input_messages(input_messages: list[BaseMessage], model_name: Optional[str]) -> list[BaseMessage]:
    """Convert input messages to a format that is compatible with the planner model"""
    if model_name is None:
        return input_messages
    if model_name == 'deepseek-reasoner' or model_name.startswith('deepseek-r1'):
        converted_input_messages = _convert_messages_for_non_function_calling_models(input_messages)
        merged_input_messages = _merge_successive_messages(converted_input_messages, HumanMessage)
        merged_input_messages = _merge_successive_messages(merged_input_messages, AIMessage)
        return merged_input_messages
    return input_messages


def _convert_messages_for_non_function_calling_models(input_messages: list[BaseMessage]) -> list[BaseMessage]:
    """Convert messages for non-function-calling models"""
    output_messages = []
    for message in input_messages:
        if isinstance(message, HumanMessage):
            output_messages.append(message)
        elif isinstance(message, SystemMessage):
            output_messages.append(message)
        elif isinstance(message, ToolMessage):
            output_messages.append(HumanMessage(content=message.content))
        elif isinstance(message, AIMessage):
            # check if tool_calls is a valid JSON object
            if message.tool_calls:
                tool_calls = json.dumps(message.tool_calls)
                output_messages.append(AIMessage(content=tool_calls))
            else:
                output_messages.append(message)
        else:
            raise ValueError(f'Unknown message type: {type(message)}')
    return output_messages


def _merge_successive_messages(messages: list[BaseMessage], class_to_merge: Type[BaseMessage]) -> list[BaseMessage]:
    """Some models like deepseek-reasoner dont allow multiple human messages in a row. This function merges them into one."""
    merged_messages = []
    streak = 0
    for message in messages:
        if isinstance(message, class_to_merge):
            streak += 1
            if streak > 1:
                if isinstance(message.content, list):
                    merged_messages[-1].content += message.content[0]['text']  # type:ignore
                else:
                    merged_messages[-1].content += message.content
            else:
                merged_messages.append(message)
        else:
            merged_messages.append(message)
            streak = 0
    return merged_messages


def save_conversation(input_messages: list[BaseMessage], response: Any, target: str,
                      encoding: Optional[str] = None) -> None:
    """Save conversation history to file."""

    # create folders if not exists
    os.makedirs(os.path.dirname(target), exist_ok=True)

    with open(
            target,
            'w',
            encoding=encoding,
    ) as f:
        _write_messages_to_file(f, input_messages)
        _write_response_to_file(f, response)


def _write_messages_to_file(f: Any, messages: list[BaseMessage]) -> None:
    """Write messages to conversation file"""
    for message in messages:
        f.write(f' {message.__class__.__name__} \n')

        if isinstance(message.content, list):
            for item in message.content:
                if isinstance(item, dict) and item.get('type') == 'text':
                    f.write(item['text'].strip() + '\n')
        elif isinstance(message.content, str):
            try:
                content = json.loads(message.content)
                f.write(json.dumps(content, indent=2) + '\n')
            except json.JSONDecodeError:
                f.write(message.content.strip() + '\n')

        f.write('\n')


def _write_response_to_file(f: Any, response: Any) -> None:
    """Write model response to conversation file"""
    f.write(' RESPONSE\n')
    f.write(json.dumps(json.loads(response.model_dump_json(exclude_unset=True)), indent=2))


# Add token counting related functions
# Note: These functions have been moved from memory.py and agent.py to utils.py, removing the dependency on MessageManager class

def estimate_text_tokens(text: str, estimated_characters_per_token: int = 3) -> int:
    """Roughly estimate token count in text
    
    Args:
        text: The text to estimate tokens for
        estimated_characters_per_token: Estimated characters per token, default is 3
        
    Returns:
        Estimated token count
    """
    if not text:
        return 0
    # Use character count divided by average characters per token to estimate tokens
    return len(text) // estimated_characters_per_token


def estimate_message_tokens(message: BaseMessage, image_tokens: int = 800, 
                       estimated_characters_per_token: int = 3) -> int:
    """Roughly estimate token count for a single message
    
    Args:
        message: The message to estimate tokens for
        image_tokens: Estimated tokens per image, default is 800
        estimated_characters_per_token: Estimated characters per token, default is 3
        
    Returns:
        Estimated token count
    """
    tokens = 0
    # Handle tuple case
    if isinstance(message, tuple):
        # Convert to string and estimate tokens
        message_str = str(message)
        return estimate_text_tokens(message_str, estimated_characters_per_token)
        
    if isinstance(message.content, list):
        for item in message.content:
            if 'image_url' in item:
                tokens += image_tokens
            elif isinstance(item, dict) and 'text' in item:
                tokens += estimate_text_tokens(item['text'], estimated_characters_per_token)
    else:
        msg = message.content
        if hasattr(message, 'tool_calls'):
            msg += str(message.tool_calls)  # type: ignore
        tokens += estimate_text_tokens(msg, estimated_characters_per_token)
    return tokens


def estimate_messages_tokens(messages: list[BaseMessage], image_tokens: int = 800,
                        estimated_characters_per_token: int = 3) -> int:
    """Roughly estimate total token count for a list of messages
    
    Args:
        messages: The list of messages to estimate tokens for
        image_tokens: Estimated tokens per image, default is 800
        estimated_characters_per_token: Estimated characters per token, default is 3
        
    Returns:
        Estimated total token count
    """
    total_tokens = 0
    for msg in messages:
        total_tokens += estimate_message_tokens(msg, image_tokens, estimated_characters_per_token)
    return total_tokens