AI-Life-Coach-Streamlit2 / src /llm /hybrid_provider.py
rdune71's picture
Implement hybrid AI architecture with HF Endpoint heavy lifting and Ollama local caching
5c1efea
import time
import logging
from typing import List, Dict, Optional, Union
from src.llm.base_provider import LLMProvider
from src.llm.hf_provider import HuggingFaceProvider
from src.llm.ollama_provider import OllamaProvider
from core.session import session_manager
from utils.config import config
logger = logging.getLogger(__name__)
class HybridProvider(LLMProvider):
"""Hybrid provider that uses HF for heavy lifting and Ollama for local caching/summarization"""
def __init__(self, model_name: str, timeout: int = 120, max_retries: int = 2):
super().__init__(model_name, timeout, max_retries)
self.hf_provider = None
self.ollama_provider = None
# Initialize providers
try:
if config.hf_token:
self.hf_provider = HuggingFaceProvider(
model_name="DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf",
timeout=120
)
except Exception as e:
logger.warning(f"Failed to initialize HF provider: {e}")
try:
if config.ollama_host:
self.ollama_provider = OllamaProvider(
model_name=config.local_model_name,
timeout=60
)
except Exception as e:
logger.warning(f"Failed to initialize Ollama provider: {e}")
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
"""Generate response using hybrid approach"""
try:
# Step 1: Get heavy lifting from HF Endpoint
hf_response = self._get_hf_response(prompt, conversation_history)
if not hf_response:
raise Exception("HF Endpoint failed to provide response")
# Step 2: Store HF response in local cache via Ollama
self._cache_response_locally(prompt, hf_response, conversation_history)
# Step 3: Optionally create local summary (if needed)
# For now, return HF response directly but with local backup
return hf_response
except Exception as e:
logger.error(f"Hybrid generation failed: {e}")
# Fallback to Ollama if available
if self.ollama_provider:
try:
logger.info("Falling back to Ollama for local response")
return self.ollama_provider.generate(prompt, conversation_history)
except Exception as fallback_error:
logger.error(f"Ollama fallback also failed: {fallback_error}")
raise Exception(f"Both HF Endpoint and Ollama failed: {str(e)}")
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
"""Stream response using hybrid approach"""
try:
# Get streaming response from HF
if self.hf_provider:
return self.hf_provider.stream_generate(prompt, conversation_history)
elif self.ollama_provider:
return self.ollama_provider.stream_generate(prompt, conversation_history)
else:
raise Exception("No providers available")
except Exception as e:
logger.error(f"Hybrid stream generation failed: {e}")
raise
def _get_hf_response(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
"""Get response from HF Endpoint with fallback handling"""
if not self.hf_provider:
return None
try:
logger.info("🚀 Getting detailed response from HF Endpoint...")
response = self.hf_provider.generate(prompt, conversation_history)
logger.info("✅ HF Endpoint response received")
return response
except Exception as e:
logger.error(f"HF Endpoint failed: {e}")
# Don't raise here, let hybrid provider handle fallback
return None
def _cache_response_locally(self, prompt: str, response: str, conversation_history: List[Dict]):
"""Cache HF response locally using Ollama for future reference"""
if not self.ollama_provider:
return
try:
# Create a simplified cache entry
cache_prompt = f"Cache this response for future reference:\n\nQuestion: {prompt}\n\nResponse: {response[:500]}..."
# Store in local Ollama for quick retrieval
# This helps if HF connection fails later
logger.info("💾 Caching response locally with Ollama...")
self.ollama_provider.generate(cache_prompt, [])
# Also store in Redis session for persistence
self._store_in_session_cache(prompt, response)
except Exception as e:
logger.warning(f"Failed to cache response locally: {e}")
def _store_in_session_cache(self, prompt: str, response: str):
"""Store response in Redis session cache"""
try:
user_session = session_manager.get_session("default_user")
cache = user_session.get("response_cache", {})
# Simple cache key
cache_key = hash(prompt) % 1000000
cache[str(cache_key)] = {
"prompt": prompt,
"response": response,
"timestamp": time.time()
}
# Keep only last 50 cached responses
if len(cache) > 50:
# Remove oldest entries
sorted_keys = sorted(cache.keys(), key=lambda k: cache[k]["timestamp"])
for key in sorted_keys[:-50]:
del cache[key]
user_session["response_cache"] = cache
session_manager.update_session("default_user", user_session)
except Exception as e:
logger.warning(f"Failed to store in session cache: {e}")
def get_cached_response(self, prompt: str) -> Optional[str]:
"""Get cached response if available"""
try:
user_session = session_manager.get_session("default_user")
cache = user_session.get("response_cache", {})
cache_key = str(hash(prompt) % 1000000)
if cache_key in cache:
cached_entry = cache[cache_key]
# Check if cache is still valid (1 hour)
if time.time() - cached_entry["timestamp"] < 3600:
return cached_entry["response"]
except Exception as e:
logger.warning(f"Failed to retrieve cached response: {e}")
return None