Spaces:
Sleeping
Sleeping
import logging | |
from typing import List, Dict, Any | |
import json | |
import os | |
from contextlib import AsyncExitStack | |
import traceback | |
import requests | |
from mcp.types import TextContent, ImageContent | |
from aworld.core.common import ActionResult | |
from aworld.logs.util import logger | |
from aworld.mcp_client.server import MCPServer, MCPServerSse, MCPServerStdio | |
from aworld.tools import get_function_tools | |
from aworld.utils.common import find_file | |
MCP_SERVERS_CONFIG = {} | |
def get_function_tool(sever_name: str) -> List[Dict[str, Any]]: | |
openai_tools = [] | |
try: | |
if not sever_name: | |
return [] | |
tool_server = get_function_tools(sever_name) | |
if not tool_server: | |
return [] | |
tools = tool_server.list_tools() | |
if not tools: | |
return [] | |
for tool in tools: | |
required = [] | |
properties = {} | |
if tool.inputSchema and tool.inputSchema.get("properties"): | |
required = tool.inputSchema.get("required", []) | |
_properties = tool.inputSchema["properties"] | |
for param_name, param_info in _properties.items(): | |
param_type = ( | |
param_info.get("type") | |
if param_info.get("type") != "str" | |
and param_info.get("type") is not None | |
else "string" | |
) | |
param_desc = param_info.get("description", "") | |
if param_type == "array": | |
# Handle array type parameters | |
items_info = param_info.get("items", {}) | |
item_type = items_info.get("type", "string") | |
# Process nested array type parameters | |
if item_type == "array": | |
nested_items = items_info.get("items", {}) | |
nested_type = nested_items.get("type", "string") | |
# If the nested type is an object | |
if nested_type == "object": | |
properties[param_name] = { | |
"description": param_desc, | |
"type": param_type, | |
"items": { | |
"type": item_type, | |
"items": { | |
"type": nested_type, | |
"properties": nested_items.get( | |
"properties", {} | |
), | |
"required": nested_items.get( | |
"required", [] | |
), | |
}, | |
}, | |
} | |
else: | |
properties[param_name] = { | |
"description": param_desc, | |
"type": param_type, | |
"items": { | |
"type": item_type, | |
"items": {"type": nested_type}, | |
}, | |
} | |
# Process object type cases | |
elif item_type == "object": | |
properties[param_name] = { | |
"description": param_desc, | |
"type": param_type, | |
"items": { | |
"type": item_type, | |
"properties": items_info.get("properties", {}), | |
"required": items_info.get("required", []), | |
}, | |
} | |
# Process basic type cases | |
else: | |
if item_type == "str": | |
item_type = "string" | |
properties[param_name] = { | |
"description": param_desc, | |
"type": param_type, | |
"items": {"type": item_type}, | |
} | |
else: | |
# Handle non-array type parameters | |
properties[param_name] = { | |
"description": param_desc, | |
"type": param_type, | |
} | |
openai_function_schema = { | |
"name": f"mcp__{sever_name}__{tool.name}", | |
"description": tool.description, | |
"parameters": { | |
"type": "object", | |
"properties": properties, | |
"required": required, | |
}, | |
} | |
openai_tools.append( | |
{ | |
"type": "function", | |
"function": openai_function_schema, | |
} | |
) | |
logging.info( | |
f"✅ function_tool_server #({sever_name}) connected success,tools: {len(tools)}" | |
) | |
except Exception as e: | |
logging.warning( | |
f"server_name-get_function_tool:{sever_name} translate failed: {e}" | |
) | |
return [] | |
finally: | |
return openai_tools | |
async def run(mcp_servers: list[MCPServer]) -> List[Dict[str, Any]]: | |
openai_tools = [] | |
for i, server in enumerate(mcp_servers): | |
try: | |
tools = await server.list_tools() | |
for tool in tools: | |
required = [] | |
properties = {} | |
if tool.inputSchema and tool.inputSchema.get("properties"): | |
required = tool.inputSchema.get("required", []) | |
_properties = tool.inputSchema["properties"] | |
for param_name, param_info in _properties.items(): | |
param_type = ( | |
param_info.get("type") | |
if param_info.get("type") != "str" | |
and param_info.get("type") is not None | |
else "string" | |
) | |
param_desc = param_info.get("description", "") | |
if param_type == "array": | |
# Handle array type parameters | |
items_info = param_info.get("items", {}) | |
item_type = items_info.get("type", "string") | |
# Process nested array type parameters | |
if item_type == "array": | |
nested_items = items_info.get("items", {}) | |
nested_type = nested_items.get("type", "string") | |
# If the nested type is an object | |
if nested_type == "object": | |
properties[param_name] = { | |
"description": param_desc, | |
"type": param_type, | |
"items": { | |
"type": item_type, | |
"items": { | |
"type": nested_type, | |
"properties": nested_items.get( | |
"properties", {} | |
), | |
"required": nested_items.get( | |
"required", [] | |
), | |
}, | |
}, | |
} | |
else: | |
properties[param_name] = { | |
"description": param_desc, | |
"type": param_type, | |
"items": { | |
"type": item_type, | |
"items": {"type": nested_type}, | |
}, | |
} | |
# Process object type cases | |
elif item_type == "object": | |
properties[param_name] = { | |
"description": param_desc, | |
"type": param_type, | |
"items": { | |
"type": item_type, | |
"properties": items_info.get("properties", {}), | |
"required": items_info.get("required", []), | |
}, | |
} | |
# Process basic type cases | |
else: | |
if item_type == "str": | |
item_type = "string" | |
properties[param_name] = { | |
"description": param_desc, | |
"type": param_type, | |
"items": {"type": item_type}, | |
} | |
else: | |
# Handle non-array type parameters | |
properties[param_name] = { | |
"description": param_desc, | |
"type": param_type, | |
} | |
openai_function_schema = { | |
"name": f"{server.name}__{tool.name}", | |
"description": tool.description, | |
"parameters": { | |
"type": "object", | |
"properties": properties, | |
"required": required, | |
}, | |
} | |
openai_tools.append( | |
{ | |
"type": "function", | |
"function": openai_function_schema, | |
} | |
) | |
logging.info( | |
f"✅ server #{i + 1} ({server.name}) connected success,tools: {len(tools)}" | |
) | |
except Exception as e: | |
logging.error(f"❌ server #{i + 1} ({server.name}) connect fail: {e}") | |
return [] | |
return openai_tools | |
async def mcp_tool_desc_transform( | |
tools: List[str] = None, mcp_config: Dict[str, Any] = None | |
) -> List[Dict[str, Any]]: | |
"""Default implement transform framework standard protocol to openai protocol of tool description.""" | |
config = {} | |
global MCP_SERVERS_CONFIG | |
def _replace_env_variables(config): | |
if isinstance(config, dict): | |
for key, value in config.items(): | |
if ( | |
isinstance(value, str) | |
and value.startswith("${") | |
and value.endswith("}") | |
): | |
env_var_name = value[2:-1] | |
config[key] = os.getenv(env_var_name, value) | |
logging.info(f"Replaced {value} with {config[key]}") | |
elif isinstance(value, dict) or isinstance(value, list): | |
_replace_env_variables(value) | |
elif isinstance(config, list): | |
for index, item in enumerate(config): | |
if ( | |
isinstance(item, str) | |
and item.startswith("${") | |
and item.endswith("}") | |
): | |
env_var_name = item[2:-1] | |
config[index] = os.getenv(env_var_name, item) | |
logging.info(f"Replaced {item} with {config[index]}") | |
elif isinstance(item, dict) or isinstance(item, list): | |
_replace_env_variables(item) | |
if mcp_config: | |
try: | |
config = mcp_config | |
MCP_SERVERS_CONFIG = config | |
except Exception as e: | |
logging.error(f"mcp_config error: {e}") | |
return [] | |
else: | |
# Priority given to the running path. | |
config_path = find_file(filename="mcp.json") | |
if not os.path.exists(config_path): | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
config_path = os.path.normpath( | |
os.path.join(current_dir, "../config/mcp.json") | |
) | |
logger.info(f"mcp conf path: {config_path}") | |
if not os.path.exists(config_path): | |
logging.info(f"mcp config is not exist: {config_path}") | |
return [] | |
try: | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
except Exception as e: | |
logging.info(f"load config fail: {e}") | |
return [] | |
_replace_env_variables(config) | |
MCP_SERVERS_CONFIG = config | |
mcp_servers_config = config.get("mcpServers", {}) | |
server_configs = [] | |
for server_name, server_config in mcp_servers_config.items(): | |
# Skip disabled servers | |
if server_config.get("disabled", False): | |
continue | |
if tools is None or server_name in tools: | |
# Handle SSE server | |
if "url" in server_config: | |
server_configs.append( | |
{ | |
"name": "mcp__" + server_name, | |
"type": "sse", | |
"params": {"url": server_config["url"]}, | |
} | |
) | |
# Handle stdio server | |
elif "command" in server_config: | |
server_configs.append( | |
{ | |
"name": "mcp__" + server_name, | |
"type": "stdio", | |
"params": { | |
"command": server_config["command"], | |
"args": server_config.get("args", []), | |
"env": server_config.get("env", {}), | |
"cwd": server_config.get("cwd"), | |
"encoding": server_config.get("encoding", "utf-8"), | |
"encoding_error_handler": server_config.get( | |
"encoding_error_handler", "strict" | |
), | |
}, | |
} | |
) | |
if not server_configs: | |
return [] | |
async with AsyncExitStack() as stack: | |
servers = [] | |
for server_config in server_configs: | |
try: | |
if server_config["type"] == "sse": | |
server = MCPServerSse( | |
name=server_config["name"], params=server_config["params"] | |
) | |
elif server_config["type"] == "stdio": | |
from aworld.mcp_client.server import MCPServerStdio | |
server = MCPServerStdio( | |
name=server_config["name"], params=server_config["params"] | |
) | |
else: | |
logging.warning( | |
f"Unsupported MCP server type: {server_config['type']}" | |
) | |
continue | |
server = await stack.enter_async_context(server) | |
servers.append(server) | |
except BaseException as err: | |
# single | |
logging.error( | |
f"Failed to get tools for MCP server '{server_config['name']}'.\n" | |
f"Error: {err}\n" | |
f"Traceback:\n{traceback.format_exc()}" | |
) | |
openai_tools = await run(servers) | |
return openai_tools | |
async def sandbox_mcp_tool_desc_transform( | |
tools: List[str] = None, mcp_config: Dict[str, Any] = None | |
) -> List[Dict[str, Any]]: | |
# todo sandbox mcp_config get from registry | |
if not mcp_config: | |
return None | |
config = mcp_config | |
mcp_servers_config = config.get("mcpServers", {}) | |
server_configs = [] | |
openai_tools = [] | |
mcp_openai_tools = [] | |
for server_name, server_config in mcp_servers_config.items(): | |
# Skip disabled servers | |
if server_config.get("disabled", False): | |
continue | |
if tools is None or server_name in tools: | |
# Handle SSE server | |
if "function_tool" == server_config.get("type", ""): | |
try: | |
tmp_function_tool = get_function_tool(server_name) | |
openai_tools.extend(tmp_function_tool) | |
except Exception as e: | |
logging.warning(f"server_name:{server_name} translate failed: {e}") | |
elif "api" == server_config.get("type", ""): | |
api_result = requests.get(server_config["url"] + "/list_tools") | |
try: | |
if not api_result or not api_result.text: | |
continue | |
# return None | |
data = json.loads(api_result.text) | |
if not data or not data.get("tools"): | |
continue | |
for item in data.get("tools"): | |
tmp_function = { | |
"type": "function", | |
"function": { | |
"name": "mcp__" + server_name + "__" + item["name"], | |
"description": item["description"], | |
"parameters": { | |
**item["parameters"], | |
"properties": { | |
k: v | |
for k, v in item["parameters"] | |
.get("properties", {}) | |
.items() | |
if "default" not in v | |
}, | |
}, | |
}, | |
} | |
openai_tools.append(tmp_function) | |
except Exception as e: | |
logging.warning(f"server_name:{server_name} translate failed: {e}") | |
elif "sse" == server_config.get("type", ""): | |
server_configs.append( | |
{ | |
"name": "mcp__" + server_name, | |
"type": "sse", | |
"params": { | |
"url": server_config["url"], | |
"headers": server_config.get("headers"), | |
}, | |
} | |
) | |
# Handle stdio server | |
else: | |
# elif "stdio" == server_config.get("type", ""): | |
server_configs.append( | |
{ | |
"name": "mcp__" + server_name, | |
"type": "stdio", | |
"params": { | |
"command": server_config["command"], | |
"args": server_config.get("args", []), | |
"env": server_config.get("env", {}), | |
"cwd": server_config.get("cwd"), | |
"encoding": server_config.get("encoding", "utf-8"), | |
"encoding_error_handler": server_config.get( | |
"encoding_error_handler", "strict" | |
), | |
}, | |
} | |
) | |
if not server_configs: | |
return openai_tools | |
async with AsyncExitStack() as stack: | |
servers = [] | |
for server_config in server_configs: | |
try: | |
if server_config["type"] == "sse": | |
server = MCPServerSse( | |
name=server_config["name"], params=server_config["params"] | |
) | |
elif server_config["type"] == "stdio": | |
server = MCPServerStdio( | |
name=server_config["name"], params=server_config["params"] | |
) | |
else: | |
logging.warning( | |
f"Unsupported MCP server type: {server_config['type']}" | |
) | |
continue | |
server = await stack.enter_async_context(server) | |
servers.append(server) | |
except BaseException as err: | |
# single | |
logging.error( | |
f"Failed to get tools for MCP server '{server_config['name']}'.\n" | |
f"Error: {err}\n" | |
) | |
mcp_openai_tools = await run(servers) | |
if mcp_openai_tools: | |
openai_tools.extend(mcp_openai_tools) | |
return openai_tools | |
async def call_function_tool( | |
server_name: str, | |
tool_name: str, | |
parameter: Dict[str, Any] = None, | |
mcp_config: Dict[str, Any] = None, | |
) -> ActionResult: | |
"""Specifically handle API type server calls | |
Args: | |
server_name: Server name | |
tool_name: Tool name | |
parameter: Parameters | |
mcp_config: MCP configuration | |
Returns: | |
ActionResult: Call result | |
""" | |
action_result = ActionResult( | |
tool_name=server_name, action_name=tool_name, content="", keep=True | |
) | |
try: | |
tool_server = get_function_tools(server_name) | |
if not tool_server: | |
return action_result | |
call_result_raw = tool_server.call_tool(tool_name, parameter) | |
if call_result_raw and call_result_raw.content: | |
if isinstance(call_result_raw.content[0], TextContent): | |
action_result = ActionResult( | |
tool_name=server_name, | |
action_name=tool_name, | |
content=call_result_raw.content[0].text, | |
keep=True, | |
metadata=call_result_raw.content[0].model_extra.get("metadata", {}), | |
) | |
elif isinstance(call_result_raw.content[0], ImageContent): | |
action_result = ActionResult( | |
tool_name=server_name, | |
action_name=tool_name, | |
content=f"data:image/jpeg;base64,{call_result_raw.content[0].data}", | |
keep=True, | |
metadata=call_result_raw.content[0].model_extra.get("metadata", {}), | |
) | |
except Exception as e: | |
logging.warning(f"call_function_tool ({server_name})({tool_name}) failed: {e}") | |
action_result = ActionResult( | |
tool_name=server_name, action_name=tool_name, content="", keep=True | |
) | |
return action_result | |
async def call_api( | |
server_name: str, | |
tool_name: str, | |
parameter: Dict[str, Any] = None, | |
mcp_config: Dict[str, Any] = None, | |
) -> ActionResult: | |
"""Specifically handle API type server calls | |
Args: | |
server_name: Server name | |
tool_name: Tool name | |
parameter: Parameters | |
mcp_config: MCP configuration | |
Returns: | |
ActionResult: Call result | |
""" | |
action_result = ActionResult( | |
tool_name=server_name, action_name=tool_name, content="", keep=True | |
) | |
if not mcp_config or mcp_config.get("mcpServers") is None: | |
return action_result | |
mcp_servers = mcp_config.get("mcpServers") | |
if not mcp_servers.get(server_name): | |
return action_result | |
server_config = mcp_servers.get(server_name) | |
if "api" != server_config.get("type", ""): | |
logging.warning( | |
f"Server {server_name} is not API type, should use call_tool instead" | |
) | |
return action_result | |
try: | |
headers = {"Content-Type": "application/json"} | |
response = requests.post( | |
url=server_config["url"] + "/" + tool_name, headers=headers, json=parameter | |
) | |
action_result = ActionResult( | |
tool_name=server_name, | |
action_name=tool_name, | |
content=response.text, | |
keep=True, | |
) | |
except Exception as e: | |
logging.warning(f"call_api ({server_name})({tool_name}) failed: {e}") | |
action_result = ActionResult( | |
tool_name=server_name, | |
action_name=tool_name, | |
content=f"Error calling API: {str(e)}", | |
keep=True, | |
) | |
return action_result | |
async def get_server_instance( | |
server_name: str, mcp_config: Dict[str, Any] = None | |
) -> Any: | |
"""Get server instance, create a new one if it doesn't exist | |
Args: | |
server_name: Server name | |
mcp_config: MCP configuration | |
Returns: | |
Server instance or None (if creation fails) | |
""" | |
if not mcp_config or mcp_config.get("mcpServers") is None: | |
return None | |
mcp_servers = mcp_config.get("mcpServers") | |
if not mcp_servers.get(server_name): | |
return None | |
server_config = mcp_servers.get(server_name) | |
try: | |
# API type servers use special handling, no need for persistent connections | |
# Note: We've already handled API type in McpServers.call_tool method | |
# Here we don't return None, but let the caller handle it | |
if "api" == server_config.get("type", ""): | |
logging.info(f"API server {server_name} doesn't need persistent connection") | |
return None | |
elif "sse" == server_config.get("type", ""): | |
server = MCPServerSse( | |
name=server_name, | |
params={ | |
"url": server_config["url"], | |
"headers": server_config.get("headers"), | |
"timeout": server_config.get("timeout", 5.0), | |
"sse_read_timeout": server_config.get("sse_read_timeout", 300.0), | |
}, | |
) | |
await server.connect() | |
logging.info(f"Successfully connected to SSE server: {server_name}") | |
return server | |
else: # stdio type | |
params = { | |
"command": server_config["command"], | |
"args": server_config.get("args", []), | |
"env": server_config.get("env", {}), | |
"cwd": server_config.get("cwd"), | |
"encoding": server_config.get("encoding", "utf-8"), | |
"encoding_error_handler": server_config.get( | |
"encoding_error_handler", "strict" | |
), | |
} | |
server = MCPServerStdio(name=server_name, params=params) | |
await server.connect() | |
logging.info(f"Successfully connected to stdio server: {server_name}") | |
return server | |
except Exception as e: | |
logging.warning(f"Failed to create server instance for {server_name}: {e}") | |
return None | |
async def cleanup_server(server): | |
"""Clean up server connection | |
Args: | |
server: Server instance | |
""" | |
try: | |
if hasattr(server, "cleanup"): | |
await server.cleanup() | |
elif hasattr(server, "close"): | |
await server.close() | |
logging.info( | |
f"Successfully cleaned up server: {getattr(server, 'name', 'unknown')}" | |
) | |
except Exception as e: | |
logging.warning(f"Failed to cleanup server: {e}") | |