grok4-gpqa-eval / apis /api_factory.py
TeddyYao's picture
Upload 38 files
8474f02 verified
from typing import Dict, Any
from .openai_api import OpenAIAPI
from .anthropic_api import AnthropicAPI
from .grok_api import GrokAPI
from .base_api import BaseAPI
class APIFactory:
"""Factory class to create API instances based on model name"""
# Model to provider mapping
MODEL_PROVIDERS = {
# OpenAI models
'gpt-4o': 'openai',
'gpt-4-turbo': 'openai',
'gpt-3.5-turbo': 'openai',
# Anthropic models
'claude-3-5-sonnet-20241022': 'anthropic',
'claude-3-opus-20240229': 'anthropic',
'claude-3-haiku-20240307': 'anthropic',
# Grok models
'grok-4-0709': 'grok',
'grok-beta': 'grok',
'grok-2-latest': 'grok',
'grok-vision-beta': 'grok',
}
# Provider to API class mapping
PROVIDER_APIS = {
'openai': OpenAIAPI,
'anthropic': AnthropicAPI,
'grok': GrokAPI,
}
@classmethod
def create_api(cls, model_name: str, config: Dict[str, Any]) -> BaseAPI:
"""Create an API instance for the given model"""
# Determine provider
provider = cls.MODEL_PROVIDERS.get(model_name)
if not provider:
raise ValueError(f"Unknown model: {model_name}")
# Get provider config
provider_config = config['models'].get(provider)
if not provider_config:
raise ValueError(f"No configuration found for provider: {provider}")
# Get API key
api_key = provider_config.get('api_key')
if not api_key:
raise ValueError(f"No API key found for provider: {provider}")
# Get API class
api_class = cls.PROVIDER_APIS.get(provider)
if not api_class:
raise ValueError(f"No API implementation for provider: {provider}")
# Create API instance with provider-specific kwargs
kwargs = {
'rate_limit_delay': config['evaluation'].get('rate_limit_delay', 1.0),
'max_retries': config['evaluation'].get('max_retries', 3),
'timeout': config['evaluation'].get('timeout', 30),
}
# Add provider-specific config
if provider == 'grok':
kwargs['base_url'] = provider_config.get('base_url', 'https://api.x.ai/v1')
return api_class(api_key, model_name, **kwargs)