Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
JWT Authentication for Warp API | |
Handles JWT token management, refresh, and validation. | |
Integrates functionality from refresh_jwt.py. | |
仅限匿名账户,JWT和refresh token保存在内存中。 | |
""" | |
import base64 | |
import json | |
import os | |
import time | |
from pathlib import Path | |
import httpx | |
import asyncio | |
from dotenv import load_dotenv, set_key | |
from ..config.settings import REFRESH_TOKEN_B64, REFRESH_URL, CLIENT_VERSION, OS_CATEGORY, OS_NAME, OS_VERSION | |
from .logging import logger, log | |
# 内存中的token存储(仅限匿名账户) | |
_memory_jwt_token: str = "" | |
_memory_refresh_token: str = "" | |
_is_anonymous_user: bool = False | |
def set_memory_jwt_token(token: str, is_anonymous: bool = False) -> None: | |
"""设置内存中的JWT token(仅限匿名用户)""" | |
global _memory_jwt_token, _is_anonymous_user | |
if is_anonymous: | |
_memory_jwt_token = token | |
_is_anonymous_user = True | |
logger.info("JWT token已保存到内存(匿名用户)") | |
else: | |
logger.debug("非匿名用户,JWT token不保存到内存") | |
def get_memory_jwt_token() -> str: | |
"""获取内存中的JWT token""" | |
global _memory_jwt_token | |
return _memory_jwt_token | |
def set_memory_refresh_token(token: str, is_anonymous: bool = False) -> None: | |
"""设置内存中的refresh token(仅限匿名用户)""" | |
global _memory_refresh_token, _is_anonymous_user | |
if is_anonymous: | |
_memory_refresh_token = token | |
_is_anonymous_user = True | |
logger.info("Refresh token已保存到内存(匿名用户)") | |
else: | |
logger.debug("非匿名用户,refresh token不保存到内存") | |
def get_memory_refresh_token() -> str: | |
"""获取内存中的refresh token""" | |
global _memory_refresh_token | |
return _memory_refresh_token | |
def is_anonymous_user() -> bool: | |
"""检查当前是否为匿名用户""" | |
global _is_anonymous_user | |
return _is_anonymous_user | |
def decode_jwt_payload(token: str) -> dict: | |
"""Decode JWT payload to check expiration""" | |
try: | |
parts = token.split('.') | |
if len(parts) != 3: | |
return {} | |
payload_b64 = parts[1] | |
padding = 4 - len(payload_b64) % 4 | |
if padding != 4: | |
payload_b64 += '=' * padding | |
payload_bytes = base64.urlsafe_b64decode(payload_b64) | |
payload = json.loads(payload_bytes.decode('utf-8')) | |
return payload | |
except Exception as e: | |
logger.debug(f"Error decoding JWT: {e}") | |
return {} | |
def is_token_expired(token: str, buffer_minutes: int = 5) -> bool: | |
payload = decode_jwt_payload(token) | |
if not payload or 'exp' not in payload: | |
return True | |
expiry_time = payload['exp'] | |
current_time = time.time() | |
buffer_time = buffer_minutes * 60 | |
return (expiry_time - current_time) <= buffer_time | |
async def refresh_jwt_token() -> dict: | |
"""Refresh the JWT token using the refresh token. | |
优先级:内存中的refresh token(匿名用户) > 环境变量 > 默认token | |
""" | |
logger.info("Refreshing JWT token...") | |
# 优先使用内存中的refresh token(匿名用户) | |
memory_refresh = get_memory_refresh_token() | |
if memory_refresh: | |
logger.info("使用内存中的refresh token(匿名用户)") | |
payload = f"grant_type=refresh_token&refresh_token={memory_refresh}".encode("utf-8") | |
else: | |
# 其次使用环境变量中的refresh token | |
env_refresh = os.getenv("WARP_REFRESH_TOKEN") | |
if env_refresh: | |
logger.info("使用环境变量中的refresh token") | |
payload = f"grant_type=refresh_token&refresh_token={env_refresh}".encode("utf-8") | |
else: | |
logger.info("使用默认的refresh token") | |
payload = base64.b64decode(REFRESH_TOKEN_B64) | |
headers = { | |
"x-warp-client-version": CLIENT_VERSION, | |
"x-warp-os-category": OS_CATEGORY, | |
"x-warp-os-name": OS_NAME, | |
"x-warp-os-version": OS_VERSION, | |
"content-type": "application/x-www-form-urlencoded", | |
"accept": "*/*", | |
"accept-encoding": "gzip, br", | |
"content-length": str(len(payload)) | |
} | |
try: | |
async with httpx.AsyncClient(timeout=30.0) as client: | |
response = await client.post( | |
REFRESH_URL, | |
headers=headers, | |
content=payload | |
) | |
if response.status_code == 200: | |
token_data = response.json() | |
logger.info("Token refresh successful") | |
return token_data | |
else: | |
logger.error(f"Token refresh failed: {response.status_code}") | |
logger.error(f"Response: {response.text}") | |
return {} | |
except Exception as e: | |
logger.error(f"Error refreshing token: {e}") | |
return {} | |
def update_env_file(new_jwt: str, is_anonymous: bool = False) -> bool: | |
"""更新JWT token,匿名用户保存到内存,普通用户保存到.env文件""" | |
if is_anonymous: | |
set_memory_jwt_token(new_jwt, is_anonymous=True) | |
return True | |
else: | |
env_path = Path(".env") | |
try: | |
set_key(str(env_path), "WARP_JWT", new_jwt) | |
logger.info("Updated .env file with new JWT token") | |
return True | |
except Exception as e: | |
logger.error(f"Error updating .env file: {e}") | |
# 如果.env文件更新失败,尝试保存到内存 | |
logger.warning("尝试将JWT token保存到内存作为备用方案") | |
set_memory_jwt_token(new_jwt, is_anonymous=True) | |
return True | |
def update_env_refresh_token(refresh_token: str, is_anonymous: bool = False) -> bool: | |
"""更新refresh token,匿名用户保存到内存,普通用户保存到.env文件""" | |
if is_anonymous: | |
set_memory_refresh_token(refresh_token, is_anonymous=True) | |
return True | |
else: | |
env_path = Path(".env") | |
try: | |
set_key(str(env_path), "WARP_REFRESH_TOKEN", refresh_token) | |
logger.info("Updated .env with WARP_REFRESH_TOKEN") | |
return True | |
except Exception as e: | |
logger.error(f"Error updating .env WARP_REFRESH_TOKEN: {e}") | |
# 如果.env文件更新失败,尝试保存到内存 | |
logger.warning("尝试将refresh token保存到内存作为备用方案") | |
set_memory_refresh_token(refresh_token, is_anonymous=True) | |
return True | |
async def check_and_refresh_token() -> bool: | |
# 优先检查内存中的JWT token(匿名用户) | |
current_jwt = get_memory_jwt_token() | |
is_anon = is_anonymous_user() | |
if not current_jwt: | |
# 如果内存中没有,检查环境变量 | |
current_jwt = os.getenv("WARP_JWT") | |
if not current_jwt: | |
logger.warning("No JWT token found in memory or environment") | |
token_data = await refresh_jwt_token() | |
if token_data and "access_token" in token_data: | |
# 如果使用的是内存中的refresh token,说明是匿名用户 | |
memory_refresh = get_memory_refresh_token() | |
is_anonymous = bool(memory_refresh) | |
return update_env_file(token_data["access_token"], is_anonymous=is_anonymous) | |
return False | |
logger.debug("Checking current JWT token expiration...") | |
if is_token_expired(current_jwt, buffer_minutes=15): | |
logger.info("JWT token is expired or expiring soon, refreshing...") | |
token_data = await refresh_jwt_token() | |
if token_data and "access_token" in token_data: | |
new_jwt = token_data["access_token"] | |
if not is_token_expired(new_jwt, buffer_minutes=0): | |
logger.info("New token is valid") | |
# 判断是否为匿名用户 | |
memory_refresh = get_memory_refresh_token() | |
is_anonymous = bool(memory_refresh) or is_anon | |
return update_env_file(new_jwt, is_anonymous=is_anonymous) | |
else: | |
logger.warning("New token appears to be invalid or expired") | |
return False | |
else: | |
logger.error("Failed to get new token from refresh") | |
return False | |
else: | |
payload = decode_jwt_payload(current_jwt) | |
if payload and 'exp' in payload: | |
expiry_time = payload['exp'] | |
time_left = expiry_time - time.time() | |
hours_left = time_left / 3600 | |
logger.debug(f"Current token is still valid ({hours_left:.1f} hours remaining)") | |
else: | |
logger.debug("Current token appears valid") | |
return True | |
async def get_valid_jwt() -> str: | |
from dotenv import load_dotenv as _load | |
_load(override=True) | |
# 优先检查内存中的JWT token(匿名用户) | |
jwt = get_memory_jwt_token() | |
if not jwt: | |
# 如果内存中没有,检查环境变量 | |
jwt = os.getenv("WARP_JWT") | |
if not jwt: | |
logger.info("No JWT token found in memory or environment, attempting to refresh...") | |
if await check_and_refresh_token(): | |
# 刷新后重新获取token | |
jwt = get_memory_jwt_token() | |
if not jwt: | |
_load(override=True) | |
jwt = os.getenv("WARP_JWT") | |
if not jwt: | |
raise RuntimeError("WARP_JWT is not set and refresh failed") | |
if is_token_expired(jwt, buffer_minutes=2): | |
logger.info("JWT token is expired or expiring soon, attempting to refresh...") | |
if await check_and_refresh_token(): | |
# 刷新后重新获取token | |
jwt = get_memory_jwt_token() | |
if not jwt: | |
_load(override=True) | |
jwt = os.getenv("WARP_JWT") | |
if not jwt or is_token_expired(jwt, buffer_minutes=0): | |
logger.warning("Warning: New token has short expiry but proceeding anyway") | |
else: | |
logger.warning("Warning: JWT token refresh failed, trying to use existing token") | |
return jwt | |
def get_jwt_token() -> str: | |
"""获取JWT token,优先从内存获取(匿名用户),然后从环境变量获取""" | |
# 优先检查内存中的JWT token(匿名用户) | |
jwt = get_memory_jwt_token() | |
if jwt: | |
return jwt | |
# 如果内存中没有,检查环境变量 | |
from dotenv import load_dotenv as _load | |
_load() | |
return os.getenv("WARP_JWT", "") | |
async def refresh_jwt_if_needed() -> bool: | |
try: | |
return await check_and_refresh_token() | |
except Exception as e: | |
logger.error(f"JWT refresh failed: {e}") | |
return False | |
# ============ Anonymous token acquisition (quota refresh) ============ | |
_ANON_GQL_URL = "https://app.warp.dev/graphql/v2?op=CreateAnonymousUser" | |
_IDENTITY_TOOLKIT_BASE = "https://identitytoolkit.googleapis.com/v1/accounts:signInWithCustomToken" | |
def _extract_google_api_key_from_refresh_url() -> str: | |
try: | |
# REFRESH_URL like: https://app.warp.dev/proxy/token?key=API_KEY | |
from urllib.parse import urlparse, parse_qs | |
parsed = urlparse(REFRESH_URL) | |
qs = parse_qs(parsed.query) | |
key = qs.get("key", [""])[0] | |
return key | |
except Exception: | |
return "" | |
async def _create_anonymous_user() -> dict: | |
headers = { | |
"accept-encoding": "gzip, br", | |
"content-type": "application/json", | |
"x-warp-client-version": CLIENT_VERSION, | |
"x-warp-os-category": OS_CATEGORY, | |
"x-warp-os-name": OS_NAME, | |
"x-warp-os-version": OS_VERSION, | |
} | |
# GraphQL payload per anonymous.MD | |
query = ( | |
"mutation CreateAnonymousUser($input: CreateAnonymousUserInput!, $requestContext: RequestContext!) {\n" | |
" createAnonymousUser(input: $input, requestContext: $requestContext) {\n" | |
" __typename\n" | |
" ... on CreateAnonymousUserOutput {\n" | |
" expiresAt\n" | |
" anonymousUserType\n" | |
" firebaseUid\n" | |
" idToken\n" | |
" isInviteValid\n" | |
" responseContext { serverVersion }\n" | |
" }\n" | |
" ... on UserFacingError {\n" | |
" error { __typename message }\n" | |
" responseContext { serverVersion }\n" | |
" }\n" | |
" }\n" | |
"}\n" | |
) | |
variables = { | |
"input": { | |
"anonymousUserType": "NATIVE_CLIENT_ANONYMOUS_USER_FEATURE_GATED", | |
"expirationType": "NO_EXPIRATION", | |
"referralCode": None | |
}, | |
"requestContext": { | |
"clientContext": {"version": CLIENT_VERSION}, | |
"osContext": { | |
"category": OS_CATEGORY, | |
"linuxKernelVersion": None, | |
"name": OS_NAME, | |
"version": OS_VERSION, | |
} | |
} | |
} | |
body = {"query": query, "variables": variables, "operationName": "CreateAnonymousUser"} | |
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client: | |
resp = await client.post(_ANON_GQL_URL, headers=headers, json=body) | |
if resp.status_code != 200: | |
raise RuntimeError(f"CreateAnonymousUser failed: HTTP {resp.status_code} {resp.text[:200]}") | |
data = resp.json() | |
return data | |
async def _exchange_id_token_for_refresh_token(id_token: str) -> dict: | |
key = _extract_google_api_key_from_refresh_url() | |
url = f"{_IDENTITY_TOOLKIT_BASE}?key={key}" if key else f"{_IDENTITY_TOOLKIT_BASE}?key=AIzaSyBdy3O3S9hrdayLJxJ7mriBR4qgUaUygAs" | |
headers = { | |
"accept-encoding": "gzip, br", | |
"content-type": "application/x-www-form-urlencoded", | |
"x-warp-client-version": CLIENT_VERSION, | |
"x-warp-os-category": OS_CATEGORY, | |
"x-warp-os-name": OS_NAME, | |
"x-warp-os-version": OS_VERSION, | |
} | |
form = { | |
"returnSecureToken": "true", | |
"token": id_token, | |
} | |
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client: | |
resp = await client.post(url, headers=headers, data=form) | |
if resp.status_code != 200: | |
raise RuntimeError(f"signInWithCustomToken failed: HTTP {resp.status_code} {resp.text[:200]}") | |
return resp.json() | |
async def acquire_anonymous_access_token() -> str: | |
"""Acquire a new anonymous access token (quota refresh) and persist to memory. | |
Returns the new access token string. Raises on failure. | |
""" | |
logger.info("Acquiring anonymous access token via GraphQL + Identity Toolkit…") | |
data = await _create_anonymous_user() | |
id_token = None | |
try: | |
id_token = data["data"]["createAnonymousUser"].get("idToken") | |
except Exception: | |
pass | |
if not id_token: | |
raise RuntimeError(f"CreateAnonymousUser did not return idToken: {data}") | |
signin = await _exchange_id_token_for_refresh_token(id_token) | |
refresh_token = signin.get("refreshToken") | |
if not refresh_token: | |
raise RuntimeError(f"signInWithCustomToken did not return refreshToken: {signin}") | |
# 匿名用户的refresh token保存到内存中 | |
update_env_refresh_token(refresh_token, is_anonymous=True) | |
# Now call Warp proxy token endpoint to get access_token using this refresh token | |
payload = f"grant_type=refresh_token&refresh_token={refresh_token}".encode("utf-8") | |
headers = { | |
"x-warp-client-version": CLIENT_VERSION, | |
"x-warp-os-category": OS_CATEGORY, | |
"x-warp-os-name": OS_NAME, | |
"x-warp-os-version": OS_VERSION, | |
"content-type": "application/x-www-form-urlencoded", | |
"accept": "*/*", | |
"accept-encoding": "gzip, br", | |
"content-length": str(len(payload)) | |
} | |
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client: | |
resp = await client.post(REFRESH_URL, headers=headers, content=payload) | |
if resp.status_code != 200: | |
raise RuntimeError(f"Acquire access_token failed: HTTP {resp.status_code} {resp.text[:200]}") | |
token_data = resp.json() | |
access = token_data.get("access_token") | |
if not access: | |
raise RuntimeError(f"No access_token in response: {token_data}") | |
# 匿名用户的access token保存到内存中 | |
update_env_file(access, is_anonymous=True) | |
return access | |
def print_token_info(): | |
# 优先检查内存中的JWT token | |
current_jwt = get_memory_jwt_token() | |
if not current_jwt: | |
current_jwt = os.getenv("WARP_JWT") | |
if not current_jwt: | |
logger.info("No JWT token found in memory or environment") | |
return | |
payload = decode_jwt_payload(current_jwt) | |
if not payload: | |
logger.info("Cannot decode JWT token") | |
return | |
logger.info("=== JWT Token Information ===") | |
if 'email' in payload: | |
logger.info(f"Email: {payload['email']}") | |
if 'user_id' in payload: | |
logger.info(f"User ID: {payload['user_id']}") |