Spaces:
Sleeping
Sleeping
File size: 7,630 Bytes
a27d8ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
# coding: utf-8
import requests
import json
from io import BytesIO
import os
from typing import Any, Optional, Type
import base64
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from aworld.logs.util import logger
def extract_json_from_model_output(content: str) -> dict:
"""Extract JSON from model output, handling both plain JSON and code-block-wrapped JSON."""
try:
# If content is wrapped in code blocks, extract just the JSON part
if '```' in content:
# Find the JSON content between code blocks
content = content.split('```')[1]
# Remove language identifier if present (e.g., 'json\n')
if '\n' in content:
content = content.split('\n', 1)[1]
# Parse the cleaned content
return json.loads(content)
except json.JSONDecodeError as e:
logger.warning(f'Failed to parse model output: {content} {str(e)}')
raise ValueError('Could not parse response.')
def convert_input_messages(input_messages: list[BaseMessage], model_name: Optional[str]) -> list[BaseMessage]:
"""Convert input messages to a format that is compatible with the planner model"""
if model_name is None:
return input_messages
if model_name == 'deepseek-reasoner' or model_name.startswith('deepseek-r1'):
converted_input_messages = _convert_messages_for_non_function_calling_models(input_messages)
merged_input_messages = _merge_successive_messages(converted_input_messages, HumanMessage)
merged_input_messages = _merge_successive_messages(merged_input_messages, AIMessage)
return merged_input_messages
return input_messages
def _convert_messages_for_non_function_calling_models(input_messages: list[BaseMessage]) -> list[BaseMessage]:
"""Convert messages for non-function-calling models"""
output_messages = []
for message in input_messages:
if isinstance(message, HumanMessage):
output_messages.append(message)
elif isinstance(message, SystemMessage):
output_messages.append(message)
elif isinstance(message, ToolMessage):
output_messages.append(HumanMessage(content=message.content))
elif isinstance(message, AIMessage):
# check if tool_calls is a valid JSON object
if message.tool_calls:
tool_calls = json.dumps(message.tool_calls)
output_messages.append(AIMessage(content=tool_calls))
else:
output_messages.append(message)
else:
raise ValueError(f'Unknown message type: {type(message)}')
return output_messages
def _merge_successive_messages(messages: list[BaseMessage], class_to_merge: Type[BaseMessage]) -> list[BaseMessage]:
"""Some models like deepseek-reasoner dont allow multiple human messages in a row. This function merges them into one."""
merged_messages = []
streak = 0
for message in messages:
if isinstance(message, class_to_merge):
streak += 1
if streak > 1:
if isinstance(message.content, list):
merged_messages[-1].content += message.content[0]['text'] # type:ignore
else:
merged_messages[-1].content += message.content
else:
merged_messages.append(message)
else:
merged_messages.append(message)
streak = 0
return merged_messages
def save_conversation(input_messages: list[BaseMessage], response: Any, target: str,
encoding: Optional[str] = None) -> None:
"""Save conversation history to file."""
# create folders if not exists
os.makedirs(os.path.dirname(target), exist_ok=True)
with open(
target,
'w',
encoding=encoding,
) as f:
_write_messages_to_file(f, input_messages)
_write_response_to_file(f, response)
def _write_messages_to_file(f: Any, messages: list[BaseMessage]) -> None:
"""Write messages to conversation file"""
for message in messages:
f.write(f' {message.__class__.__name__} \n')
if isinstance(message.content, list):
for item in message.content:
if isinstance(item, dict) and item.get('type') == 'text':
f.write(item['text'].strip() + '\n')
elif isinstance(message.content, str):
try:
content = json.loads(message.content)
f.write(json.dumps(content, indent=2) + '\n')
except json.JSONDecodeError:
f.write(message.content.strip() + '\n')
f.write('\n')
def _write_response_to_file(f: Any, response: Any) -> None:
"""Write model response to conversation file"""
f.write(' RESPONSE\n')
f.write(json.dumps(json.loads(response.model_dump_json(exclude_unset=True)), indent=2))
# Add token counting related functions
# Note: These functions have been moved from memory.py and agent.py to utils.py, removing the dependency on MessageManager class
def estimate_text_tokens(text: str, estimated_characters_per_token: int = 3) -> int:
"""Roughly estimate token count in text
Args:
text: The text to estimate tokens for
estimated_characters_per_token: Estimated characters per token, default is 3
Returns:
Estimated token count
"""
if not text:
return 0
# Use character count divided by average characters per token to estimate tokens
return len(text) // estimated_characters_per_token
def estimate_message_tokens(message: BaseMessage, image_tokens: int = 800,
estimated_characters_per_token: int = 3) -> int:
"""Roughly estimate token count for a single message
Args:
message: The message to estimate tokens for
image_tokens: Estimated tokens per image, default is 800
estimated_characters_per_token: Estimated characters per token, default is 3
Returns:
Estimated token count
"""
tokens = 0
# Handle tuple case
if isinstance(message, tuple):
# Convert to string and estimate tokens
message_str = str(message)
return estimate_text_tokens(message_str, estimated_characters_per_token)
if isinstance(message.content, list):
for item in message.content:
if 'image_url' in item:
tokens += image_tokens
elif isinstance(item, dict) and 'text' in item:
tokens += estimate_text_tokens(item['text'], estimated_characters_per_token)
else:
msg = message.content
if hasattr(message, 'tool_calls'):
msg += str(message.tool_calls) # type: ignore
tokens += estimate_text_tokens(msg, estimated_characters_per_token)
return tokens
def estimate_messages_tokens(messages: list[BaseMessage], image_tokens: int = 800,
estimated_characters_per_token: int = 3) -> int:
"""Roughly estimate total token count for a list of messages
Args:
messages: The list of messages to estimate tokens for
image_tokens: Estimated tokens per image, default is 800
estimated_characters_per_token: Estimated characters per token, default is 3
Returns:
Estimated total token count
"""
total_tokens = 0
for msg in messages:
total_tokens += estimate_message_tokens(msg, image_tokens, estimated_characters_per_token)
return total_tokens
|