|
|
|
""" |
|
Centralised OAuth 2.0 helper for GitHub, Google Drive and Slack. |
|
|
|
Usage |
|
----- |
|
from auth import oauth_manager |
|
|
|
# 1) Redirect user to consent page |
|
auth_url, state = oauth_manager.get_authorization_url( |
|
provider="github", |
|
redirect_uri="https://your‑app.com/callback" |
|
) |
|
|
|
# 2) In your callback handler, exchange the code for a token |
|
token = oauth_manager.fetch_token( |
|
provider="github", |
|
redirect_uri="https://your‑app.com/callback", |
|
authorization_response=request.url # full URL with ?code=... |
|
) |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import os |
|
from dataclasses import dataclass |
|
from typing import Dict, Optional, Tuple |
|
|
|
from authlib.integrations.requests_client import OAuth2Session |
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class ProviderConfig: |
|
client_id: str | None |
|
client_secret: str | None |
|
authorize_url: str |
|
token_url: str |
|
scope: str |
|
|
|
|
|
def _env(name: str) -> str | None: |
|
"""Shorthand for os.getenv with strip().""" |
|
val = os.getenv(name) |
|
return val.strip() if val else None |
|
|
|
|
|
PROVIDERS: Dict[str, ProviderConfig] = { |
|
"github": ProviderConfig( |
|
client_id=_env("GITHUB_CLIENT_ID"), |
|
client_secret=_env("GITHUB_CLIENT_SECRET"), |
|
authorize_url="https://github.com/login/oauth/authorize", |
|
token_url="https://github.com/login/oauth/access_token", |
|
scope="repo read:org", |
|
), |
|
"google": ProviderConfig( |
|
client_id=_env("GOOGLE_CLIENT_ID"), |
|
client_secret=_env("GOOGLE_CLIENT_SECRET"), |
|
authorize_url="https://accounts.google.com/o/oauth2/auth", |
|
token_url="https://oauth2.googleapis.com/token", |
|
scope="openid email profile https://www.googleapis.com/auth/drive.readonly", |
|
), |
|
"slack": ProviderConfig( |
|
client_id=_env("SLACK_CLIENT_ID"), |
|
client_secret=_env("SLACK_CLIENT_SECRET"), |
|
authorize_url="https://slack.com/oauth/v2/authorize", |
|
token_url="https://slack.com/api/oauth.v2.access", |
|
scope="channels:read chat:write", |
|
), |
|
} |
|
|
|
|
|
|
|
|
|
|
|
class OAuthManager: |
|
"""Tiny wrapper around Authlib’s OAuth2Session per provider.""" |
|
|
|
def __init__(self, providers: Dict[str, ProviderConfig]): |
|
self.providers = providers |
|
|
|
|
|
def _session(self, provider: str, redirect_uri: str) -> OAuth2Session: |
|
cfg = self.providers.get(provider) |
|
if not cfg: |
|
raise KeyError(f"Unsupported provider '{provider}'.") |
|
if not (cfg.client_id and cfg.client_secret): |
|
raise RuntimeError( |
|
f"OAuth credentials for '{provider}' are missing. " |
|
"Set the *_CLIENT_ID and *_CLIENT_SECRET env‑vars." |
|
) |
|
return OAuth2Session( |
|
client_id=cfg.client_id, |
|
client_secret=cfg.client_secret, |
|
scope=cfg.scope, |
|
redirect_uri=redirect_uri, |
|
) |
|
|
|
|
|
def get_authorization_url( |
|
self, provider: str, redirect_uri: str, state: Optional[str] = None |
|
) -> Tuple[str, str]: |
|
""" |
|
Return (auth_url, state) for the given provider. |
|
|
|
Pass the *state* back into `fetch_token` to mitigate CSRF. |
|
""" |
|
sess = self._session(provider, redirect_uri) |
|
cfg = self.providers[provider] |
|
auth_url, final_state = sess.create_authorization_url( |
|
cfg.authorize_url, state=state |
|
) |
|
return auth_url, final_state |
|
|
|
def fetch_token( |
|
self, provider: str, redirect_uri: str, authorization_response: str |
|
) -> Dict: |
|
""" |
|
Exchange ?code=… for an access token. |
|
|
|
Returns the token dict from Authlib (includes access_token, |
|
refresh_token, expires_in, etc.). |
|
""" |
|
sess = self._session(provider, redirect_uri) |
|
cfg = self.providers[provider] |
|
return sess.fetch_token( |
|
cfg.token_url, |
|
authorization_response=authorization_response, |
|
client_secret=cfg.client_secret, |
|
) |
|
|
|
|
|
|
|
oauth_manager = OAuthManager(PROVIDERS) |
|
|