vkjlwnvbioWBV / warp2protobuf /core /protobuf_utils.py
devme's picture
Upload 90 files
9314c03 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Protobuf utility functions
Shared functions for protobuf encoding/decoding across the application.
"""
from typing import Any, Dict
from fastapi import HTTPException
from .logging import logger
from .protobuf import ensure_proto_runtime, msg_cls
from google.protobuf.json_format import MessageToDict
from google.protobuf import struct_pb2
from google.protobuf.descriptor import FieldDescriptor as _FD
from .server_message_data import decode_server_message_data, encode_server_message_data
def protobuf_to_dict(protobuf_bytes: bytes, message_type: str) -> Dict:
"""将protobuf字节转换为字典"""
ensure_proto_runtime()
try:
MessageClass = msg_cls(message_type)
message = MessageClass()
message.ParseFromString(protobuf_bytes)
data = MessageToDict(message, preserving_proto_field_name=True)
# 在转换阶段自动解析 server_message_data(Base64URL -> 结构化对象)
data = _decode_smd_inplace(data)
return data
except Exception as e:
logger.error(f"Protobuf解码失败: {e}")
raise HTTPException(500, f"Protobuf解码失败: {e}")
def dict_to_protobuf_bytes(data_dict: Dict, message_type: str = "warp.multi_agent.v1.Request") -> bytes:
"""字典转protobuf字节的包装函数"""
ensure_proto_runtime()
try:
MessageClass = msg_cls(message_type)
message = MessageClass()
# 在转换阶段自动处理 server_message_data(对象 -> Base64URL 字符串)
safe_dict = _encode_smd_inplace(data_dict)
_populate_protobuf_from_dict(message, safe_dict, path="$")
return message.SerializeToString()
except Exception as e:
logger.error(f"Protobuf编码失败: {e}")
raise HTTPException(500, f"Protobuf编码失败: {e}")
def _fill_google_value_dynamic(value_msg: Any, py_value: Any) -> None:
"""在动态 google.protobuf.Value 消息上填充 Python 值(不创建 struct_pb2.Value 实例)。"""
try:
if py_value is None:
setattr(value_msg, "null_value", 0)
return
if isinstance(py_value, bool):
setattr(value_msg, "bool_value", bool(py_value))
return
if isinstance(py_value, (int, float)):
setattr(value_msg, "number_value", float(py_value))
return
if isinstance(py_value, str):
setattr(value_msg, "string_value", py_value)
return
if isinstance(py_value, dict):
struct_value = getattr(value_msg, "struct_value")
_fill_google_struct_dynamic(struct_value, py_value)
return
if isinstance(py_value, list):
list_value = getattr(value_msg, "list_value")
values_rep = getattr(list_value, "values")
for item in py_value:
sub = values_rep.add()
_fill_google_value_dynamic(sub, item)
return
setattr(value_msg, "string_value", str(py_value))
except Exception as e:
logger.warning(f"填充 google.protobuf.Value 失败: {e}")
def _fill_google_struct_dynamic(struct_msg: Any, py_dict: Dict[str, Any]) -> None:
"""在动态 google.protobuf.Struct 上填充 Python dict(不使用 struct_pb2.Struct.update)。"""
try:
fields_map = getattr(struct_msg, "fields")
for k, v in py_dict.items():
sub_val = fields_map[k]
_fill_google_value_dynamic(sub_val, v)
except Exception as e:
logger.warning(f"填充 google.protobuf.Struct 失败: {e}")
def _python_to_struct_value(py_value: Any) -> struct_pb2.Value:
v = struct_pb2.Value()
if py_value is None:
v.null_value = struct_pb2.NULL_VALUE
elif isinstance(py_value, bool):
v.bool_value = bool(py_value)
elif isinstance(py_value, (int, float)):
v.number_value = float(py_value)
elif isinstance(py_value, str):
v.string_value = py_value
elif isinstance(py_value, dict):
s = struct_pb2.Struct()
s.update(py_value)
v.struct_value.CopyFrom(s)
elif isinstance(py_value, list):
lv = struct_pb2.ListValue()
for item in py_value:
lv.values.append(_python_to_struct_value(item))
v.list_value.CopyFrom(lv)
else:
v.string_value = str(py_value)
return v
def _populate_protobuf_from_dict(proto_msg, data_dict: Dict, path: str = "$"):
for key, value in data_dict.items():
current_path = f"{path}.{key}"
if not hasattr(proto_msg, key):
logger.warning(f"忽略未知字段: {current_path}")
continue
field = getattr(proto_msg, key)
fd = None
descriptor = getattr(proto_msg, "DESCRIPTOR", None)
if descriptor is not None:
fd = descriptor.fields_by_name.get(key)
try:
if (
fd is not None
and fd.type == _FD.TYPE_MESSAGE
and fd.message_type is not None
and fd.message_type.full_name == "google.protobuf.Struct"
and isinstance(value, dict)
):
_fill_google_struct_dynamic(field, value)
continue
except Exception as e:
logger.warning(f"处理 Struct 字段 {current_path} 失败: {e}")
if isinstance(field, struct_pb2.Struct) and isinstance(value, dict):
try:
field.update(value)
except Exception as e:
logger.warning(f"填充Struct失败: {current_path}: {e}")
continue
try:
if (
fd is not None
and fd.type == _FD.TYPE_MESSAGE
and fd.message_type is not None
and fd.message_type.GetOptions().map_entry
and isinstance(value, dict)
):
value_desc = fd.message_type.fields_by_name.get("value")
for mk, mv in value.items():
try:
if value_desc is not None and value_desc.type == _FD.TYPE_MESSAGE:
if value_desc.message_type is not None and value_desc.message_type.full_name == "google.protobuf.Value":
_fill_google_value_dynamic(field[mk], mv)
else:
sub_msg = field[mk]
if isinstance(mv, dict):
_populate_protobuf_from_dict(sub_msg, mv, path=f"{current_path}.{mk}")
else:
try:
logger.warning(f"map值类型不匹配,期望message: {current_path}.{mk}")
except Exception:
pass
else:
field[mk] = mv
except Exception as me:
logger.warning(f"设置 map 字段 {current_path}.{mk} 失败: {me}")
continue
except Exception as e:
logger.warning(f"处理 map 字段 {current_path} 失败: {e}")
if isinstance(value, dict):
try:
_populate_protobuf_from_dict(field, value, path=current_path)
except Exception as e:
logger.error(f"填充子消息失败: {current_path}: {e}")
raise
elif isinstance(value, list):
# 处理 repeated enum:允许传入字符串名称或数字
try:
if fd is not None and fd.type == _FD.TYPE_ENUM:
enum_desc = getattr(fd, "enum_type", None)
resolved_values = []
for item in value:
if isinstance(item, str):
ev = enum_desc.values_by_name.get(item) if enum_desc is not None else None
if ev is not None:
resolved_values.append(ev.number)
else:
try:
resolved_values.append(int(item))
except Exception:
logger.warning(f"无法解析枚举值 '{item}' 为 {current_path},已忽略")
else:
try:
resolved_values.append(int(item))
except Exception:
logger.warning(f"无法转换枚举值 {item} 为整数: {current_path}")
field.extend(resolved_values)
continue
except Exception as e:
logger.warning(f"处理 repeated enum 字段 {current_path} 失败: {e}")
if value and isinstance(value[0], dict):
try:
for idx, item in enumerate(value):
new_item = field.add() # type: ignore[attr-defined]
_populate_protobuf_from_dict(new_item, item, path=f"{current_path}[{idx}]")
except Exception as e:
logger.warning(f"填充复合数组失败 {current_path}: {e}")
else:
try:
field.extend(value)
except Exception as e:
logger.warning(f"设置数组字段 {current_path} 失败: {e}")
else:
if key in ["in_progress", "resume_conversation"]:
field.SetInParent()
else:
try:
# 处理标量 enum:允许传入字符串名称或数字
if fd is not None and fd.type == _FD.TYPE_ENUM:
enum_desc = getattr(fd, "enum_type", None)
if isinstance(value, str):
ev = enum_desc.values_by_name.get(value) if enum_desc is not None else None
if ev is not None:
setattr(proto_msg, key, ev.number)
continue
try:
setattr(proto_msg, key, int(value))
continue
except Exception:
pass
# 其余情况直接赋值,若类型不匹配由底层抛错
setattr(proto_msg, key, value)
except Exception as e:
logger.warning(f"设置字段 {current_path} 失败: {e}")
# ===== server_message_data 递归处理 =====
def _encode_smd_inplace(obj: Any) -> Any:
if isinstance(obj, dict):
new_d: Dict[str, Any] = {}
for k, v in obj.items():
if k in ("server_message_data", "serverMessageData") and isinstance(v, dict):
try:
b64 = encode_server_message_data(
uuid=v.get("uuid"),
seconds=v.get("seconds"),
nanos=v.get("nanos"),
)
new_d[k] = b64
except Exception:
new_d[k] = v
else:
new_d[k] = _encode_smd_inplace(v)
return new_d
elif isinstance(obj, list):
return [_encode_smd_inplace(x) for x in obj]
else:
return obj
def _decode_smd_inplace(obj: Any) -> Any:
if isinstance(obj, dict):
new_d: Dict[str, Any] = {}
for k, v in obj.items():
if k in ("server_message_data", "serverMessageData") and isinstance(v, str):
try:
dec = decode_server_message_data(v)
new_d[k] = dec
except Exception:
new_d[k] = v
else:
new_d[k] = _decode_smd_inplace(v)
return new_d
elif isinstance(obj, list):
return [_decode_smd_inplace(x) for x in obj]
else:
return obj