File size: 3,228 Bytes
83ce746 adf8222 83ce746 adf8222 83ce746 adf8222 83ce746 adf8222 83ce746 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import time
import logging
from abc import ABC, abstractmethod
from typing import List, Dict, Optional, Union
logger = logging.getLogger(__name__)
class LLMProvider(ABC):
"""Abstract base class for all LLM providers with circuit breaker"""
def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3):
self.model_name = model_name
self.timeout = timeout
self.max_retries = max_retries
# Circuit breaker properties
self.failure_count = 0
self.last_failure_time = None
self.circuit_open = False
self.reset_timeout = 60 # Reset circuit after 60 seconds
@abstractmethod
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
"""Generate a response synchronously"""
pass
@abstractmethod
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
"""Generate a response with streaming support"""
pass
def _check_circuit_breaker(self) -> bool:
"""Check if circuit breaker is open (preventing calls)"""
if not self.circuit_open:
return True
# Check if enough time has passed to reset
if self.last_failure_time and (time.time() - self.last_failure_time) > self.reset_timeout:
logger.info("Circuit breaker reset - allowing call")
self.circuit_open = False
self.failure_count = 0
return True
logger.warning("Circuit breaker is OPEN - preventing call")
return False
def _handle_failure(self, error: Exception):
"""Handle failure and update circuit breaker"""
self.failure_count += 1
self.last_failure_time = time.time()
# Open circuit after 3 failures
if self.failure_count >= 3:
self.circuit_open = True
logger.warning(f"Circuit breaker OPEN for {self.__class__.__name__} after {self.failure_count} failures")
raise error
def _retry_with_backoff(self, func, *args, **kwargs):
"""Retry logic with exponential backoff"""
last_exception = None
for attempt in range(self.max_retries):
try:
if not self._check_circuit_breaker():
raise Exception("Circuit breaker is open")
result = func(*args, **kwargs)
# Reset failure count on success
self.failure_count = 0
self.circuit_open = False
return result
except Exception as e:
last_exception = e
self._handle_failure(e)
if attempt < self.max_retries - 1:
sleep_time = min((2 ** attempt) * 1.0, 10.0) # Cap at 10 seconds
logger.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {sleep_time}s...")
time.sleep(sleep_time)
else:
logger.error(f"All {self.max_retries} attempts failed. Last error: {str(e)}")
raise last_exception
|