File size: 5,524 Bytes
5b5f50c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import logging
from typing import Optional, List
from core.providers.base import LLMProvider
from core.providers.ollama import OllamaProvider
from core.providers.huggingface import HuggingFaceProvider
from core.providers.openai import OpenAIProvider
from utils.config import config

logger = logging.getLogger(__name__)

class ProviderNotAvailableError(Exception):
    """Raised when no provider is available"""
    pass

class LLMFactory:
    """Factory for creating LLM providers with fallback support"""
    
    _instance = None
    _providers = {}
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(LLMFactory, cls).__new__(cls)
            cls._instance._initialized = False
        return cls._instance
    
    def __init__(self):
        if self._initialized:
            return
            
        self._initialized = True
        self._provider_chain = []
        self._circuit_breakers = {}
        self._initialize_providers()
    
    def _initialize_providers(self):
        """Initialize all available providers based on configuration"""
        # Define provider priority order
        provider_configs = [
            {
                'name': 'ollama',
                'class': OllamaProvider,
                'enabled': bool(config.ollama_host),
                'model': config.local_model_name
            },
            {
                'name': 'huggingface',
                'class': HuggingFaceProvider,
                'enabled': bool(config.hf_token),
                'model': "meta-llama/Llama-2-7b-chat-hf"  # Default HF model
            },
            {
                'name': 'openai',
                'class': OpenAIProvider,
                'enabled': bool(config.openai_api_key),
                'model': "gpt-3.5-turbo"  # Default OpenAI model
            }
        ]
        
        # Initialize providers in priority order
        for provider_config in provider_configs:
            if provider_config['enabled']:
                try:
                    provider = provider_config['class'](
                        model_name=provider_config['model']
                    )
                    self._providers[provider_config['name']] = provider
                    self._provider_chain.append(provider_config['name'])
                    self._circuit_breakers[provider_config['name']] = {
                        'failures': 0,
                        'last_failure': None,
                        'tripped': False
                    }
                    logger.info(f"Initialized {provider_config['name']} provider")
                except Exception as e:
                    logger.warning(f"Failed to initialize {provider_config['name']} provider: {e}")
    
    def get_provider(self, preferred_provider: Optional[str] = None) -> LLMProvider:
        """
        Get an LLM provider based on preference and availability
        
        Args:
            preferred_provider: Preferred provider name (ollama, huggingface, openai)
            
        Returns:
            LLMProvider instance
            
        Raises:
            ProviderNotAvailableError: When no providers are available
        """
        # Check preferred provider first
        if preferred_provider and preferred_provider in self._providers:
            provider = self._providers[preferred_provider]
            if self._is_provider_available(preferred_provider) and provider.validate_model():
                logger.info(f"Using preferred provider: {preferred_provider}")
                return provider
        
        # Fallback through provider chain
        for provider_name in self._provider_chain:
            if self._is_provider_available(provider_name):
                provider = self._providers[provider_name]
                try:
                    if provider.validate_model():
                        logger.info(f"Using fallback provider: {provider_name}")
                        return provider
                except Exception as e:
                    logger.warning(f"Provider {provider_name} model validation failed: {e}")
                    self._record_provider_failure(provider_name)
        
        raise ProviderNotAvailableError("No LLM providers are available")
    
    def get_all_providers(self) -> List[LLMProvider]:
        """Get all initialized providers"""
        return list(self._providers.values())
    
    def _is_provider_available(self, provider_name: str) -> bool:
        """Check if a provider is available (not tripped by circuit breaker)"""
        if provider_name not in self._circuit_breakers:
            return False
            
        breaker = self._circuit_breakers[provider_name]
        if not breaker['tripped']:
            return True
            
        # Check if enough time has passed to reset the circuit breaker
        # In a real implementation, you might want more sophisticated logic here
        return False
    
    def _record_provider_failure(self, provider_name: str):
        """Record a provider failure for circuit breaker logic"""
        if provider_name in self._circuit_breakers:
            breaker = self._circuit_breakers[provider_name]
            breaker['failures'] += 1
            # Trip the circuit breaker after 3 failures
            if breaker['failures'] >= 3:
                breaker['tripped'] = True
                logger.warning(f"Circuit breaker tripped for provider: {provider_name}")

# Global factory instance
llm_factory = LLMFactory()