Spaces:
Running
Running
import base64 | |
import contextlib | |
import os | |
import re | |
import tempfile | |
import warnings | |
from collections.abc import AsyncIterator, Sequence | |
from io import BytesIO | |
from pathlib import Path | |
from typing import TYPE_CHECKING, Any | |
import gradio_client.utils as client_utils | |
from mcp import types | |
from mcp.server import Server | |
from mcp.server.sse import SseServerTransport | |
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager | |
from PIL import Image | |
from starlette.applications import Starlette | |
from starlette.requests import Request | |
from starlette.responses import JSONResponse, Response | |
from starlette.routing import Mount, Route | |
from starlette.types import Receive, Scope, Send | |
from gradio import processing_utils, route_utils, utils | |
from gradio.blocks import BlockFunction | |
from gradio.components import State | |
from gradio.data_classes import FileData | |
if TYPE_CHECKING: | |
from gradio.blocks import BlockContext, Blocks | |
from gradio.components import Component | |
DEFAULT_TEMP_DIR = os.environ.get("GRADIO_TEMP_DIR") or str( | |
Path(tempfile.gettempdir()) / "gradio" | |
) | |
class GradioMCPServer: | |
""" | |
A class for creating an MCP server around a Gradio app. | |
Args: | |
blocks: The Blocks app to create the MCP server for. | |
""" | |
def __init__(self, blocks: "Blocks", root_path: str): | |
self.blocks = blocks | |
self.api_info = self.blocks.get_api_info() | |
self.mcp_server = self.create_mcp_server() | |
self.request = None | |
self.root_url = None | |
tool_prefix = utils.get_space() | |
if tool_prefix: | |
tool_prefix = tool_prefix.split("/")[-1] + "_" | |
self.tool_prefix = re.sub(r"[^a-zA-Z0-9]", "_", tool_prefix) | |
else: | |
self.tool_prefix = "" | |
self.tool_to_endpoint = self.get_tool_to_endpoint() | |
self.warn_about_state_inputs() | |
manager = StreamableHTTPSessionManager( | |
app=self.mcp_server, json_response=False, stateless=True | |
) | |
async def handle_streamable_http( | |
scope: Scope, receive: Receive, send: Send | |
) -> None: | |
request = Request(scope, receive) | |
self.request = request | |
self.root_url = route_utils.get_root_url( | |
request=request, | |
route_path="/gradio_api/mcp/http", | |
root_path=root_path, | |
) | |
await manager.handle_request(scope, receive, send) | |
async def lifespan(app: Starlette) -> AsyncIterator[None]: # noqa: ARG001 | |
"""Context manager for managing session manager lifecycle.""" | |
async with manager.run(): | |
try: | |
yield | |
finally: | |
pass | |
self.lifespan = lifespan | |
self.manager = manager | |
self.handle_streamable_http = handle_streamable_http | |
def get_tool_to_endpoint(self) -> dict[str, str]: | |
""" | |
Gets all of the tools that are exposed by the Gradio app and also | |
creates a mapping from the tool names to the endpoint names in the API docs. | |
""" | |
tool_to_endpoint = {} | |
for endpoint_name, endpoint_info in self.api_info["named_endpoints"].items(): | |
if endpoint_info["show_api"]: | |
block_fn = self.get_block_fn_from_endpoint_name(endpoint_name) | |
if block_fn is None or block_fn.fn is None: | |
continue | |
fn_name = ( | |
getattr(block_fn.fn, "__name__", None) | |
or ( | |
hasattr(block_fn.fn, "__class__") | |
and getattr(block_fn.fn.__class__, "__name__", None) | |
) | |
or endpoint_name.lstrip("/") | |
) | |
tool_name = self.tool_prefix + fn_name | |
while tool_name in tool_to_endpoint: | |
tool_name = tool_name + "_" | |
tool_to_endpoint[tool_name] = endpoint_name | |
return tool_to_endpoint | |
def warn_about_state_inputs(self) -> None: | |
""" | |
Warn about tools that have gr.State inputs. | |
""" | |
for _, endpoint_name in self.tool_to_endpoint.items(): | |
block_fn = self.get_block_fn_from_endpoint_name(endpoint_name) | |
if block_fn and any(isinstance(input, State) for input in block_fn.inputs): | |
warnings.warn( | |
"This MCP server includes a tool that has a gr.State input, which will not be " | |
"updated between tool calls. The original, default value of the State will be " | |
"used each time." | |
) | |
def create_mcp_server(self) -> Server: | |
""" | |
Create an MCP server for the given Gradio Blocks app. | |
Parameters: | |
blocks: The Blocks app to create the MCP server for. | |
Returns: | |
The MCP server. | |
""" | |
server = Server(str(self.blocks.title or "Gradio App")) | |
async def call_tool( | |
name: str, arguments: dict[str, Any] | |
) -> list[types.TextContent | types.ImageContent]: | |
""" | |
Call a tool on the Gradio app. | |
Args: | |
name: The name of the tool to call. | |
arguments: The arguments to pass to the tool. | |
""" | |
_, filedata_positions = self.get_input_schema(name) | |
processed_kwargs = self.convert_strings_to_filedata( | |
arguments, filedata_positions | |
) | |
endpoint_name = self.tool_to_endpoint.get(name) | |
if endpoint_name is None: | |
raise ValueError(f"Unknown tool for this Gradio app: {name}") | |
block_fn = self.get_block_fn_from_endpoint_name(endpoint_name) | |
assert block_fn is not None # noqa: S101 | |
if endpoint_name in self.api_info["named_endpoints"]: | |
parameters_info = self.api_info["named_endpoints"][endpoint_name][ | |
"parameters" | |
] | |
processed_args = client_utils.construct_args( | |
parameters_info, | |
(), | |
processed_kwargs, | |
) | |
else: | |
processed_args = [] | |
processed_args = self.insert_empty_state(block_fn.inputs, processed_args) | |
output = await self.blocks.process_api( | |
block_fn=block_fn, | |
inputs=processed_args, | |
request=self.request, | |
) | |
processed_args = self.pop_returned_state(block_fn.inputs, processed_args) | |
return self.postprocess_output_data(output["data"]) | |
async def list_tools() -> list[types.Tool]: | |
""" | |
List all tools on the Gradio app. | |
""" | |
tools = [] | |
for tool_name, endpoint_name in self.tool_to_endpoint.items(): | |
block_fn = self.get_block_fn_from_endpoint_name(endpoint_name) | |
assert block_fn is not None and block_fn.fn is not None # noqa: S101 | |
description, parameters, returns = utils.get_function_description( | |
block_fn.fn | |
) | |
if returns: | |
description += ( | |
("" if description.endswith(".") else ".") | |
+ " Returns: " | |
+ ", ".join(returns) | |
) | |
schema, _ = self.get_input_schema(tool_name, parameters) | |
tools.append( | |
types.Tool( | |
name=tool_name, | |
description=description, | |
inputSchema=schema, | |
) | |
) | |
return tools | |
return server | |
def launch_mcp_on_sse(self, app: Starlette, subpath: str, root_path: str) -> None: | |
""" | |
Launch the MCP server on the SSE transport. | |
Parameters: | |
app: The Gradio app to mount the MCP server on. | |
subpath: The subpath to mount the MCP server on. E.g. "/gradio_api/mcp" | |
""" | |
messages_path = "/messages/" | |
sse = SseServerTransport(messages_path) | |
async def handle_sse(request): | |
self.request = request | |
self.root_url = route_utils.get_root_url( | |
request=request, | |
route_path="/gradio_api/mcp/sse", | |
root_path=root_path, | |
) | |
try: | |
async with sse.connect_sse( | |
request.scope, request.receive, request._send | |
) as streams: | |
await self.mcp_server.run( | |
streams[0], | |
streams[1], | |
self.mcp_server.create_initialization_options(), | |
) | |
return Response() | |
except Exception as e: | |
print(f"MCP SSE connection error: {str(e)}") | |
raise | |
app.mount( | |
subpath, | |
Starlette( | |
routes=[ | |
Route( | |
"/schema", | |
endpoint=self.get_complete_schema, # Not required for MCP but useful for debugging | |
), | |
Route("/sse", endpoint=handle_sse), | |
Mount("/messages/", app=sse.handle_post_message), | |
Mount("/http/", app=self.handle_streamable_http), | |
], | |
), | |
) | |
def get_block_fn_from_endpoint_name( | |
self, endpoint_name: str | |
) -> "BlockFunction | None": | |
""" | |
Get the BlockFunction for a given endpoint name (e.g. "/predict"). | |
Parameters: | |
endpoint_name: The name of the endpoint to get the BlockFunction for. | |
Returns: | |
The BlockFunction for the given endpoint name, or None if it is not found. | |
""" | |
block_fn = next( | |
( | |
fn | |
for fn in self.blocks.fns.values() | |
if fn.api_name == endpoint_name.lstrip("/") | |
), | |
None, | |
) | |
return block_fn | |
def insert_empty_state( | |
inputs: Sequence["Component | BlockContext"], data: list | |
) -> list: | |
""" | |
Insert None placeholder values for any State input components, as State inputs | |
are not included in the endpoint schema. | |
""" | |
for i, input_component_type in enumerate(inputs): | |
if isinstance(input_component_type, State): | |
data.insert(i, None) | |
return data | |
def pop_returned_state( | |
inputs: Sequence["Component | BlockContext"], data: list | |
) -> list: | |
""" | |
Remove any values corresponding to State output components from the data | |
as State outputs are not included in the endpoint schema. | |
""" | |
for i, input_component_type in enumerate(inputs): | |
if isinstance(input_component_type, State): | |
data.pop(i) | |
return data | |
def get_input_schema( | |
self, | |
tool_name: str, | |
parameters: dict[str, str] | None = None, | |
) -> tuple[dict[str, Any], list[list[str | int]]]: | |
""" | |
Get the input schema of the Gradio app API, appropriately formatted for MCP. | |
Parameters: | |
tool_name: The name of the tool to get the schema for, e.g. "predict" | |
parameters: The description and parameters of the tool to get the schema for. | |
Returns: | |
- The input schema of the Gradio app API. | |
- A list of positions of FileData objects in the input schema. | |
""" | |
endpoint_name = self.tool_to_endpoint.get(tool_name) | |
if endpoint_name is None: | |
raise ValueError(f"Unknown tool for this Gradio app: {tool_name}") | |
named_endpoints = self.api_info["named_endpoints"] | |
endpoint_info = named_endpoints.get(endpoint_name) | |
assert endpoint_info is not None # noqa: S101 | |
schema = { | |
"type": "object", | |
"properties": { | |
p["parameter_name"]: { | |
**p["type"], | |
**( | |
{"description": parameters[p["parameter_name"]]} | |
if parameters and p["parameter_name"] in parameters | |
else {} | |
), | |
**( | |
{"default": p["parameter_default"]} | |
if "parameter_default" in p and p["parameter_default"] | |
else {} | |
), | |
} | |
for p in endpoint_info["parameters"] | |
}, | |
} | |
return self.simplify_filedata_schema(schema) | |
async def get_complete_schema(self, request) -> JSONResponse: # noqa: ARG002 | |
""" | |
Get the complete schema of the Gradio app API. (For debugging purposes) | |
Parameters: | |
request: The Starlette request object. | |
Returns: | |
A JSONResponse containing a dictionary mapping tool names to their input schemas. | |
""" | |
if not self.api_info: | |
return JSONResponse({}) | |
schemas = [] | |
for tool_name, endpoint_name in self.tool_to_endpoint.items(): | |
block_fn = self.get_block_fn_from_endpoint_name(endpoint_name) | |
assert block_fn is not None and block_fn.fn is not None # noqa: S101 | |
description, parameters, returns = utils.get_function_description( | |
block_fn.fn | |
) | |
if returns: | |
description += ( | |
("" if description.endswith(".") else ".") | |
+ " Returns: " | |
+ ", ".join(returns) | |
) | |
schema, _ = self.get_input_schema(tool_name, parameters) | |
info = { | |
"name": tool_name, | |
"description": description, | |
"inputSchema": schema, | |
} | |
schemas.append(info) | |
return JSONResponse(schemas) | |
def simplify_filedata_schema( | |
self, schema: dict[str, Any] | |
) -> tuple[dict[str, Any], list[list[str | int]]]: | |
""" | |
Parses a schema of a Gradio app API to identify positions of FileData objects. Replaces them with base64 | |
strings while keeping track of their positions so that they can be converted back to FileData objects | |
later. | |
Parameters: | |
schema: The original schema of the Gradio app API. | |
Returns: | |
A tuple containing the simplified schema and the positions of the FileData objects. | |
""" | |
def is_gradio_filedata(obj: Any, defs: dict[str, Any]) -> bool: | |
if not isinstance(obj, dict): | |
return False | |
if "$ref" in obj: | |
ref = obj["$ref"] | |
if ref.startswith("#/$defs/"): | |
key = ref.split("/")[-1] | |
obj = defs.get(key, {}) | |
else: | |
return False | |
props = obj.get("properties", {}) | |
meta = props.get("meta", {}) | |
if "$ref" in meta: | |
ref = meta["$ref"] | |
if ref.startswith("#/$defs/"): | |
key = ref.split("/")[-1] | |
meta = defs.get(key, {}) | |
else: | |
return False | |
type_field = meta.get("properties", {}).get("_type", {}) | |
default_type = meta.get("default", {}).get("_type") | |
return ( | |
type_field.get("const") == "gradio.FileData" | |
or default_type == "gradio.FileData" | |
) | |
def traverse( | |
node: Any, | |
path: list[str | int] | None = None, | |
defs: dict[str, Any] | None = None, | |
) -> Any: | |
if path is None: | |
path = [] | |
if defs is None: | |
defs = {} | |
if isinstance(node, dict): | |
if "$defs" in node: | |
defs.update(node["$defs"]) | |
if is_gradio_filedata(node, defs): | |
filedata_positions.append(path.copy()) | |
for key in ["properties", "additional_description", "$defs"]: | |
node.pop(key, None) | |
node["type"] = "string" | |
node["format"] = "a http or https url to a file" | |
result = {} | |
is_schema_root = "type" in node and "properties" in node | |
for key, value in node.items(): | |
if is_schema_root and key == "properties": | |
result[key] = traverse(value, path, defs) | |
else: | |
path.append(key) | |
result[key] = traverse(value, path, defs) | |
path.pop() | |
return result | |
elif isinstance(node, list): | |
result = [] | |
for i, item in enumerate(node): | |
path.append(i) | |
result.append(traverse(item, path, defs)) | |
path.pop() | |
return result | |
return node | |
filedata_positions: list[list[str | int]] = [] | |
simplified_schema = traverse(schema) | |
return simplified_schema, filedata_positions | |
def convert_strings_to_filedata( | |
self, value: Any, filedata_positions: list[list[str | int]] | |
) -> Any: | |
""" | |
Convert specific string values back to FileData objects based on their positions. | |
This is used to convert string values (as base64 encoded strings) to FileData | |
dictionaries so that they can be passed into .preprocess() logic of a Gradio app. | |
Parameters: | |
value: The input data to process, which can be an arbitrary nested data structure | |
that may or may not contain strings that should be converted to FileData objects. | |
filedata_positions: List of paths to positions in the input data that should be converted to FileData objects. | |
Returns: | |
The processed data with strings converted to FileData objects where appropriate. Base64 | |
encoded strings are first saved to a temporary file and then converted to a FileData object. | |
Example: | |
>>> convert_strings_to_filedata( | |
{"image": "data:image/jpeg;base64,..."}, | |
[["image"]] | |
) | |
>>> {'image': FileData(path='<temporary file path>')}, | |
""" | |
def traverse(node: Any, path: list[str | int] | None = None) -> Any: | |
if path is None: | |
path = [] | |
if isinstance(node, dict): | |
return { | |
key: traverse(value, path + [key]) for key, value in node.items() | |
} | |
elif isinstance(node, list): | |
return [traverse(item, path + [i]) for i, item in enumerate(node)] | |
elif isinstance(node, str) and path in filedata_positions: | |
if node.startswith("data:"): | |
# Even though base64 is not officially part of our schema, some MCP clients | |
# might return base64 encoded strings, so try to save it to a temporary file. | |
return FileData( | |
path=processing_utils.save_base64_to_cache( | |
node, DEFAULT_TEMP_DIR | |
) | |
) | |
elif node.startswith(("http://", "https://")): | |
return FileData(path=node) | |
else: | |
raise ValueError( | |
f"Invalid file data format, provide a url ('http://...' or 'https://...'). Received: {node}" | |
) | |
return node | |
return traverse(value) | |
def get_image(file_path: str) -> Image.Image | None: | |
""" | |
If a filepath is a valid image, returns a PIL Image object. Otherwise returns None. | |
""" | |
if not os.path.exists(file_path): | |
return None | |
ext = os.path.splitext(file_path.lower())[1] | |
if ext not in Image.registered_extensions(): | |
return None | |
try: | |
return Image.open(file_path) | |
except Exception: | |
return None | |
def get_base64_data(image: Image.Image, format: str) -> str: | |
""" | |
Returns a base64 encoded string of the image. | |
""" | |
buffer = BytesIO() | |
image.save(buffer, format=format) | |
return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
def postprocess_output_data( | |
self, data: Any | |
) -> list[types.TextContent | types.ImageContent]: | |
""" | |
Postprocess the output data from the Gradio app to convert FileData objects back to base64 encoded strings. | |
Parameters: | |
data: The output data to postprocess. | |
""" | |
return_values = [] | |
if self.root_url: | |
data = processing_utils.add_root_url(data, self.root_url, None) | |
for output in data: | |
if client_utils.is_file_obj_with_meta(output): | |
if image := self.get_image(output["path"]): | |
image_format = image.format or "png" | |
base64_data = self.get_base64_data(image, image_format) | |
mimetype = f"image/{image_format.lower()}" | |
return_value = [ | |
types.ImageContent( | |
type="image", data=base64_data, mimeType=mimetype | |
), | |
types.TextContent( | |
type="text", | |
text=f"Image URL: {output['url'] or output['path']}", | |
), | |
] | |
else: | |
return_value = [ | |
types.TextContent( | |
type="text", text=str(output["url"] or output["path"]) | |
) | |
] | |
else: | |
return_value = [types.TextContent(type="text", text=str(output))] | |
return_values.extend(return_value) | |
return return_values | |