Spaces:
Running
Running
""" | |
Simple OAuth provider for MCP servers. | |
This module contains a basic OAuth implementation using hardcoded user credentials | |
for demonstration purposes. No external authentication provider is required. | |
NOTE: this is a simplified example for demonstration purposes. | |
This is not a production-ready implementation. | |
""" | |
import logging | |
import secrets | |
import time | |
from typing import Any | |
import os | |
from pydantic import AnyHttpUrl | |
from pydantic_settings import BaseSettings, SettingsConfigDict | |
from starlette.exceptions import HTTPException | |
from starlette.requests import Request | |
from starlette.responses import HTMLResponse, RedirectResponse, Response | |
from mcp.server.auth.provider import ( | |
AccessToken, | |
AuthorizationCode, | |
AuthorizationParams, | |
OAuthAuthorizationServerProvider, | |
RefreshToken, | |
construct_redirect_uri, | |
) | |
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken | |
logger = logging.getLogger(__name__) | |
class SimpleAuthSettings(BaseSettings): | |
"""Simple OAuth settings for demo purposes.""" | |
model_config = SettingsConfigDict(env_prefix="MCP_") | |
# Demo user credentials | |
# To over write default user and password set up enviromental | |
# variables for DEMO_USER and DEMO_PASSWORD to force use of | |
# user name and password | |
demo_username: str = os.getenv('DEMO_USER',"demo_user") | |
demo_password: str = os.getenv('DEMO_PASSWORD',"demo_password") | |
# MCP OAuth scope | |
mcp_scope: str = "user" | |
class SimpleOAuthProvider(OAuthAuthorizationServerProvider): | |
""" | |
Simple OAuth provider for demo purposes. | |
This provider handles the OAuth flow by: | |
1. Providing a simple login form for demo credentials | |
2. Issuing MCP tokens after successful authentication | |
3. Maintaining token state for introspection | |
""" | |
def __init__(self, settings: SimpleAuthSettings, auth_callback_url: str, server_url: str): | |
self.settings = settings | |
self.auth_callback_url = auth_callback_url | |
self.server_url = server_url | |
self.clients: dict[str, OAuthClientInformationFull] = {} | |
self.auth_codes: dict[str, AuthorizationCode] = {} | |
self.tokens: dict[str, AccessToken] = {} | |
self.state_mapping: dict[str, dict[str, str | None]] = {} | |
# Store authenticated user information | |
self.user_data: dict[str, dict[str, Any]] = {} | |
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: | |
"""Get OAuth client information.""" | |
return self.clients.get(client_id) | |
async def register_client(self, client_info: OAuthClientInformationFull): | |
"""Register a new OAuth client.""" | |
self.clients[client_info.client_id] = client_info | |
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: | |
"""Generate an authorization URL for simple login flow.""" | |
state = params.state or secrets.token_hex(16) | |
# Store state mapping for callback | |
self.state_mapping[state] = { | |
"redirect_uri": str(params.redirect_uri), | |
"code_challenge": params.code_challenge, | |
"redirect_uri_provided_explicitly": str(params.redirect_uri_provided_explicitly), | |
"client_id": client.client_id, | |
"resource": params.resource, # RFC 8707 | |
} | |
# Build simple login URL that points to login page | |
auth_url = f"{self.auth_callback_url}?state={state}&client_id={client.client_id}" | |
return auth_url | |
async def get_login_page(self, state: str) -> HTMLResponse: | |
"""Generate login page HTML for the given state.""" | |
if not state: | |
raise HTTPException(400, "Missing state parameter") | |
# Create simple login form HTML | |
html_content = f""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>MCP Demo Authentication</title> | |
<style> | |
body {{ font-family: Arial, sans-serif; max-width: 500px; margin: 0 auto; padding: 20px; }} | |
.form-group {{ margin-bottom: 15px; }} | |
input {{ width: 100%; padding: 8px; margin-top: 5px; }} | |
button {{ background-color: #4CAF50; color: white; padding: 10px 15px; border: none; cursor: pointer; }} | |
</style> | |
</head> | |
<body> | |
<h2>MCP Demo Authentication</h2> | |
<p>This is a simplified authentication demo. Use the demo credentials below:</p> | |
<p><strong>Username:</strong> demo_user<br> | |
<strong>Password:</strong> demo_password</p> | |
<form action="{self.server_url.rstrip("/")}/login/callback" method="post"> | |
<input type="hidden" name="state" value="{state}"> | |
<div class="form-group"> | |
<label>Username:</label> | |
<input type="text" name="username" value="demo_user" required> | |
</div> | |
<div class="form-group"> | |
<label>Password:</label> | |
<input type="password" name="password" value="demo_password" required> | |
</div> | |
<button type="submit">Sign In</button> | |
</form> | |
</body> | |
</html> | |
""" | |
return HTMLResponse(content=html_content) | |
async def handle_login_callback(self, request: Request) -> Response: | |
"""Handle login form submission callback.""" | |
form = await request.form() | |
username = form.get("username") | |
password = form.get("password") | |
state = form.get("state") | |
if not username or not password or not state: | |
raise HTTPException(400, "Missing username, password, or state parameter") | |
# Ensure we have strings, not UploadFile objects | |
if not isinstance(username, str) or not isinstance(password, str) or not isinstance(state, str): | |
raise HTTPException(400, "Invalid parameter types") | |
redirect_uri = await self.handle_simple_callback(username, password, state) | |
return RedirectResponse(url=redirect_uri, status_code=302) | |
async def handle_simple_callback(self, username: str, password: str, state: str) -> str: | |
"""Handle simple authentication callback and return redirect URI.""" | |
state_data = self.state_mapping.get(state) | |
if not state_data: | |
raise HTTPException(400, "Invalid state parameter") | |
redirect_uri = state_data["redirect_uri"] | |
code_challenge = state_data["code_challenge"] | |
redirect_uri_provided_explicitly = state_data["redirect_uri_provided_explicitly"] == "True" | |
client_id = state_data["client_id"] | |
resource = state_data.get("resource") # RFC 8707 | |
# These are required values from our own state mapping | |
assert redirect_uri is not None | |
assert code_challenge is not None | |
assert client_id is not None | |
# Validate demo credentials | |
if username != self.settings.demo_username or password != self.settings.demo_password: | |
raise HTTPException(401, "Invalid credentials") | |
# Create MCP authorization code | |
new_code = f"mcp_{secrets.token_hex(16)}" | |
auth_code = AuthorizationCode( | |
code=new_code, | |
client_id=client_id, | |
redirect_uri=AnyHttpUrl(redirect_uri), | |
redirect_uri_provided_explicitly=redirect_uri_provided_explicitly, | |
expires_at=time.time() + 300, | |
scopes=[self.settings.mcp_scope], | |
code_challenge=code_challenge, | |
resource=resource, # RFC 8707 | |
) | |
self.auth_codes[new_code] = auth_code | |
# Store user data | |
self.user_data[username] = { | |
"username": username, | |
"user_id": f"user_{secrets.token_hex(8)}", | |
"authenticated_at": time.time(), | |
} | |
del self.state_mapping[state] | |
return construct_redirect_uri(redirect_uri, code=new_code, state=state) | |
async def load_authorization_code( | |
self, client: OAuthClientInformationFull, authorization_code: str | |
) -> AuthorizationCode | None: | |
"""Load an authorization code.""" | |
return self.auth_codes.get(authorization_code) | |
async def exchange_authorization_code( | |
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode | |
) -> OAuthToken: | |
"""Exchange authorization code for tokens.""" | |
if authorization_code.code not in self.auth_codes: | |
raise ValueError("Invalid authorization code") | |
# Generate MCP access token | |
mcp_token = f"mcp_{secrets.token_hex(32)}" | |
# Store MCP token | |
self.tokens[mcp_token] = AccessToken( | |
token=mcp_token, | |
client_id=client.client_id, | |
scopes=authorization_code.scopes, | |
expires_at=int(time.time()) + 3600, | |
resource=authorization_code.resource, # RFC 8707 | |
) | |
# Store user data mapping for this token | |
self.user_data[mcp_token] = { | |
"username": self.settings.demo_username, | |
"user_id": f"user_{secrets.token_hex(8)}", | |
"authenticated_at": time.time(), | |
} | |
del self.auth_codes[authorization_code.code] | |
return OAuthToken( | |
access_token=mcp_token, | |
token_type="Bearer", | |
expires_in=3600, | |
scope=" ".join(authorization_code.scopes), | |
) | |
async def load_access_token(self, token: str) -> AccessToken | None: | |
"""Load and validate an access token.""" | |
access_token = self.tokens.get(token) | |
if not access_token: | |
return None | |
# Check if expired | |
if access_token.expires_at and access_token.expires_at < time.time(): | |
del self.tokens[token] | |
return None | |
return access_token | |
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: | |
"""Load a refresh token - not supported in this example.""" | |
return None | |
async def exchange_refresh_token( | |
self, | |
client: OAuthClientInformationFull, | |
refresh_token: RefreshToken, | |
scopes: list[str], | |
) -> OAuthToken: | |
"""Exchange refresh token - not supported in this example.""" | |
raise NotImplementedError("Refresh tokens not supported") | |
async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: | |
"""Revoke a token.""" | |
if token in self.tokens: | |
del self.tokens[token] | |