Spaces:
Running
Running
"""Auth proxy for FastMCP HTTP transport. | |
Minimal ASGI proxy that enforces a shared bearer token and forwards | |
requests to the internal FastMCP HTTP server while preserving SSE. | |
MVP scope to avoid bloat: | |
- Header auth preferred; `?key` fallback for Desktop bridge | |
- Strip auth before forwarding | |
- Streaming proxying for SSE (no buffering) | |
- Minimal `/health` endpoint | |
""" | |
from __future__ import annotations | |
import os | |
from typing import Dict, Tuple, List | |
import secrets | |
import httpx | |
from fastapi import FastAPI, HTTPException, Request, Response | |
from fastapi.responses import JSONResponse, StreamingResponse | |
app = FastAPI() | |
def _read_config() -> Tuple[List[str], str]: | |
# Support multiple tokens via MCP_AUTH_TOKENS (comma-separated) | |
multi = os.getenv("MCP_AUTH_TOKENS", "") | |
single = os.getenv("MCP_AUTH_TOKEN", "") | |
tokens: List[str] = [] | |
if multi: | |
tokens.extend([t.strip() for t in multi.split(",") if t.strip()]) | |
if single: | |
tokens.append(single.strip()) | |
upstream = f"http://{os.getenv('UPSTREAM_HOST', '127.0.0.1')}:{os.getenv('UPSTREAM_PORT', '7870')}" | |
return tokens, upstream | |
def _is_authorized(req: Request, tokens: List[str]) -> bool: | |
if not tokens: | |
# No token configured → open access (useful for local/dev) | |
return True | |
auth_header = req.headers.get("authorization", "") | |
if auth_header.startswith("Bearer "): | |
presented = auth_header.split(" ", 1)[1] | |
for t in tokens: | |
if secrets.compare_digest(presented, t): | |
return True | |
key_param = req.query_params.get("key") | |
if key_param: | |
for t in tokens: | |
if secrets.compare_digest(key_param, t): | |
return True | |
return False | |
async def health() -> Dict[str, str]: | |
return {"status": "ok"} | |
def _sanitize_forward_headers(original_headers: Dict[str, str]) -> Dict[str, str]: | |
# Remove Authorization before forwarding | |
sanitized = {k: v for k, v in original_headers.items() if k.lower() != "authorization"} | |
return sanitized | |
def _sanitize_forward_params(original_params: Dict[str, str]) -> Dict[str, str]: | |
# Remove `key` query param before forwarding | |
return {k: v for k, v in original_params.items() if k != "key"} | |
async def proxy(req: Request, path: str) -> Response: | |
tokens, upstream = _read_config() | |
if not _is_authorized(req, tokens): | |
# Return 401 with standard WWW-Authenticate header | |
return JSONResponse( | |
status_code=401, | |
content={"detail": "Unauthorized"}, | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
forward_url = f"{upstream}/mcp/{path}" | |
forward_headers = _sanitize_forward_headers(dict(req.headers)) | |
forward_params = _sanitize_forward_params(dict(req.query_params)) | |
timeout = httpx.Timeout(None) # no timeout to support long-lived SSE | |
# Open the upstream stream without a context manager so we can keep it | |
# alive for the lifetime of the StreamingResponse and close it explicitly. | |
client = httpx.AsyncClient(timeout=timeout) | |
try: | |
if req.method == "GET": | |
request = client.build_request( | |
"GET", forward_url, params=forward_params, headers=forward_headers | |
) | |
upstream_resp = await client.send(request, stream=True) | |
else: | |
body = await req.body() | |
request = client.build_request( | |
"POST", | |
forward_url, | |
params=forward_params, | |
headers=forward_headers, | |
content=body, | |
) | |
upstream_resp = await client.send(request, stream=True) | |
status = upstream_resp.status_code | |
headers = {k: v for k, v in upstream_resp.headers.items() if k.lower() != "transfer-encoding"} | |
media_type = upstream_resp.headers.get("content-type") | |
async def body_iterator(): | |
try: | |
async for chunk in upstream_resp.aiter_raw(): | |
yield chunk | |
except httpx.StreamClosed: | |
# Upstream ended the stream; finish cleanly | |
return | |
finally: | |
await upstream_resp.aclose() | |
await client.aclose() | |
return StreamingResponse(body_iterator(), status_code=status, headers=headers, media_type=media_type) | |
except Exception as e: | |
try: | |
await client.aclose() | |
finally: | |
pass | |
raise HTTPException(status_code=502, detail=f"Upstream error: {e}") | |