File size: 3,032 Bytes
5b5f50c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import time
import logging
from typing import List, Dict, Optional, Union
from core.providers.base import LLMProvider
from utils.config import config
logger = logging.getLogger(__name__)
try:
from openai import OpenAI
OPENAI_SDK_AVAILABLE = True
except ImportError:
OPENAI_SDK_AVAILABLE = False
OpenAI = None
class OpenAIProvider(LLMProvider):
"""OpenAI LLM provider implementation"""
def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3):
super().__init__(model_name, timeout, max_retries)
if not OPENAI_SDK_AVAILABLE:
raise ImportError("OpenAI provider requires 'openai' package")
if not config.openai_api_key:
raise ValueError("OPENAI_API_KEY not set - required for OpenAI provider")
self.client = OpenAI(api_key=config.openai_api_key)
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
"""Generate a response synchronously"""
try:
return self._retry_with_backoff(self._generate_impl, prompt, conversation_history)
except Exception as e:
logger.error(f"OpenAI generation failed: {e}")
return None
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
"""Generate a response with streaming support"""
try:
return self._retry_with_backoff(self._stream_generate_impl, prompt, conversation_history)
except Exception as e:
logger.error(f"OpenAI stream generation failed: {e}")
return None
def validate_model(self) -> bool:
"""Validate if the model is available"""
try:
models = self.client.models.list()
model_ids = [model.id for model in models.data]
return self.model_name in model_ids
except Exception as e:
logger.warning(f"OpenAI model validation failed: {e}")
return False
def _generate_impl(self, prompt: str, conversation_history: List[Dict]) -> str:
"""Implementation of synchronous generation"""
response = self.client.chat.completions.create(
model=self.model_name,
messages=conversation_history,
max_tokens=500,
temperature=0.7
)
return response.choices[0].message.content
def _stream_generate_impl(self, prompt: str, conversation_history: List[Dict]) -> List[str]:
"""Implementation of streaming generation"""
response = self.client.chat.completions.create(
model=self.model_name,
messages=conversation_history,
max_tokens=500,
temperature=0.7,
stream=True
)
chunks = []
for chunk in response:
content = chunk.choices[0].delta.content
if content:
chunks.append(content)
return chunks
|