Spaces:
Running
Running
""" | |
Authorization Server for MCP Split Demo. | |
This server handles OAuth flows, client registration, and token issuance. | |
Can be replaced with enterprise authorization servers like Auth0, Entra ID, etc. | |
NOTE: this is a simplified example for demonstration purposes. | |
This is not a production-ready implementation. | |
""" | |
import asyncio | |
import logging | |
import time | |
import click | |
from pydantic import AnyHttpUrl, BaseModel | |
from starlette.applications import Starlette | |
from starlette.exceptions import HTTPException | |
from starlette.requests import Request | |
from starlette.responses import JSONResponse, Response | |
from starlette.routing import Route | |
from uvicorn import Config, Server | |
from mcp.server.auth.routes import cors_middleware, create_auth_routes | |
from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions | |
from simple_auth_provider import SimpleAuthSettings, SimpleOAuthProvider | |
import os | |
logger = logging.getLogger(__name__) | |
as_host_name = os.getenv('AS_HOST_NAME', 'localhost') | |
class AuthServerSettings(BaseModel): | |
"""Settings for the Authorization Server.""" | |
# Server settings | |
host: str = as_host_name | |
port: int = 7860 | |
server_url: AnyHttpUrl = AnyHttpUrl(F"https://{as_host_name}") | |
auth_callback_path: str = F"https://{as_host_name}/login/callback" | |
class SimpleAuthProvider(SimpleOAuthProvider): | |
""" | |
Authorization Server provider with simple demo authentication. | |
This provider: | |
1. Issues MCP tokens after simple credential authentication | |
2. Stores token state for introspection by Resource Servers | |
""" | |
def __init__(self, auth_settings: SimpleAuthSettings, auth_callback_path: str, server_url: str): | |
super().__init__(auth_settings, auth_callback_path, server_url) | |
print(f"Init SimpleAuthProvider:") | |
print(f"auth_callback_path:{auth_callback_path}") | |
print(f"server_url: {server_url}") | |
print(f"SimpleAuthSettines: {SimpleAuthSettings}") | |
def create_authorization_server(server_settings: AuthServerSettings, auth_settings: SimpleAuthSettings) -> Starlette: | |
"""Create the Authorization Server application.""" | |
oauth_provider = SimpleAuthProvider( | |
auth_settings, server_settings.auth_callback_path, str(server_settings.server_url) | |
) | |
mcp_auth_settings = AuthSettings( | |
issuer_url=server_settings.server_url, | |
client_registration_options=ClientRegistrationOptions( | |
enabled=True, | |
valid_scopes=[auth_settings.mcp_scope], | |
default_scopes=[auth_settings.mcp_scope], | |
), | |
required_scopes=[auth_settings.mcp_scope], | |
resource_server_url=None, | |
) | |
# Create OAuth routes | |
routes = create_auth_routes( | |
provider=oauth_provider, | |
issuer_url=mcp_auth_settings.issuer_url, | |
service_documentation_url=mcp_auth_settings.service_documentation_url, | |
client_registration_options=mcp_auth_settings.client_registration_options, | |
revocation_options=mcp_auth_settings.revocation_options, | |
) | |
# Add login page route (GET) | |
async def login_page_handler(request: Request) -> Response: | |
"""Show login form.""" | |
state = request.query_params.get("state") | |
if not state: | |
raise HTTPException(400, "Missing state parameter") | |
return await oauth_provider.get_login_page(state) | |
routes.append(Route("/login", endpoint=login_page_handler, methods=["GET"])) | |
# Add login callback route (POST) | |
async def login_callback_handler(request: Request) -> Response: | |
"""Handle simple authentication callback.""" | |
return await oauth_provider.handle_login_callback(request) | |
routes.append(Route("/login/callback", endpoint=login_callback_handler, methods=["POST"])) | |
# Add token introspection endpoint (RFC 7662) for Resource Servers | |
async def introspect_handler(request: Request) -> Response: | |
""" | |
Token introspection endpoint for Resource Servers. | |
Resource Servers call this endpoint to validate tokens without | |
needing direct access to token storage. | |
""" | |
form = await request.form() | |
token = form.get("token") | |
if not token or not isinstance(token, str): | |
return JSONResponse({"active": False}, status_code=400) | |
# Look up token in provider | |
access_token = await oauth_provider.load_access_token(token) | |
if not access_token: | |
return JSONResponse({"active": False}) | |
return JSONResponse( | |
{ | |
"active": True, | |
"client_id": access_token.client_id, | |
"scope": " ".join(access_token.scopes), | |
"exp": access_token.expires_at, | |
"iat": int(time.time()), | |
"token_type": "Bearer", | |
"aud": access_token.resource, # RFC 8707 audience claim | |
} | |
) | |
routes.append( | |
Route( | |
"/introspect", | |
endpoint=cors_middleware(introspect_handler, ["POST", "OPTIONS"]), | |
methods=["POST", "OPTIONS"], | |
) | |
) | |
return Starlette(routes=routes) | |
async def run_server(server_settings: AuthServerSettings, auth_settings: SimpleAuthSettings): | |
"""Run the Authorization Server.""" | |
auth_server = create_authorization_server(server_settings, auth_settings) | |
config = Config( | |
auth_server, | |
host="0.0.0.0", #server_settings.host, | |
port=server_settings.port, | |
log_level="info", | |
) | |
server = Server(config) | |
print(f"π MCP Authorization Server running on {server_settings.server_url}") | |
await server.serve() | |
def main(port: int) -> int: | |
""" | |
Run the MCP Authorization Server. | |
This server handles OAuth flows and can be used by multiple Resource Servers. | |
Uses simple hardcoded credentials for demo purposes. | |
""" | |
logging.basicConfig(level=logging.DEBUG) | |
# Load simple auth settings | |
auth_settings = SimpleAuthSettings() | |
# Create server settings | |
host = as_host_name | |
server_url = f"https://{host}" #:{port}" | |
server_settings = AuthServerSettings( | |
host=host, | |
port=port, | |
server_url=AnyHttpUrl(server_url), | |
auth_callback_path=f"{server_url}/login", | |
) | |
print(f"auth_call_path:, {server_settings.auth_callback_path}") | |
asyncio.run(run_server(server_settings, auth_settings)) | |
return 0 | |
if __name__ == "__main__": | |
main() # type: ignore[call-arg] | |