""" Simple OAuth provider for MCP servers. This module contains a basic OAuth implementation using hardcoded user credentials for demonstration purposes. No external authentication provider is required. NOTE: this is a simplified example for demonstration purposes. This is not a production-ready implementation. """ import logging import secrets import time from typing import Any import os from pydantic import AnyHttpUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import HTMLResponse, RedirectResponse, Response from mcp.server.auth.provider import ( AccessToken, AuthorizationCode, AuthorizationParams, OAuthAuthorizationServerProvider, RefreshToken, construct_redirect_uri, ) from mcp.shared.auth import OAuthClientInformationFull, OAuthToken logger = logging.getLogger(__name__) class SimpleAuthSettings(BaseSettings): """Simple OAuth settings for demo purposes.""" model_config = SettingsConfigDict(env_prefix="MCP_") # Demo user credentials # To over write default user and password set up enviromental # variables for DEMO_USER and DEMO_PASSWORD to force use of # user name and password demo_username: str = os.getenv('DEMO_USER',"demo_user") demo_password: str = os.getenv('DEMO_PASSWORD',"demo_password") # MCP OAuth scope mcp_scope: str = "user" class SimpleOAuthProvider(OAuthAuthorizationServerProvider): """ Simple OAuth provider for demo purposes. This provider handles the OAuth flow by: 1. Providing a simple login form for demo credentials 2. Issuing MCP tokens after successful authentication 3. Maintaining token state for introspection """ def __init__(self, settings: SimpleAuthSettings, auth_callback_url: str, server_url: str): self.settings = settings self.auth_callback_url = auth_callback_url self.server_url = server_url self.clients: dict[str, OAuthClientInformationFull] = {} self.auth_codes: dict[str, AuthorizationCode] = {} self.tokens: dict[str, AccessToken] = {} self.state_mapping: dict[str, dict[str, str | None]] = {} # Store authenticated user information self.user_data: dict[str, dict[str, Any]] = {} async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: """Get OAuth client information.""" return self.clients.get(client_id) async def register_client(self, client_info: OAuthClientInformationFull): """Register a new OAuth client.""" self.clients[client_info.client_id] = client_info async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: """Generate an authorization URL for simple login flow.""" state = params.state or secrets.token_hex(16) # Store state mapping for callback self.state_mapping[state] = { "redirect_uri": str(params.redirect_uri), "code_challenge": params.code_challenge, "redirect_uri_provided_explicitly": str(params.redirect_uri_provided_explicitly), "client_id": client.client_id, "resource": params.resource, # RFC 8707 } # Build simple login URL that points to login page auth_url = f"{self.auth_callback_url}?state={state}&client_id={client.client_id}" return auth_url async def get_login_page(self, state: str) -> HTMLResponse: """Generate login page HTML for the given state.""" if not state: raise HTTPException(400, "Missing state parameter") # Create simple login form HTML html_content = f""" MCP Demo Authentication

MCP Demo Authentication

This is a simplified authentication demo. Use the demo credentials below:

Username: demo_user
Password: demo_password

""" return HTMLResponse(content=html_content) async def handle_login_callback(self, request: Request) -> Response: """Handle login form submission callback.""" form = await request.form() username = form.get("username") password = form.get("password") state = form.get("state") if not username or not password or not state: raise HTTPException(400, "Missing username, password, or state parameter") # Ensure we have strings, not UploadFile objects if not isinstance(username, str) or not isinstance(password, str) or not isinstance(state, str): raise HTTPException(400, "Invalid parameter types") redirect_uri = await self.handle_simple_callback(username, password, state) return RedirectResponse(url=redirect_uri, status_code=302) async def handle_simple_callback(self, username: str, password: str, state: str) -> str: """Handle simple authentication callback and return redirect URI.""" state_data = self.state_mapping.get(state) if not state_data: raise HTTPException(400, "Invalid state parameter") redirect_uri = state_data["redirect_uri"] code_challenge = state_data["code_challenge"] redirect_uri_provided_explicitly = state_data["redirect_uri_provided_explicitly"] == "True" client_id = state_data["client_id"] resource = state_data.get("resource") # RFC 8707 # These are required values from our own state mapping assert redirect_uri is not None assert code_challenge is not None assert client_id is not None # Validate demo credentials if username != self.settings.demo_username or password != self.settings.demo_password: raise HTTPException(401, "Invalid credentials") # Create MCP authorization code new_code = f"mcp_{secrets.token_hex(16)}" auth_code = AuthorizationCode( code=new_code, client_id=client_id, redirect_uri=AnyHttpUrl(redirect_uri), redirect_uri_provided_explicitly=redirect_uri_provided_explicitly, expires_at=time.time() + 300, scopes=[self.settings.mcp_scope], code_challenge=code_challenge, resource=resource, # RFC 8707 ) self.auth_codes[new_code] = auth_code # Store user data self.user_data[username] = { "username": username, "user_id": f"user_{secrets.token_hex(8)}", "authenticated_at": time.time(), } del self.state_mapping[state] return construct_redirect_uri(redirect_uri, code=new_code, state=state) async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str ) -> AuthorizationCode | None: """Load an authorization code.""" return self.auth_codes.get(authorization_code) async def exchange_authorization_code( self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode ) -> OAuthToken: """Exchange authorization code for tokens.""" if authorization_code.code not in self.auth_codes: raise ValueError("Invalid authorization code") # Generate MCP access token mcp_token = f"mcp_{secrets.token_hex(32)}" # Store MCP token self.tokens[mcp_token] = AccessToken( token=mcp_token, client_id=client.client_id, scopes=authorization_code.scopes, expires_at=int(time.time()) + 3600, resource=authorization_code.resource, # RFC 8707 ) # Store user data mapping for this token self.user_data[mcp_token] = { "username": self.settings.demo_username, "user_id": f"user_{secrets.token_hex(8)}", "authenticated_at": time.time(), } del self.auth_codes[authorization_code.code] return OAuthToken( access_token=mcp_token, token_type="Bearer", expires_in=3600, scope=" ".join(authorization_code.scopes), ) async def load_access_token(self, token: str) -> AccessToken | None: """Load and validate an access token.""" access_token = self.tokens.get(token) if not access_token: return None # Check if expired if access_token.expires_at and access_token.expires_at < time.time(): del self.tokens[token] return None return access_token async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: """Load a refresh token - not supported in this example.""" return None async def exchange_refresh_token( self, client: OAuthClientInformationFull, refresh_token: RefreshToken, scopes: list[str], ) -> OAuthToken: """Exchange refresh token - not supported in this example.""" raise NotImplementedError("Refresh tokens not supported") async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: """Revoke a token.""" if token in self.tokens: del self.tokens[token]