File size: 12,540 Bytes
9314c03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

Protobuf utility functions



Shared functions for protobuf encoding/decoding across the application.

"""
from typing import Any, Dict
from fastapi import HTTPException
from .logging import logger
from .protobuf import ensure_proto_runtime, msg_cls
from google.protobuf.json_format import MessageToDict
from google.protobuf import struct_pb2
from google.protobuf.descriptor import FieldDescriptor as _FD
from .server_message_data import decode_server_message_data, encode_server_message_data





def protobuf_to_dict(protobuf_bytes: bytes, message_type: str) -> Dict:
    """将protobuf字节转换为字典"""
    ensure_proto_runtime()
    
    try:
        MessageClass = msg_cls(message_type)
        message = MessageClass()
        message.ParseFromString(protobuf_bytes)
        
        data = MessageToDict(message, preserving_proto_field_name=True)
        
        # 在转换阶段自动解析 server_message_data(Base64URL -> 结构化对象)
        data = _decode_smd_inplace(data)
        return data
    
    except Exception as e:
        logger.error(f"Protobuf解码失败: {e}")
        raise HTTPException(500, f"Protobuf解码失败: {e}")





def dict_to_protobuf_bytes(data_dict: Dict, message_type: str = "warp.multi_agent.v1.Request") -> bytes:
    """字典转protobuf字节的包装函数"""
    ensure_proto_runtime()
    
    try:
        MessageClass = msg_cls(message_type)
        message = MessageClass()
        
        # 在转换阶段自动处理 server_message_data(对象 -> Base64URL 字符串)
        safe_dict = _encode_smd_inplace(data_dict)
        
        _populate_protobuf_from_dict(message, safe_dict, path="$")
        
        return message.SerializeToString()
    
    except Exception as e:
        logger.error(f"Protobuf编码失败: {e}")
        raise HTTPException(500, f"Protobuf编码失败: {e}")




def _fill_google_value_dynamic(value_msg: Any, py_value: Any) -> None:
    """在动态 google.protobuf.Value 消息上填充 Python 值(不创建 struct_pb2.Value 实例)。"""
    try:
        if py_value is None:
            setattr(value_msg, "null_value", 0)
            return
        if isinstance(py_value, bool):
            setattr(value_msg, "bool_value", bool(py_value))
            return
        if isinstance(py_value, (int, float)):
            setattr(value_msg, "number_value", float(py_value))
            return
        if isinstance(py_value, str):
            setattr(value_msg, "string_value", py_value)
            return
        if isinstance(py_value, dict):
            struct_value = getattr(value_msg, "struct_value")
            _fill_google_struct_dynamic(struct_value, py_value)
            return
        if isinstance(py_value, list):
            list_value = getattr(value_msg, "list_value")
            values_rep = getattr(list_value, "values")
            for item in py_value:
                sub = values_rep.add()
                _fill_google_value_dynamic(sub, item)
            return
        setattr(value_msg, "string_value", str(py_value))
    except Exception as e:
        logger.warning(f"填充 google.protobuf.Value 失败: {e}")




def _fill_google_struct_dynamic(struct_msg: Any, py_dict: Dict[str, Any]) -> None:
    """在动态 google.protobuf.Struct 上填充 Python dict(不使用 struct_pb2.Struct.update)。"""
    try:
        fields_map = getattr(struct_msg, "fields")
        for k, v in py_dict.items():
            sub_val = fields_map[k]
            _fill_google_value_dynamic(sub_val, v)
    except Exception as e:
        logger.warning(f"填充 google.protobuf.Struct 失败: {e}")




def _python_to_struct_value(py_value: Any) -> struct_pb2.Value:
    v = struct_pb2.Value()
    if py_value is None:
        v.null_value = struct_pb2.NULL_VALUE
    elif isinstance(py_value, bool):
        v.bool_value = bool(py_value)
    elif isinstance(py_value, (int, float)):
        v.number_value = float(py_value)
    elif isinstance(py_value, str):
        v.string_value = py_value
    elif isinstance(py_value, dict):
        s = struct_pb2.Struct()
        s.update(py_value)
        v.struct_value.CopyFrom(s)
    elif isinstance(py_value, list):
        lv = struct_pb2.ListValue()
        for item in py_value:
            lv.values.append(_python_to_struct_value(item))
        v.list_value.CopyFrom(lv)
    else:
        v.string_value = str(py_value)
    return v




def _populate_protobuf_from_dict(proto_msg, data_dict: Dict, path: str = "$"):
    for key, value in data_dict.items():
        current_path = f"{path}.{key}"
        if not hasattr(proto_msg, key):
            logger.warning(f"忽略未知字段: {current_path}")
            continue
            
        field = getattr(proto_msg, key)
        fd = None
        descriptor = getattr(proto_msg, "DESCRIPTOR", None)
        if descriptor is not None:
            fd = descriptor.fields_by_name.get(key)
        
        try:
            if (
                fd is not None
                and fd.type == _FD.TYPE_MESSAGE
                and fd.message_type is not None
                and fd.message_type.full_name == "google.protobuf.Struct"
                and isinstance(value, dict)
            ):
                _fill_google_struct_dynamic(field, value)
                continue
        except Exception as e:
            logger.warning(f"处理 Struct 字段 {current_path} 失败: {e}")

        if isinstance(field, struct_pb2.Struct) and isinstance(value, dict):
            try:
                field.update(value)
            except Exception as e:
                logger.warning(f"填充Struct失败: {current_path}: {e}")
            continue

        try:
            if (
                fd is not None
                and fd.type == _FD.TYPE_MESSAGE
                and fd.message_type is not None
                and fd.message_type.GetOptions().map_entry
                and isinstance(value, dict)
            ):
                value_desc = fd.message_type.fields_by_name.get("value")
                for mk, mv in value.items():
                    try:
                        if value_desc is not None and value_desc.type == _FD.TYPE_MESSAGE:
                            if value_desc.message_type is not None and value_desc.message_type.full_name == "google.protobuf.Value":
                                _fill_google_value_dynamic(field[mk], mv)
                            else:
                                sub_msg = field[mk]
                                if isinstance(mv, dict):
                                    _populate_protobuf_from_dict(sub_msg, mv, path=f"{current_path}.{mk}")
                                else:
                                    try:
                                        logger.warning(f"map值类型不匹配,期望message: {current_path}.{mk}")
                                    except Exception:
                                        pass
                        else:
                            field[mk] = mv
                    except Exception as me:
                        logger.warning(f"设置 map 字段 {current_path}.{mk} 失败: {me}")
                continue
        except Exception as e:
            logger.warning(f"处理 map 字段 {current_path} 失败: {e}")
        
        if isinstance(value, dict):
            try:
                _populate_protobuf_from_dict(field, value, path=current_path)
            except Exception as e:
                logger.error(f"填充子消息失败: {current_path}: {e}")
                raise
        elif isinstance(value, list):
            # 处理 repeated enum:允许传入字符串名称或数字
            try:
                if fd is not None and fd.type == _FD.TYPE_ENUM:
                    enum_desc = getattr(fd, "enum_type", None)
                    resolved_values = []
                    for item in value:
                        if isinstance(item, str):
                            ev = enum_desc.values_by_name.get(item) if enum_desc is not None else None
                            if ev is not None:
                                resolved_values.append(ev.number)
                            else:
                                try:
                                    resolved_values.append(int(item))
                                except Exception:
                                    logger.warning(f"无法解析枚举值 '{item}' 为 {current_path},已忽略")
                        else:
                            try:
                                resolved_values.append(int(item))
                            except Exception:
                                logger.warning(f"无法转换枚举值 {item} 为整数: {current_path}")
                    field.extend(resolved_values)
                    continue
            except Exception as e:
                logger.warning(f"处理 repeated enum 字段 {current_path} 失败: {e}")
            if value and isinstance(value[0], dict):
                try:
                    for idx, item in enumerate(value):
                        new_item = field.add()  # type: ignore[attr-defined]
                        _populate_protobuf_from_dict(new_item, item, path=f"{current_path}[{idx}]")
                except Exception as e:
                    logger.warning(f"填充复合数组失败 {current_path}: {e}")
            else:
                try:
                    field.extend(value)
                except Exception as e:
                    logger.warning(f"设置数组字段 {current_path} 失败: {e}")
        else:
            if key in ["in_progress", "resume_conversation"]:
                field.SetInParent()
            else:
                try:
                    # 处理标量 enum:允许传入字符串名称或数字
                    if fd is not None and fd.type == _FD.TYPE_ENUM:
                        enum_desc = getattr(fd, "enum_type", None)
                        if isinstance(value, str):
                            ev = enum_desc.values_by_name.get(value) if enum_desc is not None else None
                            if ev is not None:
                                setattr(proto_msg, key, ev.number)
                                continue
                            try:
                                setattr(proto_msg, key, int(value))
                                continue
                            except Exception:
                                pass
                        # 其余情况直接赋值,若类型不匹配由底层抛错
                    setattr(proto_msg, key, value)
                except Exception as e:
                    logger.warning(f"设置字段 {current_path} 失败: {e}")


# ===== server_message_data 递归处理 =====

def _encode_smd_inplace(obj: Any) -> Any:
    if isinstance(obj, dict):
        new_d: Dict[str, Any] = {}
        for k, v in obj.items():
            if k in ("server_message_data", "serverMessageData") and isinstance(v, dict):
                try:
                    b64 = encode_server_message_data(
                        uuid=v.get("uuid"),
                        seconds=v.get("seconds"),
                        nanos=v.get("nanos"),
                    )
                    new_d[k] = b64
                except Exception:
                    new_d[k] = v
            else:
                new_d[k] = _encode_smd_inplace(v)
        return new_d
    elif isinstance(obj, list):
        return [_encode_smd_inplace(x) for x in obj]
    else:
        return obj


def _decode_smd_inplace(obj: Any) -> Any:
    if isinstance(obj, dict):
        new_d: Dict[str, Any] = {}
        for k, v in obj.items():
            if k in ("server_message_data", "serverMessageData") and isinstance(v, str):
                try:
                    dec = decode_server_message_data(v)
                    new_d[k] = dec
                except Exception:
                    new_d[k] = v
            else:
                new_d[k] = _decode_smd_inplace(v)
        return new_d
    elif isinstance(obj, list):
        return [_decode_smd_inplace(x) for x in obj]
    else:
        return obj