""" HuggingFace LLM adapter implementation. This adapter provides integration with HuggingFace Inference API, handling the specific API format and response structure of HuggingFace models. Architecture Notes: - Converts between unified interface and HuggingFace API format - Handles both chat completion and text generation endpoints - Supports automatic model selection and fallback - Maps HuggingFace errors to standard LLMError types """ import os import logging import time from typing import Dict, Any, Optional, List, Iterator from datetime import datetime from .base_adapter import BaseLLMAdapter, LLMError, ModelNotFoundError, AuthenticationError, RateLimitError from ..base import GenerationParams logger = logging.getLogger(__name__) # Check for HuggingFace Hub availability try: from huggingface_hub import InferenceClient HF_HUB_AVAILABLE = True except ImportError: HF_HUB_AVAILABLE = False logger.warning("huggingface_hub not available. Install with: pip install huggingface-hub") class HuggingFaceAdapter(BaseLLMAdapter): """ Adapter for HuggingFace Inference API integration. Features: - Support for both chat completion and text generation - Automatic model selection and fallback - OpenAI-compatible chat completion format - Comprehensive error handling and retry logic - Multiple model support with automatic fallback Configuration: - api_token: HuggingFace API token (required) - timeout: Request timeout in seconds (default: 30) - use_chat_completion: Prefer chat completion over text generation - fallback_models: List of models to try if primary fails """ # Models that work well with chat completion format CHAT_MODELS = [ "microsoft/DialoGPT-medium", # Proven conversational model "google/gemma-2-2b-it", # Instruction-tuned, good for Q&A "meta-llama/Llama-3.2-3B-Instruct", # If available with token "Qwen/Qwen2.5-1.5B-Instruct", # Small, fast, good quality ] # Fallback models for classic text generation CLASSIC_MODELS = [ "google/flan-t5-small", # Good for instructions "deepset/roberta-base-squad2", # Q&A specific "facebook/bart-base", # Summarization ] def __init__(self, model_name: str = "microsoft/DialoGPT-medium", api_token: Optional[str] = None, timeout: int = 30, use_chat_completion: bool = True, fallback_models: Optional[List[str]] = None, config: Optional[Dict[str, Any]] = None): """ Initialize HuggingFace adapter. Args: model_name: HuggingFace model name api_token: HuggingFace API token timeout: Request timeout in seconds use_chat_completion: Prefer chat completion over text generation fallback_models: List of fallback models to try config: Additional configuration """ if not HF_HUB_AVAILABLE: raise ImportError("huggingface_hub is required for HuggingFace adapter. Install with: pip install huggingface-hub") # Get API token from various sources self.api_token = ( api_token or os.getenv("HUGGINGFACE_API_TOKEN") or os.getenv("HF_TOKEN") or os.getenv("HF_API_TOKEN") ) if not self.api_token: raise AuthenticationError("HuggingFace API token required. Set HF_TOKEN environment variable or pass api_token parameter.") # Merge configuration adapter_config = { 'api_token': self.api_token, 'timeout': timeout, 'use_chat_completion': use_chat_completion, 'fallback_models': fallback_models or [], **(config or {}) } super().__init__(model_name, adapter_config) self.timeout = adapter_config['timeout'] self.use_chat_completion = adapter_config['use_chat_completion'] self.fallback_models = adapter_config['fallback_models'] # Initialize client self.client = InferenceClient(token=self.api_token) # Test connection and determine best model (only if not using dummy token) if not self.api_token.startswith("dummy_"): self._test_connection() else: logger.info("Using dummy token, skipping connection test") logger.info(f"Initialized HuggingFace adapter for model '{self.model_name}' (chat_completion: {self.use_chat_completion})") def _make_request(self, prompt: str, params: GenerationParams) -> Dict[str, Any]: """ Make a request to HuggingFace API. Args: prompt: The prompt to send params: Generation parameters Returns: HuggingFace API response Raises: Various request exceptions """ try: if self.use_chat_completion: return self._make_chat_completion_request(prompt, params) else: return self._make_text_generation_request(prompt, params) except Exception as e: # Try fallback models if primary fails for fallback_model in self.fallback_models: try: logger.info(f"Trying fallback model: {fallback_model}") original_model = self.model_name self.model_name = fallback_model if self.use_chat_completion: result = self._make_chat_completion_request(prompt, params) else: result = self._make_text_generation_request(prompt, params) # Success with fallback logger.info(f"Successfully used fallback model: {fallback_model}") return result except Exception as fallback_error: logger.warning(f"Fallback model {fallback_model} failed: {fallback_error}") # Restore original model name self.model_name = original_model continue # All models failed self._handle_provider_error(e) def _make_chat_completion_request(self, prompt: str, params: GenerationParams) -> Dict[str, Any]: """Make a chat completion request.""" messages = [{"role": "user", "content": prompt}] try: response = self.client.chat_completion( messages=messages, model=self.model_name, temperature=params.temperature, max_tokens=params.max_tokens, stream=False ) # Extract content from response if hasattr(response, 'choices') and response.choices: content = response.choices[0].message.content return { 'content': content, 'model': self.model_name, 'usage': getattr(response, 'usage', {}), 'response_type': 'chat_completion' } else: # Handle different response formats if hasattr(response, 'generated_text'): content = response.generated_text else: content = str(response) return { 'content': content, 'model': self.model_name, 'usage': {}, 'response_type': 'chat_completion' } except Exception as e: logger.error(f"Chat completion failed: {e}") raise def _make_text_generation_request(self, prompt: str, params: GenerationParams) -> Dict[str, Any]: """Make a text generation request.""" try: response = self.client.text_generation( model=self.model_name, prompt=prompt, max_new_tokens=params.max_tokens, temperature=params.temperature, do_sample=params.temperature > 0, top_p=params.top_p, stop_sequences=params.stop_sequences ) # Handle response format if isinstance(response, str): content = response else: content = getattr(response, 'generated_text', str(response)) return { 'content': content, 'model': self.model_name, 'usage': {}, 'response_type': 'text_generation' } except Exception as e: logger.error(f"Text generation failed: {e}") raise def _parse_response(self, response: Dict[str, Any]) -> str: """ Parse HuggingFace response to extract generated text. Args: response: HuggingFace API response Returns: Generated text """ content = response.get('content', '') # Log usage if available if 'usage' in response and response['usage']: usage = response['usage'] total_tokens = usage.get('total_tokens', 0) if total_tokens > 0: logger.debug(f"HuggingFace used {total_tokens} tokens for generation") return content def generate_streaming(self, prompt: str, params: GenerationParams) -> Iterator[str]: """ Generate a streaming response from HuggingFace. Args: prompt: The prompt to send params: Generation parameters Yields: Generated text chunks """ try: if self.use_chat_completion: # Try streaming chat completion messages = [{"role": "user", "content": prompt}] response = self.client.chat_completion( messages=messages, model=self.model_name, temperature=params.temperature, max_tokens=params.max_tokens, stream=True ) for chunk in response: if hasattr(chunk, 'choices') and chunk.choices: delta = chunk.choices[0].delta if hasattr(delta, 'content') and delta.content: yield delta.content else: # Fallback to non-streaming for text generation logger.warning("Streaming not supported for text generation, falling back to non-streaming") yield self.generate(prompt, params) except Exception as e: logger.error(f"Streaming generation failed: {e}") # Fallback to non-streaming yield self.generate(prompt, params) def _get_provider_name(self) -> str: """Return the provider name.""" return "HuggingFace" def _validate_model(self) -> bool: """Check if the model exists in HuggingFace.""" try: # Try a simple test request test_prompt = "Hello" if self.use_chat_completion: test_messages = [{"role": "user", "content": test_prompt}] response = self.client.chat_completion( messages=test_messages, model=self.model_name, max_tokens=10, temperature=0.1 ) else: response = self.client.text_generation( model=self.model_name, prompt=test_prompt, max_new_tokens=10 ) return bool(response) except Exception as e: logger.warning(f"Model validation failed: {e}") return False def _supports_streaming(self) -> bool: """HuggingFace supports streaming for chat completion.""" return self.use_chat_completion def _get_max_tokens(self) -> Optional[int]: """Get max tokens for current model.""" # Model-specific limits (approximate) model_limits = { 'microsoft/DialoGPT-medium': 1024, 'google/gemma-2-2b-it': 8192, 'meta-llama/Llama-3.2-3B-Instruct': 4096, 'Qwen/Qwen2.5-1.5B-Instruct': 32768, 'google/flan-t5-small': 512, 'deepset/roberta-base-squad2': 512, 'facebook/bart-base': 1024, } # Check for exact match if self.model_name in model_limits: return model_limits[self.model_name] # Check for partial match for model_prefix, limit in model_limits.items(): if model_prefix in self.model_name: return limit # Default for unknown models return 1024 def _test_connection(self) -> None: """Test the connection and find the best working model.""" logger.info("Testing HuggingFace API connection...") # Test primary model first if self._validate_model(): logger.info(f"Primary model {self.model_name} is working") return # Try chat models if using chat completion if self.use_chat_completion: for model in self.CHAT_MODELS: if model != self.model_name: try: logger.info(f"Testing chat model: {model}") original_model = self.model_name self.model_name = model if self._validate_model(): logger.info(f"Found working chat model: {model}") return # Restore original if failed self.model_name = original_model except Exception as e: logger.warning(f"Chat model {model} failed: {e}") continue # Try classic models as fallback logger.info("Trying classic text generation models...") for model in self.CLASSIC_MODELS: try: logger.info(f"Testing classic model: {model}") original_model = self.model_name original_chat = self.use_chat_completion self.model_name = model self.use_chat_completion = False if self._validate_model(): logger.info(f"Found working classic model: {model}") return # Restore original settings if failed self.model_name = original_model self.use_chat_completion = original_chat except Exception as e: logger.warning(f"Classic model {model} failed: {e}") continue # If we get here, no models worked raise ModelNotFoundError(f"No working models found. Original model '{self.model_name}' is not accessible.") def _handle_provider_error(self, error: Exception) -> None: """Map HuggingFace-specific errors to standard errors.""" error_msg = str(error).lower() if 'rate limit' in error_msg or '429' in error_msg: raise RateLimitError(f"HuggingFace rate limit exceeded: {error}") elif 'unauthorized' in error_msg or '401' in error_msg or 'token' in error_msg: raise AuthenticationError(f"HuggingFace authentication failed: {error}") elif 'not found' in error_msg or '404' in error_msg: raise ModelNotFoundError(f"HuggingFace model not found: {error}") elif 'timeout' in error_msg: raise LLMError(f"HuggingFace request timed out: {error}") elif 'connection' in error_msg: raise LLMError(f"Failed to connect to HuggingFace API: {error}") else: raise LLMError(f"HuggingFace API error: {error}")