Spaces:
Running
Running
import aiohttp | |
import asyncio | |
import json | |
import re | |
from typing import Dict, Any, List, Union, Optional, AsyncGenerator | |
import time | |
# Global cache for project IDs: {api_key: project_id} | |
PROJECT_ID_CACHE: Dict[str, str] = {} | |
class DirectVertexClient: | |
""" | |
A client that connects to Vertex AI using direct URLs instead of the SDK. | |
Mimics the interface of genai.Client for seamless integration. | |
""" | |
def __init__(self, api_key: str): | |
self.api_key = api_key | |
self.project_id: Optional[str] = None | |
self.base_url = "https://aiplatform.googleapis.com/v1" | |
self.session: Optional[aiohttp.ClientSession] = None | |
# Mimic the model_name attribute that might be accessed | |
self.model_name = "direct_vertex_client" | |
# Create nested structure to mimic genai.Client interface | |
self.aio = self._AioNamespace(self) | |
class _AioNamespace: | |
def __init__(self, parent): | |
self.parent = parent | |
self.models = self._ModelsNamespace(parent) | |
class _ModelsNamespace: | |
def __init__(self, parent): | |
self.parent = parent | |
async def generate_content(self, model: str, contents: Any, config: Dict[str, Any]) -> Any: | |
"""Non-streaming content generation""" | |
return await self.parent._generate_content(model, contents, config, stream=False) | |
async def generate_content_stream(self, model: str, contents: Any, config: Dict[str, Any]): | |
"""Streaming content generation - returns an async generator""" | |
# This needs to be an async method that returns the generator | |
# to match the SDK's interface where you await the method call | |
return self.parent._generate_content_stream(model, contents, config) | |
async def _ensure_session(self): | |
"""Ensure aiohttp session is created""" | |
if self.session is None: | |
self.session = aiohttp.ClientSession() | |
async def close(self): | |
"""Clean up resources""" | |
if self.session: | |
await self.session.close() | |
self.session = None | |
async def discover_project_id(self) -> None: | |
"""Discover project ID by triggering an intentional error""" | |
# Check cache first | |
if self.api_key in PROJECT_ID_CACHE: | |
self.project_id = PROJECT_ID_CACHE[self.api_key] | |
print(f"INFO: Using cached project ID: {self.project_id}") | |
return | |
await self._ensure_session() | |
# Use a non-existent model to trigger error | |
error_url = f"{self.base_url}/publishers/google/models/gemini-2.7-pro-preview-05-06:streamGenerateContent?key={self.api_key}" | |
try: | |
# Send minimal request to trigger error | |
payload = { | |
"contents": [{"role": "user", "parts": [{"text": "test"}]}] | |
} | |
async with self.session.post(error_url, json=payload) as response: | |
response_text = await response.text() | |
try: | |
# Try to parse as JSON first | |
error_data = json.loads(response_text) | |
# Handle array response format | |
if isinstance(error_data, list) and len(error_data) > 0: | |
error_data = error_data[0] | |
if "error" in error_data: | |
error_message = error_data["error"].get("message", "") | |
# Extract project ID from error message | |
# Pattern: "projects/39982734461/locations/..." | |
match = re.search(r'projects/(\d+)/locations/', error_message) | |
if match: | |
self.project_id = match.group(1) | |
PROJECT_ID_CACHE[self.api_key] = self.project_id | |
print(f"INFO: Discovered project ID: {self.project_id}") | |
return | |
except json.JSONDecodeError: | |
# If not JSON, try to find project ID in raw text | |
match = re.search(r'projects/(\d+)/locations/', response_text) | |
if match: | |
self.project_id = match.group(1) | |
PROJECT_ID_CACHE[self.api_key] = self.project_id | |
print(f"INFO: Discovered project ID from raw response: {self.project_id}") | |
return | |
raise Exception(f"Failed to discover project ID. Status: {response.status}, Response: {response_text[:500]}") | |
except Exception as e: | |
print(f"ERROR: Failed to discover project ID: {e}") | |
raise | |
def _convert_contents(self, contents: Any) -> List[Dict[str, Any]]: | |
"""Convert SDK Content objects to REST API format""" | |
if isinstance(contents, list): | |
return [self._convert_content_item(item) for item in contents] | |
else: | |
return [self._convert_content_item(contents)] | |
def _convert_content_item(self, content: Any) -> Dict[str, Any]: | |
"""Convert a single content item to REST API format""" | |
if isinstance(content, dict): | |
return content | |
# Handle SDK Content objects | |
result = {} | |
if hasattr(content, 'role'): | |
result['role'] = content.role | |
if hasattr(content, 'parts'): | |
result['parts'] = [] | |
for part in content.parts: | |
if isinstance(part, dict): | |
result['parts'].append(part) | |
elif hasattr(part, 'text'): | |
result['parts'].append({'text': part.text}) | |
elif hasattr(part, 'inline_data'): | |
result['parts'].append({ | |
'inline_data': { | |
'mime_type': part.inline_data.mime_type, | |
'data': part.inline_data.data | |
} | |
}) | |
return result | |
def _convert_safety_settings(self, safety_settings: Any) -> List[Dict[str, str]]: | |
"""Convert SDK SafetySetting objects to REST API format""" | |
if not safety_settings: | |
return [] | |
result = [] | |
for setting in safety_settings: | |
if isinstance(setting, dict): | |
result.append(setting) | |
elif hasattr(setting, 'category') and hasattr(setting, 'threshold'): | |
# Convert SDK SafetySetting to dict | |
result.append({ | |
'category': setting.category, | |
'threshold': setting.threshold | |
}) | |
return result | |
def _convert_tools(self, tools: Any) -> List[Dict[str, Any]]: | |
"""Convert SDK Tool objects to REST API format""" | |
if not tools: | |
return [] | |
result = [] | |
for tool in tools: | |
if isinstance(tool, dict): | |
result.append(tool) | |
else: | |
# Convert SDK Tool object to dict | |
result.append(self._convert_tool_item(tool)) | |
return result | |
def _convert_tool_item(self, tool: Any) -> Dict[str, Any]: | |
"""Convert a single tool item to REST API format""" | |
if isinstance(tool, dict): | |
return tool | |
tool_dict = {} | |
# Convert all non-private attributes | |
if hasattr(tool, '__dict__'): | |
for attr_name, attr_value in tool.__dict__.items(): | |
if not attr_name.startswith('_'): | |
# Convert attribute names from snake_case to camelCase for REST API | |
rest_api_name = self._to_camel_case(attr_name) | |
# Special handling for known types | |
if attr_name == 'google_search' and attr_value is not None: | |
tool_dict[rest_api_name] = {} # GoogleSearch is empty object in REST | |
elif attr_name == 'function_declarations' and attr_value is not None: | |
tool_dict[rest_api_name] = attr_value | |
elif attr_value is not None: | |
# Recursively convert any other SDK objects | |
tool_dict[rest_api_name] = self._convert_sdk_object(attr_value) | |
return tool_dict | |
def _to_camel_case(self, snake_str: str) -> str: | |
"""Convert snake_case to camelCase""" | |
components = snake_str.split('_') | |
return components[0] + ''.join(x.title() for x in components[1:]) | |
def _convert_sdk_object(self, obj: Any) -> Any: | |
"""Generic SDK object converter""" | |
if isinstance(obj, (str, int, float, bool, type(None))): | |
return obj | |
elif isinstance(obj, dict): | |
return {k: self._convert_sdk_object(v) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [self._convert_sdk_object(item) for item in obj] | |
elif hasattr(obj, '__dict__'): | |
# Convert SDK object to dict | |
result = {} | |
for key, value in obj.__dict__.items(): | |
if not key.startswith('_'): | |
result[self._to_camel_case(key)] = self._convert_sdk_object(value) | |
return result | |
else: | |
return obj | |
async def _generate_content(self, model: str, contents: Any, config: Dict[str, Any], stream: bool = False) -> Any: | |
"""Internal method for content generation""" | |
if not self.project_id: | |
raise ValueError("Project ID not discovered. Call discover_project_id() first.") | |
await self._ensure_session() | |
# Build URL | |
endpoint = "streamGenerateContent" if stream else "generateContent" | |
url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:{endpoint}?key={self.api_key}" | |
# Convert contents to REST API format | |
payload = { | |
"contents": self._convert_contents(contents) | |
} | |
# Extract specific config sections | |
if "system_instruction" in config: | |
# System instruction should be a content object | |
if isinstance(config["system_instruction"], dict): | |
payload["systemInstruction"] = config["system_instruction"] | |
else: | |
payload["systemInstruction"] = self._convert_content_item(config["system_instruction"]) | |
if "safety_settings" in config: | |
payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"]) | |
if "tools" in config: | |
payload["tools"] = self._convert_tools(config["tools"]) | |
# All other config goes under generationConfig | |
generation_config = {} | |
for key, value in config.items(): | |
if key not in ["system_instruction", "safety_settings", "tools"]: | |
generation_config[key] = value | |
if generation_config: | |
payload["generationConfig"] = generation_config | |
try: | |
async with self.session.post(url, json=payload) as response: | |
if response.status != 200: | |
error_data = await response.json() | |
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status}") | |
raise Exception(f"Vertex AI API error: {error_msg}") | |
# Get the JSON response | |
response_data = await response.json() | |
# Convert dict to object with attributes for compatibility | |
return self._dict_to_obj(response_data) | |
except Exception as e: | |
print(f"ERROR: Direct Vertex API call failed: {e}") | |
raise | |
def _dict_to_obj(self, data): | |
"""Convert a dict to an object with attributes""" | |
if isinstance(data, dict): | |
# Create a simple object that allows attribute access | |
class AttrDict: | |
def __init__(self, d): | |
for key, value in d.items(): | |
setattr(self, key, self._convert_value(value)) | |
def _convert_value(self, value): | |
if isinstance(value, dict): | |
return AttrDict(value) | |
elif isinstance(value, list): | |
return [self._convert_value(item) for item in value] | |
else: | |
return value | |
return AttrDict(data) | |
elif isinstance(data, list): | |
return [self._dict_to_obj(item) for item in data] | |
else: | |
return data | |
async def _generate_content_stream(self, model: str, contents: Any, config: Dict[str, Any]) -> AsyncGenerator: | |
"""Internal method for streaming content generation""" | |
if not self.project_id: | |
raise ValueError("Project ID not discovered. Call discover_project_id() first.") | |
await self._ensure_session() | |
# Build URL for streaming | |
url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:streamGenerateContent?key={self.api_key}" | |
# Convert contents to REST API format | |
payload = { | |
"contents": self._convert_contents(contents) | |
} | |
# Extract specific config sections | |
if "system_instruction" in config: | |
# System instruction should be a content object | |
if isinstance(config["system_instruction"], dict): | |
payload["systemInstruction"] = config["system_instruction"] | |
else: | |
payload["systemInstruction"] = self._convert_content_item(config["system_instruction"]) | |
if "safety_settings" in config: | |
payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"]) | |
if "tools" in config: | |
payload["tools"] = self._convert_tools(config["tools"]) | |
# All other config goes under generationConfig | |
generation_config = {} | |
for key, value in config.items(): | |
if key not in ["system_instruction", "safety_settings", "tools"]: | |
generation_config[key] = value | |
if generation_config: | |
payload["generationConfig"] = generation_config | |
try: | |
async with self.session.post(url, json=payload) as response: | |
if response.status != 200: | |
error_data = await response.json() | |
# Handle array response format | |
if isinstance(error_data, list) and len(error_data) > 0: | |
error_data = error_data[0] | |
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status}") if isinstance(error_data, dict) else str(error_data) | |
raise Exception(f"Vertex AI API error: {error_msg}") | |
# The Vertex AI streaming endpoint returns JSON array elements | |
# We need to parse these as they arrive | |
buffer = "" | |
async for chunk in response.content.iter_any(): | |
decoded_chunk = chunk.decode('utf-8') | |
buffer += decoded_chunk | |
# Try to extract complete JSON objects from the buffer | |
while True: | |
# Skip whitespace and array brackets | |
buffer = buffer.lstrip() | |
if buffer.startswith('['): | |
buffer = buffer[1:].lstrip() | |
continue | |
if buffer.startswith(']'): | |
# End of array | |
return | |
# Skip comma and whitespace between objects | |
if buffer.startswith(','): | |
buffer = buffer[1:].lstrip() | |
continue | |
# Look for a complete JSON object | |
if buffer.startswith('{'): | |
# Find the matching closing brace | |
brace_count = 0 | |
in_string = False | |
escape_next = False | |
for i, char in enumerate(buffer): | |
if escape_next: | |
escape_next = False | |
continue | |
if char == '\\' and in_string: | |
escape_next = True | |
continue | |
if char == '"' and not in_string: | |
in_string = True | |
elif char == '"' and in_string: | |
in_string = False | |
elif char == '{' and not in_string: | |
brace_count += 1 | |
elif char == '}' and not in_string: | |
brace_count -= 1 | |
if brace_count == 0: | |
# Found complete object | |
obj_str = buffer[:i+1] | |
buffer = buffer[i+1:] | |
try: | |
chunk_data = json.loads(obj_str) | |
converted_obj = self._dict_to_obj(chunk_data) | |
yield converted_obj | |
except json.JSONDecodeError as e: | |
print(f"ERROR: DirectVertexClient - Failed to parse JSON: {e}") | |
break | |
else: | |
# No complete object found, need more data | |
break | |
else: | |
# No more objects to process in current buffer | |
break | |
except Exception as e: | |
print(f"ERROR: Direct Vertex streaming API call failed: {e}") | |
raise |