Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
流式数据包处理器 | |
处理流式protobuf数据包,支持实时解析和WebSocket推送。 | |
""" | |
import asyncio | |
import json | |
import base64 | |
from typing import AsyncGenerator, List, Dict, Any, Optional | |
from datetime import datetime | |
from .logging import logger | |
from .protobuf_utils import protobuf_to_dict | |
class StreamProcessor: | |
"""流式数据包处理器""" | |
def __init__(self, websocket_manager=None): | |
self.websocket_manager = websocket_manager | |
self.active_streams: Dict[str, StreamSession] = {} | |
async def create_stream_session(self, stream_id: str, message_type: str = "warp.multi_agent.v1.Response") -> 'StreamSession': | |
"""创建流式会话""" | |
session = StreamSession(stream_id, message_type, self.websocket_manager) | |
self.active_streams[stream_id] = session | |
logger.info(f"创建流式会话: {stream_id}, 消息类型: {message_type}") | |
return session | |
async def get_stream_session(self, stream_id: str) -> Optional['StreamSession']: | |
"""获取流式会话""" | |
return self.active_streams.get(stream_id) | |
async def close_stream_session(self, stream_id: str): | |
"""关闭流式会话""" | |
if stream_id in self.active_streams: | |
session = self.active_streams[stream_id] | |
await session.close() | |
del self.active_streams[stream_id] | |
logger.info(f"关闭流式会话: {stream_id}") | |
async def process_stream_chunk(self, stream_id: str, chunk_data: bytes) -> Dict[str, Any]: | |
"""处理流式数据块""" | |
session = await self.get_stream_session(stream_id) | |
if not session: | |
raise ValueError(f"流式会话不存在: {stream_id}") | |
return await session.process_chunk(chunk_data) | |
async def finalize_stream(self, stream_id: str) -> Dict[str, Any]: | |
"""完成流式处理""" | |
session = await self.get_stream_session(stream_id) | |
if not session: | |
raise ValueError(f"流式会话不存在: {stream_id}") | |
result = await session.finalize() | |
await self.close_stream_session(stream_id) | |
return result | |
class StreamSession: | |
"""流式会话""" | |
def __init__(self, session_id: str, message_type: str, websocket_manager=None): | |
self.session_id = session_id | |
self.message_type = message_type | |
self.websocket_manager = websocket_manager | |
self.chunks: List[bytes] = [] | |
self.chunk_count = 0 | |
self.total_size = 0 | |
self.start_time = datetime.now() | |
self.parsed_chunks: List[Dict] = [] | |
self.complete_message: Optional[Dict] = None | |
async def process_chunk(self, chunk_data: bytes) -> Dict[str, Any]: | |
"""处理单个数据块""" | |
self.chunk_count += 1 | |
self.total_size += len(chunk_data) | |
self.chunks.append(chunk_data) | |
logger.debug(f"流式会话 {self.session_id}: 处理数据块 {self.chunk_count}, 大小 {len(chunk_data)} 字节") | |
chunk_result = { | |
"chunk_index": self.chunk_count - 1, | |
"size": len(chunk_data), | |
"timestamp": datetime.now().isoformat() | |
} | |
try: | |
chunk_json = protobuf_to_dict(chunk_data, self.message_type) | |
chunk_result["json_data"] = chunk_json | |
chunk_result["parsed_successfully"] = True | |
self.parsed_chunks.append(chunk_json) | |
if self.websocket_manager: | |
await self.websocket_manager.broadcast({ | |
"event": "stream_chunk_parsed", | |
"stream_id": self.session_id, | |
"chunk": chunk_result | |
}) | |
except Exception as e: | |
chunk_result["error"] = str(e) | |
chunk_result["parsed_successfully"] = False | |
logger.warning(f"数据块解析失败: {e}") | |
if self.websocket_manager: | |
await self.websocket_manager.broadcast({ | |
"event": "stream_chunk_error", | |
"stream_id": self.session_id, | |
"chunk": chunk_result | |
}) | |
return chunk_result | |
async def finalize(self) -> Dict[str, Any]: | |
"""完成流式处理,尝试拼接完整消息""" | |
duration = (datetime.now() - self.start_time).total_seconds() | |
logger.info(f"流式会话 {self.session_id} 完成: {self.chunk_count} 块, 总大小 {self.total_size} 字节, 耗时 {duration:.2f}s") | |
result = { | |
"session_id": self.session_id, | |
"chunk_count": self.chunk_count, | |
"total_size": self.total_size, | |
"duration_seconds": duration, | |
"chunks": [] | |
} | |
for i, chunk in enumerate(self.chunks): | |
chunk_info = { | |
"index": i, | |
"size": len(chunk), | |
"hex_preview": chunk[:32].hex() if len(chunk) >= 32 else chunk.hex() | |
} | |
if i < len(self.parsed_chunks): | |
chunk_info["parsed_data"] = self.parsed_chunks[i] | |
result["chunks"].append(chunk_info) | |
try: | |
complete_data = b''.join(self.chunks) | |
complete_json = protobuf_to_dict(complete_data, self.message_type) | |
result["complete_message"] = { | |
"size": len(complete_data), | |
"json_data": complete_json, | |
"assembly_successful": True | |
} | |
self.complete_message = complete_json | |
logger.info(f"流式消息拼接成功: {len(complete_data)} 字节") | |
except Exception as e: | |
result["complete_message"] = { | |
"error": str(e), | |
"assembly_successful": False | |
} | |
logger.warning(f"流式消息拼接失败: {e}") | |
if self.websocket_manager: | |
await self.websocket_manager.broadcast({ | |
"event": "stream_completed", | |
"stream_id": self.session_id, | |
"result": result | |
}) | |
return result | |
async def close(self): | |
"""关闭会话""" | |
self.chunks.clear() | |
self.parsed_chunks.clear() | |
self.complete_message = None | |
logger.debug(f"流式会话 {self.session_id} 已关闭") | |
class StreamPacketAnalyzer: | |
"""流式数据包分析器""" | |
def analyze_chunk_patterns(chunks: List[bytes]) -> Dict[str, Any]: | |
if not chunks: | |
return {"error": "无数据块"} | |
analysis = { | |
"total_chunks": len(chunks), | |
"size_distribution": {}, | |
"size_stats": {}, | |
"pattern_analysis": {} | |
} | |
sizes = [len(chunk) for chunk in chunks] | |
analysis["size_stats"] = { | |
"min": min(sizes), | |
"max": max(sizes), | |
"avg": sum(sizes) / len(sizes), | |
"total": sum(sizes) | |
} | |
size_ranges = [(0, 100), (100, 500), (500, 1000), (1000, 5000), (5000, float('inf'))] | |
for start, end in size_ranges: | |
range_name = f"{start}-{end if end != float('inf') else '∞'}" | |
count = sum(1 for size in sizes if start <= size < end) | |
analysis["size_distribution"][range_name] = count | |
if len(chunks) >= 2: | |
first_bytes = [chunk[:4].hex() if len(chunk) >= 4 else chunk.hex() for chunk in chunks[:5]] | |
analysis["pattern_analysis"]["first_bytes_samples"] = first_bytes | |
if chunks: | |
common_prefix_len = 0 | |
first_chunk = chunks[0] | |
for i in range(min(len(first_chunk), 10)): | |
if all(len(chunk) > i and chunk[i] == first_chunk[i] for chunk in chunks[1:]): | |
common_prefix_len = i + 1 | |
else: | |
break | |
if common_prefix_len > 0: | |
analysis["pattern_analysis"]["common_prefix_length"] = common_prefix_len | |
analysis["pattern_analysis"]["common_prefix_hex"] = first_chunk[:common_prefix_len].hex() | |
return analysis | |
def extract_streaming_deltas(parsed_chunks: List[Dict]) -> List[Dict]: | |
if not parsed_chunks: | |
return [] | |
deltas = [] | |
previous_content = "" | |
for i, chunk in enumerate(parsed_chunks): | |
delta = { | |
"chunk_index": i, | |
"timestamp": datetime.now().isoformat() | |
} | |
current_content = StreamPacketAnalyzer._extract_text_content(chunk) | |
if current_content and current_content != previous_content: | |
if previous_content and current_content.startswith(previous_content): | |
delta["content_delta"] = current_content[len(previous_content):] | |
delta["delta_type"] = "append" | |
else: | |
delta["content_delta"] = current_content | |
delta["delta_type"] = "replace" | |
delta["total_content_length"] = len(current_content) | |
previous_content = current_content | |
else: | |
delta["content_delta"] = "" | |
delta["delta_type"] = "no_change" | |
if i > 0: | |
delta["field_changes"] = StreamPacketAnalyzer._compare_dicts(parsed_chunks[i-1], chunk) | |
deltas.append(delta) | |
return deltas | |
def _extract_text_content(data: Dict) -> str: | |
text_paths = [ | |
["content"], | |
["text"], | |
["message"], | |
["agent_output", "text"], | |
["choices", 0, "delta", "content"], | |
["choices", 0, "message", "content"] | |
] | |
for path in text_paths: | |
try: | |
current = data | |
for key in path: | |
if isinstance(current, dict) and key in current: | |
current = current[key] | |
elif isinstance(current, list) and isinstance(key, int) and 0 <= key < len(current): | |
current = current[key] | |
else: | |
break | |
else: | |
if isinstance(current, str): | |
return current | |
except Exception: | |
continue | |
return "" | |
def _compare_dicts(dict1: Dict, dict2: Dict, prefix: str = "") -> List[str]: | |
changes = [] | |
all_keys = set(dict1.keys()) | set(dict2.keys()) | |
for key in all_keys: | |
current_path = f"{prefix}.{key}" if prefix else key | |
if key not in dict1: | |
changes.append(f"添加: {current_path}") | |
elif key not in dict2: | |
changes.append(f"删除: {current_path}") | |
elif dict1[key] != dict2[key]: | |
if isinstance(dict1[key], dict) and isinstance(dict2[key], dict): | |
changes.extend(StreamPacketAnalyzer._compare_dicts(dict1[key], dict2[key], current_path)) | |
else: | |
changes.append(f"修改: {current_path}") | |
return changes[:10] | |
_global_processor: Optional[StreamProcessor] = None | |
def get_stream_processor() -> StreamProcessor: | |
global _global_processor | |
if _global_processor is None: | |
_global_processor = StreamProcessor() | |
return _global_processor | |
def set_websocket_manager(manager): | |
processor = get_stream_processor() | |
processor.websocket_manager = manager |