Spaces:
Running
Running
from typing import Dict, Any | |
from .openai_api import OpenAIAPI | |
from .anthropic_api import AnthropicAPI | |
from .grok_api import GrokAPI | |
from .base_api import BaseAPI | |
class APIFactory: | |
"""Factory class to create API instances based on model name""" | |
# Model to provider mapping | |
MODEL_PROVIDERS = { | |
# OpenAI models | |
'gpt-4o': 'openai', | |
'gpt-4-turbo': 'openai', | |
'gpt-3.5-turbo': 'openai', | |
# Anthropic models | |
'claude-3-5-sonnet-20241022': 'anthropic', | |
'claude-3-opus-20240229': 'anthropic', | |
'claude-3-haiku-20240307': 'anthropic', | |
# Grok models | |
'grok-4-0709': 'grok', | |
'grok-beta': 'grok', | |
'grok-2-latest': 'grok', | |
'grok-vision-beta': 'grok', | |
} | |
# Provider to API class mapping | |
PROVIDER_APIS = { | |
'openai': OpenAIAPI, | |
'anthropic': AnthropicAPI, | |
'grok': GrokAPI, | |
} | |
def create_api(cls, model_name: str, config: Dict[str, Any]) -> BaseAPI: | |
"""Create an API instance for the given model""" | |
# Determine provider | |
provider = cls.MODEL_PROVIDERS.get(model_name) | |
if not provider: | |
raise ValueError(f"Unknown model: {model_name}") | |
# Get provider config | |
provider_config = config['models'].get(provider) | |
if not provider_config: | |
raise ValueError(f"No configuration found for provider: {provider}") | |
# Get API key | |
api_key = provider_config.get('api_key') | |
if not api_key: | |
raise ValueError(f"No API key found for provider: {provider}") | |
# Get API class | |
api_class = cls.PROVIDER_APIS.get(provider) | |
if not api_class: | |
raise ValueError(f"No API implementation for provider: {provider}") | |
# Create API instance with provider-specific kwargs | |
kwargs = { | |
'rate_limit_delay': config['evaluation'].get('rate_limit_delay', 1.0), | |
'max_retries': config['evaluation'].get('max_retries', 3), | |
'timeout': config['evaluation'].get('timeout', 30), | |
} | |
# Add provider-specific config | |
if provider == 'grok': | |
kwargs['base_url'] = provider_config.get('base_url', 'https://api.x.ai/v1') | |
return api_class(api_key, model_name, **kwargs) |