File size: 2,341 Bytes
8474f02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from abc import ABC, abstractmethod
import time
import asyncio
from typing import List, Dict, Any, Optional

class BaseAPI(ABC):
    """Base class for all API implementations"""
    
    def __init__(self, api_key: str, model_name: str, **kwargs):
        self.api_key = api_key
        self.model_name = model_name
        self.rate_limit_delay = kwargs.get('rate_limit_delay', 1.0)
        self.max_retries = kwargs.get('max_retries', 3)
        self.timeout = kwargs.get('timeout', 30)
        
    @abstractmethod
    async def generate_response(self, prompt: str, **kwargs) -> str:
        """Generate a response from the model"""
        pass
    
    async def generate_with_retry(self, prompt: str, **kwargs) -> str:
        """Generate response with retry logic"""
        for attempt in range(self.max_retries):
            try:
                response = await self.generate_response(prompt, **kwargs)
                return response
            except Exception as e:
                error_str = str(e).lower()
                
                # Check if it's a timeout error
                if 'timeout' in error_str or 'timed out' in error_str:
                    # For timeout errors, use longer backoff
                    max_retries = min(self.max_retries + 2, 5)  # Allow more retries for timeouts
                    if attempt < max_retries - 1:
                        backoff = min(60, 5 * (2 ** attempt))  # Max 60 seconds wait
                        print(f"Timeout error, retrying in {backoff}s... (attempt {attempt + 1}/{max_retries})")
                        await asyncio.sleep(backoff)
                        continue
                
                # For other errors, use standard backoff
                if attempt == self.max_retries - 1:
                    raise e
                    
                backoff = min(30, 2 ** attempt)  # Max 30 seconds for other errors
                await asyncio.sleep(backoff)
        
    async def batch_generate(self, prompts: List[str], **kwargs) -> List[str]:
        """Generate responses for multiple prompts"""
        responses = []
        for prompt in prompts:
            response = await self.generate_with_retry(prompt, **kwargs)
            responses.append(response)
            await asyncio.sleep(self.rate_limit_delay)
        return responses