File size: 5,728 Bytes
59f10cb
 
adf8222
084503a
59f10cb
e441606
59f10cb
 
 
 
 
 
 
 
adf8222
084503a
59f10cb
adf8222
084503a
adf8222
59f10cb
 
 
 
 
 
 
 
 
 
 
 
 
adf8222
 
59f10cb
 
e441606
 
084503a
59f10cb
 
084503a
59f10cb
 
 
 
 
 
084503a
59f10cb
83ce746
59f10cb
 
e441606
59f10cb
 
 
 
 
 
 
 
 
adf8222
084503a
 
59f10cb
e441606
 
084503a
59f10cb
 
084503a
59f10cb
 
 
 
 
 
 
 
 
 
 
 
084503a
59f10cb
83ce746
59f10cb
 
e441606
59f10cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83ce746
e441606
 
 
 
 
084503a
 
 
 
 
 
 
e441606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
084503a
e441606
 
084503a
83ce746
 
 
 
 
 
 
e441606
83ce746
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import time
import logging
from typing import List, Dict, Optional, Union
from src.llm.enhanced_provider import EnhancedLLMProvider
from utils.config import config
from src.services.context_provider import context_provider
logger = logging.getLogger(__name__)

try:
    from openai import OpenAI
    HF_SDK_AVAILABLE = True
except ImportError:
    HF_SDK_AVAILABLE = False
    OpenAI = None

class HuggingFaceProvider(EnhancedLLMProvider):
    """Hugging Face LLM provider for your custom endpoint"""

    def __init__(self, model_name: str, timeout: int = 120, max_retries: int = 2):
        super().__init__(model_name, timeout, max_retries)
        
        if not HF_SDK_AVAILABLE:
            raise ImportError("Hugging Face provider requires 'openai' package")
            
        if not config.hf_token:
            raise ValueError("HF_TOKEN not set - required for Hugging Face provider")

        # Use your specific endpoint URL
        self.client = OpenAI(
            base_url=config.hf_api_url,
            api_key=config.hf_token
        )
        logger.info(f"Initialized HF provider with endpoint: {config.hf_api_url}")

    def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
        """Generate a response synchronously"""
        try:
            # Intelligently enrich context
            enriched_history = self._enrich_context_intelligently(conversation_history)
            
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=enriched_history,
                max_tokens=8192,
                temperature=0.7,
                stream=False
            )
            return response.choices[0].message.content
        except Exception as e:
            logger.error(f"HF generation failed: {e}")
            # Handle scale-to-zero behavior
            if self._is_scale_to_zero_error(e):
                logger.info("HF endpoint is scaling up, waiting...")
                time.sleep(60)  # Wait for endpoint to initialize
                # Retry once after waiting
                response = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=conversation_history,
                    max_tokens=8192,
                    temperature=0.7,
                    stream=False
                )
                return response.choices[0].message.content
            raise

    def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
        """Generate a response with streaming support"""
        try:
            # Intelligently enrich context
            enriched_history = self._enrich_context_intelligently(conversation_history)
            
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=enriched_history,
                max_tokens=8192,
                temperature=0.7,
                stream=True
            )
            
            chunks = []
            for chunk in response:
                content = chunk.choices[0].delta.content
                if content:
                    chunks.append(content)
            return chunks
        except Exception as e:
            logger.error(f"HF stream generation failed: {e}")
            # Handle scale-to-zero behavior
            if self._is_scale_to_zero_error(e):
                logger.info("HF endpoint is scaling up, waiting...")
                time.sleep(60)  # Wait for endpoint to initialize
                # Retry once after waiting
                response = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=conversation_history,
                    max_tokens=8192,
                    temperature=0.7,
                    stream=True
                )
                
                chunks = []
                for chunk in response:
                    content = chunk.choices[0].delta.content
                    if content:
                        chunks.append(content)
                return chunks
            raise

    def _enrich_context_intelligently(self, conversation_history: List[Dict]) -> List[Dict]:
        """Intelligently add context only when relevant"""
        if not conversation_history:
            return conversation_history
        
        # Get the last user message to determine context needs
        last_user_message = ""
        for msg in reversed(conversation_history):
            if msg["role"] == "user":
                last_user_message = msg["content"]
                break
        
        # Get intelligent context
        context_string = context_provider.get_context_for_llm(
            last_user_message, 
            conversation_history
        )
        
        # Only add context if it's relevant
        if context_string:
            context_message = {
                "role": "system",
                "content": context_string
            }
            # Insert context at the beginning
            enriched_history = [context_message] + conversation_history
            return enriched_history
        
        # Return original history if no context needed
        return conversation_history

    def _is_scale_to_zero_error(self, error: Exception) -> bool:
        """Check if the error is related to scale-to-zero initialization"""
        error_str = str(error).lower()
        scale_to_zero_indicators = [
            "503",
            "service unavailable",
            "initializing",
            "cold start"
        ]
        return any(indicator in error_str for indicator in scale_to_zero_indicators)