Spaces:
Paused
Paused
import os | |
import json | |
import requests | |
from fastapi import FastAPI, Request | |
from fastapi.responses import Response | |
from fastapi.middleware.cors import CORSMiddleware | |
from mcp.server.lowlevel import Server, NotificationOptions | |
from mcp.server.sse import SseServerTransport | |
from mcp import types as mcp_types | |
import uvicorn | |
from sse_starlette import EventSourceResponse | |
import anyio | |
import asyncio | |
import logging | |
from typing import Dict | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# Add CORS middleware to allow Deep Agent to connect | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Adjust for production | |
allow_credentials=True, | |
allow_methods=["GET", "POST", "OPTIONS"], | |
allow_headers=["*"], | |
) | |
# Load environment variables | |
AIRTABLE_API_TOKEN = os.getenv("AIRTABLE_API_TOKEN") | |
AIRTABLE_BASE_ID = os.getenv("AIRTABLE_BASE_ID") | |
TABLE_ID = "tblQECi5f7m4y2NEV" | |
AIRTABLE_API_URL = f"https://api.airtable.com/v0/{AIRTABLE_BASE_ID}/{TABLE_ID}" | |
# Helper function for Airtable API requests | |
def airtable_request(method, endpoint="", data=None): | |
headers = { | |
"Authorization": f"Bearer {AIRTABLE_API_TOKEN}", | |
"Content-Type": "application/json" | |
} | |
url = f"{AIRTABLE_API_URL}/{endpoint}" if endpoint else AIRTABLE_API_URL | |
response = requests.request(method, url, headers=headers, json=data) | |
response.raise_for_status() | |
return response.json() | |
# Tool to list records | |
async def list_records_tool(request: mcp_types.CallToolRequest): | |
logger.debug(f"Received list_records_tool request: {request}") | |
try: | |
records = airtable_request("GET") | |
response = { | |
"success": True, | |
"result": json.dumps(records) | |
} | |
logger.debug(f"list_records_tool response: {response}") | |
return response | |
except Exception as e: | |
response = { | |
"success": False, | |
"error": str(e) | |
} | |
logger.error(f"list_records_tool error: {response}") | |
return response | |
# Tool to create a record | |
async def create_record_tool(request: mcp_types.CallToolRequest): | |
logger.debug(f"Received create_record_tool request: {request}") | |
try: | |
record_data = request.input.get("record_data", {}) | |
data = {"records": [{"fields": record_data}]} | |
response_data = airtable_request("POST", data=data) | |
response = { | |
"success": True, | |
"result": json.dumps(response_data) | |
} | |
logger.debug(f"create_record_tool response: {response}") | |
return response | |
except Exception as e: | |
response = { | |
"success": False, | |
"error": str(e) | |
} | |
logger.error(f"create_record_tool error: {response}") | |
return response | |
# Define tools separately (for Deep Agent to discover them) | |
tools = [ | |
mcp_types.Tool( | |
name="list_airtable_records", | |
description="Lists all records in the specified Airtable table", | |
inputSchema={} | |
), | |
mcp_types.Tool( | |
name="create_airtable_record", | |
description="Creates a new record in the specified Airtable table", | |
inputSchema={"record_data": {"type": "object"}} | |
) | |
] | |
# Define tool handlers | |
tool_handlers = { | |
"list_airtable_records": list_records_tool, | |
"create_airtable_record": create_record_tool | |
} | |
# Create MCP server | |
mcp_server = Server(name="airtable-mcp") | |
mcp_server.tool_handlers = tool_handlers # Set as attribute | |
mcp_server.tools = tools # Set tools as attribute for Deep Agent to discover | |
# Store write streams for each session ID (for SseServerTransport messages) | |
write_streams: Dict[str, anyio.streams.memory.MemoryObjectSendStream] = {} | |
# Store SSE stream writers for each session ID (for manual messages) | |
sse_stream_writers: Dict[str, anyio.streams.memory.MemoryObjectSendStream] = {} | |
# Initialize SseServerTransport | |
transport = SseServerTransport("/airtable/mcp") | |
# SSE endpoint for GET requests | |
async def handle_sse(request: Request): | |
logger.debug("Handling SSE connection request") | |
session_id = None # We'll extract this later | |
async def sse_writer(): | |
nonlocal session_id | |
logger.debug("Starting SSE writer") | |
async with sse_stream_writer, write_stream_reader: | |
# Send the initial endpoint event manually to capture the session_id | |
endpoint_data = "/airtable/mcp?session_id={session_id}" | |
await sse_stream_writer.send( | |
{"event": "endpoint", "data": endpoint_data} | |
) | |
logger.debug(f"Sent endpoint event: {endpoint_data}") | |
async for session_message in write_stream_reader: | |
# Handle messages from SseServerTransport | |
if hasattr(session_message, 'message'): | |
message_data = session_message.message.model_dump_json(by_alias=True, exclude_none=True) | |
event_data = json.loads(message_data) | |
logger.debug(f"Received SessionMessage from SseServerTransport: {event_data}") | |
else: | |
event_data = session_message | |
logger.debug(f"Received dict event from SseServerTransport: {event_data}") | |
# Extract session_id from the endpoint event | |
if not session_id and event_data.get("event") == "endpoint": | |
endpoint_url = event_data.get("data", "") | |
if "session_id=" in endpoint_url: | |
session_id = endpoint_url.split("session_id=")[1] | |
placeholder_id = f"placeholder_{id(write_stream)}" | |
if placeholder_id in write_streams: | |
write_streams[session_id] = write_streams.pop(placeholder_id) | |
sse_stream_writers[session_id] = sse_stream_writer | |
logger.debug(f"Updated placeholder {placeholder_id} to session_id {session_id}") | |
# Forward the event to the client | |
await sse_stream_writer.send({ | |
"event": event_data.get("event", "message"), | |
"data": event_data.get("data", json.dumps(event_data)) | |
}) | |
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0) | |
try: | |
async with transport.connect_sse(request.scope, request.receive, request._send) as streams: | |
read_stream, write_stream = streams | |
write_stream_reader = write_stream # Since streams are MemoryObject streams | |
# Store the write_stream with a placeholder ID | |
placeholder_id = f"placeholder_{id(write_stream)}" | |
write_streams[placeholder_id] = write_stream | |
logger.debug(f"Stored write_stream with placeholder_id: {placeholder_id}") | |
logger.debug("Running MCP server with streams") | |
await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options()) | |
except Exception as e: | |
logger.error(f"Error in handle_sse: {str(e)}") | |
# Clean up write_streams and sse_stream_writers on error | |
placeholder_id = f"placeholder_{id(write_stream)}" | |
write_streams.pop(placeholder_id, None) | |
if session_id: | |
write_streams.pop(session_id, None) | |
sse_stream_writers.pop(session_id, None) | |
raise | |
return EventSourceResponse(sse_stream_reader, data_sender_callable=sse_writer) | |
# Message handling endpoint for POST requests | |
async def handle_post_message(request: Request): | |
logger.debug("Handling POST message request") | |
body = await request.body() | |
logger.debug(f"Received POST message body: {body}") | |
try: | |
message = json.loads(body.decode()) | |
session_id = request.query_params.get("session_id") | |
# Use sse_stream_writers to send manual responses directly | |
sse_writer = sse_stream_writers.get(session_id) if session_id else None | |
write_stream = write_streams.get(session_id) if session_id else None | |
if message.get("method") == "initialize" and sse_writer: | |
logger.debug("Handling initialize request manually") | |
response = { | |
"jsonrpc": "2.0", | |
"id": message.get("id"), | |
"result": { | |
"protocolVersion": "2025-03-26", | |
"capabilities": { | |
"tools": { | |
"listChanged": True | |
}, | |
"prompts": { | |
"listChanged": False | |
}, | |
"resources": { | |
"subscribe": False, | |
"listChanged": False | |
}, | |
"logging": {}, | |
"experimental": {} | |
}, | |
"serverInfo": { | |
"name": "airtable-mcp", | |
"version": "1.0.0" | |
}, | |
"instructions": "Airtable MCP server for listing and creating records." | |
} | |
} | |
logger.debug(f"Manual initialize response: {response}") | |
response_data = json.dumps(response) | |
await sse_writer.send({ | |
"event": "message", | |
"data": response_data | |
}) | |
logger.debug(f"Sent initialize response directly via SSE for session {session_id}") | |
return Response(status_code=202) | |
if message.get("method") == "tools/list": | |
logger.debug("Handling tools/list request manually") | |
response = { | |
"jsonrpc": "2.0", | |
"id": message.get("id"), | |
"result": { | |
"tools": [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools], | |
"nextCursor": None | |
} | |
} | |
logger.debug(f"Manual tools/list response: {response}") | |
response_data = json.dumps(response) | |
sent = False | |
# First, try sending directly via sse_writer | |
if sse_writer: | |
try: | |
await sse_writer.send({ | |
"event": "message", | |
"data": response_data | |
}) | |
logger.debug(f"Sent tools/list response directly via SSE for session {session_id}") | |
sent = True | |
except Exception as e: | |
logger.error(f"Error sending to session {session_id} via sse_writer: {str(e)}") | |
sse_stream_writers.pop(session_id, None) | |
# If not found or failed, look for a placeholder ID and update it | |
if not sent and write_stream: | |
for sid, ws in list(write_streams.items()): | |
if sid.startswith("placeholder_"): | |
try: | |
write_streams[session_id] = ws | |
sse_stream_writers[session_id] = sse_writer | |
write_streams.pop(sid, None) | |
await sse_writer.send({ | |
"event": "message", | |
"data": response_data | |
}) | |
logger.debug(f"Updated placeholder {sid} to session_id {session_id} and sent tools/list response") | |
sent = True | |
break | |
except Exception as e: | |
logger.error(f"Error sending to placeholder {sid}: {str(e)}") | |
write_streams.pop(sid, None) | |
sse_stream_writers.pop(session_id, None) | |
if not sent: | |
logger.warning(f"Failed to send tools/list response: no active write_streams or sse_writer found") | |
return Response(status_code=202) | |
# If neither sse_writer nor write_stream is available, log and handle gracefully | |
if not sse_writer and not write_stream: | |
logger.error(f"No sse_writer or write_stream found for session_id: {session_id}") | |
return Response(status_code=202) | |
await transport.handle_post_message(request.scope, request.receive, request._send) | |
logger.debug("POST message handled successfully") | |
except Exception as e: | |
logger.error(f"Error handling POST message: {str(e)}") | |
return Response(status_code=202) | |
return Response(status_code=202) | |
# Health check endpoint | |
async def health_check(): | |
return {"status": "healthy"} | |
# Endpoint to list tools (for debugging) | |
async def list_tools(): | |
return {"tools": [tool.model_dump() for tool in tools]} | |
if __name__ == "__main__": | |
port = int(os.getenv("PORT", 7860)) | |
uvicorn.run(app, host="0.0.0.0", port=port) |