File size: 3,032 Bytes
5b5f50c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import time
import logging
from typing import List, Dict, Optional, Union
from core.providers.base import LLMProvider
from utils.config import config

logger = logging.getLogger(__name__)

try:
    from openai import OpenAI
    OPENAI_SDK_AVAILABLE = True
except ImportError:
    OPENAI_SDK_AVAILABLE = False
    OpenAI = None

class OpenAIProvider(LLMProvider):
    """OpenAI LLM provider implementation"""
    
    def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3):
        super().__init__(model_name, timeout, max_retries)
        
        if not OPENAI_SDK_AVAILABLE:
            raise ImportError("OpenAI provider requires 'openai' package")
            
        if not config.openai_api_key:
            raise ValueError("OPENAI_API_KEY not set - required for OpenAI provider")
            
        self.client = OpenAI(api_key=config.openai_api_key)
    
    def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
        """Generate a response synchronously"""
        try:
            return self._retry_with_backoff(self._generate_impl, prompt, conversation_history)
        except Exception as e:
            logger.error(f"OpenAI generation failed: {e}")
            return None
    
    def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
        """Generate a response with streaming support"""
        try:
            return self._retry_with_backoff(self._stream_generate_impl, prompt, conversation_history)
        except Exception as e:
            logger.error(f"OpenAI stream generation failed: {e}")
            return None
    
    def validate_model(self) -> bool:
        """Validate if the model is available"""
        try:
            models = self.client.models.list()
            model_ids = [model.id for model in models.data]
            return self.model_name in model_ids
        except Exception as e:
            logger.warning(f"OpenAI model validation failed: {e}")
            return False
    
    def _generate_impl(self, prompt: str, conversation_history: List[Dict]) -> str:
        """Implementation of synchronous generation"""
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=conversation_history,
            max_tokens=500,
            temperature=0.7
        )
        return response.choices[0].message.content
    
    def _stream_generate_impl(self, prompt: str, conversation_history: List[Dict]) -> List[str]:
        """Implementation of streaming generation"""
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=conversation_history,
            max_tokens=500,
            temperature=0.7,
            stream=True
        )
        
        chunks = []
        for chunk in response:
            content = chunk.choices[0].delta.content
            if content:
                chunks.append(content)
                
        return chunks