import time import logging from abc import ABC, abstractmethod from typing import List, Dict, Optional, Union from utils.config import config logger = logging.getLogger(__name__) class LLMProvider(ABC): """Abstract base class for all LLM providers""" def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3): self.model_name = model_name self.timeout = timeout self.max_retries = max_retries self.is_available = True self.failure_count = 0 self.last_failure_time = None @abstractmethod def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]: """Generate a response synchronously""" pass @abstractmethod def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]: """Generate a response with streaming support""" pass @abstractmethod def validate_model(self) -> bool: """Validate if the model is available""" pass def _retry_with_backoff(self, func, *args, **kwargs): """Retry logic with exponential backoff and circuit breaker""" last_exception = None for attempt in range(self.max_retries): try: # Simple circuit breaker - fail fast if too many recent failures if self.failure_count > 5 and self.last_failure_time: time_since_failure = time.time() - self.last_failure_time if time_since_failure < 60: # Wait 1 minute after 5 failures raise Exception("Circuit breaker tripped - too many recent failures") result = func(*args, **kwargs) # Reset failure count on success self.failure_count = 0 self.last_failure_time = None return result except Exception as e: last_exception = e self.failure_count += 1 self.last_failure_time = time.time() if attempt < self.max_retries - 1: # Don't sleep on last attempt sleep_time = min((2 ** attempt) * 1.0, 10.0) # Cap at 10 seconds logger.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {sleep_time}s...") time.sleep(sleep_time) else: logger.error(f"All {self.max_retries} attempts failed. Last error: {str(e)}") raise last_exception def _classify_error(self, error: Exception) -> str: """Classify error type for better handling""" error_str = str(error).lower() # Network errors if any(term in error_str for term in ['connection', 'timeout', 'resolve', 'unreachable']): return 'network' # Authentication errors if any(term in error_str for term in ['auth', 'unauthorized', 'invalid token', '401', '403']): return 'authentication' # Rate limiting if any(term in error_str for term in ['rate limit', 'too many requests', 'quota exceeded', '429']): return 'rate_limit' # Server errors if any(term in error_str for term in ['500', '502', '503', 'server error']): return 'server' return 'other' def _is_recoverable_error(self, error: Exception) -> bool: """Determine if error is likely recoverable""" error_type = self._classify_error(error) return error_type in ['network', 'rate_limit', 'server']