as / simple_auth_provider.py
apple muncy
Add AS_HOST_NAME EV
2017445
"""
Simple OAuth provider for MCP servers.
This module contains a basic OAuth implementation using hardcoded user credentials
for demonstration purposes. No external authentication provider is required.
NOTE: this is a simplified example for demonstration purposes.
This is not a production-ready implementation.
"""
import logging
import secrets
import time
from typing import Any
import os
from pydantic import AnyHttpUrl
from pydantic_settings import BaseSettings, SettingsConfigDict
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import HTMLResponse, RedirectResponse, Response
from mcp.server.auth.provider import (
AccessToken,
AuthorizationCode,
AuthorizationParams,
OAuthAuthorizationServerProvider,
RefreshToken,
construct_redirect_uri,
)
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
logger = logging.getLogger(__name__)
class SimpleAuthSettings(BaseSettings):
"""Simple OAuth settings for demo purposes."""
model_config = SettingsConfigDict(env_prefix="MCP_")
# Demo user credentials
# To over write default user and password set up enviromental
# variables for DEMO_USER and DEMO_PASSWORD to force use of
# user name and password
demo_username: str = os.getenv('DEMO_USER',"demo_user")
demo_password: str = os.getenv('DEMO_PASSWORD',"demo_password")
# MCP OAuth scope
mcp_scope: str = "user"
class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
"""
Simple OAuth provider for demo purposes.
This provider handles the OAuth flow by:
1. Providing a simple login form for demo credentials
2. Issuing MCP tokens after successful authentication
3. Maintaining token state for introspection
"""
def __init__(self, settings: SimpleAuthSettings, auth_callback_url: str, server_url: str):
self.settings = settings
self.auth_callback_url = auth_callback_url
self.server_url = server_url
self.clients: dict[str, OAuthClientInformationFull] = {}
self.auth_codes: dict[str, AuthorizationCode] = {}
self.tokens: dict[str, AccessToken] = {}
self.state_mapping: dict[str, dict[str, str | None]] = {}
# Store authenticated user information
self.user_data: dict[str, dict[str, Any]] = {}
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
"""Get OAuth client information."""
return self.clients.get(client_id)
async def register_client(self, client_info: OAuthClientInformationFull):
"""Register a new OAuth client."""
self.clients[client_info.client_id] = client_info
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
"""Generate an authorization URL for simple login flow."""
state = params.state or secrets.token_hex(16)
# Store state mapping for callback
self.state_mapping[state] = {
"redirect_uri": str(params.redirect_uri),
"code_challenge": params.code_challenge,
"redirect_uri_provided_explicitly": str(params.redirect_uri_provided_explicitly),
"client_id": client.client_id,
"resource": params.resource, # RFC 8707
}
# Build simple login URL that points to login page
auth_url = f"{self.auth_callback_url}?state={state}&client_id={client.client_id}"
return auth_url
async def get_login_page(self, state: str) -> HTMLResponse:
"""Generate login page HTML for the given state."""
if not state:
raise HTTPException(400, "Missing state parameter")
# Create simple login form HTML
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>MCP Demo Authentication</title>
<style>
body {{ font-family: Arial, sans-serif; max-width: 500px; margin: 0 auto; padding: 20px; }}
.form-group {{ margin-bottom: 15px; }}
input {{ width: 100%; padding: 8px; margin-top: 5px; }}
button {{ background-color: #4CAF50; color: white; padding: 10px 15px; border: none; cursor: pointer; }}
</style>
</head>
<body>
<h2>MCP Demo Authentication</h2>
<p>This is a simplified authentication demo. Use the demo credentials below:</p>
<p><strong>Username:</strong> demo_user<br>
<strong>Password:</strong> demo_password</p>
<form action="{self.server_url.rstrip("/")}/login/callback" method="post">
<input type="hidden" name="state" value="{state}">
<div class="form-group">
<label>Username:</label>
<input type="text" name="username" value="demo_user" required>
</div>
<div class="form-group">
<label>Password:</label>
<input type="password" name="password" value="demo_password" required>
</div>
<button type="submit">Sign In</button>
</form>
</body>
</html>
"""
return HTMLResponse(content=html_content)
async def handle_login_callback(self, request: Request) -> Response:
"""Handle login form submission callback."""
form = await request.form()
username = form.get("username")
password = form.get("password")
state = form.get("state")
if not username or not password or not state:
raise HTTPException(400, "Missing username, password, or state parameter")
# Ensure we have strings, not UploadFile objects
if not isinstance(username, str) or not isinstance(password, str) or not isinstance(state, str):
raise HTTPException(400, "Invalid parameter types")
redirect_uri = await self.handle_simple_callback(username, password, state)
return RedirectResponse(url=redirect_uri, status_code=302)
async def handle_simple_callback(self, username: str, password: str, state: str) -> str:
"""Handle simple authentication callback and return redirect URI."""
state_data = self.state_mapping.get(state)
if not state_data:
raise HTTPException(400, "Invalid state parameter")
redirect_uri = state_data["redirect_uri"]
code_challenge = state_data["code_challenge"]
redirect_uri_provided_explicitly = state_data["redirect_uri_provided_explicitly"] == "True"
client_id = state_data["client_id"]
resource = state_data.get("resource") # RFC 8707
# These are required values from our own state mapping
assert redirect_uri is not None
assert code_challenge is not None
assert client_id is not None
# Validate demo credentials
if username != self.settings.demo_username or password != self.settings.demo_password:
raise HTTPException(401, "Invalid credentials")
# Create MCP authorization code
new_code = f"mcp_{secrets.token_hex(16)}"
auth_code = AuthorizationCode(
code=new_code,
client_id=client_id,
redirect_uri=AnyHttpUrl(redirect_uri),
redirect_uri_provided_explicitly=redirect_uri_provided_explicitly,
expires_at=time.time() + 300,
scopes=[self.settings.mcp_scope],
code_challenge=code_challenge,
resource=resource, # RFC 8707
)
self.auth_codes[new_code] = auth_code
# Store user data
self.user_data[username] = {
"username": username,
"user_id": f"user_{secrets.token_hex(8)}",
"authenticated_at": time.time(),
}
del self.state_mapping[state]
return construct_redirect_uri(redirect_uri, code=new_code, state=state)
async def load_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: str
) -> AuthorizationCode | None:
"""Load an authorization code."""
return self.auth_codes.get(authorization_code)
async def exchange_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
) -> OAuthToken:
"""Exchange authorization code for tokens."""
if authorization_code.code not in self.auth_codes:
raise ValueError("Invalid authorization code")
# Generate MCP access token
mcp_token = f"mcp_{secrets.token_hex(32)}"
# Store MCP token
self.tokens[mcp_token] = AccessToken(
token=mcp_token,
client_id=client.client_id,
scopes=authorization_code.scopes,
expires_at=int(time.time()) + 3600,
resource=authorization_code.resource, # RFC 8707
)
# Store user data mapping for this token
self.user_data[mcp_token] = {
"username": self.settings.demo_username,
"user_id": f"user_{secrets.token_hex(8)}",
"authenticated_at": time.time(),
}
del self.auth_codes[authorization_code.code]
return OAuthToken(
access_token=mcp_token,
token_type="Bearer",
expires_in=3600,
scope=" ".join(authorization_code.scopes),
)
async def load_access_token(self, token: str) -> AccessToken | None:
"""Load and validate an access token."""
access_token = self.tokens.get(token)
if not access_token:
return None
# Check if expired
if access_token.expires_at and access_token.expires_at < time.time():
del self.tokens[token]
return None
return access_token
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None:
"""Load a refresh token - not supported in this example."""
return None
async def exchange_refresh_token(
self,
client: OAuthClientInformationFull,
refresh_token: RefreshToken,
scopes: list[str],
) -> OAuthToken:
"""Exchange refresh token - not supported in this example."""
raise NotImplementedError("Refresh tokens not supported")
async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None:
"""Revoke a token."""
if token in self.tokens:
del self.tokens[token]