File size: 16,680 Bytes
5e1a30c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
"""
HuggingFace LLM adapter implementation.

This adapter provides integration with HuggingFace Inference API, handling
the specific API format and response structure of HuggingFace models.

Architecture Notes:
- Converts between unified interface and HuggingFace API format
- Handles both chat completion and text generation endpoints
- Supports automatic model selection and fallback
- Maps HuggingFace errors to standard LLMError types
"""

import os
import logging
import time
from typing import Dict, Any, Optional, List, Iterator
from datetime import datetime

from .base_adapter import BaseLLMAdapter, LLMError, ModelNotFoundError, AuthenticationError, RateLimitError
from ..base import GenerationParams

logger = logging.getLogger(__name__)

# Check for HuggingFace Hub availability
try:
    from huggingface_hub import InferenceClient
    HF_HUB_AVAILABLE = True
except ImportError:
    HF_HUB_AVAILABLE = False
    logger.warning("huggingface_hub not available. Install with: pip install huggingface-hub")


class HuggingFaceAdapter(BaseLLMAdapter):
    """
    Adapter for HuggingFace Inference API integration.
    
    Features:
    - Support for both chat completion and text generation
    - Automatic model selection and fallback
    - OpenAI-compatible chat completion format
    - Comprehensive error handling and retry logic
    - Multiple model support with automatic fallback
    
    Configuration:
    - api_token: HuggingFace API token (required)
    - timeout: Request timeout in seconds (default: 30)
    - use_chat_completion: Prefer chat completion over text generation
    - fallback_models: List of models to try if primary fails
    """
    
    # Models that work well with chat completion format
    CHAT_MODELS = [
        "microsoft/DialoGPT-medium",       # Proven conversational model
        "google/gemma-2-2b-it",            # Instruction-tuned, good for Q&A
        "meta-llama/Llama-3.2-3B-Instruct", # If available with token
        "Qwen/Qwen2.5-1.5B-Instruct",     # Small, fast, good quality
    ]
    
    # Fallback models for classic text generation
    CLASSIC_MODELS = [
        "google/flan-t5-small",            # Good for instructions
        "deepset/roberta-base-squad2",     # Q&A specific
        "facebook/bart-base",              # Summarization
    ]
    
    def __init__(self,
                 model_name: str = "microsoft/DialoGPT-medium",
                 api_token: Optional[str] = None,
                 timeout: int = 30,
                 use_chat_completion: bool = True,
                 fallback_models: Optional[List[str]] = None,
                 config: Optional[Dict[str, Any]] = None):
        """
        Initialize HuggingFace adapter.
        
        Args:
            model_name: HuggingFace model name
            api_token: HuggingFace API token
            timeout: Request timeout in seconds
            use_chat_completion: Prefer chat completion over text generation
            fallback_models: List of fallback models to try
            config: Additional configuration
        """
        if not HF_HUB_AVAILABLE:
            raise ImportError("huggingface_hub is required for HuggingFace adapter. Install with: pip install huggingface-hub")
        
        # Get API token from various sources
        self.api_token = (
            api_token or 
            os.getenv("HUGGINGFACE_API_TOKEN") or 
            os.getenv("HF_TOKEN") or 
            os.getenv("HF_API_TOKEN")
        )
        
        if not self.api_token:
            raise AuthenticationError("HuggingFace API token required. Set HF_TOKEN environment variable or pass api_token parameter.")
        
        # Merge configuration
        adapter_config = {
            'api_token': self.api_token,
            'timeout': timeout,
            'use_chat_completion': use_chat_completion,
            'fallback_models': fallback_models or [],
            **(config or {})
        }
        
        super().__init__(model_name, adapter_config)
        
        self.timeout = adapter_config['timeout']
        self.use_chat_completion = adapter_config['use_chat_completion']
        self.fallback_models = adapter_config['fallback_models']
        
        # Initialize client
        self.client = InferenceClient(token=self.api_token)
        
        # Test connection and determine best model (only if not using dummy token)
        if not self.api_token.startswith("dummy_"):
            self._test_connection()
        else:
            logger.info("Using dummy token, skipping connection test")
        
        logger.info(f"Initialized HuggingFace adapter for model '{self.model_name}' (chat_completion: {self.use_chat_completion})")
    
    def _make_request(self, prompt: str, params: GenerationParams) -> Dict[str, Any]:
        """
        Make a request to HuggingFace API.
        
        Args:
            prompt: The prompt to send
            params: Generation parameters
            
        Returns:
            HuggingFace API response
            
        Raises:
            Various request exceptions
        """
        try:
            if self.use_chat_completion:
                return self._make_chat_completion_request(prompt, params)
            else:
                return self._make_text_generation_request(prompt, params)
        except Exception as e:
            # Try fallback models if primary fails
            for fallback_model in self.fallback_models:
                try:
                    logger.info(f"Trying fallback model: {fallback_model}")
                    original_model = self.model_name
                    self.model_name = fallback_model
                    
                    if self.use_chat_completion:
                        result = self._make_chat_completion_request(prompt, params)
                    else:
                        result = self._make_text_generation_request(prompt, params)
                    
                    # Success with fallback
                    logger.info(f"Successfully used fallback model: {fallback_model}")
                    return result
                    
                except Exception as fallback_error:
                    logger.warning(f"Fallback model {fallback_model} failed: {fallback_error}")
                    # Restore original model name
                    self.model_name = original_model
                    continue
            
            # All models failed
            self._handle_provider_error(e)
    
    def _make_chat_completion_request(self, prompt: str, params: GenerationParams) -> Dict[str, Any]:
        """Make a chat completion request."""
        messages = [{"role": "user", "content": prompt}]
        
        try:
            response = self.client.chat_completion(
                messages=messages,
                model=self.model_name,
                temperature=params.temperature,
                max_tokens=params.max_tokens,
                stream=False
            )
            
            # Extract content from response
            if hasattr(response, 'choices') and response.choices:
                content = response.choices[0].message.content
                return {
                    'content': content,
                    'model': self.model_name,
                    'usage': getattr(response, 'usage', {}),
                    'response_type': 'chat_completion'
                }
            else:
                # Handle different response formats
                if hasattr(response, 'generated_text'):
                    content = response.generated_text
                else:
                    content = str(response)
                
                return {
                    'content': content,
                    'model': self.model_name,
                    'usage': {},
                    'response_type': 'chat_completion'
                }
                
        except Exception as e:
            logger.error(f"Chat completion failed: {e}")
            raise
    
    def _make_text_generation_request(self, prompt: str, params: GenerationParams) -> Dict[str, Any]:
        """Make a text generation request."""
        try:
            response = self.client.text_generation(
                model=self.model_name,
                prompt=prompt,
                max_new_tokens=params.max_tokens,
                temperature=params.temperature,
                do_sample=params.temperature > 0,
                top_p=params.top_p,
                stop_sequences=params.stop_sequences
            )
            
            # Handle response format
            if isinstance(response, str):
                content = response
            else:
                content = getattr(response, 'generated_text', str(response))
            
            return {
                'content': content,
                'model': self.model_name,
                'usage': {},
                'response_type': 'text_generation'
            }
            
        except Exception as e:
            logger.error(f"Text generation failed: {e}")
            raise
    
    def _parse_response(self, response: Dict[str, Any]) -> str:
        """
        Parse HuggingFace response to extract generated text.
        
        Args:
            response: HuggingFace API response
            
        Returns:
            Generated text
        """
        content = response.get('content', '')
        
        # Log usage if available
        if 'usage' in response and response['usage']:
            usage = response['usage']
            total_tokens = usage.get('total_tokens', 0)
            if total_tokens > 0:
                logger.debug(f"HuggingFace used {total_tokens} tokens for generation")
        
        return content
    
    def generate_streaming(self, prompt: str, params: GenerationParams) -> Iterator[str]:
        """
        Generate a streaming response from HuggingFace.
        
        Args:
            prompt: The prompt to send
            params: Generation parameters
            
        Yields:
            Generated text chunks
        """
        try:
            if self.use_chat_completion:
                # Try streaming chat completion
                messages = [{"role": "user", "content": prompt}]
                
                response = self.client.chat_completion(
                    messages=messages,
                    model=self.model_name,
                    temperature=params.temperature,
                    max_tokens=params.max_tokens,
                    stream=True
                )
                
                for chunk in response:
                    if hasattr(chunk, 'choices') and chunk.choices:
                        delta = chunk.choices[0].delta
                        if hasattr(delta, 'content') and delta.content:
                            yield delta.content
            else:
                # Fallback to non-streaming for text generation
                logger.warning("Streaming not supported for text generation, falling back to non-streaming")
                yield self.generate(prompt, params)
                
        except Exception as e:
            logger.error(f"Streaming generation failed: {e}")
            # Fallback to non-streaming
            yield self.generate(prompt, params)
    
    def _get_provider_name(self) -> str:
        """Return the provider name."""
        return "HuggingFace"
    
    def _validate_model(self) -> bool:
        """Check if the model exists in HuggingFace."""
        try:
            # Try a simple test request
            test_prompt = "Hello"
            
            if self.use_chat_completion:
                test_messages = [{"role": "user", "content": test_prompt}]
                response = self.client.chat_completion(
                    messages=test_messages,
                    model=self.model_name,
                    max_tokens=10,
                    temperature=0.1
                )
            else:
                response = self.client.text_generation(
                    model=self.model_name,
                    prompt=test_prompt,
                    max_new_tokens=10
                )
            
            return bool(response)
            
        except Exception as e:
            logger.warning(f"Model validation failed: {e}")
            return False
    
    def _supports_streaming(self) -> bool:
        """HuggingFace supports streaming for chat completion."""
        return self.use_chat_completion
    
    def _get_max_tokens(self) -> Optional[int]:
        """Get max tokens for current model."""
        # Model-specific limits (approximate)
        model_limits = {
            'microsoft/DialoGPT-medium': 1024,
            'google/gemma-2-2b-it': 8192,
            'meta-llama/Llama-3.2-3B-Instruct': 4096,
            'Qwen/Qwen2.5-1.5B-Instruct': 32768,
            'google/flan-t5-small': 512,
            'deepset/roberta-base-squad2': 512,
            'facebook/bart-base': 1024,
        }
        
        # Check for exact match
        if self.model_name in model_limits:
            return model_limits[self.model_name]
        
        # Check for partial match
        for model_prefix, limit in model_limits.items():
            if model_prefix in self.model_name:
                return limit
        
        # Default for unknown models
        return 1024
    
    def _test_connection(self) -> None:
        """Test the connection and find the best working model."""
        logger.info("Testing HuggingFace API connection...")
        
        # Test primary model first
        if self._validate_model():
            logger.info(f"Primary model {self.model_name} is working")
            return
        
        # Try chat models if using chat completion
        if self.use_chat_completion:
            for model in self.CHAT_MODELS:
                if model != self.model_name:
                    try:
                        logger.info(f"Testing chat model: {model}")
                        original_model = self.model_name
                        self.model_name = model
                        
                        if self._validate_model():
                            logger.info(f"Found working chat model: {model}")
                            return
                        
                        # Restore original if failed
                        self.model_name = original_model
                        
                    except Exception as e:
                        logger.warning(f"Chat model {model} failed: {e}")
                        continue
        
        # Try classic models as fallback
        logger.info("Trying classic text generation models...")
        for model in self.CLASSIC_MODELS:
            try:
                logger.info(f"Testing classic model: {model}")
                original_model = self.model_name
                original_chat = self.use_chat_completion
                
                self.model_name = model
                self.use_chat_completion = False
                
                if self._validate_model():
                    logger.info(f"Found working classic model: {model}")
                    return
                
                # Restore original settings if failed
                self.model_name = original_model
                self.use_chat_completion = original_chat
                
            except Exception as e:
                logger.warning(f"Classic model {model} failed: {e}")
                continue
        
        # If we get here, no models worked
        raise ModelNotFoundError(f"No working models found. Original model '{self.model_name}' is not accessible.")
    
    def _handle_provider_error(self, error: Exception) -> None:
        """Map HuggingFace-specific errors to standard errors."""
        error_msg = str(error).lower()
        
        if 'rate limit' in error_msg or '429' in error_msg:
            raise RateLimitError(f"HuggingFace rate limit exceeded: {error}")
        elif 'unauthorized' in error_msg or '401' in error_msg or 'token' in error_msg:
            raise AuthenticationError(f"HuggingFace authentication failed: {error}")
        elif 'not found' in error_msg or '404' in error_msg:
            raise ModelNotFoundError(f"HuggingFace model not found: {error}")
        elif 'timeout' in error_msg:
            raise LLMError(f"HuggingFace request timed out: {error}")
        elif 'connection' in error_msg:
            raise LLMError(f"Failed to connect to HuggingFace API: {error}")
        else:
            raise LLMError(f"HuggingFace API error: {error}")