|
import time |
|
import logging |
|
from abc import ABC, abstractmethod |
|
from typing import List, Dict, Optional, Union |
|
from utils.config import config |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class LLMProvider(ABC): |
|
"""Abstract base class for all LLM providers""" |
|
|
|
def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3): |
|
self.model_name = model_name |
|
self.timeout = timeout |
|
self.max_retries = max_retries |
|
self.is_available = True |
|
self.failure_count = 0 |
|
self.last_failure_time = None |
|
|
|
@abstractmethod |
|
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]: |
|
"""Generate a response synchronously""" |
|
pass |
|
|
|
@abstractmethod |
|
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]: |
|
"""Generate a response with streaming support""" |
|
pass |
|
|
|
@abstractmethod |
|
def validate_model(self) -> bool: |
|
"""Validate if the model is available""" |
|
pass |
|
|
|
def _retry_with_backoff(self, func, *args, **kwargs): |
|
"""Retry logic with exponential backoff and circuit breaker""" |
|
last_exception = None |
|
|
|
for attempt in range(self.max_retries): |
|
try: |
|
|
|
if self.failure_count > 5 and self.last_failure_time: |
|
time_since_failure = time.time() - self.last_failure_time |
|
if time_since_failure < 60: |
|
raise Exception("Circuit breaker tripped - too many recent failures") |
|
|
|
result = func(*args, **kwargs) |
|
|
|
self.failure_count = 0 |
|
self.last_failure_time = None |
|
return result |
|
|
|
except Exception as e: |
|
last_exception = e |
|
self.failure_count += 1 |
|
self.last_failure_time = time.time() |
|
|
|
if attempt < self.max_retries - 1: |
|
sleep_time = min((2 ** attempt) * 1.0, 10.0) |
|
logger.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {sleep_time}s...") |
|
time.sleep(sleep_time) |
|
else: |
|
logger.error(f"All {self.max_retries} attempts failed. Last error: {str(e)}") |
|
|
|
raise last_exception |
|
|
|
def _classify_error(self, error: Exception) -> str: |
|
"""Classify error type for better handling""" |
|
error_str = str(error).lower() |
|
|
|
|
|
if any(term in error_str for term in ['connection', 'timeout', 'resolve', 'unreachable']): |
|
return 'network' |
|
|
|
|
|
if any(term in error_str for term in ['auth', 'unauthorized', 'invalid token', '401', '403']): |
|
return 'authentication' |
|
|
|
|
|
if any(term in error_str for term in ['rate limit', 'too many requests', 'quota exceeded', '429']): |
|
return 'rate_limit' |
|
|
|
|
|
if any(term in error_str for term in ['500', '502', '503', 'server error']): |
|
return 'server' |
|
|
|
return 'other' |
|
|
|
def _is_recoverable_error(self, error: Exception) -> bool: |
|
"""Determine if error is likely recoverable""" |
|
error_type = self._classify_error(error) |
|
return error_type in ['network', 'rate_limit', 'server'] |
|
|