File size: 6,811 Bytes
5c1efea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import time
import logging
from typing import List, Dict, Optional, Union
from src.llm.base_provider import LLMProvider
from src.llm.hf_provider import HuggingFaceProvider
from src.llm.ollama_provider import OllamaProvider
from core.session import session_manager
from utils.config import config
logger = logging.getLogger(__name__)
class HybridProvider(LLMProvider):
"""Hybrid provider that uses HF for heavy lifting and Ollama for local caching/summarization"""
def __init__(self, model_name: str, timeout: int = 120, max_retries: int = 2):
super().__init__(model_name, timeout, max_retries)
self.hf_provider = None
self.ollama_provider = None
# Initialize providers
try:
if config.hf_token:
self.hf_provider = HuggingFaceProvider(
model_name="DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf",
timeout=120
)
except Exception as e:
logger.warning(f"Failed to initialize HF provider: {e}")
try:
if config.ollama_host:
self.ollama_provider = OllamaProvider(
model_name=config.local_model_name,
timeout=60
)
except Exception as e:
logger.warning(f"Failed to initialize Ollama provider: {e}")
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
"""Generate response using hybrid approach"""
try:
# Step 1: Get heavy lifting from HF Endpoint
hf_response = self._get_hf_response(prompt, conversation_history)
if not hf_response:
raise Exception("HF Endpoint failed to provide response")
# Step 2: Store HF response in local cache via Ollama
self._cache_response_locally(prompt, hf_response, conversation_history)
# Step 3: Optionally create local summary (if needed)
# For now, return HF response directly but with local backup
return hf_response
except Exception as e:
logger.error(f"Hybrid generation failed: {e}")
# Fallback to Ollama if available
if self.ollama_provider:
try:
logger.info("Falling back to Ollama for local response")
return self.ollama_provider.generate(prompt, conversation_history)
except Exception as fallback_error:
logger.error(f"Ollama fallback also failed: {fallback_error}")
raise Exception(f"Both HF Endpoint and Ollama failed: {str(e)}")
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
"""Stream response using hybrid approach"""
try:
# Get streaming response from HF
if self.hf_provider:
return self.hf_provider.stream_generate(prompt, conversation_history)
elif self.ollama_provider:
return self.ollama_provider.stream_generate(prompt, conversation_history)
else:
raise Exception("No providers available")
except Exception as e:
logger.error(f"Hybrid stream generation failed: {e}")
raise
def _get_hf_response(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
"""Get response from HF Endpoint with fallback handling"""
if not self.hf_provider:
return None
try:
logger.info("🚀 Getting detailed response from HF Endpoint...")
response = self.hf_provider.generate(prompt, conversation_history)
logger.info("✅ HF Endpoint response received")
return response
except Exception as e:
logger.error(f"HF Endpoint failed: {e}")
# Don't raise here, let hybrid provider handle fallback
return None
def _cache_response_locally(self, prompt: str, response: str, conversation_history: List[Dict]):
"""Cache HF response locally using Ollama for future reference"""
if not self.ollama_provider:
return
try:
# Create a simplified cache entry
cache_prompt = f"Cache this response for future reference:\n\nQuestion: {prompt}\n\nResponse: {response[:500]}..."
# Store in local Ollama for quick retrieval
# This helps if HF connection fails later
logger.info("💾 Caching response locally with Ollama...")
self.ollama_provider.generate(cache_prompt, [])
# Also store in Redis session for persistence
self._store_in_session_cache(prompt, response)
except Exception as e:
logger.warning(f"Failed to cache response locally: {e}")
def _store_in_session_cache(self, prompt: str, response: str):
"""Store response in Redis session cache"""
try:
user_session = session_manager.get_session("default_user")
cache = user_session.get("response_cache", {})
# Simple cache key
cache_key = hash(prompt) % 1000000
cache[str(cache_key)] = {
"prompt": prompt,
"response": response,
"timestamp": time.time()
}
# Keep only last 50 cached responses
if len(cache) > 50:
# Remove oldest entries
sorted_keys = sorted(cache.keys(), key=lambda k: cache[k]["timestamp"])
for key in sorted_keys[:-50]:
del cache[key]
user_session["response_cache"] = cache
session_manager.update_session("default_user", user_session)
except Exception as e:
logger.warning(f"Failed to store in session cache: {e}")
def get_cached_response(self, prompt: str) -> Optional[str]:
"""Get cached response if available"""
try:
user_session = session_manager.get_session("default_user")
cache = user_session.get("response_cache", {})
cache_key = str(hash(prompt) % 1000000)
if cache_key in cache:
cached_entry = cache[cache_key]
# Check if cache is still valid (1 hour)
if time.time() - cached_entry["timestamp"] < 3600:
return cached_entry["response"]
except Exception as e:
logger.warning(f"Failed to retrieve cached response: {e}")
return None
|