File size: 3,689 Bytes
5b5f50c 45df059 5b5f50c 45df059 5b5f50c 45df059 5b5f50c 45df059 5b5f50c 45df059 5b5f50c 45df059 5b5f50c 45df059 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import time
import logging
from abc import ABC, abstractmethod
from typing import List, Dict, Optional, Union
from utils.config import config
logger = logging.getLogger(__name__)
class LLMProvider(ABC):
"""Abstract base class for all LLM providers"""
def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3):
self.model_name = model_name
self.timeout = timeout
self.max_retries = max_retries
self.is_available = True
self.failure_count = 0
self.last_failure_time = None
@abstractmethod
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
"""Generate a response synchronously"""
pass
@abstractmethod
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
"""Generate a response with streaming support"""
pass
@abstractmethod
def validate_model(self) -> bool:
"""Validate if the model is available"""
pass
def _retry_with_backoff(self, func, *args, **kwargs):
"""Retry logic with exponential backoff and circuit breaker"""
last_exception = None
for attempt in range(self.max_retries):
try:
# Simple circuit breaker - fail fast if too many recent failures
if self.failure_count > 5 and self.last_failure_time:
time_since_failure = time.time() - self.last_failure_time
if time_since_failure < 60: # Wait 1 minute after 5 failures
raise Exception("Circuit breaker tripped - too many recent failures")
result = func(*args, **kwargs)
# Reset failure count on success
self.failure_count = 0
self.last_failure_time = None
return result
except Exception as e:
last_exception = e
self.failure_count += 1
self.last_failure_time = time.time()
if attempt < self.max_retries - 1: # Don't sleep on last attempt
sleep_time = min((2 ** attempt) * 1.0, 10.0) # Cap at 10 seconds
logger.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {sleep_time}s...")
time.sleep(sleep_time)
else:
logger.error(f"All {self.max_retries} attempts failed. Last error: {str(e)}")
raise last_exception
def _classify_error(self, error: Exception) -> str:
"""Classify error type for better handling"""
error_str = str(error).lower()
# Network errors
if any(term in error_str for term in ['connection', 'timeout', 'resolve', 'unreachable']):
return 'network'
# Authentication errors
if any(term in error_str for term in ['auth', 'unauthorized', 'invalid token', '401', '403']):
return 'authentication'
# Rate limiting
if any(term in error_str for term in ['rate limit', 'too many requests', 'quota exceeded', '429']):
return 'rate_limit'
# Server errors
if any(term in error_str for term in ['500', '502', '503', 'server error']):
return 'server'
return 'other'
def _is_recoverable_error(self, error: Exception) -> bool:
"""Determine if error is likely recoverable"""
error_type = self._classify_error(error)
return error_type in ['network', 'rate_limit', 'server']
|