LeoWalker's picture
Proxy: keep httpx stream open for SSE (no context manager); explicit aclose() to avoid StreamClosed; still bypassed when no token
a588e56
"""Auth proxy for FastMCP HTTP transport.
Minimal ASGI proxy that enforces a shared bearer token and forwards
requests to the internal FastMCP HTTP server while preserving SSE.
MVP scope to avoid bloat:
- Header auth preferred; `?key` fallback for Desktop bridge
- Strip auth before forwarding
- Streaming proxying for SSE (no buffering)
- Minimal `/health` endpoint
"""
from __future__ import annotations
import os
from typing import Dict, Tuple, List
import secrets
import httpx
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
app = FastAPI()
def _read_config() -> Tuple[List[str], str]:
# Support multiple tokens via MCP_AUTH_TOKENS (comma-separated)
multi = os.getenv("MCP_AUTH_TOKENS", "")
single = os.getenv("MCP_AUTH_TOKEN", "")
tokens: List[str] = []
if multi:
tokens.extend([t.strip() for t in multi.split(",") if t.strip()])
if single:
tokens.append(single.strip())
upstream = f"http://{os.getenv('UPSTREAM_HOST', '127.0.0.1')}:{os.getenv('UPSTREAM_PORT', '7870')}"
return tokens, upstream
def _is_authorized(req: Request, tokens: List[str]) -> bool:
if not tokens:
# No token configured → open access (useful for local/dev)
return True
auth_header = req.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
presented = auth_header.split(" ", 1)[1]
for t in tokens:
if secrets.compare_digest(presented, t):
return True
key_param = req.query_params.get("key")
if key_param:
for t in tokens:
if secrets.compare_digest(key_param, t):
return True
return False
@app.get("/health")
async def health() -> Dict[str, str]:
return {"status": "ok"}
def _sanitize_forward_headers(original_headers: Dict[str, str]) -> Dict[str, str]:
# Remove Authorization before forwarding
sanitized = {k: v for k, v in original_headers.items() if k.lower() != "authorization"}
return sanitized
def _sanitize_forward_params(original_params: Dict[str, str]) -> Dict[str, str]:
# Remove `key` query param before forwarding
return {k: v for k, v in original_params.items() if k != "key"}
@app.api_route("/mcp/{path:path}", methods=["GET", "POST"])
async def proxy(req: Request, path: str) -> Response:
tokens, upstream = _read_config()
if not _is_authorized(req, tokens):
# Return 401 with standard WWW-Authenticate header
return JSONResponse(
status_code=401,
content={"detail": "Unauthorized"},
headers={"WWW-Authenticate": "Bearer"},
)
forward_url = f"{upstream}/mcp/{path}"
forward_headers = _sanitize_forward_headers(dict(req.headers))
forward_params = _sanitize_forward_params(dict(req.query_params))
timeout = httpx.Timeout(None) # no timeout to support long-lived SSE
# Open the upstream stream without a context manager so we can keep it
# alive for the lifetime of the StreamingResponse and close it explicitly.
client = httpx.AsyncClient(timeout=timeout)
try:
if req.method == "GET":
request = client.build_request(
"GET", forward_url, params=forward_params, headers=forward_headers
)
upstream_resp = await client.send(request, stream=True)
else:
body = await req.body()
request = client.build_request(
"POST",
forward_url,
params=forward_params,
headers=forward_headers,
content=body,
)
upstream_resp = await client.send(request, stream=True)
status = upstream_resp.status_code
headers = {k: v for k, v in upstream_resp.headers.items() if k.lower() != "transfer-encoding"}
media_type = upstream_resp.headers.get("content-type")
async def body_iterator():
try:
async for chunk in upstream_resp.aiter_raw():
yield chunk
except httpx.StreamClosed:
# Upstream ended the stream; finish cleanly
return
finally:
await upstream_resp.aclose()
await client.aclose()
return StreamingResponse(body_iterator(), status_code=status, headers=headers, media_type=media_type)
except Exception as e:
try:
await client.aclose()
finally:
pass
raise HTTPException(status_code=502, detail=f"Upstream error: {e}")