from __future__ import annotations import json import uuid from typing import Any, AsyncGenerator, Dict from .logging import logger from .helpers import _get # 导入内部模块,替代HTTP调用 from warp2protobuf.core.protobuf_utils import dict_to_protobuf_bytes from warp2protobuf.warp.api_client import send_protobuf_to_warp_api_parsed import httpx import os from warp2protobuf.core.protobuf_utils import protobuf_to_dict from warp2protobuf.core.auth import get_valid_jwt, acquire_anonymous_access_token, refresh_jwt_if_needed from warp2protobuf.config.settings import WARP_URL as CONFIG_WARP_URL from warp2protobuf.core.schema_sanitizer import sanitize_mcp_input_schema_in_packet def _get_event_type(event_data: dict) -> str: """获取事件类型""" if "init" in event_data: return "INIT" elif "client_actions" in event_data or "clientActions" in event_data: return "CLIENT_ACTIONS" elif "finished" in event_data: return "FINISHED" else: return "UNKNOWN_EVENT" async def stream_openai_sse(packet: Dict[str, Any], completion_id: str, created_ts: int, model_id: str) -> AsyncGenerator[str, None]: """使用直接模块调用实现的SSE流处理,替代HTTP调用""" try: # 发出首个OpenAI格式的SSE事件 first = { "id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": model_id, "choices": [{"index": 0, "delta": {"role": "assistant"}}], } logger.info("[OpenAI Compat] 转换后的 SSE(emit): %s", json.dumps(first, ensure_ascii=False)) yield f"data: {json.dumps(first, ensure_ascii=False)}\n\n" # 应用schema清理 wrapped = {"json_data": packet} wrapped = sanitize_mcp_input_schema_in_packet(wrapped) actual_data = wrapped.get("json_data", packet) # 转换为protobuf protobuf_bytes = dict_to_protobuf_bytes(actual_data, "warp.multi_agent.v1.Request") logger.info(f"[OpenAI Compat] JSON编码为protobuf成功: {len(protobuf_bytes)} 字节") tool_calls_emitted = False try: # 直接处理SSE流,实时返回事件 warp_url = CONFIG_WARP_URL verify_opt = True insecure_env = os.getenv("WARP_INSECURE_TLS", "").lower() if insecure_env in ("1", "true", "yes"): verify_opt = False logger.warning("TLS verification disabled via WARP_INSECURE_TLS for OpenAI SSE streaming") async with httpx.AsyncClient(http2=True, timeout=httpx.Timeout(60.0), verify=verify_opt, trust_env=True) as client: # 最多尝试两次:第一次失败且为401/429时尝试刷新token并重试一次 for attempt in range(2): jwt = await get_valid_jwt() if attempt == 0 else jwt headers = { "accept": "text/event-stream", "content-type": "application/x-protobuf", "x-warp-client-version": "v0.2025.08.06.08.12.stable_02", "x-warp-os-category": "Windows", "x-warp-os-name": "Windows", "x-warp-os-version": "11 (26100)", "authorization": f"Bearer {jwt}", "content-length": str(len(protobuf_bytes)), } async with client.stream("POST", warp_url, headers=headers, content=protobuf_bytes) as response: if response.status_code != 200: error_text = await response.aread() error_content = error_text.decode('utf-8') if error_text else "No error content" # 检测JWT token无效错误并在第一次失败时尝试刷新token if response.status_code == 401 and attempt == 0: logger.warning("WARP API 返回 401 (token无效, OpenAI SSE)。尝试刷新JWT token并重试一次…") try: refresh_success = await refresh_jwt_if_needed() if refresh_success: jwt = await get_valid_jwt() logger.info("JWT token刷新成功,重试API调用 (OpenAI SSE)") continue else: logger.warning("JWT token刷新失败,尝试申请匿名token (OpenAI SSE)") new_jwt = await acquire_anonymous_access_token() if new_jwt: jwt = new_jwt continue except Exception as e: logger.warning(f"JWT token刷新异常 (OpenAI SSE): {e}") # 检测配额耗尽错误并在第一次失败时尝试申请匿名token elif response.status_code == 429 and attempt == 0 and ( ("No remaining quota" in error_content) or ("No AI requests remaining" in error_content) ): logger.warning("WARP API 返回 429 (配额用尽, OpenAI SSE)。尝试申请匿名token并重试一次…") try: new_jwt = await acquire_anonymous_access_token() if new_jwt: jwt = new_jwt continue except Exception: pass # 其他错误或第二次失败 logger.error(f"WARP API HTTP ERROR (OpenAI SSE) {response.status_code}: {error_content}") raise Exception(f"Warp API Error (HTTP {response.status_code}): {error_content}") logger.info(f"✅ 收到HTTP {response.status_code}响应 (OpenAI SSE)") logger.info("开始实时处理SSE事件流...") import re as _re def _parse_payload_bytes(data_str: str): s = _re.sub(r"\s+", "", data_str or "") if not s: return None if _re.fullmatch(r"[0-9a-fA-F]+", s or ""): try: return bytes.fromhex(s) except Exception: pass pad = "=" * ((4 - (len(s) % 4)) % 4) try: import base64 as _b64 return _b64.urlsafe_b64decode(s + pad) except Exception: try: return _b64.b64decode(s + pad) except Exception: return None current_data = "" event_count = 0 async for line in response.aiter_lines(): if line.startswith("data:"): payload = line[5:].strip() if not payload: continue if payload == "[DONE]": logger.info("收到[DONE]标记,结束处理") break current_data += payload continue if (line.strip() == "") and current_data: raw_bytes = _parse_payload_bytes(current_data) current_data = "" if raw_bytes is None: continue try: event_data = protobuf_to_dict(raw_bytes, "warp.multi_agent.v1.ResponseEvent") event_count += 1 event_type = _get_event_type(event_data) logger.info(f"🔄 实时处理 Event #{event_count}: {event_type}") # 实时处理每个事件 async for chunk in _process_single_event(event_data, completion_id, created_ts, model_id, tool_calls_emitted): if chunk.get("tool_calls_emitted"): tool_calls_emitted = True if chunk.get("sse_data"): yield chunk["sse_data"] except Exception as parse_err: logger.debug(f"解析事件失败,跳过: {str(parse_err)[:100]}") continue logger.info(f"✅ 实时流处理完成,共处理 {event_count} 个事件") break # 成功处理,跳出重试循环 except Exception as e: logger.error(f"[OpenAI Compat] Stream processing failed: {e}") raise e # 发出完成标记 logger.info("[OpenAI Compat] 转换后的 SSE(emit): [DONE]") yield "data: [DONE]\n\n" except Exception as e: logger.error(f"[OpenAI Compat] Stream processing failed: {e}") error_chunk = { "id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": model_id, "choices": [{"index": 0, "delta": {}, "finish_reason": "error"}], "error": {"message": str(e)}, } logger.info("[OpenAI Compat] 转换后的 SSE(emit error): %s", json.dumps(error_chunk, ensure_ascii=False)) yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" async def _process_single_event(event_data: dict, completion_id: str, created_ts: int, model_id: str, tool_calls_emitted: bool) -> AsyncGenerator[dict, None]: """处理单个事件并生成SSE数据""" if "init" in event_data: return # 处理完成事件 if "finished" in event_data: done_chunk = { "id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": model_id, "choices": [{"index": 0, "delta": {}, "finish_reason": ("tool_calls" if tool_calls_emitted else "stop")}], } logger.info("[OpenAI Compat] 转换后的 SSE(emit done): %s", json.dumps(done_chunk, ensure_ascii=False)) yield {"sse_data": f"data: {json.dumps(done_chunk, ensure_ascii=False)}\n\n"} return client_actions = _get(event_data, "client_actions", "clientActions") if isinstance(client_actions, dict): actions = _get(client_actions, "actions", "Actions") or [] for action in actions: # 处理文本追加 append_data = _get(action, "append_to_message_content", "appendToMessageContent") if isinstance(append_data, dict): message = append_data.get("message", {}) agent_output = _get(message, "agent_output", "agentOutput") or {} text_content = agent_output.get("text", "") if text_content: delta = { "id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": model_id, "choices": [{"index": 0, "delta": {"content": text_content}}], } logger.info("[OpenAI Compat] 转换后的 SSE(emit): %s", json.dumps(delta, ensure_ascii=False)) yield {"sse_data": f"data: {json.dumps(delta, ensure_ascii=False)}\n\n"} # 处理消息添加 messages_data = _get(action, "add_messages_to_task", "addMessagesToTask") if isinstance(messages_data, dict): messages = messages_data.get("messages", []) for message in messages: # 处理工具调用 tool_call = _get(message, "tool_call", "toolCall") or {} call_mcp = _get(tool_call, "call_mcp_tool", "callMcpTool") or {} if isinstance(call_mcp, dict) and call_mcp.get("name"): try: args_obj = call_mcp.get("args", {}) or {} args_str = json.dumps(args_obj, ensure_ascii=False) except Exception: args_str = "{}" tool_call_id = tool_call.get("tool_call_id") or str(uuid.uuid4()) delta = { "id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": model_id, "choices": [{ "index": 0, "delta": { "tool_calls": [{ "index": 0, "id": tool_call_id, "type": "function", "function": {"name": call_mcp.get("name"), "arguments": args_str}, }] } }], } logger.info("[OpenAI Compat] 转换后的 SSE(emit tool_calls): %s", json.dumps(delta, ensure_ascii=False)) yield {"sse_data": f"data: {json.dumps(delta, ensure_ascii=False)}\n\n", "tool_calls_emitted": True} else: # 处理文本消息 agent_output = _get(message, "agent_output", "agentOutput") or {} text_content = agent_output.get("text", "") if text_content: delta = { "id": completion_id, "object": "chat.completion.chunk", "created": created_ts, "model": model_id, "choices": [{"index": 0, "delta": {"content": text_content}}], } logger.info("[OpenAI Compat] 转换后的 SSE(emit): %s", json.dumps(delta, ensure_ascii=False)) yield {"sse_data": f"data: {json.dumps(delta, ensure_ascii=False)}\n\n"}