File size: 4,621 Bytes
aee78d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2020c7
 
aee78d5
 
 
 
 
 
 
 
 
c2020c7
 
 
 
 
 
 
 
 
aee78d5
c2020c7
aee78d5
 
c2020c7
 
aee78d5
 
 
c2020c7
 
 
 
 
aee78d5
c2020c7
 
 
 
aee78d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2020c7
 
aee78d5
 
 
 
 
 
 
 
 
 
 
 
ba9cfe5
a588e56
 
 
 
 
 
ba9cfe5
a588e56
 
 
 
 
ba9cfe5
 
 
 
a588e56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba9cfe5
a588e56
 
 
 
ba9cfe5
aee78d5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Auth proxy for FastMCP HTTP transport.

Minimal ASGI proxy that enforces a shared bearer token and forwards
requests to the internal FastMCP HTTP server while preserving SSE.

MVP scope to avoid bloat:
- Header auth preferred; `?key` fallback for Desktop bridge
- Strip auth before forwarding
- Streaming proxying for SSE (no buffering)
- Minimal `/health` endpoint
"""

from __future__ import annotations

import os
from typing import Dict, Tuple, List
import secrets

import httpx
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse


app = FastAPI()


def _read_config() -> Tuple[List[str], str]:
    # Support multiple tokens via MCP_AUTH_TOKENS (comma-separated)
    multi = os.getenv("MCP_AUTH_TOKENS", "")
    single = os.getenv("MCP_AUTH_TOKEN", "")
    tokens: List[str] = []
    if multi:
        tokens.extend([t.strip() for t in multi.split(",") if t.strip()])
    if single:
        tokens.append(single.strip())
    upstream = f"http://{os.getenv('UPSTREAM_HOST', '127.0.0.1')}:{os.getenv('UPSTREAM_PORT', '7870')}"
    return tokens, upstream


def _is_authorized(req: Request, tokens: List[str]) -> bool:
    if not tokens:
        # No token configured → open access (useful for local/dev)
        return True
    auth_header = req.headers.get("authorization", "")
    if auth_header.startswith("Bearer "):
        presented = auth_header.split(" ", 1)[1]
        for t in tokens:
            if secrets.compare_digest(presented, t):
                return True
    key_param = req.query_params.get("key")
    if key_param:
        for t in tokens:
            if secrets.compare_digest(key_param, t):
                return True
    return False


@app.get("/health")
async def health() -> Dict[str, str]:
    return {"status": "ok"}


def _sanitize_forward_headers(original_headers: Dict[str, str]) -> Dict[str, str]:
    # Remove Authorization before forwarding
    sanitized = {k: v for k, v in original_headers.items() if k.lower() != "authorization"}
    return sanitized


def _sanitize_forward_params(original_params: Dict[str, str]) -> Dict[str, str]:
    # Remove `key` query param before forwarding
    return {k: v for k, v in original_params.items() if k != "key"}


@app.api_route("/mcp/{path:path}", methods=["GET", "POST"])
async def proxy(req: Request, path: str) -> Response:
    tokens, upstream = _read_config()
    if not _is_authorized(req, tokens):
        # Return 401 with standard WWW-Authenticate header
        return JSONResponse(
            status_code=401,
            content={"detail": "Unauthorized"},
            headers={"WWW-Authenticate": "Bearer"},
        )

    forward_url = f"{upstream}/mcp/{path}"
    forward_headers = _sanitize_forward_headers(dict(req.headers))
    forward_params = _sanitize_forward_params(dict(req.query_params))

    timeout = httpx.Timeout(None)  # no timeout to support long-lived SSE

    # Open the upstream stream without a context manager so we can keep it
    # alive for the lifetime of the StreamingResponse and close it explicitly.
    client = httpx.AsyncClient(timeout=timeout)
    try:
        if req.method == "GET":
            request = client.build_request(
                "GET", forward_url, params=forward_params, headers=forward_headers
            )
            upstream_resp = await client.send(request, stream=True)
        else:
            body = await req.body()
            request = client.build_request(
                "POST",
                forward_url,
                params=forward_params,
                headers=forward_headers,
                content=body,
            )
            upstream_resp = await client.send(request, stream=True)

        status = upstream_resp.status_code
        headers = {k: v for k, v in upstream_resp.headers.items() if k.lower() != "transfer-encoding"}
        media_type = upstream_resp.headers.get("content-type")

        async def body_iterator():
            try:
                async for chunk in upstream_resp.aiter_raw():
                    yield chunk
            except httpx.StreamClosed:
                # Upstream ended the stream; finish cleanly
                return
            finally:
                await upstream_resp.aclose()
                await client.aclose()

        return StreamingResponse(body_iterator(), status_code=status, headers=headers, media_type=media_type)
    except Exception as e:
        try:
            await client.aclose()
        finally:
            pass
        raise HTTPException(status_code=502, detail=f"Upstream error: {e}")