Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Warp Protobuf 桥接服务器启动文件 | |
提供 Warp API 桥接端点(/api/warp/send_stream、/api/warp/send_stream_sse)、/healthz 与 /v1/models。已移除 GUI、静态文件与 WebSocket 功能。 | |
""" | |
import os | |
import asyncio | |
import json | |
from pathlib import Path | |
import uvicorn | |
from fastapi import FastAPI, HTTPException | |
from contextlib import asynccontextmanager | |
# 新增:类型导入 | |
from typing import Any, Dict, List | |
from warp2protobuf.api.protobuf_routes import app as protobuf_app | |
from warp2protobuf.core.logging import logger | |
from warp2protobuf.core.auth import acquire_anonymous_access_token | |
from warp2protobuf.config.models import get_all_unique_models | |
# 导入OpenAI兼容路由 | |
from protobuf2openai.router import router as openai_router | |
# ============= JSON Schema 清理函数已移至 warp2protobuf.core.schema_sanitizer 模块 ============= | |
# ============= 应用创建 ============= | |
def create_app() -> FastAPI: | |
"""创建FastAPI应用""" | |
# 使用protobuf路由的应用作为主应用 | |
app = protobuf_app | |
# 挂载OpenAI兼容路由 | |
app.include_router(openai_router) | |
# 挂载输入 schema 清理中间件(覆盖 Warp 相关端点) | |
# ============= OpenAI 兼容:模型列表接口(通过router提供) ============= | |
return app | |
def create_app_with_lifespan() -> FastAPI: | |
"""创建带有lifespan事件处理的FastAPI应用""" | |
app = FastAPI( | |
title="Warp Protobuf Bridge Server", | |
description="Warp API 桥接服务器,提供 Protobuf 编解码与 OpenAI 兼容接口", | |
version="1.0.0", | |
lifespan=lifespan | |
) | |
# 添加CORS中间件 | |
from fastapi.middleware.cors import CORSMiddleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# 挂载子应用 | |
app.mount("/api", protobuf_app) | |
# 包含OpenAI兼容路由 | |
app.include_router(openai_router) | |
# ============= 根路径与健康检查 ============= | |
async def root(): | |
return { | |
"service": "Warp Protobuf Bridge Server", | |
"status": "running", | |
"endpoints": { | |
"health": "/healthz", | |
"models": "/v1/models", | |
"protobuf_bridge": "/api/warp/send_stream", | |
"sse_bridge": "/api/warp/send_stream_sse", | |
"auth_refresh": "/api/auth/refresh" | |
} | |
} | |
async def health_check(): | |
return {"status": "ok", "service": "Warp Protobuf Bridge Server"} | |
return app | |
############################################################ | |
# server_message_data 深度编解码工具 | |
############################################################ | |
# 说明: | |
# 根据抓包与分析,server_message_data 是 Base64URL 编码的 proto3 消息: | |
# - 字段 1:string(通常为 36 字节 UUID) | |
# - 字段 3:google.protobuf.Timestamp(字段1=seconds,字段2=nanos) | |
# 可能出现:仅 Timestamp、仅 UUID、或 UUID + Timestamp。 | |
from typing import Dict, Optional, Tuple | |
import base64 | |
from datetime import datetime, timezone | |
try: | |
from zoneinfo import ZoneInfo # Python 3.9+ | |
except Exception: | |
ZoneInfo = None # type: ignore | |
def _b64url_decode_padded(s: str) -> bytes: | |
t = s.replace("-", "+").replace("_", "/") | |
pad = (-len(t)) % 4 | |
if pad: | |
t += "=" * pad | |
return base64.b64decode(t) | |
def _b64url_encode_nopad(b: bytes) -> str: | |
return base64.urlsafe_b64encode(b).decode("ascii").rstrip("=") | |
def _read_varint(buf: bytes, i: int) -> Tuple[int, int]: | |
shift = 0 | |
val = 0 | |
while i < len(buf): | |
b = buf[i] | |
i += 1 | |
val |= (b & 0x7F) << shift | |
if not (b & 0x80): | |
return val, i | |
shift += 7 | |
if shift > 63: | |
break | |
raise ValueError("invalid varint") | |
def _write_varint(v: int) -> bytes: | |
out = bytearray() | |
vv = int(v) | |
while True: | |
to_write = vv & 0x7F | |
vv >>= 7 | |
if vv: | |
out.append(to_write | 0x80) | |
else: | |
out.append(to_write) | |
break | |
return bytes(out) | |
def _make_key(field_no: int, wire_type: int) -> bytes: | |
return _write_varint((field_no << 3) | wire_type) | |
def _decode_timestamp(buf: bytes) -> Tuple[Optional[int], Optional[int]]: | |
# google.protobuf.Timestamp: field 1 = seconds (int64 varint), field 2 = nanos (int32 varint) | |
i = 0 | |
seconds: Optional[int] = None | |
nanos: Optional[int] = None | |
while i < len(buf): | |
key, i = _read_varint(buf, i) | |
field_no = key >> 3 | |
wt = key & 0x07 | |
if wt == 0: # varint | |
val, i = _read_varint(buf, i) | |
if field_no == 1: | |
seconds = int(val) | |
elif field_no == 2: | |
nanos = int(val) | |
elif wt == 2: # length-delimited (not expected inside Timestamp) | |
ln, i2 = _read_varint(buf, i) | |
i = i2 + ln | |
elif wt == 1: | |
i += 8 | |
elif wt == 5: | |
i += 4 | |
else: | |
break | |
return seconds, nanos | |
def _encode_timestamp(seconds: Optional[int], nanos: Optional[int]) -> bytes: | |
parts = bytearray() | |
if seconds is not None: | |
parts += _make_key(1, 0) # field 1, varint | |
parts += _write_varint(int(seconds)) | |
if nanos is not None: | |
parts += _make_key(2, 0) # field 2, varint | |
parts += _write_varint(int(nanos)) | |
return bytes(parts) | |
def decode_server_message_data(b64url: str) -> Dict: | |
"""解码 Base64URL 的 server_message_data,返回结构化信息。""" | |
try: | |
raw = _b64url_decode_padded(b64url) | |
except Exception as e: | |
return {"error": f"base64url decode failed: {e}", "raw_b64url": b64url} | |
i = 0 | |
uuid: Optional[str] = None | |
seconds: Optional[int] = None | |
nanos: Optional[int] = None | |
while i < len(raw): | |
key, i = _read_varint(raw, i) | |
field_no = key >> 3 | |
wt = key & 0x07 | |
if wt == 2: # length-delimited | |
ln, i2 = _read_varint(raw, i) | |
i = i2 | |
data = raw[i:i+ln] | |
i += ln | |
if field_no == 1: # uuid string | |
try: | |
uuid = data.decode("utf-8") | |
except Exception: | |
uuid = None | |
elif field_no == 3: # google.protobuf.Timestamp | |
seconds, nanos = _decode_timestamp(data) | |
elif wt == 0: # varint -> not expected, skip | |
_, i = _read_varint(raw, i) | |
elif wt == 1: | |
i += 8 | |
elif wt == 5: | |
i += 4 | |
else: | |
break | |
out: Dict[str, Any] = {} | |
if uuid is not None: | |
out["uuid"] = uuid | |
if seconds is not None: | |
out["seconds"] = seconds | |
if nanos is not None: | |
out["nanos"] = nanos | |
return out | |
def encode_server_message_data(uuid: str = None, seconds: int = None, nanos: int = None) -> str: | |
"""将 uuid/seconds/nanos 组合编码为 Base64URL 字符串。""" | |
parts = bytearray() | |
if uuid: | |
b = uuid.encode("utf-8") | |
parts += _make_key(1, 2) # field 1, length-delimited | |
parts += _write_varint(len(b)) | |
parts += b | |
if seconds is not None or nanos is not None: | |
ts = _encode_timestamp(seconds, nanos) | |
parts += _make_key(3, 2) # field 3, length-delimited | |
parts += _write_varint(len(ts)) | |
parts += ts | |
return _b64url_encode_nopad(bytes(parts)) | |
async def startup_tasks(): | |
"""启动时执行的任务""" | |
logger.info("="*60) | |
logger.info("Warp Protobuf编解码服务器启动") | |
logger.info("="*60) | |
# 检查protobuf运行时 | |
try: | |
from warp2protobuf.core.protobuf import ensure_proto_runtime | |
ensure_proto_runtime() | |
logger.info("✅ Protobuf运行时初始化成功") | |
except Exception as e: | |
logger.error(f"❌ Protobuf运行时初始化失败: {e}") | |
raise | |
# 检查JWT token | |
try: | |
from warp2protobuf.core.auth import get_jwt_token, is_token_expired, refresh_jwt_if_needed | |
token = get_jwt_token() | |
if token and not is_token_expired(token): | |
logger.info("✅ JWT token有效") | |
elif not token: | |
logger.warning("⚠️ 未找到JWT token,尝试申请匿名访问token用于额度初始化…") | |
try: | |
new_token = await acquire_anonymous_access_token() | |
if new_token: | |
logger.info("✅ 匿名访问token申请成功") | |
else: | |
logger.warning("⚠️ 匿名访问token申请失败") | |
except Exception as e2: | |
logger.warning(f"⚠️ 匿名访问token申请异常: {e2}") | |
else: | |
logger.warning("⚠️ JWT token无效或已过期,尝试自动刷新…") | |
try: | |
refresh_success = await refresh_jwt_if_needed() | |
if refresh_success: | |
logger.info("✅ JWT token自动刷新成功") | |
else: | |
logger.warning("⚠️ JWT token自动刷新失败,建议手动运行: uv run refresh_jwt.py") | |
except Exception as e3: | |
logger.warning(f"⚠️ JWT token自动刷新异常: {e3},建议手动运行: uv run refresh_jwt.py") | |
except Exception as e: | |
logger.warning(f"⚠️ JWT检查失败: {e}") | |
# OpenAI 兼容层已集成到当前服务器中 | |
# 显示可用端点 | |
logger.info("-"*40) | |
logger.info("可用的API端点:") | |
logger.info(" GET / - 服务信息") | |
logger.info(" GET /healthz - 健康检查") | |
logger.info(" GET /v1/models - 模型列表(OpenAI兼容)") | |
logger.info(" POST /api/warp/send_stream - Warp API 转发(返回解析事件)") | |
logger.info(" POST /api/warp/send_stream_sse - Warp API 转发(实时SSE)") | |
logger.info(" POST /api/auth/refresh - 刷新JWT token") | |
logger.info("="*60) | |
async def lifespan(app: FastAPI): | |
"""应用生命周期管理""" | |
# 启动时执行 | |
await startup_tasks() | |
yield | |
# 关闭时执行(如果需要的话) | |
def main(): | |
"""主函数""" | |
# 创建应用(使用lifespan) | |
app = create_app_with_lifespan() | |
# 启动服务器 | |
try: | |
uvicorn.run( | |
app, | |
host="0.0.0.0", | |
port=8000, | |
log_level="info", | |
access_log=True | |
) | |
except KeyboardInterrupt: | |
logger.info("服务器被用户停止") | |
except Exception as e: | |
logger.error(f"服务器启动失败: {e}") | |
raise | |
if __name__ == "__main__": | |
main() | |