Spaces:
Running
Running
File size: 4,621 Bytes
aee78d5 c2020c7 aee78d5 c2020c7 aee78d5 c2020c7 aee78d5 c2020c7 aee78d5 c2020c7 aee78d5 c2020c7 aee78d5 c2020c7 aee78d5 ba9cfe5 a588e56 ba9cfe5 a588e56 ba9cfe5 a588e56 ba9cfe5 a588e56 ba9cfe5 aee78d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
"""Auth proxy for FastMCP HTTP transport.
Minimal ASGI proxy that enforces a shared bearer token and forwards
requests to the internal FastMCP HTTP server while preserving SSE.
MVP scope to avoid bloat:
- Header auth preferred; `?key` fallback for Desktop bridge
- Strip auth before forwarding
- Streaming proxying for SSE (no buffering)
- Minimal `/health` endpoint
"""
from __future__ import annotations
import os
from typing import Dict, Tuple, List
import secrets
import httpx
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
app = FastAPI()
def _read_config() -> Tuple[List[str], str]:
# Support multiple tokens via MCP_AUTH_TOKENS (comma-separated)
multi = os.getenv("MCP_AUTH_TOKENS", "")
single = os.getenv("MCP_AUTH_TOKEN", "")
tokens: List[str] = []
if multi:
tokens.extend([t.strip() for t in multi.split(",") if t.strip()])
if single:
tokens.append(single.strip())
upstream = f"http://{os.getenv('UPSTREAM_HOST', '127.0.0.1')}:{os.getenv('UPSTREAM_PORT', '7870')}"
return tokens, upstream
def _is_authorized(req: Request, tokens: List[str]) -> bool:
if not tokens:
# No token configured → open access (useful for local/dev)
return True
auth_header = req.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
presented = auth_header.split(" ", 1)[1]
for t in tokens:
if secrets.compare_digest(presented, t):
return True
key_param = req.query_params.get("key")
if key_param:
for t in tokens:
if secrets.compare_digest(key_param, t):
return True
return False
@app.get("/health")
async def health() -> Dict[str, str]:
return {"status": "ok"}
def _sanitize_forward_headers(original_headers: Dict[str, str]) -> Dict[str, str]:
# Remove Authorization before forwarding
sanitized = {k: v for k, v in original_headers.items() if k.lower() != "authorization"}
return sanitized
def _sanitize_forward_params(original_params: Dict[str, str]) -> Dict[str, str]:
# Remove `key` query param before forwarding
return {k: v for k, v in original_params.items() if k != "key"}
@app.api_route("/mcp/{path:path}", methods=["GET", "POST"])
async def proxy(req: Request, path: str) -> Response:
tokens, upstream = _read_config()
if not _is_authorized(req, tokens):
# Return 401 with standard WWW-Authenticate header
return JSONResponse(
status_code=401,
content={"detail": "Unauthorized"},
headers={"WWW-Authenticate": "Bearer"},
)
forward_url = f"{upstream}/mcp/{path}"
forward_headers = _sanitize_forward_headers(dict(req.headers))
forward_params = _sanitize_forward_params(dict(req.query_params))
timeout = httpx.Timeout(None) # no timeout to support long-lived SSE
# Open the upstream stream without a context manager so we can keep it
# alive for the lifetime of the StreamingResponse and close it explicitly.
client = httpx.AsyncClient(timeout=timeout)
try:
if req.method == "GET":
request = client.build_request(
"GET", forward_url, params=forward_params, headers=forward_headers
)
upstream_resp = await client.send(request, stream=True)
else:
body = await req.body()
request = client.build_request(
"POST",
forward_url,
params=forward_params,
headers=forward_headers,
content=body,
)
upstream_resp = await client.send(request, stream=True)
status = upstream_resp.status_code
headers = {k: v for k, v in upstream_resp.headers.items() if k.lower() != "transfer-encoding"}
media_type = upstream_resp.headers.get("content-type")
async def body_iterator():
try:
async for chunk in upstream_resp.aiter_raw():
yield chunk
except httpx.StreamClosed:
# Upstream ended the stream; finish cleanly
return
finally:
await upstream_resp.aclose()
await client.aclose()
return StreamingResponse(body_iterator(), status_code=status, headers=headers, media_type=media_type)
except Exception as e:
try:
await client.aclose()
finally:
pass
raise HTTPException(status_code=502, detail=f"Upstream error: {e}")
|