Spaces:
Sleeping
Sleeping
# coding: utf-8 | |
# Copyright (c) 2025 inclusionAI. | |
import inspect | |
import json | |
import logging | |
import traceback | |
from typing import Any, Dict, List, Optional, Union, get_type_hints | |
from mcp.types import TextContent, ImageContent, CallToolResult | |
from mcp import Tool as MCPTool | |
from pydantic import Field, create_model | |
from pydantic.fields import FieldInfo # Import FieldInfo type | |
from aworld.core.common import ActionResult | |
from aworld.logs.util import logger | |
# Global function tools server registry | |
_FUNCTION_TOOLS_REGISTRY = {} | |
def _register_function_tools(function_tools): | |
"""Register function tools server to global registry""" | |
_FUNCTION_TOOLS_REGISTRY[function_tools.name] = function_tools | |
logger.info(f"Registered FunctionTools server: {function_tools.name}") | |
def get_function_tools(name): | |
"""Get specified function tools server""" | |
return _FUNCTION_TOOLS_REGISTRY.get(name) | |
def list_function_tools(): | |
"""List all registered function tools servers""" | |
return list(_FUNCTION_TOOLS_REGISTRY.keys()) | |
class FunctionTools: | |
"""Function tools server, providing tool registration and calling mechanism similar to MCP | |
Example: | |
```python | |
# Create function tools server | |
function = FunctionTools("my-server", description="My function tools server") | |
# Define tool function | |
@function.tool(description="Example search function") | |
def search(query: str, limit: int = 10) -> str: | |
# Actual search logic | |
results = [f"Result {i} for {query}" for i in range(limit)] | |
return json.dumps(results) | |
# Using Field decorator | |
@function.tool(description="Example search function") | |
def search( | |
query: str = Field(description="Search query"), | |
limit: int = Field(10, description="Max results") | |
) -> str: | |
# Actual search logic | |
results = [f"Result {i} for {query}" for i in range(limit)] | |
return json.dumps(results) | |
``` | |
""" | |
def __new__(cls, name: str, description: Optional[str] = None, version: str = "1.0"): | |
"""Implement singleton pattern, return existing instance if one with same name exists | |
Args: | |
name: Server name | |
description: Server description | |
version: Server version | |
""" | |
# Check if instance with same name already exists | |
if name in _FUNCTION_TOOLS_REGISTRY: | |
logger.info(f"Returning existing FunctionTools instance: {name}") | |
return _FUNCTION_TOOLS_REGISTRY[name] | |
# Create new instance | |
instance = super().__new__(cls) | |
return instance | |
def __init__(self, name: str, description: Optional[str] = None, version: str = "1.0"): | |
"""Initialize function tools server | |
Args: | |
name: Server name | |
description: Server description | |
version: Server version | |
""" | |
# Skip if already initialized | |
if hasattr(self, 'name') and self.name == name: | |
return | |
self.name = name | |
self.description = description or f"Function tools server: {name}" | |
self.version = version | |
self.tools = {} | |
# Register server to global registry | |
_register_function_tools(self) | |
def tool(self, description: Optional[str] = None, parameters: Optional[Dict[str, Any]] = None): | |
"""Tool function decorator | |
Args: | |
description: Tool description | |
parameters: Additional parameter information to supplement auto-generated parameter schema | |
Returns: | |
Decorator function | |
""" | |
def decorator(func): | |
# Get function metadata | |
tool_name = func.__name__ | |
tool_desc = description or f"Tool function: {tool_name}" | |
# Auto-generate parameter schema from function signature | |
param_schema = self._generate_param_schema(func, parameters) | |
# Register tool | |
self._register_tool(tool_name, func, tool_desc, param_schema) | |
# Return original function, maintaining its callable nature | |
return func | |
return decorator | |
def _register_tool(self, name: str, func, description: str, param_schema: Dict[str, Any]): | |
"""Register tool to server""" | |
self.tools[name] = { | |
"function": func, | |
"description": description, | |
"parameters": param_schema, | |
"is_async": inspect.iscoroutinefunction(func) | |
} | |
logger.info(f"Registered tool '{name}' to server '{self.name}'") | |
def _generate_param_schema(self, func, additional_params: Optional[Dict[str, Any]] = None): | |
"""Generate parameter schema from function signature, maintaining MCP sample format""" | |
# Get function signature and type annotations | |
sig = inspect.signature(func) | |
type_hints = get_type_hints(func) | |
properties = {} | |
required = [] | |
# Process each parameter | |
for name, param in sig.parameters.items(): | |
# Skip self parameter | |
if name == 'self': | |
continue | |
param_type = type_hints.get(name, inspect.Parameter.empty) | |
has_default = param.default != inspect.Parameter.empty | |
# Build parameter properties | |
param_info = self._type_to_schema(param_type) | |
# Add title field - space-separated capitalized words | |
param_info["title"] = " ".join(word.capitalize() for word in name.split("_")) | |
# Handle Field decorator | |
if has_default and isinstance(param.default, FieldInfo): | |
field_info = param.default | |
# Add description | |
if field_info.description: | |
param_info["description"] = field_info.description | |
# Only add default field when Field has actual default value | |
if field_info.default is not None and field_info.default is not ...: | |
# Simple check to ensure it's not PydanticUndefined | |
if not str(field_info.default).endswith("PydanticUndefined"): | |
param_info["default"] = field_info.default | |
else: | |
# No actual default value, add to required | |
required.append(name) | |
else: | |
# No default value, add to required | |
required.append(name) | |
# Handle regular default values | |
elif has_default and param.default is not None: | |
param_info["default"] = param.default | |
else: | |
# Parameters without default values are required | |
required.append(name) | |
# Add description (if provided in additional_params) | |
if additional_params and name in additional_params: | |
param_info.update(additional_params[name]) | |
properties[name] = param_info | |
# Special handling: ensure query_list is in required list | |
if "query_list" in properties and "query_list" not in required: | |
required.append("query_list") | |
# Create schema consistent with MCP sample | |
schema = { | |
"properties": properties, | |
"type": "object", | |
"required": required, | |
"title": func.__name__ + "Arguments" | |
} | |
return schema | |
def _type_to_schema(self, type_hint): | |
"""Convert Python type to JSON Schema type""" | |
import typing | |
# Basic type mapping | |
if type_hint == str: | |
return {"type": "string"} | |
elif type_hint == int: | |
return {"type": "integer"} | |
elif type_hint == float: | |
return {"type": "number"} | |
elif type_hint == bool: | |
return {"type": "boolean"} | |
elif type_hint == list or getattr(type_hint, "__origin__", None) == list: | |
item_type = getattr(type_hint, "__args__", [None])[0] | |
return { | |
"type": "array", | |
"items": self._type_to_schema(item_type) | |
} | |
elif type_hint == dict or getattr(type_hint, "__origin__", None) == dict: | |
return {"type": "object"} | |
else: | |
# Default to string type | |
return {"type": "string"} | |
def list_tools(self) -> List[MCPTool]: | |
"""List all tools and their descriptions | |
Returns: | |
List of MCPTool objects | |
""" | |
mcp_tools = [] | |
for name, info in self.tools.items(): | |
# Create MCPTool object, consistent with MCP sample format | |
mcp_tool = MCPTool( | |
name=name, | |
description=info["description"], | |
inputSchema=info["parameters"] | |
# Don't set annotations field | |
) | |
mcp_tools.append(mcp_tool) | |
return mcp_tools | |
async def call_tool_async(self, tool_name: str, arguments: Optional[Dict[str, Any]] = None): | |
"""Asynchronously call the specified tool function | |
Args: | |
tool_name: Tool name | |
arguments: Tool arguments | |
Returns: | |
Tool call result | |
Raises: | |
ValueError: When tool doesn't exist | |
Exception: Exceptions during tool execution | |
""" | |
if tool_name not in self.tools: | |
raise ValueError(f"Tool '{tool_name}' not found in server '{self.name}'") | |
tool_info = self.tools[tool_name] | |
func = tool_info["function"] | |
is_async = tool_info["is_async"] | |
arguments = arguments or {} | |
# Filter parameters, only keep parameters defined in the function | |
filtered_args = self._filter_arguments(func, arguments) | |
try: | |
# Call based on function type | |
if is_async: | |
# Async call | |
result = await func(**filtered_args) | |
else: | |
# Sync call | |
import asyncio | |
# Use run_in_executor to run sync function, avoid blocking | |
loop = asyncio.get_event_loop() | |
result = await loop.run_in_executor(None, lambda: func(**filtered_args)) | |
return self._format_result(result) | |
except Exception as e: | |
logger.error(f"Error calling tool '{tool_name}': {str(e)}") | |
logger.debug(traceback.format_exc()) | |
# Return error message | |
return CallToolResult( | |
content=[TextContent(type="text", text=f"Error: {str(e)}")] | |
) | |
def call_tool(self, tool_name: str, arguments: Optional[Dict[str, Any]] = None): | |
"""Synchronously call the specified tool function | |
For async tools, it will run in the event loop. | |
Args: | |
tool_name: Tool name | |
arguments: Tool arguments | |
Returns: | |
Tool call result | |
Raises: | |
ValueError: When tool doesn't exist | |
Exception: Exceptions during tool execution | |
""" | |
if tool_name not in self.tools: | |
raise ValueError(f"Tool '{tool_name}' not found in server '{self.name}'") | |
tool_info = self.tools[tool_name] | |
func = tool_info["function"] | |
is_async = tool_info["is_async"] | |
arguments = arguments or {} | |
# Filter parameters, only keep parameters defined in the function | |
filtered_args = self._filter_arguments(func, arguments) | |
try: | |
# Call based on function type | |
if is_async: | |
# Async functions need to run in event loop | |
import asyncio | |
# Safer way to handle async calls | |
try: | |
# Check if already in event loop | |
running_loop = asyncio._get_running_loop() | |
if running_loop is not None: | |
# Already in event loop, use nest_asyncio to solve nesting issues | |
try: | |
import nest_asyncio | |
nest_asyncio.apply() | |
logger.debug(f"Applied nest_asyncio for {tool_name}") | |
except ImportError: | |
logger.warning("nest_asyncio not available, using alternative approach") | |
# If nest_asyncio not available, use alternative method | |
# Create new thread to run async function | |
import threading | |
import queue | |
result_queue = queue.Queue() | |
def run_async_in_thread(): | |
try: | |
# Create new event loop | |
new_loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(new_loop) | |
# Run async function | |
result = new_loop.run_until_complete(func(**filtered_args)) | |
# Put in queue | |
result_queue.put(("result", result)) | |
except Exception as e: | |
# Put in queue | |
result_queue.put(("error", e)) | |
finally: | |
new_loop.close() | |
# Start thread | |
thread = threading.Thread(target=run_async_in_thread) | |
thread.start() | |
thread.join(timeout=60) # Wait up to 60 seconds | |
if thread.is_alive(): | |
raise TimeoutError(f"Timeout waiting for {tool_name} to complete") | |
# Get result | |
result_type, result_value = result_queue.get() | |
if result_type == "error": | |
raise result_value | |
result = result_value | |
return self._format_result(result) | |
# Get or create event loop | |
try: | |
loop = asyncio.get_event_loop() | |
except RuntimeError: | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
# Run async function | |
result = loop.run_until_complete(func(**filtered_args)) | |
except RuntimeError as e: | |
if "This event loop is already running" in str(e): | |
# If event loop already running, use thread method | |
logger.warning(f"Event loop already running, using thread approach for {tool_name}") | |
import threading | |
import queue | |
result_queue = queue.Queue() | |
def run_async_in_thread(): | |
try: | |
# Create new event loop | |
new_loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(new_loop) | |
# Run async function | |
result = new_loop.run_until_complete(func(**filtered_args)) | |
# Put in queue | |
result_queue.put(("result", result)) | |
except Exception as e: | |
# Put in queue | |
result_queue.put(("error", e)) | |
finally: | |
new_loop.close() | |
# Start thread | |
thread = threading.Thread(target=run_async_in_thread) | |
thread.start() | |
thread.join(timeout=60) # Wait up to 60 seconds | |
if thread.is_alive(): | |
raise TimeoutError(f"Timeout waiting for {tool_name} to complete") | |
# Get result | |
result_type, result_value = result_queue.get() | |
if result_type == "error": | |
raise result_value | |
result = result_value | |
else: | |
# Other RuntimeError | |
raise | |
else: | |
# Sync call | |
result = func(**filtered_args) | |
return self._format_result(result) | |
except Exception as e: | |
logger.error(f"Error calling tool '{tool_name}': {str(e)}") | |
logger.debug(traceback.format_exc()) | |
# Return error message | |
return CallToolResult( | |
content=[TextContent(type="text", text=f"Error: {str(e)}")] | |
) | |
def _filter_arguments(self, func, arguments: Dict[str, Any]) -> Dict[str, Any]: | |
"""Filter arguments, only keep parameters defined in the function | |
Args: | |
func: Function to call | |
arguments: Input argument dictionary | |
Returns: | |
Filtered argument dictionary | |
""" | |
# Get function signature | |
sig = inspect.signature(func) | |
param_names = set(sig.parameters.keys()) | |
# Filter arguments | |
filtered_args = {} | |
for name, value in arguments.items(): | |
if name in param_names: | |
filtered_args[name] = value | |
else: | |
# Log filtered arguments | |
logger.debug(f"Filtered out argument '{name}' not defined in function {func.__name__}") | |
return filtered_args | |
def _format_result(self, result): | |
"""Format function return value to MCP compatible format""" | |
# If result is already MCP type, return directly | |
if isinstance(result, CallToolResult): | |
return result | |
# Create content list | |
content = [] | |
# Handle different result types | |
if isinstance(result, str): | |
# String result | |
content.append(TextContent(type="text", text=result)) | |
elif isinstance(result, bytes): | |
# Image data | |
import base64 | |
image_base64 = base64.b64encode(result).decode('utf-8') | |
content.append(ImageContent(type="image", data=image_base64)) | |
elif isinstance(result, TextContent): | |
# If already TextContent, use directly | |
content.append(result) | |
elif isinstance(result, dict): | |
if result.get("type") in ["text", "image"]: | |
# Dictionary already in content format | |
if result["type"] == "text": | |
# Ensure text field is plain text, without type= format issues | |
text_content = result.get("text", "") | |
# If text field looks like serialized content, try to extract actual text | |
if isinstance(text_content, str) and text_content.startswith("type="): | |
# Try to extract actual text content | |
import re | |
match = re.search(r"text=['\"](.+?)['\"]", text_content) | |
if match: | |
text_content = match.group(1) | |
content.append(TextContent(type="text", text=text_content)) | |
elif result["type"] == "image": | |
content.append(ImageContent(type="image", data=result.get("data", ""))) | |
elif "metadata" in result and "text" in result: | |
# Special handling for results with metadata | |
content.append(TextContent( | |
type="text", | |
text=result["text"], | |
metadata=result["metadata"] | |
)) | |
else: | |
# Other dictionary types, convert to JSON | |
try: | |
content.append(TextContent(type="text", text=json.dumps(result, ensure_ascii=False))) | |
except: | |
content.append(TextContent(type="text", text=str(result))) | |
else: | |
# Other types try JSON serialization | |
try: | |
content.append(TextContent(type="text", text=json.dumps(result, ensure_ascii=False))) | |
except: | |
content.append(TextContent(type="text", text=str(result))) | |
return CallToolResult(content=content) | |
class FunctionToolsAdapter: | |
"""Adapter base class for adapting FunctionTools to MCPServer interface | |
This class provides basic adaptation functionality, but needs to be inherited and extended in specific implementations. | |
""" | |
def __init__(self, name: str): | |
"""Initialize adapter | |
Args: | |
name: Function tools server name | |
""" | |
self._function_tools = get_function_tools(name) | |
if not self._function_tools: | |
raise ValueError(f"FunctionTools '{name}' not found") | |
self._name = name | |
def name(self) -> str: | |
"""Server name""" | |
return self._name | |
async def list_tools(self) -> List[MCPTool]: | |
"""List all tools and their descriptions""" | |
return self._function_tools.list_tools() | |
async def call_tool(self, tool_name: str, arguments: Optional[Dict[str, Any]] = None): | |
"""Asynchronously call the specified tool function""" | |
return await self._function_tools.call_tool_async(tool_name, arguments) | |
def to_action_result(self, result) -> ActionResult: | |
"""Convert call result to ActionResult | |
This method is used to convert MCP call results to AWorld framework's ActionResult objects. | |
Args: | |
result: MCP call result | |
Returns: | |
ActionResult object | |
""" | |
action_result = ActionResult( | |
content="", | |
keep=True | |
) | |
if result and result.content: | |
if len(result.content) > 0: | |
if isinstance(result.content[0], TextContent): | |
action_result = ActionResult( | |
content=result.content[0].text, | |
keep=True, | |
metadata=getattr(result.content[0], "metadata", {}) | |
) | |
elif isinstance(result.content[0], ImageContent): | |
action_result = ActionResult( | |
content=f"data:image/jpeg;base64,{result.content[0].data}", | |
keep=True, | |
metadata=getattr(result.content[0], "metadata", {}) | |
) | |
return action_result |