|
import time |
|
import logging |
|
from datetime import datetime |
|
from typing import List, Dict, Optional, Union |
|
from core.providers.base import LLMProvider |
|
from utils.config import config |
|
from services.weather import weather_service |
|
logger = logging.getLogger(__name__) |
|
|
|
try: |
|
from openai import OpenAI |
|
HUGGINGFACE_SDK_AVAILABLE = True |
|
except ImportError: |
|
HUGGINGFACE_SDK_AVAILABLE = False |
|
OpenAI = None |
|
|
|
class HuggingFaceProvider(LLMProvider): |
|
"""Hugging Face LLM provider implementation with cached validation""" |
|
|
|
def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3): |
|
super().__init__(model_name, timeout, max_retries) |
|
logger.info(f"Initializing HuggingFaceProvider with:") |
|
logger.info(f" HF_API_URL: {config.hf_api_url}") |
|
logger.info(f" HF_TOKEN SET: {bool(config.hf_token)}") |
|
|
|
if not HUGGINGFACE_SDK_AVAILABLE: |
|
raise ImportError("Hugging Face provider requires 'openai' package") |
|
|
|
if not config.hf_token: |
|
raise ValueError("HF_TOKEN not set - required for Hugging Face provider") |
|
|
|
|
|
try: |
|
self.client = OpenAI( |
|
base_url=config.hf_api_url, |
|
api_key=config.hf_token |
|
) |
|
logger.info("HuggingFaceProvider initialized successfully") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize HuggingFaceProvider: {e}") |
|
logger.error(f"Error type: {type(e)}") |
|
raise |
|
|
|
|
|
self._model_validated = False |
|
self._last_validation = 0 |
|
self._validation_cache_duration = 300 |
|
|
|
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]: |
|
"""Generate a response synchronously""" |
|
try: |
|
return self._retry_with_backoff(self._generate_impl, prompt, conversation_history) |
|
except Exception as e: |
|
logger.error(f"Hugging Face generation failed: {e}") |
|
return None |
|
|
|
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]: |
|
"""Generate a response with streaming support""" |
|
try: |
|
return self._retry_with_backoff(self._stream_generate_impl, prompt, conversation_history) |
|
except Exception as e: |
|
logger.error(f"Hugging Face stream generation failed: {e}") |
|
return None |
|
|
|
def validate_model(self) -> bool: |
|
"""Validate if the model is available with caching""" |
|
current_time = time.time() |
|
if (self._model_validated and |
|
current_time - self._last_validation < self._validation_cache_duration): |
|
return True |
|
|
|
try: |
|
self.client.models.list() |
|
self._model_validated = True |
|
self._last_validation = current_time |
|
return True |
|
except Exception as e: |
|
logger.warning(f"Hugging Face model validation failed: {e}") |
|
return False |
|
|
|
def _generate_impl(self, prompt: str, conversation_history: List[Dict]) -> str: |
|
"""Implementation of synchronous generation with proper configuration and context injection""" |
|
|
|
current_time = datetime.now().strftime("%A, %B %d, %Y at %I:%M %p") |
|
weather_summary = weather_service.get_weather_summary() |
|
context_msg = { |
|
"role": "system", |
|
"content": f"[Current Context: {current_time} | Weather: {weather_summary}]" |
|
} |
|
enhanced_history = [context_msg] + conversation_history |
|
|
|
try: |
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=enhanced_history, |
|
max_tokens=8192, |
|
temperature=0.7, |
|
top_p=0.9, |
|
frequency_penalty=0.1, |
|
presence_penalty=0.1 |
|
) |
|
return response.choices[0].message.content |
|
except Exception as e: |
|
|
|
if self._is_scale_to_zero_error(e): |
|
logger.info("Hugging Face endpoint is scaling up, waiting...") |
|
time.sleep(60) |
|
|
|
|
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=enhanced_history, |
|
max_tokens=8192, |
|
temperature=0.7, |
|
top_p=0.9, |
|
frequency_penalty=0.1, |
|
presence_penalty=0.1 |
|
) |
|
return response.choices[0].message.content |
|
else: |
|
raise |
|
|
|
def _stream_generate_impl(self, prompt: str, conversation_history: List[Dict]) -> List[str]: |
|
"""Implementation of streaming generation with proper configuration and context injection""" |
|
|
|
current_time = datetime.now().strftime("%A, %B %d, %Y at %I:%M %p") |
|
weather_summary = weather_service.get_weather_summary() |
|
context_msg = { |
|
"role": "system", |
|
"content": f"[Current Context: {current_time} | Weather: {weather_summary}]" |
|
} |
|
enhanced_history = [context_msg] + conversation_history |
|
|
|
try: |
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=enhanced_history, |
|
max_tokens=8192, |
|
temperature=0.7, |
|
top_p=0.9, |
|
frequency_penalty=0.1, |
|
presence_penalty=0.1, |
|
stream=True |
|
) |
|
chunks = [] |
|
for chunk in response: |
|
content = chunk.choices[0].delta.content |
|
if content: |
|
chunks.append(content) |
|
return chunks |
|
except Exception as e: |
|
|
|
if self._is_scale_to_zero_error(e): |
|
logger.info("Hugging Face endpoint is scaling up, waiting...") |
|
time.sleep(60) |
|
|
|
|
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=enhanced_history, |
|
max_tokens=8192, |
|
temperature=0.7, |
|
top_p=0.9, |
|
frequency_penalty=0.1, |
|
presence_penalty=0.1, |
|
stream=True |
|
) |
|
chunks = [] |
|
for chunk in response: |
|
content = chunk.choices[0].delta.content |
|
if content: |
|
chunks.append(content) |
|
return chunks |
|
else: |
|
raise |
|
|
|
def _is_scale_to_zero_error(self, error: Exception) -> bool: |
|
"""Check if the error is related to scale-to-zero initialization""" |
|
error_str = str(error).lower() |
|
scale_to_zero_indicators = [ |
|
"503", |
|
"service unavailable", |
|
"initializing", |
|
"cold start" |
|
] |
|
return any(indicator in error_str for indicator in scale_to_zero_indicators) |
|
|
|
def _get_weather_summary(self) -> str: |
|
"""Get formatted weather summary""" |
|
try: |
|
weather = weather_service.get_current_weather_cached( |
|
"New York", |
|
ttl_hash=weather_service._get_ttl_hash(300) |
|
) |
|
if weather: |
|
return f"{weather.get('temperature', 'N/A')}°C, {weather.get('description', 'Clear skies')}" |
|
else: |
|
return "Clear skies" |
|
except: |
|
return "Clear skies" |
|
|