File size: 2,398 Bytes
8474f02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import Dict, Any
from .openai_api import OpenAIAPI
from .anthropic_api import AnthropicAPI
from .grok_api import GrokAPI
from .base_api import BaseAPI

class APIFactory:
    """Factory class to create API instances based on model name"""
    
    # Model to provider mapping
    MODEL_PROVIDERS = {
        # OpenAI models
        'gpt-4o': 'openai',
        'gpt-4-turbo': 'openai',
        'gpt-3.5-turbo': 'openai',
        
        # Anthropic models
        'claude-3-5-sonnet-20241022': 'anthropic',
        'claude-3-opus-20240229': 'anthropic',
        'claude-3-haiku-20240307': 'anthropic',
        
        # Grok models
        'grok-4-0709': 'grok',
        'grok-beta': 'grok',
        'grok-2-latest': 'grok',
        'grok-vision-beta': 'grok',
    }
    
    # Provider to API class mapping
    PROVIDER_APIS = {
        'openai': OpenAIAPI,
        'anthropic': AnthropicAPI,
        'grok': GrokAPI,
    }
    
    @classmethod
    def create_api(cls, model_name: str, config: Dict[str, Any]) -> BaseAPI:
        """Create an API instance for the given model"""
        
        # Determine provider
        provider = cls.MODEL_PROVIDERS.get(model_name)
        if not provider:
            raise ValueError(f"Unknown model: {model_name}")
        
        # Get provider config
        provider_config = config['models'].get(provider)
        if not provider_config:
            raise ValueError(f"No configuration found for provider: {provider}")
        
        # Get API key
        api_key = provider_config.get('api_key')
        if not api_key:
            raise ValueError(f"No API key found for provider: {provider}")
        
        # Get API class
        api_class = cls.PROVIDER_APIS.get(provider)
        if not api_class:
            raise ValueError(f"No API implementation for provider: {provider}")
        
        # Create API instance with provider-specific kwargs
        kwargs = {
            'rate_limit_delay': config['evaluation'].get('rate_limit_delay', 1.0),
            'max_retries': config['evaluation'].get('max_retries', 3),
            'timeout': config['evaluation'].get('timeout', 30),
        }
        
        # Add provider-specific config
        if provider == 'grok':
            kwargs['base_url'] = provider_config.get('base_url', 'https://api.x.ai/v1')
        
        return api_class(api_key, model_name, **kwargs)