Duibonduil's picture
Upload 9 files
7d18ad1 verified
# coding: utf-8
# Copyright (c) 2025 inclusionAI.
import inspect
import json
import logging
import traceback
from typing import Any, Dict, List, Optional, Union, get_type_hints
from mcp.types import TextContent, ImageContent, CallToolResult
from mcp import Tool as MCPTool
from pydantic import Field, create_model
from pydantic.fields import FieldInfo # Import FieldInfo type
from aworld.core.common import ActionResult
from aworld.logs.util import logger
# Global function tools server registry
_FUNCTION_TOOLS_REGISTRY = {}
def _register_function_tools(function_tools):
"""Register function tools server to global registry"""
_FUNCTION_TOOLS_REGISTRY[function_tools.name] = function_tools
logger.info(f"Registered FunctionTools server: {function_tools.name}")
def get_function_tools(name):
"""Get specified function tools server"""
return _FUNCTION_TOOLS_REGISTRY.get(name)
def list_function_tools():
"""List all registered function tools servers"""
return list(_FUNCTION_TOOLS_REGISTRY.keys())
class FunctionTools:
"""Function tools server, providing tool registration and calling mechanism similar to MCP
Example:
```python
# Create function tools server
function = FunctionTools("my-server", description="My function tools server")
# Define tool function
@function.tool(description="Example search function")
def search(query: str, limit: int = 10) -> str:
# Actual search logic
results = [f"Result {i} for {query}" for i in range(limit)]
return json.dumps(results)
# Using Field decorator
@function.tool(description="Example search function")
def search(
query: str = Field(description="Search query"),
limit: int = Field(10, description="Max results")
) -> str:
# Actual search logic
results = [f"Result {i} for {query}" for i in range(limit)]
return json.dumps(results)
```
"""
def __new__(cls, name: str, description: Optional[str] = None, version: str = "1.0"):
"""Implement singleton pattern, return existing instance if one with same name exists
Args:
name: Server name
description: Server description
version: Server version
"""
# Check if instance with same name already exists
if name in _FUNCTION_TOOLS_REGISTRY:
logger.info(f"Returning existing FunctionTools instance: {name}")
return _FUNCTION_TOOLS_REGISTRY[name]
# Create new instance
instance = super().__new__(cls)
return instance
def __init__(self, name: str, description: Optional[str] = None, version: str = "1.0"):
"""Initialize function tools server
Args:
name: Server name
description: Server description
version: Server version
"""
# Skip if already initialized
if hasattr(self, 'name') and self.name == name:
return
self.name = name
self.description = description or f"Function tools server: {name}"
self.version = version
self.tools = {}
# Register server to global registry
_register_function_tools(self)
def tool(self, description: Optional[str] = None, parameters: Optional[Dict[str, Any]] = None):
"""Tool function decorator
Args:
description: Tool description
parameters: Additional parameter information to supplement auto-generated parameter schema
Returns:
Decorator function
"""
def decorator(func):
# Get function metadata
tool_name = func.__name__
tool_desc = description or f"Tool function: {tool_name}"
# Auto-generate parameter schema from function signature
param_schema = self._generate_param_schema(func, parameters)
# Register tool
self._register_tool(tool_name, func, tool_desc, param_schema)
# Return original function, maintaining its callable nature
return func
return decorator
def _register_tool(self, name: str, func, description: str, param_schema: Dict[str, Any]):
"""Register tool to server"""
self.tools[name] = {
"function": func,
"description": description,
"parameters": param_schema,
"is_async": inspect.iscoroutinefunction(func)
}
logger.info(f"Registered tool '{name}' to server '{self.name}'")
def _generate_param_schema(self, func, additional_params: Optional[Dict[str, Any]] = None):
"""Generate parameter schema from function signature, maintaining MCP sample format"""
# Get function signature and type annotations
sig = inspect.signature(func)
type_hints = get_type_hints(func)
properties = {}
required = []
# Process each parameter
for name, param in sig.parameters.items():
# Skip self parameter
if name == 'self':
continue
param_type = type_hints.get(name, inspect.Parameter.empty)
has_default = param.default != inspect.Parameter.empty
# Build parameter properties
param_info = self._type_to_schema(param_type)
# Add title field - space-separated capitalized words
param_info["title"] = " ".join(word.capitalize() for word in name.split("_"))
# Handle Field decorator
if has_default and isinstance(param.default, FieldInfo):
field_info = param.default
# Add description
if field_info.description:
param_info["description"] = field_info.description
# Only add default field when Field has actual default value
if field_info.default is not None and field_info.default is not ...:
# Simple check to ensure it's not PydanticUndefined
if not str(field_info.default).endswith("PydanticUndefined"):
param_info["default"] = field_info.default
else:
# No actual default value, add to required
required.append(name)
else:
# No default value, add to required
required.append(name)
# Handle regular default values
elif has_default and param.default is not None:
param_info["default"] = param.default
else:
# Parameters without default values are required
required.append(name)
# Add description (if provided in additional_params)
if additional_params and name in additional_params:
param_info.update(additional_params[name])
properties[name] = param_info
# Special handling: ensure query_list is in required list
if "query_list" in properties and "query_list" not in required:
required.append("query_list")
# Create schema consistent with MCP sample
schema = {
"properties": properties,
"type": "object",
"required": required,
"title": func.__name__ + "Arguments"
}
return schema
def _type_to_schema(self, type_hint):
"""Convert Python type to JSON Schema type"""
import typing
# Basic type mapping
if type_hint == str:
return {"type": "string"}
elif type_hint == int:
return {"type": "integer"}
elif type_hint == float:
return {"type": "number"}
elif type_hint == bool:
return {"type": "boolean"}
elif type_hint == list or getattr(type_hint, "__origin__", None) == list:
item_type = getattr(type_hint, "__args__", [None])[0]
return {
"type": "array",
"items": self._type_to_schema(item_type)
}
elif type_hint == dict or getattr(type_hint, "__origin__", None) == dict:
return {"type": "object"}
else:
# Default to string type
return {"type": "string"}
def list_tools(self) -> List[MCPTool]:
"""List all tools and their descriptions
Returns:
List of MCPTool objects
"""
mcp_tools = []
for name, info in self.tools.items():
# Create MCPTool object, consistent with MCP sample format
mcp_tool = MCPTool(
name=name,
description=info["description"],
inputSchema=info["parameters"]
# Don't set annotations field
)
mcp_tools.append(mcp_tool)
return mcp_tools
async def call_tool_async(self, tool_name: str, arguments: Optional[Dict[str, Any]] = None):
"""Asynchronously call the specified tool function
Args:
tool_name: Tool name
arguments: Tool arguments
Returns:
Tool call result
Raises:
ValueError: When tool doesn't exist
Exception: Exceptions during tool execution
"""
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not found in server '{self.name}'")
tool_info = self.tools[tool_name]
func = tool_info["function"]
is_async = tool_info["is_async"]
arguments = arguments or {}
# Filter parameters, only keep parameters defined in the function
filtered_args = self._filter_arguments(func, arguments)
try:
# Call based on function type
if is_async:
# Async call
result = await func(**filtered_args)
else:
# Sync call
import asyncio
# Use run_in_executor to run sync function, avoid blocking
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, lambda: func(**filtered_args))
return self._format_result(result)
except Exception as e:
logger.error(f"Error calling tool '{tool_name}': {str(e)}")
logger.debug(traceback.format_exc())
# Return error message
return CallToolResult(
content=[TextContent(type="text", text=f"Error: {str(e)}")]
)
def call_tool(self, tool_name: str, arguments: Optional[Dict[str, Any]] = None):
"""Synchronously call the specified tool function
For async tools, it will run in the event loop.
Args:
tool_name: Tool name
arguments: Tool arguments
Returns:
Tool call result
Raises:
ValueError: When tool doesn't exist
Exception: Exceptions during tool execution
"""
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not found in server '{self.name}'")
tool_info = self.tools[tool_name]
func = tool_info["function"]
is_async = tool_info["is_async"]
arguments = arguments or {}
# Filter parameters, only keep parameters defined in the function
filtered_args = self._filter_arguments(func, arguments)
try:
# Call based on function type
if is_async:
# Async functions need to run in event loop
import asyncio
# Safer way to handle async calls
try:
# Check if already in event loop
running_loop = asyncio._get_running_loop()
if running_loop is not None:
# Already in event loop, use nest_asyncio to solve nesting issues
try:
import nest_asyncio
nest_asyncio.apply()
logger.debug(f"Applied nest_asyncio for {tool_name}")
except ImportError:
logger.warning("nest_asyncio not available, using alternative approach")
# If nest_asyncio not available, use alternative method
# Create new thread to run async function
import threading
import queue
result_queue = queue.Queue()
def run_async_in_thread():
try:
# Create new event loop
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
# Run async function
result = new_loop.run_until_complete(func(**filtered_args))
# Put in queue
result_queue.put(("result", result))
except Exception as e:
# Put in queue
result_queue.put(("error", e))
finally:
new_loop.close()
# Start thread
thread = threading.Thread(target=run_async_in_thread)
thread.start()
thread.join(timeout=60) # Wait up to 60 seconds
if thread.is_alive():
raise TimeoutError(f"Timeout waiting for {tool_name} to complete")
# Get result
result_type, result_value = result_queue.get()
if result_type == "error":
raise result_value
result = result_value
return self._format_result(result)
# Get or create event loop
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Run async function
result = loop.run_until_complete(func(**filtered_args))
except RuntimeError as e:
if "This event loop is already running" in str(e):
# If event loop already running, use thread method
logger.warning(f"Event loop already running, using thread approach for {tool_name}")
import threading
import queue
result_queue = queue.Queue()
def run_async_in_thread():
try:
# Create new event loop
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
# Run async function
result = new_loop.run_until_complete(func(**filtered_args))
# Put in queue
result_queue.put(("result", result))
except Exception as e:
# Put in queue
result_queue.put(("error", e))
finally:
new_loop.close()
# Start thread
thread = threading.Thread(target=run_async_in_thread)
thread.start()
thread.join(timeout=60) # Wait up to 60 seconds
if thread.is_alive():
raise TimeoutError(f"Timeout waiting for {tool_name} to complete")
# Get result
result_type, result_value = result_queue.get()
if result_type == "error":
raise result_value
result = result_value
else:
# Other RuntimeError
raise
else:
# Sync call
result = func(**filtered_args)
return self._format_result(result)
except Exception as e:
logger.error(f"Error calling tool '{tool_name}': {str(e)}")
logger.debug(traceback.format_exc())
# Return error message
return CallToolResult(
content=[TextContent(type="text", text=f"Error: {str(e)}")]
)
def _filter_arguments(self, func, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Filter arguments, only keep parameters defined in the function
Args:
func: Function to call
arguments: Input argument dictionary
Returns:
Filtered argument dictionary
"""
# Get function signature
sig = inspect.signature(func)
param_names = set(sig.parameters.keys())
# Filter arguments
filtered_args = {}
for name, value in arguments.items():
if name in param_names:
filtered_args[name] = value
else:
# Log filtered arguments
logger.debug(f"Filtered out argument '{name}' not defined in function {func.__name__}")
return filtered_args
def _format_result(self, result):
"""Format function return value to MCP compatible format"""
# If result is already MCP type, return directly
if isinstance(result, CallToolResult):
return result
# Create content list
content = []
# Handle different result types
if isinstance(result, str):
# String result
content.append(TextContent(type="text", text=result))
elif isinstance(result, bytes):
# Image data
import base64
image_base64 = base64.b64encode(result).decode('utf-8')
content.append(ImageContent(type="image", data=image_base64))
elif isinstance(result, TextContent):
# If already TextContent, use directly
content.append(result)
elif isinstance(result, dict):
if result.get("type") in ["text", "image"]:
# Dictionary already in content format
if result["type"] == "text":
# Ensure text field is plain text, without type= format issues
text_content = result.get("text", "")
# If text field looks like serialized content, try to extract actual text
if isinstance(text_content, str) and text_content.startswith("type="):
# Try to extract actual text content
import re
match = re.search(r"text=['\"](.+?)['\"]", text_content)
if match:
text_content = match.group(1)
content.append(TextContent(type="text", text=text_content))
elif result["type"] == "image":
content.append(ImageContent(type="image", data=result.get("data", "")))
elif "metadata" in result and "text" in result:
# Special handling for results with metadata
content.append(TextContent(
type="text",
text=result["text"],
metadata=result["metadata"]
))
else:
# Other dictionary types, convert to JSON
try:
content.append(TextContent(type="text", text=json.dumps(result, ensure_ascii=False)))
except:
content.append(TextContent(type="text", text=str(result)))
else:
# Other types try JSON serialization
try:
content.append(TextContent(type="text", text=json.dumps(result, ensure_ascii=False)))
except:
content.append(TextContent(type="text", text=str(result)))
return CallToolResult(content=content)
class FunctionToolsAdapter:
"""Adapter base class for adapting FunctionTools to MCPServer interface
This class provides basic adaptation functionality, but needs to be inherited and extended in specific implementations.
"""
def __init__(self, name: str):
"""Initialize adapter
Args:
name: Function tools server name
"""
self._function_tools = get_function_tools(name)
if not self._function_tools:
raise ValueError(f"FunctionTools '{name}' not found")
self._name = name
@property
def name(self) -> str:
"""Server name"""
return self._name
async def list_tools(self) -> List[MCPTool]:
"""List all tools and their descriptions"""
return self._function_tools.list_tools()
async def call_tool(self, tool_name: str, arguments: Optional[Dict[str, Any]] = None):
"""Asynchronously call the specified tool function"""
return await self._function_tools.call_tool_async(tool_name, arguments)
def to_action_result(self, result) -> ActionResult:
"""Convert call result to ActionResult
This method is used to convert MCP call results to AWorld framework's ActionResult objects.
Args:
result: MCP call result
Returns:
ActionResult object
"""
action_result = ActionResult(
content="",
keep=True
)
if result and result.content:
if len(result.content) > 0:
if isinstance(result.content[0], TextContent):
action_result = ActionResult(
content=result.content[0].text,
keep=True,
metadata=getattr(result.content[0], "metadata", {})
)
elif isinstance(result.content[0], ImageContent):
action_result = ActionResult(
content=f"data:image/jpeg;base64,{result.content[0].data}",
keep=True,
metadata=getattr(result.content[0], "metadata", {})
)
return action_result