|
import time |
|
import logging |
|
from typing import List, Dict, Optional, Union |
|
from src.llm.base_provider import LLMProvider |
|
from src.llm.hf_provider import HuggingFaceProvider |
|
from src.llm.ollama_provider import OllamaProvider |
|
from core.session import session_manager |
|
from utils.config import config |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class HybridProvider(LLMProvider): |
|
"""Hybrid provider that uses HF for heavy lifting and Ollama for local caching/summarization""" |
|
|
|
def __init__(self, model_name: str, timeout: int = 120, max_retries: int = 2): |
|
super().__init__(model_name, timeout, max_retries) |
|
self.hf_provider = None |
|
self.ollama_provider = None |
|
|
|
|
|
try: |
|
if config.hf_token: |
|
self.hf_provider = HuggingFaceProvider( |
|
model_name="DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf", |
|
timeout=120 |
|
) |
|
except Exception as e: |
|
logger.warning(f"Failed to initialize HF provider: {e}") |
|
|
|
try: |
|
if config.ollama_host: |
|
self.ollama_provider = OllamaProvider( |
|
model_name=config.local_model_name, |
|
timeout=60 |
|
) |
|
except Exception as e: |
|
logger.warning(f"Failed to initialize Ollama provider: {e}") |
|
|
|
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]: |
|
"""Generate response using hybrid approach""" |
|
try: |
|
|
|
hf_response = self._get_hf_response(prompt, conversation_history) |
|
|
|
if not hf_response: |
|
raise Exception("HF Endpoint failed to provide response") |
|
|
|
|
|
self._cache_response_locally(prompt, hf_response, conversation_history) |
|
|
|
|
|
|
|
return hf_response |
|
|
|
except Exception as e: |
|
logger.error(f"Hybrid generation failed: {e}") |
|
|
|
|
|
if self.ollama_provider: |
|
try: |
|
logger.info("Falling back to Ollama for local response") |
|
return self.ollama_provider.generate(prompt, conversation_history) |
|
except Exception as fallback_error: |
|
logger.error(f"Ollama fallback also failed: {fallback_error}") |
|
|
|
raise Exception(f"Both HF Endpoint and Ollama failed: {str(e)}") |
|
|
|
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]: |
|
"""Stream response using hybrid approach""" |
|
try: |
|
|
|
if self.hf_provider: |
|
return self.hf_provider.stream_generate(prompt, conversation_history) |
|
elif self.ollama_provider: |
|
return self.ollama_provider.stream_generate(prompt, conversation_history) |
|
else: |
|
raise Exception("No providers available") |
|
except Exception as e: |
|
logger.error(f"Hybrid stream generation failed: {e}") |
|
raise |
|
|
|
def _get_hf_response(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]: |
|
"""Get response from HF Endpoint with fallback handling""" |
|
if not self.hf_provider: |
|
return None |
|
|
|
try: |
|
logger.info("🚀 Getting detailed response from HF Endpoint...") |
|
response = self.hf_provider.generate(prompt, conversation_history) |
|
logger.info("✅ HF Endpoint response received") |
|
return response |
|
except Exception as e: |
|
logger.error(f"HF Endpoint failed: {e}") |
|
|
|
return None |
|
|
|
def _cache_response_locally(self, prompt: str, response: str, conversation_history: List[Dict]): |
|
"""Cache HF response locally using Ollama for future reference""" |
|
if not self.ollama_provider: |
|
return |
|
|
|
try: |
|
|
|
cache_prompt = f"Cache this response for future reference:\n\nQuestion: {prompt}\n\nResponse: {response[:500]}..." |
|
|
|
|
|
|
|
logger.info("💾 Caching response locally with Ollama...") |
|
self.ollama_provider.generate(cache_prompt, []) |
|
|
|
|
|
self._store_in_session_cache(prompt, response) |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to cache response locally: {e}") |
|
|
|
def _store_in_session_cache(self, prompt: str, response: str): |
|
"""Store response in Redis session cache""" |
|
try: |
|
user_session = session_manager.get_session("default_user") |
|
cache = user_session.get("response_cache", {}) |
|
|
|
|
|
cache_key = hash(prompt) % 1000000 |
|
cache[str(cache_key)] = { |
|
"prompt": prompt, |
|
"response": response, |
|
"timestamp": time.time() |
|
} |
|
|
|
|
|
if len(cache) > 50: |
|
|
|
sorted_keys = sorted(cache.keys(), key=lambda k: cache[k]["timestamp"]) |
|
for key in sorted_keys[:-50]: |
|
del cache[key] |
|
|
|
user_session["response_cache"] = cache |
|
session_manager.update_session("default_user", user_session) |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to store in session cache: {e}") |
|
|
|
def get_cached_response(self, prompt: str) -> Optional[str]: |
|
"""Get cached response if available""" |
|
try: |
|
user_session = session_manager.get_session("default_user") |
|
cache = user_session.get("response_cache", {}) |
|
|
|
cache_key = str(hash(prompt) % 1000000) |
|
if cache_key in cache: |
|
cached_entry = cache[cache_key] |
|
|
|
if time.time() - cached_entry["timestamp"] < 3600: |
|
return cached_entry["response"] |
|
except Exception as e: |
|
logger.warning(f"Failed to retrieve cached response: {e}") |
|
|
|
return None |
|
|