builder / auth.py
mgbam's picture
Update auth.py
c33cb65 verified
# auth.py
"""
Centralised OAuth 2.0 helper for GitHub, Google Drive and Slack.
Usage
-----
from auth import oauth_manager
# 1) Redirect user to consent page
auth_url, state = oauth_manager.get_authorization_url(
provider="github",
redirect_uri="https://your‑app.com/callback"
)
# 2) In your callback handler, exchange the code for a token
token = oauth_manager.fetch_token(
provider="github",
redirect_uri="https://your‑app.com/callback",
authorization_response=request.url # full URL with ?code=...
)
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
from authlib.integrations.requests_client import OAuth2Session
# ------------------------------------------------------------------ #
# 1  Provider configuration
# ------------------------------------------------------------------ #
@dataclass(frozen=True)
class ProviderConfig:
client_id: str | None
client_secret: str | None
authorize_url: str
token_url: str
scope: str
def _env(name: str) -> str | None:
"""Shorthand for os.getenv with strip()."""
val = os.getenv(name)
return val.strip() if val else None
PROVIDERS: Dict[str, ProviderConfig] = {
"github": ProviderConfig(
client_id=_env("GITHUB_CLIENT_ID"),
client_secret=_env("GITHUB_CLIENT_SECRET"),
authorize_url="https://github.com/login/oauth/authorize",
token_url="https://github.com/login/oauth/access_token",
scope="repo read:org",
),
"google": ProviderConfig(
client_id=_env("GOOGLE_CLIENT_ID"),
client_secret=_env("GOOGLE_CLIENT_SECRET"),
authorize_url="https://accounts.google.com/o/oauth2/auth",
token_url="https://oauth2.googleapis.com/token",
scope="openid email profile https://www.googleapis.com/auth/drive.readonly",
),
"slack": ProviderConfig(
client_id=_env("SLACK_CLIENT_ID"),
client_secret=_env("SLACK_CLIENT_SECRET"),
authorize_url="https://slack.com/oauth/v2/authorize",
token_url="https://slack.com/api/oauth.v2.access",
scope="channels:read chat:write",
),
}
# ------------------------------------------------------------------ #
# 2  OAuth manager
# ------------------------------------------------------------------ #
class OAuthManager:
"""Tiny wrapper around Authlib’s OAuth2Session per provider."""
def __init__(self, providers: Dict[str, ProviderConfig]):
self.providers = providers
# ---------- helpers -------------------------------------------------
def _session(self, provider: str, redirect_uri: str) -> OAuth2Session:
cfg = self.providers.get(provider)
if not cfg:
raise KeyError(f"Unsupported provider '{provider}'.")
if not (cfg.client_id and cfg.client_secret):
raise RuntimeError(
f"OAuth credentials for '{provider}' are missing. "
"Set the *_CLIENT_ID and *_CLIENT_SECRET env‑vars."
)
return OAuth2Session(
client_id=cfg.client_id,
client_secret=cfg.client_secret,
scope=cfg.scope,
redirect_uri=redirect_uri,
)
# ---------- public API ----------------------------------------------
def get_authorization_url(
self, provider: str, redirect_uri: str, state: Optional[str] = None
) -> Tuple[str, str]:
"""
Return (auth_url, state) for the given provider.
Pass the *state* back into `fetch_token` to mitigate CSRF.
"""
sess = self._session(provider, redirect_uri)
cfg = self.providers[provider]
auth_url, final_state = sess.create_authorization_url(
cfg.authorize_url, state=state
)
return auth_url, final_state
def fetch_token(
self, provider: str, redirect_uri: str, authorization_response: str
) -> Dict:
"""
Exchange ?code=… for an access token.
Returns the token dict from Authlib (includes access_token,
refresh_token, expires_in, etc.).
"""
sess = self._session(provider, redirect_uri)
cfg = self.providers[provider]
return sess.fetch_token(
cfg.token_url,
authorization_response=authorization_response,
client_secret=cfg.client_secret, # some providers require it explicitly
)
# Singleton instance used throughout the app
oauth_manager = OAuthManager(PROVIDERS)