Spaces:
Running
Running
from __future__ import annotations | |
import asyncio | |
import json | |
import time | |
import uuid | |
from typing import Any, Dict, List, Optional | |
from fastapi import APIRouter, HTTPException | |
from fastapi.responses import StreamingResponse | |
from .logging import logger | |
from .models import ChatCompletionsRequest, ChatMessage | |
from .reorder import reorder_messages_for_anthropic | |
from .helpers import normalize_content_to_list, segments_to_text | |
from .packets import packet_template, map_history_to_warp_messages, attach_user_and_tools_to_inputs | |
from .state import STATE | |
from .bridge import initialize_once, bridge_send_stream | |
from .sse_transform import stream_openai_sse | |
# 导入warp2protobuf模块,替代HTTP调用 | |
from warp2protobuf.config.models import get_all_unique_models | |
from warp2protobuf.core.auth import refresh_jwt_if_needed | |
router = APIRouter() | |
def root(): | |
return {"service": "OpenAI Chat Completions (Warp bridge) - Streaming", "status": "ok"} | |
def health_check(): | |
return {"status": "ok", "service": "OpenAI Chat Completions (Warp bridge) - Streaming"} | |
def list_models(): | |
"""OpenAI-compatible model listing. Direct call to get_all_unique_models.""" | |
try: | |
models = get_all_unique_models() | |
return {"object": "list", "data": models} | |
except Exception as e: | |
logger.error(f"❌ 获取模型列表失败: {e}") | |
raise HTTPException(500, f"获取模型列表失败: {str(e)}") | |
async def chat_completions(req: ChatCompletionsRequest): | |
try: | |
await initialize_once() | |
except Exception as e: | |
logger.warning(f"[OpenAI Compat] initialize_once failed or skipped: {e}") | |
if not req.messages: | |
raise HTTPException(400, "messages 不能为空") | |
# 1) 打印接收到的 Chat Completions 原始请求体 | |
try: | |
logger.info("[OpenAI Compat] 接收到的 Chat Completions 请求体(原始): %s", json.dumps(req.dict(), ensure_ascii=False)) | |
except Exception: | |
logger.info("[OpenAI Compat] 接收到的 Chat Completions 请求体(原始) 序列化失败") | |
# 整理消息 | |
history: List[ChatMessage] = reorder_messages_for_anthropic(list(req.messages)) | |
# 2) 打印整理后的请求体(post-reorder) | |
try: | |
logger.info("[OpenAI Compat] 整理后的请求体(post-reorder): %s", json.dumps({ | |
**req.dict(), | |
"messages": [m.dict() for m in history] | |
}, ensure_ascii=False)) | |
except Exception: | |
logger.info("[OpenAI Compat] 整理后的请求体(post-reorder) 序列化失败") | |
system_prompt_text: Optional[str] = None | |
try: | |
chunks: List[str] = [] | |
for _m in history: | |
if _m.role == "system": | |
_txt = segments_to_text(normalize_content_to_list(_m.content)) | |
if _txt.strip(): | |
chunks.append(_txt) | |
if chunks: | |
system_prompt_text = "\n\n".join(chunks) | |
except Exception: | |
system_prompt_text = None | |
task_id = STATE.baseline_task_id or str(uuid.uuid4()) | |
packet = packet_template() | |
packet["task_context"] = { | |
"tasks": [{ | |
"id": task_id, | |
"description": "", | |
"status": {"in_progress": {}}, | |
"messages": map_history_to_warp_messages(history, task_id, None, False), | |
}], | |
"active_task_id": task_id, | |
} | |
packet.setdefault("settings", {}).setdefault("model_config", {}) | |
packet["settings"]["model_config"]["base"] = req.model or packet["settings"]["model_config"].get("base") or "claude-4.1-opus" | |
if STATE.conversation_id: | |
packet.setdefault("metadata", {})["conversation_id"] = STATE.conversation_id | |
attach_user_and_tools_to_inputs(packet, history, system_prompt_text) | |
if req.tools: | |
mcp_tools: List[Dict[str, Any]] = [] | |
for t in req.tools: | |
if t.type != "function" or not t.function: | |
continue | |
mcp_tools.append({ | |
"name": t.function.name, | |
"description": t.function.description or "", | |
"input_schema": t.function.parameters or {}, | |
}) | |
if mcp_tools: | |
packet.setdefault("mcp_context", {}).setdefault("tools", []).extend(mcp_tools) | |
# 3) 打印转换成 protobuf JSON 的请求体(发送到 bridge 的数据包) | |
try: | |
logger.info("[OpenAI Compat] 转换成 Protobuf JSON 的请求体: %s", json.dumps(packet, ensure_ascii=False)) | |
except Exception: | |
logger.info("[OpenAI Compat] 转换成 Protobuf JSON 的请求体 序列化失败") | |
created_ts = int(time.time()) | |
completion_id = str(uuid.uuid4()) | |
model_id = req.model or "warp-default" | |
if req.stream: | |
async def _agen(): | |
async for chunk in stream_openai_sse(packet, completion_id, created_ts, model_id): | |
yield chunk | |
return StreamingResponse(_agen(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}) | |
try: | |
bridge_resp = await bridge_send_stream(packet) | |
except Exception as e: | |
# 如果是429错误(配额用尽),尝试刷新JWT | |
if "429" in str(e): | |
try: | |
await refresh_jwt_if_needed() | |
logger.warning("[OpenAI Compat] Tried JWT refresh after 429 error") | |
bridge_resp = await bridge_send_stream(packet) | |
except Exception as _e: | |
logger.warning("[OpenAI Compat] JWT refresh attempt failed after 429: %s", _e) | |
raise HTTPException(429, f"bridge_error: {e}") | |
else: | |
raise HTTPException(502, f"bridge_error: {e}") | |
try: | |
STATE.conversation_id = bridge_resp.get("conversation_id") or STATE.conversation_id | |
ret_task_id = bridge_resp.get("task_id") | |
if isinstance(ret_task_id, str) and ret_task_id: | |
STATE.baseline_task_id = ret_task_id | |
except Exception: | |
pass | |
tool_calls: List[Dict[str, Any]] = [] | |
try: | |
parsed_events = bridge_resp.get("parsed_events", []) or [] | |
for ev in parsed_events: | |
evd = ev.get("parsed_data") or ev.get("raw_data") or {} | |
client_actions = evd.get("client_actions") or evd.get("clientActions") or {} | |
actions = client_actions.get("actions") or client_actions.get("Actions") or [] | |
for action in actions: | |
add_msgs = action.get("add_messages_to_task") or action.get("addMessagesToTask") or {} | |
if not isinstance(add_msgs, dict): | |
continue | |
for message in add_msgs.get("messages", []) or []: | |
tc = message.get("tool_call") or message.get("toolCall") or {} | |
call_mcp = tc.get("call_mcp_tool") or tc.get("callMcpTool") or {} | |
if isinstance(call_mcp, dict) and call_mcp.get("name"): | |
try: | |
args_obj = call_mcp.get("args", {}) or {} | |
args_str = json.dumps(args_obj, ensure_ascii=False) | |
except Exception: | |
args_str = "{}" | |
tool_calls.append({ | |
"id": tc.get("tool_call_id") or str(uuid.uuid4()), | |
"type": "function", | |
"function": {"name": call_mcp.get("name"), "arguments": args_str}, | |
}) | |
except Exception: | |
pass | |
if tool_calls: | |
msg_payload = {"role": "assistant", "content": "", "tool_calls": tool_calls} | |
finish_reason = "tool_calls" | |
else: | |
response_text = bridge_resp.get("response", "") | |
msg_payload = {"role": "assistant", "content": response_text} | |
finish_reason = "stop" | |
final = { | |
"id": completion_id, | |
"object": "chat.completion", | |
"created": created_ts, | |
"model": model_id, | |
"choices": [{"index": 0, "message": msg_payload, "finish_reason": finish_reason}], | |
} | |
return final |