Spaces:
Running
Running
import openai | |
from .base_api import BaseAPI | |
class GrokAPI(BaseAPI): | |
"""Grok API implementation (uses OpenAI-compatible interface)""" | |
def __init__(self, api_key: str, model_name: str, **kwargs): | |
super().__init__(api_key, model_name, **kwargs) | |
self.base_url = kwargs.get('base_url', 'https://api.x.ai/v1') | |
self.client = openai.AsyncOpenAI( | |
api_key=api_key, | |
base_url=self.base_url | |
) | |
async def generate_response(self, prompt: str, **kwargs) -> str: | |
"""Generate response using Grok API""" | |
try: | |
# Build parameters | |
params = { | |
"model": self.model_name, | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": kwargs.get('temperature', 0.0), | |
"timeout": self.timeout | |
} | |
# For grok-4-0709, don't set max_tokens to allow full reasoning | |
if self.model_name != 'grok-4-0709': | |
params['max_tokens'] = kwargs.get('max_tokens', 2048) | |
response = await self.client.chat.completions.create(**params) | |
return response.choices[0].message.content | |
except Exception as e: | |
raise Exception(f"Grok API error: {str(e)}") | |
def get_model_info(self) -> dict: | |
"""Get model information""" | |
return { | |
"provider": "Grok", | |
"model": self.model_name, | |
"api_version": "v1", | |
"base_url": self.base_url | |
} |