|
import time |
|
import logging |
|
from typing import List, Dict, Optional, Union |
|
from core.providers.base import LLMProvider |
|
from utils.config import config |
|
logger = logging.getLogger(__name__) |
|
try: |
|
from openai import OpenAI |
|
HUGGINGFACE_SDK_AVAILABLE = True |
|
except ImportError: |
|
HUGGINGFACE_SDK_AVAILABLE = False |
|
OpenAI = None |
|
|
|
class HuggingFaceProvider(LLMProvider): |
|
"""Hugging Face LLM provider implementation""" |
|
|
|
def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3): |
|
super().__init__(model_name, timeout, max_retries) |
|
logger.info(f"Initializing HuggingFaceProvider with:") |
|
logger.info(f" HF_API_URL: {config.hf_api_url}") |
|
logger.info(f" HF_TOKEN SET: {bool(config.hf_token)}") |
|
|
|
if not HUGGINGFACE_SDK_AVAILABLE: |
|
raise ImportError("Hugging Face provider requires 'openai' package") |
|
if not config.hf_token: |
|
raise ValueError("HF_TOKEN not set - required for Hugging Face provider") |
|
|
|
|
|
try: |
|
self.client = OpenAI( |
|
base_url=config.hf_api_url, |
|
api_key=config.hf_token |
|
) |
|
logger.info("HuggingFaceProvider initialized successfully") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize HuggingFaceProvider: {e}") |
|
logger.error(f"Error type: {type(e)}") |
|
raise |
|
|
|
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]: |
|
"""Generate a response synchronously""" |
|
try: |
|
return self._retry_with_backoff(self._generate_impl, prompt, conversation_history) |
|
except Exception as e: |
|
logger.error(f"Hugging Face generation failed: {e}") |
|
return None |
|
|
|
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]: |
|
"""Generate a response with streaming support""" |
|
try: |
|
return self._retry_with_backoff(self._stream_generate_impl, prompt, conversation_history) |
|
except Exception as e: |
|
logger.error(f"Hugging Face stream generation failed: {e}") |
|
return None |
|
|
|
def validate_model(self) -> bool: |
|
"""Validate if the model is available""" |
|
|
|
try: |
|
|
|
self.client.models.list() |
|
return True |
|
except Exception as e: |
|
logger.warning(f"Hugging Face model validation failed: {e}") |
|
return False |
|
|
|
def _generate_impl(self, prompt: str, conversation_history: List[Dict]) -> str: |
|
"""Implementation of synchronous generation with proper configuration""" |
|
try: |
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=conversation_history, |
|
max_tokens=8192, |
|
temperature=0.7, |
|
top_p=0.9, |
|
frequency_penalty=0.1, |
|
presence_penalty=0.1 |
|
) |
|
return response.choices[0].message.content |
|
except Exception as e: |
|
|
|
if self._is_scale_to_zero_error(e): |
|
logger.info("Hugging Face endpoint is scaling up, waiting...") |
|
time.sleep(60) |
|
|
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=conversation_history, |
|
max_tokens=8192, |
|
temperature=0.7, |
|
top_p=0.9, |
|
frequency_penalty=0.1, |
|
presence_penalty=0.1 |
|
) |
|
return response.choices[0].message.content |
|
else: |
|
raise |
|
|
|
def _stream_generate_impl(self, prompt: str, conversation_history: List[Dict]) -> List[str]: |
|
"""Implementation of streaming generation with proper configuration""" |
|
try: |
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=conversation_history, |
|
max_tokens=8192, |
|
temperature=0.7, |
|
top_p=0.9, |
|
frequency_penalty=0.1, |
|
presence_penalty=0.1, |
|
stream=True |
|
) |
|
|
|
chunks = [] |
|
for chunk in response: |
|
content = chunk.choices[0].delta.content |
|
if content: |
|
chunks.append(content) |
|
|
|
return chunks |
|
except Exception as e: |
|
|
|
if self._is_scale_to_zero_error(e): |
|
logger.info("Hugging Face endpoint is scaling up, waiting...") |
|
time.sleep(60) |
|
|
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=conversation_history, |
|
max_tokens=8192, |
|
temperature=0.7, |
|
top_p=0.9, |
|
frequency_penalty=0.1, |
|
presence_penalty=0.1, |
|
stream=True |
|
) |
|
|
|
chunks = [] |
|
for chunk in response: |
|
content = chunk.choices[0].delta.content |
|
if content: |
|
chunks.append(content) |
|
|
|
return chunks |
|
else: |
|
raise |
|
|
|
def _is_scale_to_zero_error(self, error: Exception) -> bool: |
|
"""Check if the error is related to scale-to-zero initialization""" |
|
error_str = str(error).lower() |
|
scale_to_zero_indicators = [ |
|
"503", "service unavailable", "initializing", "cold start" |
|
] |
|
return any(indicator in error_str for indicator in scale_to_zero_indicators) |
|
|