|
import logging |
|
from typing import Optional, List |
|
from core.providers.base import LLMProvider |
|
from core.providers.ollama import OllamaProvider |
|
from core.providers.huggingface import HuggingFaceProvider |
|
from core.providers.openai import OpenAIProvider |
|
from utils.config import config |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class ProviderNotAvailableError(Exception): |
|
"""Raised when no provider is available""" |
|
pass |
|
|
|
class LLMFactory: |
|
"""Factory for creating LLM providers with fallback support""" |
|
|
|
_instance = None |
|
_providers = {} |
|
|
|
def __new__(cls): |
|
if cls._instance is None: |
|
cls._instance = super(LLMFactory, cls).__new__(cls) |
|
cls._instance._initialized = False |
|
return cls._instance |
|
|
|
def __init__(self): |
|
if self._initialized: |
|
return |
|
|
|
self._initialized = True |
|
self._provider_chain = [] |
|
self._circuit_breakers = {} |
|
self._initialize_providers() |
|
|
|
def _initialize_providers(self): |
|
"""Initialize all available providers based on configuration""" |
|
|
|
provider_configs = [ |
|
{ |
|
'name': 'ollama', |
|
'class': OllamaProvider, |
|
'enabled': bool(config.ollama_host), |
|
'model': config.local_model_name |
|
}, |
|
{ |
|
'name': 'huggingface', |
|
'class': HuggingFaceProvider, |
|
'enabled': bool(config.hf_token), |
|
'model': "meta-llama/Llama-2-7b-chat-hf" |
|
}, |
|
{ |
|
'name': 'openai', |
|
'class': OpenAIProvider, |
|
'enabled': bool(config.openai_api_key), |
|
'model': "gpt-3.5-turbo" |
|
} |
|
] |
|
|
|
|
|
for provider_config in provider_configs: |
|
if provider_config['enabled']: |
|
try: |
|
provider = provider_config['class']( |
|
model_name=provider_config['model'] |
|
) |
|
self._providers[provider_config['name']] = provider |
|
self._provider_chain.append(provider_config['name']) |
|
self._circuit_breakers[provider_config['name']] = { |
|
'failures': 0, |
|
'last_failure': None, |
|
'tripped': False |
|
} |
|
logger.info(f"Initialized {provider_config['name']} provider") |
|
except Exception as e: |
|
logger.warning(f"Failed to initialize {provider_config['name']} provider: {e}") |
|
|
|
def get_provider(self, preferred_provider: Optional[str] = None) -> LLMProvider: |
|
""" |
|
Get an LLM provider based on preference and availability |
|
|
|
Args: |
|
preferred_provider: Preferred provider name (ollama, huggingface, openai) |
|
|
|
Returns: |
|
LLMProvider instance |
|
|
|
Raises: |
|
ProviderNotAvailableError: When no providers are available |
|
""" |
|
|
|
if preferred_provider and preferred_provider in self._providers: |
|
provider = self._providers[preferred_provider] |
|
if self._is_provider_available(preferred_provider) and provider.validate_model(): |
|
logger.info(f"Using preferred provider: {preferred_provider}") |
|
return provider |
|
|
|
|
|
for provider_name in self._provider_chain: |
|
if self._is_provider_available(provider_name): |
|
provider = self._providers[provider_name] |
|
try: |
|
if provider.validate_model(): |
|
logger.info(f"Using fallback provider: {provider_name}") |
|
return provider |
|
except Exception as e: |
|
logger.warning(f"Provider {provider_name} model validation failed: {e}") |
|
self._record_provider_failure(provider_name) |
|
|
|
raise ProviderNotAvailableError("No LLM providers are available") |
|
|
|
def get_all_providers(self) -> List[LLMProvider]: |
|
"""Get all initialized providers""" |
|
return list(self._providers.values()) |
|
|
|
def _is_provider_available(self, provider_name: str) -> bool: |
|
"""Check if a provider is available (not tripped by circuit breaker)""" |
|
if provider_name not in self._circuit_breakers: |
|
return False |
|
|
|
breaker = self._circuit_breakers[provider_name] |
|
if not breaker['tripped']: |
|
return True |
|
|
|
|
|
|
|
return False |
|
|
|
def _record_provider_failure(self, provider_name: str): |
|
"""Record a provider failure for circuit breaker logic""" |
|
if provider_name in self._circuit_breakers: |
|
breaker = self._circuit_breakers[provider_name] |
|
breaker['failures'] += 1 |
|
|
|
if breaker['failures'] >= 3: |
|
breaker['tripped'] = True |
|
logger.warning(f"Circuit breaker tripped for provider: {provider_name}") |
|
|
|
|
|
llm_factory = LLMFactory() |
|
|