"""Auth proxy for FastMCP HTTP transport. Minimal ASGI proxy that enforces a shared bearer token and forwards requests to the internal FastMCP HTTP server while preserving SSE. MVP scope to avoid bloat: - Header auth preferred; `?key` fallback for Desktop bridge - Strip auth before forwarding - Streaming proxying for SSE (no buffering) - Minimal `/health` endpoint """ from __future__ import annotations import os from typing import Dict, Tuple, List import secrets import httpx from fastapi import FastAPI, HTTPException, Request, Response from fastapi.responses import JSONResponse, StreamingResponse app = FastAPI() def _read_config() -> Tuple[List[str], str]: # Support multiple tokens via MCP_AUTH_TOKENS (comma-separated) multi = os.getenv("MCP_AUTH_TOKENS", "") single = os.getenv("MCP_AUTH_TOKEN", "") tokens: List[str] = [] if multi: tokens.extend([t.strip() for t in multi.split(",") if t.strip()]) if single: tokens.append(single.strip()) upstream = f"http://{os.getenv('UPSTREAM_HOST', '127.0.0.1')}:{os.getenv('UPSTREAM_PORT', '7870')}" return tokens, upstream def _is_authorized(req: Request, tokens: List[str]) -> bool: if not tokens: # No token configured → open access (useful for local/dev) return True auth_header = req.headers.get("authorization", "") if auth_header.startswith("Bearer "): presented = auth_header.split(" ", 1)[1] for t in tokens: if secrets.compare_digest(presented, t): return True key_param = req.query_params.get("key") if key_param: for t in tokens: if secrets.compare_digest(key_param, t): return True return False @app.get("/health") async def health() -> Dict[str, str]: return {"status": "ok"} def _sanitize_forward_headers(original_headers: Dict[str, str]) -> Dict[str, str]: # Remove Authorization before forwarding sanitized = {k: v for k, v in original_headers.items() if k.lower() != "authorization"} return sanitized def _sanitize_forward_params(original_params: Dict[str, str]) -> Dict[str, str]: # Remove `key` query param before forwarding return {k: v for k, v in original_params.items() if k != "key"} @app.api_route("/mcp/{path:path}", methods=["GET", "POST"]) async def proxy(req: Request, path: str) -> Response: tokens, upstream = _read_config() if not _is_authorized(req, tokens): # Return 401 with standard WWW-Authenticate header return JSONResponse( status_code=401, content={"detail": "Unauthorized"}, headers={"WWW-Authenticate": "Bearer"}, ) forward_url = f"{upstream}/mcp/{path}" forward_headers = _sanitize_forward_headers(dict(req.headers)) forward_params = _sanitize_forward_params(dict(req.query_params)) timeout = httpx.Timeout(None) # no timeout to support long-lived SSE # Open the upstream stream without a context manager so we can keep it # alive for the lifetime of the StreamingResponse and close it explicitly. client = httpx.AsyncClient(timeout=timeout) try: if req.method == "GET": request = client.build_request( "GET", forward_url, params=forward_params, headers=forward_headers ) upstream_resp = await client.send(request, stream=True) else: body = await req.body() request = client.build_request( "POST", forward_url, params=forward_params, headers=forward_headers, content=body, ) upstream_resp = await client.send(request, stream=True) status = upstream_resp.status_code headers = {k: v for k, v in upstream_resp.headers.items() if k.lower() != "transfer-encoding"} media_type = upstream_resp.headers.get("content-type") async def body_iterator(): try: async for chunk in upstream_resp.aiter_raw(): yield chunk except httpx.StreamClosed: # Upstream ended the stream; finish cleanly return finally: await upstream_resp.aclose() await client.aclose() return StreamingResponse(body_iterator(), status_code=status, headers=headers, media_type=media_type) except Exception as e: try: await client.aclose() finally: pass raise HTTPException(status_code=502, detail=f"Upstream error: {e}")