|
import time |
|
import logging |
|
from typing import List, Dict, Optional, Union |
|
from src.llm.enhanced_provider import EnhancedLLMProvider |
|
from utils.config import config |
|
from src.services.context_provider import context_provider |
|
logger = logging.getLogger(__name__) |
|
|
|
try: |
|
from openai import OpenAI |
|
HF_SDK_AVAILABLE = True |
|
except ImportError: |
|
HF_SDK_AVAILABLE = False |
|
OpenAI = None |
|
|
|
class HuggingFaceProvider(EnhancedLLMProvider): |
|
"""Hugging Face LLM provider for your custom endpoint""" |
|
|
|
def __init__(self, model_name: str, timeout: int = 120, max_retries: int = 2): |
|
super().__init__(model_name, timeout, max_retries) |
|
|
|
if not HF_SDK_AVAILABLE: |
|
raise ImportError("Hugging Face provider requires 'openai' package") |
|
|
|
if not config.hf_token: |
|
raise ValueError("HF_TOKEN not set - required for Hugging Face provider") |
|
|
|
|
|
self.client = OpenAI( |
|
base_url=config.hf_api_url, |
|
api_key=config.hf_token |
|
) |
|
logger.info(f"Initialized HF provider with endpoint: {config.hf_api_url}") |
|
|
|
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]: |
|
"""Generate a response synchronously""" |
|
try: |
|
|
|
enriched_history = self._enrich_context_intelligently(conversation_history) |
|
|
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=enriched_history, |
|
max_tokens=8192, |
|
temperature=0.7, |
|
stream=False |
|
) |
|
return response.choices[0].message.content |
|
except Exception as e: |
|
logger.error(f"HF generation failed: {e}") |
|
|
|
if self._is_scale_to_zero_error(e): |
|
logger.info("HF endpoint is scaling up, waiting...") |
|
time.sleep(60) |
|
|
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=conversation_history, |
|
max_tokens=8192, |
|
temperature=0.7, |
|
stream=False |
|
) |
|
return response.choices[0].message.content |
|
raise |
|
|
|
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]: |
|
"""Generate a response with streaming support""" |
|
try: |
|
|
|
enriched_history = self._enrich_context_intelligently(conversation_history) |
|
|
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=enriched_history, |
|
max_tokens=8192, |
|
temperature=0.7, |
|
stream=True |
|
) |
|
|
|
chunks = [] |
|
for chunk in response: |
|
content = chunk.choices[0].delta.content |
|
if content: |
|
chunks.append(content) |
|
return chunks |
|
except Exception as e: |
|
logger.error(f"HF stream generation failed: {e}") |
|
|
|
if self._is_scale_to_zero_error(e): |
|
logger.info("HF endpoint is scaling up, waiting...") |
|
time.sleep(60) |
|
|
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=conversation_history, |
|
max_tokens=8192, |
|
temperature=0.7, |
|
stream=True |
|
) |
|
|
|
chunks = [] |
|
for chunk in response: |
|
content = chunk.choices[0].delta.content |
|
if content: |
|
chunks.append(content) |
|
return chunks |
|
raise |
|
|
|
def _enrich_context_intelligently(self, conversation_history: List[Dict]) -> List[Dict]: |
|
"""Intelligently add context only when relevant""" |
|
if not conversation_history: |
|
return conversation_history |
|
|
|
|
|
last_user_message = "" |
|
for msg in reversed(conversation_history): |
|
if msg["role"] == "user": |
|
last_user_message = msg["content"] |
|
break |
|
|
|
|
|
context_string = context_provider.get_context_for_llm( |
|
last_user_message, |
|
conversation_history |
|
) |
|
|
|
|
|
if context_string: |
|
context_message = { |
|
"role": "system", |
|
"content": context_string |
|
} |
|
|
|
enriched_history = [context_message] + conversation_history |
|
return enriched_history |
|
|
|
|
|
return conversation_history |
|
|
|
def _is_scale_to_zero_error(self, error: Exception) -> bool: |
|
"""Check if the error is related to scale-to-zero initialization""" |
|
error_str = str(error).lower() |
|
scale_to_zero_indicators = [ |
|
"503", |
|
"service unavailable", |
|
"initializing", |
|
"cold start" |
|
] |
|
return any(indicator in error_str for indicator in scale_to_zero_indicators) |
|
|