vkjlwnvbioWBV / warp2protobuf /core /stream_processor.py
devme's picture
Upload 90 files
9314c03 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
流式数据包处理器
处理流式protobuf数据包,支持实时解析和WebSocket推送。
"""
import asyncio
import json
import base64
from typing import AsyncGenerator, List, Dict, Any, Optional
from datetime import datetime
from .logging import logger
from .protobuf_utils import protobuf_to_dict
class StreamProcessor:
"""流式数据包处理器"""
def __init__(self, websocket_manager=None):
self.websocket_manager = websocket_manager
self.active_streams: Dict[str, StreamSession] = {}
async def create_stream_session(self, stream_id: str, message_type: str = "warp.multi_agent.v1.Response") -> 'StreamSession':
"""创建流式会话"""
session = StreamSession(stream_id, message_type, self.websocket_manager)
self.active_streams[stream_id] = session
logger.info(f"创建流式会话: {stream_id}, 消息类型: {message_type}")
return session
async def get_stream_session(self, stream_id: str) -> Optional['StreamSession']:
"""获取流式会话"""
return self.active_streams.get(stream_id)
async def close_stream_session(self, stream_id: str):
"""关闭流式会话"""
if stream_id in self.active_streams:
session = self.active_streams[stream_id]
await session.close()
del self.active_streams[stream_id]
logger.info(f"关闭流式会话: {stream_id}")
async def process_stream_chunk(self, stream_id: str, chunk_data: bytes) -> Dict[str, Any]:
"""处理流式数据块"""
session = await self.get_stream_session(stream_id)
if not session:
raise ValueError(f"流式会话不存在: {stream_id}")
return await session.process_chunk(chunk_data)
async def finalize_stream(self, stream_id: str) -> Dict[str, Any]:
"""完成流式处理"""
session = await self.get_stream_session(stream_id)
if not session:
raise ValueError(f"流式会话不存在: {stream_id}")
result = await session.finalize()
await self.close_stream_session(stream_id)
return result
class StreamSession:
"""流式会话"""
def __init__(self, session_id: str, message_type: str, websocket_manager=None):
self.session_id = session_id
self.message_type = message_type
self.websocket_manager = websocket_manager
self.chunks: List[bytes] = []
self.chunk_count = 0
self.total_size = 0
self.start_time = datetime.now()
self.parsed_chunks: List[Dict] = []
self.complete_message: Optional[Dict] = None
async def process_chunk(self, chunk_data: bytes) -> Dict[str, Any]:
"""处理单个数据块"""
self.chunk_count += 1
self.total_size += len(chunk_data)
self.chunks.append(chunk_data)
logger.debug(f"流式会话 {self.session_id}: 处理数据块 {self.chunk_count}, 大小 {len(chunk_data)} 字节")
chunk_result = {
"chunk_index": self.chunk_count - 1,
"size": len(chunk_data),
"timestamp": datetime.now().isoformat()
}
try:
chunk_json = protobuf_to_dict(chunk_data, self.message_type)
chunk_result["json_data"] = chunk_json
chunk_result["parsed_successfully"] = True
self.parsed_chunks.append(chunk_json)
if self.websocket_manager:
await self.websocket_manager.broadcast({
"event": "stream_chunk_parsed",
"stream_id": self.session_id,
"chunk": chunk_result
})
except Exception as e:
chunk_result["error"] = str(e)
chunk_result["parsed_successfully"] = False
logger.warning(f"数据块解析失败: {e}")
if self.websocket_manager:
await self.websocket_manager.broadcast({
"event": "stream_chunk_error",
"stream_id": self.session_id,
"chunk": chunk_result
})
return chunk_result
async def finalize(self) -> Dict[str, Any]:
"""完成流式处理,尝试拼接完整消息"""
duration = (datetime.now() - self.start_time).total_seconds()
logger.info(f"流式会话 {self.session_id} 完成: {self.chunk_count} 块, 总大小 {self.total_size} 字节, 耗时 {duration:.2f}s")
result = {
"session_id": self.session_id,
"chunk_count": self.chunk_count,
"total_size": self.total_size,
"duration_seconds": duration,
"chunks": []
}
for i, chunk in enumerate(self.chunks):
chunk_info = {
"index": i,
"size": len(chunk),
"hex_preview": chunk[:32].hex() if len(chunk) >= 32 else chunk.hex()
}
if i < len(self.parsed_chunks):
chunk_info["parsed_data"] = self.parsed_chunks[i]
result["chunks"].append(chunk_info)
try:
complete_data = b''.join(self.chunks)
complete_json = protobuf_to_dict(complete_data, self.message_type)
result["complete_message"] = {
"size": len(complete_data),
"json_data": complete_json,
"assembly_successful": True
}
self.complete_message = complete_json
logger.info(f"流式消息拼接成功: {len(complete_data)} 字节")
except Exception as e:
result["complete_message"] = {
"error": str(e),
"assembly_successful": False
}
logger.warning(f"流式消息拼接失败: {e}")
if self.websocket_manager:
await self.websocket_manager.broadcast({
"event": "stream_completed",
"stream_id": self.session_id,
"result": result
})
return result
async def close(self):
"""关闭会话"""
self.chunks.clear()
self.parsed_chunks.clear()
self.complete_message = None
logger.debug(f"流式会话 {self.session_id} 已关闭")
class StreamPacketAnalyzer:
"""流式数据包分析器"""
@staticmethod
def analyze_chunk_patterns(chunks: List[bytes]) -> Dict[str, Any]:
if not chunks:
return {"error": "无数据块"}
analysis = {
"total_chunks": len(chunks),
"size_distribution": {},
"size_stats": {},
"pattern_analysis": {}
}
sizes = [len(chunk) for chunk in chunks]
analysis["size_stats"] = {
"min": min(sizes),
"max": max(sizes),
"avg": sum(sizes) / len(sizes),
"total": sum(sizes)
}
size_ranges = [(0, 100), (100, 500), (500, 1000), (1000, 5000), (5000, float('inf'))]
for start, end in size_ranges:
range_name = f"{start}-{end if end != float('inf') else '∞'}"
count = sum(1 for size in sizes if start <= size < end)
analysis["size_distribution"][range_name] = count
if len(chunks) >= 2:
first_bytes = [chunk[:4].hex() if len(chunk) >= 4 else chunk.hex() for chunk in chunks[:5]]
analysis["pattern_analysis"]["first_bytes_samples"] = first_bytes
if chunks:
common_prefix_len = 0
first_chunk = chunks[0]
for i in range(min(len(first_chunk), 10)):
if all(len(chunk) > i and chunk[i] == first_chunk[i] for chunk in chunks[1:]):
common_prefix_len = i + 1
else:
break
if common_prefix_len > 0:
analysis["pattern_analysis"]["common_prefix_length"] = common_prefix_len
analysis["pattern_analysis"]["common_prefix_hex"] = first_chunk[:common_prefix_len].hex()
return analysis
@staticmethod
def extract_streaming_deltas(parsed_chunks: List[Dict]) -> List[Dict]:
if not parsed_chunks:
return []
deltas = []
previous_content = ""
for i, chunk in enumerate(parsed_chunks):
delta = {
"chunk_index": i,
"timestamp": datetime.now().isoformat()
}
current_content = StreamPacketAnalyzer._extract_text_content(chunk)
if current_content and current_content != previous_content:
if previous_content and current_content.startswith(previous_content):
delta["content_delta"] = current_content[len(previous_content):]
delta["delta_type"] = "append"
else:
delta["content_delta"] = current_content
delta["delta_type"] = "replace"
delta["total_content_length"] = len(current_content)
previous_content = current_content
else:
delta["content_delta"] = ""
delta["delta_type"] = "no_change"
if i > 0:
delta["field_changes"] = StreamPacketAnalyzer._compare_dicts(parsed_chunks[i-1], chunk)
deltas.append(delta)
return deltas
@staticmethod
def _extract_text_content(data: Dict) -> str:
text_paths = [
["content"],
["text"],
["message"],
["agent_output", "text"],
["choices", 0, "delta", "content"],
["choices", 0, "message", "content"]
]
for path in text_paths:
try:
current = data
for key in path:
if isinstance(current, dict) and key in current:
current = current[key]
elif isinstance(current, list) and isinstance(key, int) and 0 <= key < len(current):
current = current[key]
else:
break
else:
if isinstance(current, str):
return current
except Exception:
continue
return ""
@staticmethod
def _compare_dicts(dict1: Dict, dict2: Dict, prefix: str = "") -> List[str]:
changes = []
all_keys = set(dict1.keys()) | set(dict2.keys())
for key in all_keys:
current_path = f"{prefix}.{key}" if prefix else key
if key not in dict1:
changes.append(f"添加: {current_path}")
elif key not in dict2:
changes.append(f"删除: {current_path}")
elif dict1[key] != dict2[key]:
if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
changes.extend(StreamPacketAnalyzer._compare_dicts(dict1[key], dict2[key], current_path))
else:
changes.append(f"修改: {current_path}")
return changes[:10]
_global_processor: Optional[StreamProcessor] = None
def get_stream_processor() -> StreamProcessor:
global _global_processor
if _global_processor is None:
_global_processor = StreamProcessor()
return _global_processor
def set_websocket_manager(manager):
processor = get_stream_processor()
processor.websocket_manager = manager