File size: 3,689 Bytes
5b5f50c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45df059
 
5b5f50c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45df059
5b5f50c
 
 
 
45df059
 
 
 
 
 
 
 
 
 
 
 
5b5f50c
 
45df059
 
 
5b5f50c
45df059
5b5f50c
 
 
 
 
 
 
45df059
 
5b5f50c
45df059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import time
import logging
from abc import ABC, abstractmethod
from typing import List, Dict, Optional, Union
from utils.config import config

logger = logging.getLogger(__name__)

class LLMProvider(ABC):
    """Abstract base class for all LLM providers"""
    
    def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3):
        self.model_name = model_name
        self.timeout = timeout
        self.max_retries = max_retries
        self.is_available = True
        self.failure_count = 0
        self.last_failure_time = None
    
    @abstractmethod
    def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
        """Generate a response synchronously"""
        pass
    
    @abstractmethod
    def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
        """Generate a response with streaming support"""
        pass
    
    @abstractmethod
    def validate_model(self) -> bool:
        """Validate if the model is available"""
        pass
    
    def _retry_with_backoff(self, func, *args, **kwargs):
        """Retry logic with exponential backoff and circuit breaker"""
        last_exception = None
        
        for attempt in range(self.max_retries):
            try:
                # Simple circuit breaker - fail fast if too many recent failures
                if self.failure_count > 5 and self.last_failure_time:
                    time_since_failure = time.time() - self.last_failure_time
                    if time_since_failure < 60:  # Wait 1 minute after 5 failures
                        raise Exception("Circuit breaker tripped - too many recent failures")
                
                result = func(*args, **kwargs)
                # Reset failure count on success
                self.failure_count = 0
                self.last_failure_time = None
                return result
                
            except Exception as e:
                last_exception = e
                self.failure_count += 1
                self.last_failure_time = time.time()
                
                if attempt < self.max_retries - 1:  # Don't sleep on last attempt
                    sleep_time = min((2 ** attempt) * 1.0, 10.0)  # Cap at 10 seconds
                    logger.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {sleep_time}s...")
                    time.sleep(sleep_time)
                else:
                    logger.error(f"All {self.max_retries} attempts failed. Last error: {str(e)}")
        
        raise last_exception
    
    def _classify_error(self, error: Exception) -> str:
        """Classify error type for better handling"""
        error_str = str(error).lower()
        
        # Network errors
        if any(term in error_str for term in ['connection', 'timeout', 'resolve', 'unreachable']):
            return 'network'
        
        # Authentication errors
        if any(term in error_str for term in ['auth', 'unauthorized', 'invalid token', '401', '403']):
            return 'authentication'
        
        # Rate limiting
        if any(term in error_str for term in ['rate limit', 'too many requests', 'quota exceeded', '429']):
            return 'rate_limit'
        
        # Server errors
        if any(term in error_str for term in ['500', '502', '503', 'server error']):
            return 'server'
        
        return 'other'
    
    def _is_recoverable_error(self, error: Exception) -> bool:
        """Determine if error is likely recoverable"""
        error_type = self._classify_error(error)
        return error_type in ['network', 'rate_limit', 'server']