File size: 4,537 Bytes
c33cb65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ae3bf6
c33cb65
 
 
4ae3bf6
 
c33cb65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ae3bf6
c33cb65
 
 
 
 
 
 
4ae3bf6
c33cb65
 
 
 
 
 
 
4ae3bf6
c33cb65
 
 
 
4ae3bf6
 
c33cb65
 
 
 
4ae3bf6
c33cb65
4ae3bf6
c33cb65
4ae3bf6
c33cb65
 
 
 
 
 
4ae3bf6
c33cb65
 
 
4ae3bf6
c33cb65
4ae3bf6
c33cb65
 
4ae3bf6
c33cb65
 
 
 
 
 
4ae3bf6
 
c33cb65
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# auth.py
"""
Centralised OAuth 2.0 helper for GitHub, Google Drive and Slack.

Usage
-----
from auth import oauth_manager

# 1)  Redirect user to consent page
auth_url, state = oauth_manager.get_authorization_url(
    provider="github",
    redirect_uri="https://your‑app.com/callback"
)

# 2)  In your callback handler, exchange the code for a token
token = oauth_manager.fetch_token(
    provider="github",
    redirect_uri="https://your‑app.com/callback",
    authorization_response=request.url  # full URL with ?code=...
)
"""

from __future__ import annotations

import os
from dataclasses import dataclass
from typing import Dict, Optional, Tuple

from authlib.integrations.requests_client import OAuth2Session


# ------------------------------------------------------------------ #
# 1  Provider configuration
# ------------------------------------------------------------------ #
@dataclass(frozen=True)
class ProviderConfig:
    client_id: str | None
    client_secret: str | None
    authorize_url: str
    token_url: str
    scope: str


def _env(name: str) -> str | None:
    """Shorthand for os.getenv with strip()."""
    val = os.getenv(name)
    return val.strip() if val else None


PROVIDERS: Dict[str, ProviderConfig] = {
    "github": ProviderConfig(
        client_id=_env("GITHUB_CLIENT_ID"),
        client_secret=_env("GITHUB_CLIENT_SECRET"),
        authorize_url="https://github.com/login/oauth/authorize",
        token_url="https://github.com/login/oauth/access_token",
        scope="repo read:org",
    ),
    "google": ProviderConfig(
        client_id=_env("GOOGLE_CLIENT_ID"),
        client_secret=_env("GOOGLE_CLIENT_SECRET"),
        authorize_url="https://accounts.google.com/o/oauth2/auth",
        token_url="https://oauth2.googleapis.com/token",
        scope="openid email profile https://www.googleapis.com/auth/drive.readonly",
    ),
    "slack": ProviderConfig(
        client_id=_env("SLACK_CLIENT_ID"),
        client_secret=_env("SLACK_CLIENT_SECRET"),
        authorize_url="https://slack.com/oauth/v2/authorize",
        token_url="https://slack.com/api/oauth.v2.access",
        scope="channels:read chat:write",
    ),
}


# ------------------------------------------------------------------ #
# 2  OAuth manager
# ------------------------------------------------------------------ #
class OAuthManager:
    """Tiny wrapper around Authlib’s OAuth2Session per provider."""

    def __init__(self, providers: Dict[str, ProviderConfig]):
        self.providers = providers

    # ---------- helpers -------------------------------------------------
    def _session(self, provider: str, redirect_uri: str) -> OAuth2Session:
        cfg = self.providers.get(provider)
        if not cfg:
            raise KeyError(f"Unsupported provider '{provider}'.")
        if not (cfg.client_id and cfg.client_secret):
            raise RuntimeError(
                f"OAuth credentials for '{provider}' are missing. "
                "Set the *_CLIENT_ID and *_CLIENT_SECRET env‑vars."
            )
        return OAuth2Session(
            client_id=cfg.client_id,
            client_secret=cfg.client_secret,
            scope=cfg.scope,
            redirect_uri=redirect_uri,
        )

    # ---------- public API ----------------------------------------------
    def get_authorization_url(
        self, provider: str, redirect_uri: str, state: Optional[str] = None
    ) -> Tuple[str, str]:
        """
        Return (auth_url, state) for the given provider.

        Pass the *state* back into `fetch_token` to mitigate CSRF.
        """
        sess = self._session(provider, redirect_uri)
        cfg = self.providers[provider]
        auth_url, final_state = sess.create_authorization_url(
            cfg.authorize_url, state=state
        )
        return auth_url, final_state

    def fetch_token(
        self, provider: str, redirect_uri: str, authorization_response: str
    ) -> Dict:
        """
        Exchange ?code=… for an access token.

        Returns the token dict from Authlib (includes access_token,
        refresh_token, expires_in, etc.).
        """
        sess = self._session(provider, redirect_uri)
        cfg = self.providers[provider]
        return sess.fetch_token(
            cfg.token_url,
            authorization_response=authorization_response,
            client_secret=cfg.client_secret,  # some providers require it explicitly
        )


# Singleton instance used throughout the app
oauth_manager = OAuthManager(PROVIDERS)