Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Protobuf runtime for Warp API | |
Handles protobuf compilation, message creation, and request building. | |
""" | |
import os | |
import re | |
import json | |
import time | |
import uuid | |
import pathlib | |
import tempfile | |
from typing import Any, Dict, List, Optional, Tuple | |
from google.protobuf import descriptor_pool, descriptor_pb2 | |
from google.protobuf.descriptor import FieldDescriptor as FD | |
from google.protobuf.message_factory import GetMessageClass | |
from google.protobuf import struct_pb2 | |
from ..config.settings import PROTO_DIR, CLIENT_VERSION, OS_CATEGORY, OS_NAME, OS_VERSION, TEXT_FIELD_NAMES, PATH_HINT_BONUS | |
from .logging import logger, log | |
# Global protobuf state | |
_pool: Optional[descriptor_pool.DescriptorPool] = None | |
ALL_MSGS: List[str] = [] | |
def _find_proto_files(root: pathlib.Path) -> List[str]: | |
"""Find necessary .proto files in the given directory, excluding problematic test files""" | |
if not root.exists(): | |
return [] | |
essential_files = [ | |
"options.proto", # 基础选项,被其他文件依赖 | |
"file_content.proto", # 文件内容定义 | |
"attachment.proto", # 附件定义 | |
"todo.proto", # TODO定义 | |
"suggestions.proto", # 建议定义 | |
"debug.proto", # 调试定义 | |
"input_context.proto", # 输入上下文 | |
"citations.proto", # 引用定义 | |
"task.proto", # 任务定义 | |
"request.proto", # 请求定义 | |
"response.proto" # 响应定义 | |
] | |
found_files = [] | |
for file_name in essential_files: | |
file_path = root / file_name | |
if file_path.exists(): | |
found_files.append(str(file_path)) | |
logger.debug(f"Found essential proto file: {file_name}") | |
if not found_files: | |
logger.warning("Essential proto files not found, scanning all files...") | |
exclude_patterns = [ | |
"unittest", "test", "sample_messages", "java_features", | |
"legacy_features", "descriptor_test" | |
] | |
for proto_file in root.rglob("*.proto"): | |
file_name = proto_file.name.lower() | |
if not any(pattern in file_name for pattern in exclude_patterns): | |
found_files.append(str(proto_file)) | |
logger.info(f"Selected {len(found_files)} proto files for compilation") | |
return found_files | |
def _build_descset(proto_files: List[str], includes: List[str]) -> bytes: | |
from grpc_tools import protoc | |
try: | |
from importlib.resources import files as pkg_files | |
tool_inc = str(pkg_files("grpc_tools").joinpath("_proto")) | |
except Exception: | |
tool_inc = None | |
outdir = pathlib.Path(tempfile.mkdtemp(prefix="desc_")) | |
out = outdir / "bundle.pb" | |
args = ["protoc", f"--descriptor_set_out={out}", "--include_imports"] | |
for inc in includes: | |
args.append(f"-I{inc}") | |
if tool_inc: | |
args.append(f"-I{tool_inc}") | |
args.extend(proto_files) | |
rc = protoc.main(args) | |
if rc != 0 or not out.exists(): | |
raise RuntimeError("protoc failed to produce descriptor set") | |
return out.read_bytes() | |
def _load_pool_from_descset(descset: bytes): | |
global _pool, ALL_MSGS | |
fds = descriptor_pb2.FileDescriptorSet() | |
fds.ParseFromString(descset) | |
pool = descriptor_pool.DescriptorPool() | |
for fd in fds.file: | |
pool.Add(fd) | |
names: List[str] = [] | |
for fd in fds.file: | |
pkg = fd.package | |
def walk(m, prefix): | |
full = f"{prefix}.{m.name}" if prefix else m.name | |
names.append(full) | |
for nested in m.nested_type: | |
walk(nested, full) | |
for m in fd.message_type: | |
walk(m, pkg) | |
_pool, ALL_MSGS = pool, names | |
log(f"proto loaded: {len(ALL_MSGS)} message type(s)") | |
def ensure_proto_runtime(): | |
if _pool is not None: | |
return | |
files = _find_proto_files(PROTO_DIR) | |
if not files: | |
raise RuntimeError(f"No .proto found under {PROTO_DIR}") | |
desc = _build_descset(files, [str(PROTO_DIR)]) | |
_load_pool_from_descset(desc) | |
def msg_cls(full: str): | |
desc = _pool.FindMessageTypeByName(full) # type: ignore | |
return GetMessageClass(desc) | |
def _list_text_paths(desc, max_depth=6): | |
out: List[Tuple[List[FD], int]] = [] | |
def walk(cur_desc, cur_path: List[FD], depth: int): | |
if depth > max_depth: | |
return | |
for f in cur_desc.fields: | |
base = 0 | |
if f.name.lower() in TEXT_FIELD_NAMES: base += 10 | |
for hint in PATH_HINT_BONUS: | |
if hint in f.name.lower(): base += 2 | |
if f.type == FD.TYPE_STRING: | |
out.append((cur_path + [f], base + depth)) | |
elif f.type == FD.TYPE_MESSAGE: | |
walk(f.message_type, cur_path + [f], depth + 1) | |
walk(desc, [], 0) | |
return out | |
def _pick_best_request_schema() -> Tuple[str, List[FD]]: | |
ensure_proto_runtime() | |
try: | |
request_type = "warp.multi_agent.v1.Request" | |
d = _pool.FindMessageTypeByName(request_type) # type: ignore | |
path_names = ["input", "user_inputs", "inputs", "user_query", "query"] | |
path_fields = [] | |
current_desc = d | |
for field_name in path_names: | |
field = current_desc.fields_by_name.get(field_name) | |
if not field: | |
raise RuntimeError(f"Field '{field_name}' not found") | |
path_fields.append(field) | |
if field.type == FD.TYPE_MESSAGE: | |
current_desc = field.message_type | |
log("using modern request format:", request_type, " :: ", ".".join(path_names)) | |
return request_type, path_fields | |
except Exception as e: | |
log(f"Failed to use modern format, falling back to auto-detection: {e}") | |
best: Optional[Tuple[str, List[FD], int]] = None | |
for full in ALL_MSGS: | |
try: | |
d = _pool.FindMessageTypeByName(full) # type: ignore | |
except Exception: | |
continue | |
name_bias = 0 | |
lname = full.lower() | |
for kw, w in (("request", 8), ("multi_agent", 6), ("multiagent", 6), | |
("chat", 5), ("client", 2), ("message", 1), ("input", 1)): | |
if kw in lname: name_bias += w | |
for path, score in _list_text_paths(d): | |
total = score + name_bias + max(0, 6 - len(path)) | |
if best is None or total > best[2]: | |
best = (full, path, total) | |
if not best: | |
raise RuntimeError("Could not auto-detect request root & text field from proto/") | |
full, path, _ = best | |
log("auto-detected request:", full, " :: ", ".".join(f.name for f in path)) | |
return full, path | |
_REQ_CACHE: Optional[Tuple[str, List[FD]]] = None | |
def get_request_schema() -> Tuple[str, List[FD]]: | |
global _REQ_CACHE | |
if _REQ_CACHE is None: | |
_REQ_CACHE = _pick_best_request_schema() | |
return _REQ_CACHE | |
def _set_text_at_path(msg, path_fields: List[FD], text: str): | |
cur = msg | |
for i, f in enumerate(path_fields): | |
last = (i == len(path_fields) - 1) | |
try: | |
is_repeated = f.is_repeated | |
except AttributeError: | |
is_repeated = (f.label == FD.LABEL_REPEATED) | |
if is_repeated: | |
rep = getattr(cur, f.name) | |
if f.type == FD.TYPE_MESSAGE: | |
cur = rep.add() | |
elif f.type == FD.TYPE_STRING: | |
if not last: raise TypeError(f"path continues after repeated string field '{f.name}'") | |
rep.append(text); return | |
else: | |
raise TypeError(f"unsupported repeated scalar at '{f.name}'") | |
else: | |
if f.type == FD.TYPE_MESSAGE: | |
cur = getattr(cur, f.name) | |
if last: | |
raise TypeError(f"last field '{f.name}' is a message, not string") | |
elif f.type == FD.TYPE_STRING: | |
if not last: raise TypeError(f"path continues after string field '{f.name}'") | |
setattr(cur, f.name, text); return | |
else: | |
raise TypeError(f"unsupported scalar at '{f.name}'") | |
raise RuntimeError("failed to set text") | |
def build_request_bytes(user_text: str, model: str = "auto") -> bytes: | |
from ..config.models import get_model_config | |
full, path = get_request_schema() | |
Cls = msg_cls(full) | |
msg = Cls() | |
_set_text_at_path(msg, path, user_text) | |
if hasattr(msg, 'settings'): | |
settings = msg.settings | |
if hasattr(settings, 'model_config'): | |
model_config_dict = get_model_config(model) | |
model_config = settings.model_config | |
model_config.base = model_config_dict["base"] | |
model_config.planning = model_config_dict["planning"] | |
model_config.coding = model_config_dict["coding"] | |
logger.debug(f"Set model config: base={model_config.base}, planning={model_config.planning}, coding={model_config.coding}") | |
settings.rules_enabled = False | |
settings.web_context_retrieval_enabled = False | |
settings.supports_parallel_tool_calls = False | |
settings.planning_enabled = False | |
settings.supports_create_files = False | |
settings.supports_long_running_commands = False | |
settings.supports_todos_ui = False | |
settings.supports_linked_code_blocks = False | |
settings.use_anthropic_text_editor_tools = False | |
settings.warp_drive_context_enabled = False | |
settings.should_preserve_file_content_in_history = True | |
try: | |
tool_types = [] | |
settings.supported_tools[:] = tool_types | |
logger.debug(f"Set supported_tools (legacy): {tool_types}") | |
except Exception as e: | |
logger.debug(f"Could not set supported_tools: {e}") | |
logger.debug("Applied all valid Settings fields based on proto definition") | |
if hasattr(msg, 'metadata'): | |
metadata = msg.metadata | |
metadata.conversation_id = f"rest-api-{uuid.uuid4().hex[:8]}" | |
rootd = msg.DESCRIPTOR | |
for fn, val in ( | |
("client_version", CLIENT_VERSION), | |
("version", CLIENT_VERSION), | |
("os_name", OS_NAME), | |
("os_category", OS_CATEGORY), | |
("os_version", OS_VERSION), | |
): | |
f = rootd.fields_by_name.get(fn) | |
if f and f.type == FD.TYPE_STRING and f.label == FD.LABEL_OPTIONAL: | |
setattr(msg, fn, val) | |
return msg.SerializeToString() |