Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Warp API客户端模块 | |
处理与Warp API的通信,包括protobuf数据发送和SSE响应解析。 | |
""" | |
import httpx | |
import os | |
import base64 | |
import binascii | |
from typing import Optional, Any, Dict | |
from urllib.parse import urlparse | |
import socket | |
from ..core.logging import logger | |
from ..core.protobuf_utils import protobuf_to_dict | |
from ..core.auth import get_valid_jwt, acquire_anonymous_access_token, refresh_jwt_if_needed | |
from ..config.settings import WARP_URL as CONFIG_WARP_URL | |
def _get(d: Dict[str, Any], *names: str) -> Any: | |
"""Return the first matching key value (camelCase/snake_case tolerant).""" | |
for name in names: | |
if name in d: | |
return d[name] | |
return None | |
def _get_event_type(event_data: dict) -> str: | |
"""Determine the type of SSE event for logging""" | |
if "init" in event_data: | |
return "INITIALIZATION" | |
client_actions = _get(event_data, "client_actions", "clientActions") | |
if isinstance(client_actions, dict): | |
actions = _get(client_actions, "actions", "Actions") or [] | |
if not actions: | |
return "CLIENT_ACTIONS_EMPTY" | |
action_types = [] | |
for action in actions: | |
if _get(action, "create_task", "createTask") is not None: | |
action_types.append("CREATE_TASK") | |
elif _get(action, "append_to_message_content", "appendToMessageContent") is not None: | |
action_types.append("APPEND_CONTENT") | |
elif _get(action, "add_messages_to_task", "addMessagesToTask") is not None: | |
action_types.append("ADD_MESSAGE") | |
elif _get(action, "tool_call", "toolCall") is not None: | |
action_types.append("TOOL_CALL") | |
elif _get(action, "tool_response", "toolResponse") is not None: | |
action_types.append("TOOL_RESPONSE") | |
else: | |
action_types.append("UNKNOWN_ACTION") | |
return f"CLIENT_ACTIONS({', '.join(action_types)})" | |
elif "finished" in event_data: | |
return "FINISHED" | |
else: | |
return "UNKNOWN_EVENT" | |
async def send_protobuf_to_warp_api( | |
protobuf_bytes: bytes, show_all_events: bool = True | |
) -> tuple[str, Optional[str], Optional[str]]: | |
"""发送protobuf数据到Warp API并获取响应""" | |
try: | |
logger.info(f"发送 {len(protobuf_bytes)} 字节到Warp API") | |
logger.info(f"数据包前32字节 (hex): {protobuf_bytes[:32].hex()}") | |
warp_url = CONFIG_WARP_URL | |
logger.info(f"发送请求到: {warp_url}") | |
conversation_id = None | |
task_id = None | |
complete_response = [] | |
all_events = [] | |
event_count = 0 | |
verify_opt = True | |
insecure_env = os.getenv("WARP_INSECURE_TLS", "").lower() | |
if insecure_env in ("1", "true", "yes"): | |
verify_opt = False | |
logger.warning("TLS verification disabled via WARP_INSECURE_TLS for Warp API client") | |
async with httpx.AsyncClient(http2=True, timeout=httpx.Timeout(60.0), verify=verify_opt, trust_env=True) as client: | |
# 最多尝试两次:第一次失败且为配额429时申请匿名token并重试一次 | |
for attempt in range(2): | |
jwt = await get_valid_jwt() if attempt == 0 else jwt # keep existing unless refreshed explicitly | |
headers = { | |
"accept": "text/event-stream", | |
"content-type": "application/x-protobuf", | |
"x-warp-client-version": "v0.2025.08.06.08.12.stable_02", | |
"x-warp-os-category": "Windows", | |
"x-warp-os-name": "Windows", | |
"x-warp-os-version": "11 (26100)", | |
"authorization": f"Bearer {jwt}", | |
"content-length": str(len(protobuf_bytes)), | |
} | |
async with client.stream("POST", warp_url, headers=headers, content=protobuf_bytes) as response: | |
if response.status_code != 200: | |
error_text = await response.aread() | |
error_content = error_text.decode('utf-8') if error_text else "No error content" | |
# 检测JWT token无效错误并在第一次失败时尝试刷新token | |
if response.status_code == 401 and attempt == 0: | |
logger.warning("WARP API 返回 401 (token无效)。尝试刷新JWT token并重试一次…") | |
try: | |
refresh_success = await refresh_jwt_if_needed() | |
if refresh_success: | |
jwt = await get_valid_jwt() | |
logger.info("JWT token刷新成功,重试API调用") | |
continue | |
else: | |
logger.warning("JWT token刷新失败,尝试申请匿名token") | |
new_jwt = await acquire_anonymous_access_token() | |
if new_jwt: | |
jwt = new_jwt | |
continue | |
except Exception as e: | |
logger.warning(f"JWT token刷新异常: {e}") | |
logger.error(f"WARP API HTTP ERROR {response.status_code}: {error_content}") | |
return f"❌ Warp API Error (HTTP {response.status_code}): {error_content}", None, None | |
# 检测配额耗尽错误并在第一次失败时尝试申请匿名token | |
if response.status_code == 429 and attempt == 0 and ( | |
("No remaining quota" in error_content) or ("No AI requests remaining" in error_content) | |
): | |
logger.warning("WARP API 返回 429 (配额用尽)。尝试申请匿名token并重试一次…") | |
try: | |
new_jwt = await acquire_anonymous_access_token() | |
except Exception: | |
new_jwt = None | |
if new_jwt: | |
jwt = new_jwt | |
# 跳出当前响应并进行下一次尝试 | |
continue | |
else: | |
logger.error("匿名token申请失败,无法重试。") | |
logger.error(f"WARP API HTTP ERROR {response.status_code}: {error_content}") | |
return f"❌ Warp API Error (HTTP {response.status_code}): {error_content}", None, None | |
# 其他错误或第二次失败 | |
logger.error(f"WARP API HTTP ERROR {response.status_code}: {error_content}") | |
return f"❌ Warp API Error (HTTP {response.status_code}): {error_content}", None, None | |
logger.info(f"✅ 收到HTTP {response.status_code}响应") | |
logger.info("开始处理SSE事件流...") | |
import re as _re | |
def _parse_payload_bytes(data_str: str): | |
s = _re.sub(r"\s+", "", data_str or "") | |
if not s: | |
return None | |
if _re.fullmatch(r"[0-9a-fA-F]+", s or ""): | |
try: | |
return bytes.fromhex(s) | |
except Exception: | |
pass | |
pad = "=" * ((4 - (len(s) % 4)) % 4) | |
try: | |
import base64 as _b64 | |
return _b64.urlsafe_b64decode(s + pad) | |
except Exception: | |
try: | |
return _b64.b64decode(s + pad) | |
except Exception: | |
return None | |
current_data = "" | |
async for line in response.aiter_lines(): | |
if line.startswith("data:"): | |
payload = line[5:].strip() | |
if not payload: | |
continue | |
if payload == "[DONE]": | |
logger.info("收到[DONE]标记,结束处理") | |
break | |
current_data += payload | |
continue | |
if (line.strip() == "") and current_data: | |
raw_bytes = _parse_payload_bytes(current_data) | |
current_data = "" | |
if raw_bytes is None: | |
logger.debug("跳过无法解析的SSE数据块(非hex/base64或不完整)") | |
continue | |
try: | |
event_data = protobuf_to_dict(raw_bytes, "warp.multi_agent.v1.ResponseEvent") | |
except Exception as parse_error: | |
logger.debug(f"解析事件失败,跳过: {str(parse_error)[:100]}") | |
continue | |
event_count += 1 | |
def _get(d: Dict[str, Any], *names: str) -> Any: | |
for n in names: | |
if isinstance(d, dict) and n in d: | |
return d[n] | |
return None | |
event_type = _get_event_type(event_data) | |
if show_all_events: | |
all_events.append({"event_number": event_count, "event_type": event_type, "raw_data": event_data}) | |
logger.info(f"🔄 Event #{event_count}: {event_type}") | |
if show_all_events: | |
logger.info(f" 📋 Event data: {str(event_data)}...") | |
if "init" in event_data: | |
init_data = event_data["init"] | |
conversation_id = init_data.get("conversation_id", conversation_id) | |
task_id = init_data.get("task_id", task_id) | |
logger.info(f"会话初始化: {conversation_id}") | |
client_actions = _get(event_data, "client_actions", "clientActions") | |
if isinstance(client_actions, dict): | |
actions = _get(client_actions, "actions", "Actions") or [] | |
for i, action in enumerate(actions): | |
logger.info(f" 🎯 Action #{i+1}: {list(action.keys())}") | |
append_data = _get(action, "append_to_message_content", "appendToMessageContent") | |
if isinstance(append_data, dict): | |
message = append_data.get("message", {}) | |
agent_output = _get(message, "agent_output", "agentOutput") or {} | |
text_content = agent_output.get("text", "") | |
if text_content: | |
complete_response.append(text_content) | |
logger.info(f" 📝 Text Fragment: {text_content[:100]}...") | |
messages_data = _get(action, "add_messages_to_task", "addMessagesToTask") | |
if isinstance(messages_data, dict): | |
messages = messages_data.get("messages", []) | |
task_id = messages_data.get("task_id", messages_data.get("taskId", task_id)) | |
for j, message in enumerate(messages): | |
logger.info(f" 📨 Message #{j+1}: {list(message.keys())}") | |
if _get(message, "agent_output", "agentOutput") is not None: | |
agent_output = _get(message, "agent_output", "agentOutput") or {} | |
text_content = agent_output.get("text", "") | |
if text_content: | |
complete_response.append(text_content) | |
logger.info(f" 📝 Complete Message: {text_content[:100]}...") | |
full_response = "".join(complete_response) | |
logger.info("="*60) | |
logger.info("📊 SSE STREAM SUMMARY") | |
logger.info("="*60) | |
logger.info(f"📈 Total Events Processed: {event_count}") | |
logger.info(f"🆔 Conversation ID: {conversation_id}") | |
logger.info(f"🆔 Task ID: {task_id}") | |
logger.info(f"📝 Response Length: {len(full_response)} characters") | |
logger.info("="*60) | |
if full_response: | |
logger.info(f"✅ Stream processing completed successfully") | |
return full_response, conversation_id, task_id | |
else: | |
logger.warning("⚠️ No text content received in response") | |
return "Warning: No response content received", conversation_id, task_id | |
except Exception as e: | |
import traceback | |
logger.error("="*60) | |
logger.error("WARP API CLIENT EXCEPTION") | |
logger.error("="*60) | |
logger.error(f"Exception Type: {type(e).__name__}") | |
logger.error(f"Exception Message: {str(e)}") | |
logger.error(f"Request URL: {warp_url if 'warp_url' in locals() else 'Unknown'}") | |
logger.error(f"Request Size: {len(protobuf_bytes) if 'protobuf_bytes' in locals() else 'Unknown'}") | |
logger.error("Python Traceback:") | |
logger.error(traceback.format_exc()) | |
logger.error("="*60) | |
raise | |
async def send_protobuf_to_warp_api_parsed(protobuf_bytes: bytes) -> tuple[str, Optional[str], Optional[str], list]: | |
"""发送protobuf数据到Warp API并获取解析后的SSE事件数据""" | |
try: | |
logger.info(f"发送 {len(protobuf_bytes)} 字节到Warp API (解析模式)") | |
logger.info(f"数据包前32字节 (hex): {protobuf_bytes[:32].hex()}") | |
warp_url = CONFIG_WARP_URL | |
logger.info(f"发送请求到: {warp_url}") | |
conversation_id = None | |
task_id = None | |
complete_response = [] | |
parsed_events = [] | |
event_count = 0 | |
verify_opt = True | |
insecure_env = os.getenv("WARP_INSECURE_TLS", "").lower() | |
if insecure_env in ("1", "true", "yes"): | |
verify_opt = False | |
logger.warning("TLS verification disabled via WARP_INSECURE_TLS for Warp API client") | |
async with httpx.AsyncClient(http2=True, timeout=httpx.Timeout(60.0), verify=verify_opt, trust_env=True) as client: | |
# 最多尝试两次:第一次失败且为配额429时申请匿名token并重试一次 | |
for attempt in range(2): | |
jwt = await get_valid_jwt() if attempt == 0 else jwt # keep existing unless refreshed explicitly | |
headers = { | |
"accept": "text/event-stream", | |
"content-type": "application/x-protobuf", | |
"x-warp-client-version": "v0.2025.08.06.08.12.stable_02", | |
"x-warp-os-category": "Windows", | |
"x-warp-os-name": "Windows", | |
"x-warp-os-version": "11 (26100)", | |
"authorization": f"Bearer {jwt}", | |
"content-length": str(len(protobuf_bytes)), | |
} | |
async with client.stream("POST", warp_url, headers=headers, content=protobuf_bytes) as response: | |
if response.status_code != 200: | |
error_text = await response.aread() | |
error_content = error_text.decode('utf-8') if error_text else "No error content" | |
# 检测JWT token无效错误并在第一次失败时尝试刷新token | |
if response.status_code == 401 and attempt == 0: | |
logger.warning("WARP API 返回 401 (token无效, 解析模式)。尝试刷新JWT token并重试一次…") | |
try: | |
refresh_success = await refresh_jwt_if_needed() | |
if refresh_success: | |
jwt = await get_valid_jwt() | |
logger.info("JWT token刷新成功,重试API调用 (解析模式)") | |
continue | |
else: | |
logger.warning("JWT token刷新失败,尝试申请匿名token (解析模式)") | |
new_jwt = await acquire_anonymous_access_token() | |
if new_jwt: | |
jwt = new_jwt | |
continue | |
except Exception as e: | |
logger.warning(f"JWT token刷新异常 (解析模式): {e}") | |
logger.error(f"WARP API HTTP ERROR (解析模式) {response.status_code}: {error_content}") | |
return f"❌ Warp API Error (HTTP {response.status_code}): {error_content}", None, None, [] | |
# 检测配额耗尽错误并在第一次失败时尝试申请匿名token | |
if response.status_code == 429 and attempt == 0 and ( | |
("No remaining quota" in error_content) or ("No AI requests remaining" in error_content) | |
): | |
logger.warning("WARP API 返回 429 (配额用尽, 解析模式)。尝试申请匿名token并重试一次…") | |
try: | |
new_jwt = await acquire_anonymous_access_token() | |
except Exception: | |
new_jwt = None | |
if new_jwt: | |
jwt = new_jwt | |
# 跳出当前响应并进行下一次尝试 | |
continue | |
else: | |
logger.error("匿名token申请失败,无法重试 (解析模式)。") | |
logger.error(f"WARP API HTTP ERROR (解析模式) {response.status_code}: {error_content}") | |
return f"❌ Warp API Error (HTTP {response.status_code}): {error_content}", None, None, [] | |
# 其他错误或第二次失败 | |
logger.error(f"WARP API HTTP ERROR (解析模式) {response.status_code}: {error_content}") | |
return f"❌ Warp API Error (HTTP {response.status_code}): {error_content}", None, None, [] | |
logger.info(f"✅ 收到HTTP {response.status_code}响应 (解析模式)") | |
logger.info("开始处理SSE事件流...") | |
import re as _re2 | |
def _parse_payload_bytes2(data_str: str): | |
s = _re2.sub(r"\s+", "", data_str or "") | |
if not s: | |
return None | |
if _re2.fullmatch(r"[0-9a-fA-F]+", s or ""): | |
try: | |
return bytes.fromhex(s) | |
except Exception: | |
pass | |
pad = "=" * ((4 - (len(s) % 4)) % 4) | |
try: | |
import base64 as _b642 | |
return _b642.urlsafe_b64decode(s + pad) | |
except Exception: | |
try: | |
return _b642.b64decode(s + pad) | |
except Exception: | |
return None | |
current_data = "" | |
async for line in response.aiter_lines(): | |
if line.startswith("data:"): | |
payload = line[5:].strip() | |
if not payload: | |
continue | |
if payload == "[DONE]": | |
logger.info("收到[DONE]标记,结束处理") | |
break | |
current_data += payload | |
continue | |
if (line.strip() == "") and current_data: | |
raw_bytes = _parse_payload_bytes2(current_data) | |
current_data = "" | |
if raw_bytes is None: | |
logger.debug("跳过无法解析的SSE数据块(非hex/base64或不完整)") | |
continue | |
try: | |
event_data = protobuf_to_dict(raw_bytes, "warp.multi_agent.v1.ResponseEvent") | |
event_count += 1 | |
event_type = _get_event_type(event_data) | |
parsed_event = {"event_number": event_count, "event_type": event_type, "parsed_data": event_data} | |
parsed_events.append(parsed_event) | |
logger.info(f"🔄 Event #{event_count}: {event_type}") | |
logger.debug(f" 📋 Event data: {str(event_data)}...") | |
def _get(d: Dict[str, Any], *names: str) -> Any: | |
for n in names: | |
if isinstance(d, dict) and n in d: | |
return d[n] | |
return None | |
if "init" in event_data: | |
init_data = event_data["init"] | |
conversation_id = init_data.get("conversation_id", conversation_id) | |
task_id = init_data.get("task_id", task_id) | |
logger.info(f"会话初始化: {conversation_id}") | |
client_actions = _get(event_data, "client_actions", "clientActions") | |
if isinstance(client_actions, dict): | |
actions = _get(client_actions, "actions", "Actions") or [] | |
for i, action in enumerate(actions): | |
logger.info(f" 🎯 Action #{i+1}: {list(action.keys())}") | |
append_data = _get(action, "append_to_message_content", "appendToMessageContent") | |
if isinstance(append_data, dict): | |
message = append_data.get("message", {}) | |
agent_output = _get(message, "agent_output", "agentOutput") or {} | |
text_content = agent_output.get("text", "") | |
if text_content: | |
complete_response.append(text_content) | |
logger.info(f" 📝 Text Fragment: {text_content[:100]}...") | |
messages_data = _get(action, "add_messages_to_task", "addMessagesToTask") | |
if isinstance(messages_data, dict): | |
messages = messages_data.get("messages", []) | |
task_id = messages_data.get("task_id", messages_data.get("taskId", task_id)) | |
for j, message in enumerate(messages): | |
logger.info(f" 📨 Message #{j+1}: {list(message.keys())}") | |
if _get(message, "agent_output", "agentOutput") is not None: | |
agent_output = _get(message, "agent_output", "agentOutput") or {} | |
text_content = agent_output.get("text", "") | |
if text_content: | |
complete_response.append(text_content) | |
logger.info(f" 📝 Complete Message: {text_content[:100]}...") | |
except Exception as parse_err: | |
logger.debug(f"解析事件失败,跳过: {str(parse_err)[:100]}") | |
continue | |
full_response = "".join(complete_response) | |
logger.info("="*60) | |
logger.info("📊 SSE STREAM SUMMARY (解析模式)") | |
logger.info("="*60) | |
logger.info(f"📈 Total Events Processed: {event_count}") | |
logger.info(f"🆔 Conversation ID: {conversation_id}") | |
logger.info(f"🆔 Task ID: {task_id}") | |
logger.info(f"📝 Response Length: {len(full_response)} characters") | |
logger.info(f"🎯 Parsed Events Count: {len(parsed_events)}") | |
logger.info("="*60) | |
logger.info(f"✅ Stream processing completed successfully (解析模式)") | |
return full_response, conversation_id, task_id, parsed_events | |
except Exception as e: | |
import traceback | |
logger.error("="*60) | |
logger.error("WARP API CLIENT EXCEPTION (解析模式)") | |
logger.error("="*60) | |
logger.error(f"Exception Type: {type(e).__name__}") | |
logger.error(f"Exception Message: {str(e)}") | |
logger.error(f"Request URL: {warp_url if 'warp_url' in locals() else 'Unknown'}") | |
logger.error(f"Request Size: {len(protobuf_bytes) if 'protobuf_bytes' in locals() else 'Unknown'}") | |
logger.error("Python Traceback:") | |
logger.error(traceback.format_exc()) | |
logger.error("="*60) | |
raise |