|
import time |
|
import logging |
|
from typing import List, Dict, Optional, Union |
|
from core.providers.base import LLMProvider |
|
from utils.config import config |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
try: |
|
from openai import OpenAI |
|
OPENAI_SDK_AVAILABLE = True |
|
except ImportError: |
|
OPENAI_SDK_AVAILABLE = False |
|
OpenAI = None |
|
|
|
class OpenAIProvider(LLMProvider): |
|
"""OpenAI LLM provider implementation""" |
|
|
|
def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3): |
|
super().__init__(model_name, timeout, max_retries) |
|
|
|
if not OPENAI_SDK_AVAILABLE: |
|
raise ImportError("OpenAI provider requires 'openai' package") |
|
|
|
if not config.openai_api_key: |
|
raise ValueError("OPENAI_API_KEY not set - required for OpenAI provider") |
|
|
|
self.client = OpenAI(api_key=config.openai_api_key) |
|
|
|
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]: |
|
"""Generate a response synchronously""" |
|
try: |
|
return self._retry_with_backoff(self._generate_impl, prompt, conversation_history) |
|
except Exception as e: |
|
logger.error(f"OpenAI generation failed: {e}") |
|
return None |
|
|
|
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]: |
|
"""Generate a response with streaming support""" |
|
try: |
|
return self._retry_with_backoff(self._stream_generate_impl, prompt, conversation_history) |
|
except Exception as e: |
|
logger.error(f"OpenAI stream generation failed: {e}") |
|
return None |
|
|
|
def validate_model(self) -> bool: |
|
"""Validate if the model is available""" |
|
try: |
|
models = self.client.models.list() |
|
model_ids = [model.id for model in models.data] |
|
return self.model_name in model_ids |
|
except Exception as e: |
|
logger.warning(f"OpenAI model validation failed: {e}") |
|
return False |
|
|
|
def _generate_impl(self, prompt: str, conversation_history: List[Dict]) -> str: |
|
"""Implementation of synchronous generation""" |
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=conversation_history, |
|
max_tokens=500, |
|
temperature=0.7 |
|
) |
|
return response.choices[0].message.content |
|
|
|
def _stream_generate_impl(self, prompt: str, conversation_history: List[Dict]) -> List[str]: |
|
"""Implementation of streaming generation""" |
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=conversation_history, |
|
max_tokens=500, |
|
temperature=0.7, |
|
stream=True |
|
) |
|
|
|
chunks = [] |
|
for chunk in response: |
|
content = chunk.choices[0].delta.content |
|
if content: |
|
chunks.append(content) |
|
|
|
return chunks |
|
|