File size: 5,524 Bytes
5b5f50c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import logging
from typing import Optional, List
from core.providers.base import LLMProvider
from core.providers.ollama import OllamaProvider
from core.providers.huggingface import HuggingFaceProvider
from core.providers.openai import OpenAIProvider
from utils.config import config
logger = logging.getLogger(__name__)
class ProviderNotAvailableError(Exception):
"""Raised when no provider is available"""
pass
class LLMFactory:
"""Factory for creating LLM providers with fallback support"""
_instance = None
_providers = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super(LLMFactory, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self._provider_chain = []
self._circuit_breakers = {}
self._initialize_providers()
def _initialize_providers(self):
"""Initialize all available providers based on configuration"""
# Define provider priority order
provider_configs = [
{
'name': 'ollama',
'class': OllamaProvider,
'enabled': bool(config.ollama_host),
'model': config.local_model_name
},
{
'name': 'huggingface',
'class': HuggingFaceProvider,
'enabled': bool(config.hf_token),
'model': "meta-llama/Llama-2-7b-chat-hf" # Default HF model
},
{
'name': 'openai',
'class': OpenAIProvider,
'enabled': bool(config.openai_api_key),
'model': "gpt-3.5-turbo" # Default OpenAI model
}
]
# Initialize providers in priority order
for provider_config in provider_configs:
if provider_config['enabled']:
try:
provider = provider_config['class'](
model_name=provider_config['model']
)
self._providers[provider_config['name']] = provider
self._provider_chain.append(provider_config['name'])
self._circuit_breakers[provider_config['name']] = {
'failures': 0,
'last_failure': None,
'tripped': False
}
logger.info(f"Initialized {provider_config['name']} provider")
except Exception as e:
logger.warning(f"Failed to initialize {provider_config['name']} provider: {e}")
def get_provider(self, preferred_provider: Optional[str] = None) -> LLMProvider:
"""
Get an LLM provider based on preference and availability
Args:
preferred_provider: Preferred provider name (ollama, huggingface, openai)
Returns:
LLMProvider instance
Raises:
ProviderNotAvailableError: When no providers are available
"""
# Check preferred provider first
if preferred_provider and preferred_provider in self._providers:
provider = self._providers[preferred_provider]
if self._is_provider_available(preferred_provider) and provider.validate_model():
logger.info(f"Using preferred provider: {preferred_provider}")
return provider
# Fallback through provider chain
for provider_name in self._provider_chain:
if self._is_provider_available(provider_name):
provider = self._providers[provider_name]
try:
if provider.validate_model():
logger.info(f"Using fallback provider: {provider_name}")
return provider
except Exception as e:
logger.warning(f"Provider {provider_name} model validation failed: {e}")
self._record_provider_failure(provider_name)
raise ProviderNotAvailableError("No LLM providers are available")
def get_all_providers(self) -> List[LLMProvider]:
"""Get all initialized providers"""
return list(self._providers.values())
def _is_provider_available(self, provider_name: str) -> bool:
"""Check if a provider is available (not tripped by circuit breaker)"""
if provider_name not in self._circuit_breakers:
return False
breaker = self._circuit_breakers[provider_name]
if not breaker['tripped']:
return True
# Check if enough time has passed to reset the circuit breaker
# In a real implementation, you might want more sophisticated logic here
return False
def _record_provider_failure(self, provider_name: str):
"""Record a provider failure for circuit breaker logic"""
if provider_name in self._circuit_breakers:
breaker = self._circuit_breakers[provider_name]
breaker['failures'] += 1
# Trip the circuit breaker after 3 failures
if breaker['failures'] >= 3:
breaker['tripped'] = True
logger.warning(f"Circuit breaker tripped for provider: {provider_name}")
# Global factory instance
llm_factory = LLMFactory()
|