File size: 12,387 Bytes
9314c03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

流式数据包处理器



处理流式protobuf数据包,支持实时解析和WebSocket推送。

"""
import asyncio
import json
import base64
from typing import AsyncGenerator, List, Dict, Any, Optional
from datetime import datetime

from .logging import logger
from .protobuf_utils import protobuf_to_dict


class StreamProcessor:
    """流式数据包处理器"""
    
    def __init__(self, websocket_manager=None):
        self.websocket_manager = websocket_manager
        self.active_streams: Dict[str, StreamSession] = {}
        
    async def create_stream_session(self, stream_id: str, message_type: str = "warp.multi_agent.v1.Response") -> 'StreamSession':
        """创建流式会话"""
        session = StreamSession(stream_id, message_type, self.websocket_manager)
        self.active_streams[stream_id] = session
        
        logger.info(f"创建流式会话: {stream_id}, 消息类型: {message_type}")
        return session
    
    async def get_stream_session(self, stream_id: str) -> Optional['StreamSession']:
        """获取流式会话"""
        return self.active_streams.get(stream_id)
    
    async def close_stream_session(self, stream_id: str):
        """关闭流式会话"""
        if stream_id in self.active_streams:
            session = self.active_streams[stream_id]
            await session.close()
            del self.active_streams[stream_id]
            logger.info(f"关闭流式会话: {stream_id}")
    
    async def process_stream_chunk(self, stream_id: str, chunk_data: bytes) -> Dict[str, Any]:
        """处理流式数据块"""
        session = await self.get_stream_session(stream_id)
        if not session:
            raise ValueError(f"流式会话不存在: {stream_id}")
        
        return await session.process_chunk(chunk_data)
    
    async def finalize_stream(self, stream_id: str) -> Dict[str, Any]:
        """完成流式处理"""
        session = await self.get_stream_session(stream_id)
        if not session:
            raise ValueError(f"流式会话不存在: {stream_id}")
        
        result = await session.finalize()
        await self.close_stream_session(stream_id)
        return result


class StreamSession:
    """流式会话"""
    
    def __init__(self, session_id: str, message_type: str, websocket_manager=None):
        self.session_id = session_id
        self.message_type = message_type
        self.websocket_manager = websocket_manager
        
        self.chunks: List[bytes] = []
        self.chunk_count = 0
        self.total_size = 0
        self.start_time = datetime.now()
        
        self.parsed_chunks: List[Dict] = []
        self.complete_message: Optional[Dict] = None
        
    async def process_chunk(self, chunk_data: bytes) -> Dict[str, Any]:
        """处理单个数据块"""
        self.chunk_count += 1
        self.total_size += len(chunk_data)
        self.chunks.append(chunk_data)
        
        logger.debug(f"流式会话 {self.session_id}: 处理数据块 {self.chunk_count}, 大小 {len(chunk_data)} 字节")
        
        chunk_result = {
            "chunk_index": self.chunk_count - 1,
            "size": len(chunk_data),
            "timestamp": datetime.now().isoformat()
        }
        
        try:
            chunk_json = protobuf_to_dict(chunk_data, self.message_type)
            chunk_result["json_data"] = chunk_json
            chunk_result["parsed_successfully"] = True
            
            self.parsed_chunks.append(chunk_json)
            
            if self.websocket_manager:
                await self.websocket_manager.broadcast({
                    "event": "stream_chunk_parsed",
                    "stream_id": self.session_id,
                    "chunk": chunk_result
                })
                
        except Exception as e:
            chunk_result["error"] = str(e)
            chunk_result["parsed_successfully"] = False
            logger.warning(f"数据块解析失败: {e}")
            
            if self.websocket_manager:
                await self.websocket_manager.broadcast({
                    "event": "stream_chunk_error", 
                    "stream_id": self.session_id,
                    "chunk": chunk_result
                })
        
        return chunk_result
    
    async def finalize(self) -> Dict[str, Any]:
        """完成流式处理,尝试拼接完整消息"""
        duration = (datetime.now() - self.start_time).total_seconds()
        
        logger.info(f"流式会话 {self.session_id} 完成: {self.chunk_count} 块, 总大小 {self.total_size} 字节, 耗时 {duration:.2f}s")
        
        result = {
            "session_id": self.session_id,
            "chunk_count": self.chunk_count,
            "total_size": self.total_size,
            "duration_seconds": duration,
            "chunks": []
        }
        
        for i, chunk in enumerate(self.chunks):
            chunk_info = {
                "index": i,
                "size": len(chunk),
                "hex_preview": chunk[:32].hex() if len(chunk) >= 32 else chunk.hex()
            }
            
            if i < len(self.parsed_chunks):
                chunk_info["parsed_data"] = self.parsed_chunks[i]
            
            result["chunks"].append(chunk_info)
        
        try:
            complete_data = b''.join(self.chunks)
            complete_json = protobuf_to_dict(complete_data, self.message_type)
            
            result["complete_message"] = {
                "size": len(complete_data),
                "json_data": complete_json,
                "assembly_successful": True
            }
            
            self.complete_message = complete_json
            
            logger.info(f"流式消息拼接成功: {len(complete_data)} 字节")
            
        except Exception as e:
            result["complete_message"] = {
                "error": str(e),
                "assembly_successful": False
            }
            logger.warning(f"流式消息拼接失败: {e}")
        
        if self.websocket_manager:
            await self.websocket_manager.broadcast({
                "event": "stream_completed",
                "stream_id": self.session_id,
                "result": result
            })
        
        return result
    
    async def close(self):
        """关闭会话"""
        self.chunks.clear()
        self.parsed_chunks.clear()
        self.complete_message = None
        
        logger.debug(f"流式会话 {self.session_id} 已关闭")


class StreamPacketAnalyzer:
    """流式数据包分析器"""
    
    @staticmethod
    def analyze_chunk_patterns(chunks: List[bytes]) -> Dict[str, Any]:
        if not chunks:
            return {"error": "无数据块"}
        
        analysis = {
            "total_chunks": len(chunks),
            "size_distribution": {},
            "size_stats": {},
            "pattern_analysis": {}
        }
        
        sizes = [len(chunk) for chunk in chunks]
        analysis["size_stats"] = {
            "min": min(sizes),
            "max": max(sizes),
            "avg": sum(sizes) / len(sizes),
            "total": sum(sizes)
        }
        
        size_ranges = [(0, 100), (100, 500), (500, 1000), (1000, 5000), (5000, float('inf'))]
        for start, end in size_ranges:
            range_name = f"{start}-{end if end != float('inf') else '∞'}"
            count = sum(1 for size in sizes if start <= size < end)
            analysis["size_distribution"][range_name] = count
        
        if len(chunks) >= 2:
            first_bytes = [chunk[:4].hex() if len(chunk) >= 4 else chunk.hex() for chunk in chunks[:5]]
            analysis["pattern_analysis"]["first_bytes_samples"] = first_bytes
            
            if chunks:
                common_prefix_len = 0
                first_chunk = chunks[0]
                for i in range(min(len(first_chunk), 10)):
                    if all(len(chunk) > i and chunk[i] == first_chunk[i] for chunk in chunks[1:]):
                        common_prefix_len = i + 1
                    else:
                        break
                
                if common_prefix_len > 0:
                    analysis["pattern_analysis"]["common_prefix_length"] = common_prefix_len
                    analysis["pattern_analysis"]["common_prefix_hex"] = first_chunk[:common_prefix_len].hex()
        
        return analysis
    
    @staticmethod
    def extract_streaming_deltas(parsed_chunks: List[Dict]) -> List[Dict]:
        if not parsed_chunks:
            return []
        
        deltas = []
        previous_content = ""
        
        for i, chunk in enumerate(parsed_chunks):
            delta = {
                "chunk_index": i,
                "timestamp": datetime.now().isoformat()
            }
            
            current_content = StreamPacketAnalyzer._extract_text_content(chunk)
            
            if current_content and current_content != previous_content:
                if previous_content and current_content.startswith(previous_content):
                    delta["content_delta"] = current_content[len(previous_content):]
                    delta["delta_type"] = "append"
                else:
                    delta["content_delta"] = current_content
                    delta["delta_type"] = "replace"
                
                delta["total_content_length"] = len(current_content)
                previous_content = current_content
            else:
                delta["content_delta"] = ""
                delta["delta_type"] = "no_change"
            
            if i > 0:
                delta["field_changes"] = StreamPacketAnalyzer._compare_dicts(parsed_chunks[i-1], chunk)
            
            deltas.append(delta)
        
        return deltas
    
    @staticmethod
    def _extract_text_content(data: Dict) -> str:
        text_paths = [
            ["content"],
            ["text"],
            ["message"],
            ["agent_output", "text"],
            ["choices", 0, "delta", "content"],
            ["choices", 0, "message", "content"]
        ]
        
        for path in text_paths:
            try:
                current = data
                for key in path:
                    if isinstance(current, dict) and key in current:
                        current = current[key]
                    elif isinstance(current, list) and isinstance(key, int) and 0 <= key < len(current):
                        current = current[key]
                    else:
                        break
                else:
                    if isinstance(current, str):
                        return current
            except Exception:
                continue
        
        return ""
    
    @staticmethod
    def _compare_dicts(dict1: Dict, dict2: Dict, prefix: str = "") -> List[str]:
        changes = []
        
        all_keys = set(dict1.keys()) | set(dict2.keys())
        
        for key in all_keys:
            current_path = f"{prefix}.{key}" if prefix else key
            
            if key not in dict1:
                changes.append(f"添加: {current_path}")
            elif key not in dict2:
                changes.append(f"删除: {current_path}")
            elif dict1[key] != dict2[key]:
                if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
                    changes.extend(StreamPacketAnalyzer._compare_dicts(dict1[key], dict2[key], current_path))
                else:
                    changes.append(f"修改: {current_path}")
        
        return changes[:10]


_global_processor: Optional[StreamProcessor] = None

def get_stream_processor() -> StreamProcessor:
    global _global_processor
    if _global_processor is None:
        _global_processor = StreamProcessor()
    return _global_processor


def set_websocket_manager(manager):
    processor = get_stream_processor()
    processor.websocket_manager = manager