File size: 28,005 Bytes
9314c03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

Warp API客户端模块



处理与Warp API的通信,包括protobuf数据发送和SSE响应解析。

"""
import httpx
import os
import base64
import binascii
from typing import Optional, Any, Dict
from urllib.parse import urlparse
import socket

from ..core.logging import logger
from ..core.protobuf_utils import protobuf_to_dict
from ..core.auth import get_valid_jwt, acquire_anonymous_access_token, refresh_jwt_if_needed
from ..config.settings import WARP_URL as CONFIG_WARP_URL


def _get(d: Dict[str, Any], *names: str) -> Any:
    """Return the first matching key value (camelCase/snake_case tolerant)."""
    for name in names:
        if name in d:
            return d[name]
    return None


def _get_event_type(event_data: dict) -> str:
    """Determine the type of SSE event for logging"""
    if "init" in event_data:
        return "INITIALIZATION"
    client_actions = _get(event_data, "client_actions", "clientActions")
    if isinstance(client_actions, dict):
        actions = _get(client_actions, "actions", "Actions") or []
        if not actions:
            return "CLIENT_ACTIONS_EMPTY"
        
        action_types = []
        for action in actions:
            if _get(action, "create_task", "createTask") is not None:
                action_types.append("CREATE_TASK")
            elif _get(action, "append_to_message_content", "appendToMessageContent") is not None:
                action_types.append("APPEND_CONTENT")
            elif _get(action, "add_messages_to_task", "addMessagesToTask") is not None:
                action_types.append("ADD_MESSAGE")
            elif _get(action, "tool_call", "toolCall") is not None:
                action_types.append("TOOL_CALL")
            elif _get(action, "tool_response", "toolResponse") is not None:
                action_types.append("TOOL_RESPONSE")
            else:
                action_types.append("UNKNOWN_ACTION")
        
        return f"CLIENT_ACTIONS({', '.join(action_types)})"
    elif "finished" in event_data:
        return "FINISHED"
    else:
        return "UNKNOWN_EVENT"


async def send_protobuf_to_warp_api(

    protobuf_bytes: bytes, show_all_events: bool = True

) -> tuple[str, Optional[str], Optional[str]]:
    """发送protobuf数据到Warp API并获取响应"""
    try:
        logger.info(f"发送 {len(protobuf_bytes)} 字节到Warp API")
        logger.info(f"数据包前32字节 (hex): {protobuf_bytes[:32].hex()}")
        
        warp_url = CONFIG_WARP_URL
        
        logger.info(f"发送请求到: {warp_url}")
        
        conversation_id = None
        task_id = None
        complete_response = []
        all_events = []
        event_count = 0
        
        verify_opt = True
        insecure_env = os.getenv("WARP_INSECURE_TLS", "").lower()
        if insecure_env in ("1", "true", "yes"):
            verify_opt = False
            logger.warning("TLS verification disabled via WARP_INSECURE_TLS for Warp API client")

        async with httpx.AsyncClient(http2=True, timeout=httpx.Timeout(60.0), verify=verify_opt, trust_env=True) as client:
            # 最多尝试两次:第一次失败且为配额429时申请匿名token并重试一次
            for attempt in range(2):
                jwt = await get_valid_jwt() if attempt == 0 else jwt  # keep existing unless refreshed explicitly
                headers = {
                    "accept": "text/event-stream",
                    "content-type": "application/x-protobuf", 
                    "x-warp-client-version": "v0.2025.08.06.08.12.stable_02",
                    "x-warp-os-category": "Windows",
                    "x-warp-os-name": "Windows", 
                    "x-warp-os-version": "11 (26100)",
                    "authorization": f"Bearer {jwt}",
                    "content-length": str(len(protobuf_bytes)),
                }
                async with client.stream("POST", warp_url, headers=headers, content=protobuf_bytes) as response:
                    if response.status_code != 200:
                        error_text = await response.aread()
                        error_content = error_text.decode('utf-8') if error_text else "No error content"

                        # 检测JWT token无效错误并在第一次失败时尝试刷新token
                        if response.status_code == 401 and attempt == 0:
                            logger.warning("WARP API 返回 401 (token无效)。尝试刷新JWT token并重试一次…")
                            try:
                                refresh_success = await refresh_jwt_if_needed()
                                if refresh_success:
                                    jwt = await get_valid_jwt()
                                    logger.info("JWT token刷新成功,重试API调用")
                                    continue
                                else:
                                    logger.warning("JWT token刷新失败,尝试申请匿名token")
                                    new_jwt = await acquire_anonymous_access_token()
                                    if new_jwt:
                                        jwt = new_jwt
                                        continue
                            except Exception as e:
                                logger.warning(f"JWT token刷新异常: {e}")
                            logger.error(f"WARP API HTTP ERROR {response.status_code}: {error_content}")
                            return f"❌ Warp API Error (HTTP {response.status_code}): {error_content}", None, None

                        # 检测配额耗尽错误并在第一次失败时尝试申请匿名token
                        if response.status_code == 429 and attempt == 0 and (
                            ("No remaining quota" in error_content) or ("No AI requests remaining" in error_content)
                        ):
                            logger.warning("WARP API 返回 429 (配额用尽)。尝试申请匿名token并重试一次…")
                            try:
                                new_jwt = await acquire_anonymous_access_token()
                            except Exception:
                                new_jwt = None
                            if new_jwt:
                                jwt = new_jwt
                                # 跳出当前响应并进行下一次尝试
                                continue
                            else:
                                logger.error("匿名token申请失败,无法重试。")
                                logger.error(f"WARP API HTTP ERROR {response.status_code}: {error_content}")
                                return f"❌ Warp API Error (HTTP {response.status_code}): {error_content}", None, None
                        # 其他错误或第二次失败
                        logger.error(f"WARP API HTTP ERROR {response.status_code}: {error_content}")
                        return f"❌ Warp API Error (HTTP {response.status_code}): {error_content}", None, None
                    
                    logger.info(f"✅ 收到HTTP {response.status_code}响应")
                    logger.info("开始处理SSE事件流...")
                    
                    import re as _re
                    def _parse_payload_bytes(data_str: str):
                        s = _re.sub(r"\s+", "", data_str or "")
                        if not s:
                            return None
                        if _re.fullmatch(r"[0-9a-fA-F]+", s or ""):
                            try:
                                return bytes.fromhex(s)
                            except Exception:
                                pass
                        pad = "=" * ((4 - (len(s) % 4)) % 4)
                        try:
                            import base64 as _b64
                            return _b64.urlsafe_b64decode(s + pad)
                        except Exception:
                            try:
                                return _b64.b64decode(s + pad)
                            except Exception:
                                return None
                    
                    current_data = ""
                    
                    async for line in response.aiter_lines():
                        if line.startswith("data:"):
                            payload = line[5:].strip()
                            if not payload:
                                continue
                            if payload == "[DONE]":
                                logger.info("收到[DONE]标记,结束处理")
                                break
                            current_data += payload
                            continue
                        
                        if (line.strip() == "") and current_data:
                            raw_bytes = _parse_payload_bytes(current_data)
                            current_data = ""
                            if raw_bytes is None:
                                logger.debug("跳过无法解析的SSE数据块(非hex/base64或不完整)")
                                continue
                            try:
                                event_data = protobuf_to_dict(raw_bytes, "warp.multi_agent.v1.ResponseEvent")
                            except Exception as parse_error:
                                logger.debug(f"解析事件失败,跳过: {str(parse_error)[:100]}")
                                continue
                            event_count += 1
                            
                            def _get(d: Dict[str, Any], *names: str) -> Any:
                                for n in names:
                                    if isinstance(d, dict) and n in d:
                                        return d[n]
                                return None
                            
                            event_type = _get_event_type(event_data)
                            if show_all_events:
                                all_events.append({"event_number": event_count, "event_type": event_type, "raw_data": event_data})
                            logger.info(f"🔄 Event #{event_count}: {event_type}")
                            if show_all_events:
                                logger.info(f"   📋 Event data: {str(event_data)}...")
                            
                            if "init" in event_data:
                                init_data = event_data["init"]
                                conversation_id = init_data.get("conversation_id", conversation_id)
                                task_id = init_data.get("task_id", task_id)
                                logger.info(f"会话初始化: {conversation_id}")
                                client_actions = _get(event_data, "client_actions", "clientActions")
                                if isinstance(client_actions, dict):
                                    actions = _get(client_actions, "actions", "Actions") or []
                                    for i, action in enumerate(actions):
                                        logger.info(f"   🎯 Action #{i+1}: {list(action.keys())}")
                                        append_data = _get(action, "append_to_message_content", "appendToMessageContent")
                                        if isinstance(append_data, dict):
                                            message = append_data.get("message", {})
                                            agent_output = _get(message, "agent_output", "agentOutput") or {}
                                            text_content = agent_output.get("text", "")
                                            if text_content:
                                                complete_response.append(text_content)
                                                logger.info(f"   📝 Text Fragment: {text_content[:100]}...")
                                        messages_data = _get(action, "add_messages_to_task", "addMessagesToTask")
                                        if isinstance(messages_data, dict):
                                            messages = messages_data.get("messages", [])
                                            task_id = messages_data.get("task_id", messages_data.get("taskId", task_id))
                                            for j, message in enumerate(messages):
                                                logger.info(f"   📨 Message #{j+1}: {list(message.keys())}")
                                                if _get(message, "agent_output", "agentOutput") is not None:
                                                    agent_output = _get(message, "agent_output", "agentOutput") or {}
                                                    text_content = agent_output.get("text", "")
                                                    if text_content:
                                                        complete_response.append(text_content)
                                                        logger.info(f"   📝 Complete Message: {text_content[:100]}...")
                    
                    full_response = "".join(complete_response)
                    logger.info("="*60)
                    logger.info("📊 SSE STREAM SUMMARY")
                    logger.info("="*60)
                    logger.info(f"📈 Total Events Processed: {event_count}")
                    logger.info(f"🆔 Conversation ID: {conversation_id}")
                    logger.info(f"🆔 Task ID: {task_id}")
                    logger.info(f"📝 Response Length: {len(full_response)} characters")
                    logger.info("="*60)
                    if full_response:
                        logger.info(f"✅ Stream processing completed successfully")
                        return full_response, conversation_id, task_id
                    else:
                        logger.warning("⚠️ No text content received in response")
                        return "Warning: No response content received", conversation_id, task_id
    except Exception as e:
        import traceback
        logger.error("="*60)
        logger.error("WARP API CLIENT EXCEPTION")
        logger.error("="*60)
        logger.error(f"Exception Type: {type(e).__name__}")
        logger.error(f"Exception Message: {str(e)}")
        logger.error(f"Request URL: {warp_url if 'warp_url' in locals() else 'Unknown'}")
        logger.error(f"Request Size: {len(protobuf_bytes) if 'protobuf_bytes' in locals() else 'Unknown'}")
        logger.error("Python Traceback:")
        logger.error(traceback.format_exc())
        logger.error("="*60)
        raise


async def send_protobuf_to_warp_api_parsed(protobuf_bytes: bytes) -> tuple[str, Optional[str], Optional[str], list]:
    """发送protobuf数据到Warp API并获取解析后的SSE事件数据"""
    try:
        logger.info(f"发送 {len(protobuf_bytes)} 字节到Warp API (解析模式)")
        logger.info(f"数据包前32字节 (hex): {protobuf_bytes[:32].hex()}")
        
        warp_url = CONFIG_WARP_URL
        
        logger.info(f"发送请求到: {warp_url}")
        
        conversation_id = None
        task_id = None
        complete_response = []
        parsed_events = []
        event_count = 0
        
        verify_opt = True
        insecure_env = os.getenv("WARP_INSECURE_TLS", "").lower()
        if insecure_env in ("1", "true", "yes"):
            verify_opt = False
            logger.warning("TLS verification disabled via WARP_INSECURE_TLS for Warp API client")

        async with httpx.AsyncClient(http2=True, timeout=httpx.Timeout(60.0), verify=verify_opt, trust_env=True) as client:
            # 最多尝试两次:第一次失败且为配额429时申请匿名token并重试一次
            for attempt in range(2):
                jwt = await get_valid_jwt() if attempt == 0 else jwt  # keep existing unless refreshed explicitly
                headers = {
                    "accept": "text/event-stream",
                    "content-type": "application/x-protobuf", 
                    "x-warp-client-version": "v0.2025.08.06.08.12.stable_02",
                    "x-warp-os-category": "Windows",
                    "x-warp-os-name": "Windows", 
                    "x-warp-os-version": "11 (26100)",
                    "authorization": f"Bearer {jwt}",
                    "content-length": str(len(protobuf_bytes)),
                }
                async with client.stream("POST", warp_url, headers=headers, content=protobuf_bytes) as response:
                    if response.status_code != 200:
                        error_text = await response.aread()
                        error_content = error_text.decode('utf-8') if error_text else "No error content"

                        # 检测JWT token无效错误并在第一次失败时尝试刷新token
                        if response.status_code == 401 and attempt == 0:
                            logger.warning("WARP API 返回 401 (token无效, 解析模式)。尝试刷新JWT token并重试一次…")
                            try:
                                refresh_success = await refresh_jwt_if_needed()
                                if refresh_success:
                                    jwt = await get_valid_jwt()
                                    logger.info("JWT token刷新成功,重试API调用 (解析模式)")
                                    continue
                                else:
                                    logger.warning("JWT token刷新失败,尝试申请匿名token (解析模式)")
                                    new_jwt = await acquire_anonymous_access_token()
                                    if new_jwt:
                                        jwt = new_jwt
                                        continue
                            except Exception as e:
                                logger.warning(f"JWT token刷新异常 (解析模式): {e}")
                            logger.error(f"WARP API HTTP ERROR (解析模式) {response.status_code}: {error_content}")
                            return f"❌ Warp API Error (HTTP {response.status_code}): {error_content}", None, None, []

                        # 检测配额耗尽错误并在第一次失败时尝试申请匿名token
                        if response.status_code == 429 and attempt == 0 and (
                            ("No remaining quota" in error_content) or ("No AI requests remaining" in error_content)
                        ):
                            logger.warning("WARP API 返回 429 (配额用尽, 解析模式)。尝试申请匿名token并重试一次…")
                            try:
                                new_jwt = await acquire_anonymous_access_token()
                            except Exception:
                                new_jwt = None
                            if new_jwt:
                                jwt = new_jwt
                                # 跳出当前响应并进行下一次尝试
                                continue
                            else:
                                logger.error("匿名token申请失败,无法重试 (解析模式)。")
                                logger.error(f"WARP API HTTP ERROR (解析模式) {response.status_code}: {error_content}")
                                return f"❌ Warp API Error (HTTP {response.status_code}): {error_content}", None, None, []
                        # 其他错误或第二次失败
                        logger.error(f"WARP API HTTP ERROR (解析模式) {response.status_code}: {error_content}")
                        return f"❌ Warp API Error (HTTP {response.status_code}): {error_content}", None, None, []
                    
                    logger.info(f"✅ 收到HTTP {response.status_code}响应 (解析模式)")
                    logger.info("开始处理SSE事件流...")
                    
                    import re as _re2
                    def _parse_payload_bytes2(data_str: str):
                        s = _re2.sub(r"\s+", "", data_str or "")
                        if not s:
                            return None
                        if _re2.fullmatch(r"[0-9a-fA-F]+", s or ""):
                            try:
                                return bytes.fromhex(s)
                            except Exception:
                                pass
                        pad = "=" * ((4 - (len(s) % 4)) % 4)
                        try:
                            import base64 as _b642
                            return _b642.urlsafe_b64decode(s + pad)
                        except Exception:
                            try:
                                return _b642.b64decode(s + pad)
                            except Exception:
                                return None
                    
                    current_data = ""
                    
                    async for line in response.aiter_lines():
                        if line.startswith("data:"):
                            payload = line[5:].strip()
                            if not payload:
                                continue
                            if payload == "[DONE]":
                                logger.info("收到[DONE]标记,结束处理")
                                break
                            current_data += payload
                            continue
                        
                        if (line.strip() == "") and current_data:
                            raw_bytes = _parse_payload_bytes2(current_data)
                            current_data = ""
                            if raw_bytes is None:
                                logger.debug("跳过无法解析的SSE数据块(非hex/base64或不完整)")
                                continue
                            try:
                                event_data = protobuf_to_dict(raw_bytes, "warp.multi_agent.v1.ResponseEvent")
                                event_count += 1
                                event_type = _get_event_type(event_data)
                                parsed_event = {"event_number": event_count, "event_type": event_type, "parsed_data": event_data}
                                parsed_events.append(parsed_event)
                                logger.info(f"🔄 Event #{event_count}: {event_type}")
                                logger.debug(f"   📋 Event data: {str(event_data)}...")
                                
                                def _get(d: Dict[str, Any], *names: str) -> Any:
                                    for n in names:
                                        if isinstance(d, dict) and n in d:
                                            return d[n]
                                    return None
                                
                                if "init" in event_data:
                                    init_data = event_data["init"]
                                    conversation_id = init_data.get("conversation_id", conversation_id)
                                    task_id = init_data.get("task_id", task_id)
                                    logger.info(f"会话初始化: {conversation_id}")
                                
                                client_actions = _get(event_data, "client_actions", "clientActions")
                                if isinstance(client_actions, dict):
                                    actions = _get(client_actions, "actions", "Actions") or []
                                    for i, action in enumerate(actions):
                                        logger.info(f"   🎯 Action #{i+1}: {list(action.keys())}")
                                        append_data = _get(action, "append_to_message_content", "appendToMessageContent")
                                        if isinstance(append_data, dict):
                                            message = append_data.get("message", {})
                                            agent_output = _get(message, "agent_output", "agentOutput") or {}
                                            text_content = agent_output.get("text", "")
                                            if text_content:
                                                complete_response.append(text_content)
                                                logger.info(f"   📝 Text Fragment: {text_content[:100]}...")
                                        messages_data = _get(action, "add_messages_to_task", "addMessagesToTask")
                                        if isinstance(messages_data, dict):
                                            messages = messages_data.get("messages", [])
                                            task_id = messages_data.get("task_id", messages_data.get("taskId", task_id))
                                            for j, message in enumerate(messages):
                                                logger.info(f"   📨 Message #{j+1}: {list(message.keys())}")
                                                if _get(message, "agent_output", "agentOutput") is not None:
                                                    agent_output = _get(message, "agent_output", "agentOutput") or {}
                                                    text_content = agent_output.get("text", "")
                                                    if text_content:
                                                        complete_response.append(text_content)
                                                        logger.info(f"   📝 Complete Message: {text_content[:100]}...")
                            except Exception as parse_err:
                                logger.debug(f"解析事件失败,跳过: {str(parse_err)[:100]}")
                                continue
                    
                    full_response = "".join(complete_response)
                    logger.info("="*60)
                    logger.info("📊 SSE STREAM SUMMARY (解析模式)")
                    logger.info("="*60)
                    logger.info(f"📈 Total Events Processed: {event_count}")
                    logger.info(f"🆔 Conversation ID: {conversation_id}")
                    logger.info(f"🆔 Task ID: {task_id}")
                    logger.info(f"📝 Response Length: {len(full_response)} characters")
                    logger.info(f"🎯 Parsed Events Count: {len(parsed_events)}")
                    logger.info("="*60)
                    
                    logger.info(f"✅ Stream processing completed successfully (解析模式)")
                    return full_response, conversation_id, task_id, parsed_events
    except Exception as e:
        import traceback
        logger.error("="*60)
        logger.error("WARP API CLIENT EXCEPTION (解析模式)")
        logger.error("="*60)
        logger.error(f"Exception Type: {type(e).__name__}")
        logger.error(f"Exception Message: {str(e)}")
        logger.error(f"Request URL: {warp_url if 'warp_url' in locals() else 'Unknown'}")
        logger.error(f"Request Size: {len(protobuf_bytes) if 'protobuf_bytes' in locals() else 'Unknown'}")
        logger.error("Python Traceback:")
        logger.error(traceback.format_exc())
        logger.error("="*60)
        raise