import requests import logging import re from typing import List, Dict, Optional, Union from core.providers.base import LLMProvider from utils.config import config logger = logging.getLogger(__name__) class OllamaProvider(LLMProvider): """Ollama LLM provider implementation""" def __init__(self, model_name: str, timeout: int = 60, max_retries: int = 3): # Increased timeout from 30 to 60 super().__init__(model_name, timeout, max_retries) self.host = self._sanitize_host(config.ollama_host or "http://localhost:11434") # Headers to skip ngrok browser warning self.headers = { "ngrok-skip-browser-warning": "true", "User-Agent": "AI-Life-Coach-Ollama" } def _sanitize_host(self, host: str) -> str: """Sanitize host URL by removing whitespace and control characters""" if not host: return "http://localhost:11434" # Remove leading/trailing whitespace and control characters host = host.strip() # Remove any newlines or control characters host = re.sub(r'[\r\n\t\0]+', '', host) # Ensure URL has a scheme if not host.startswith(('http://', 'https://')): host = 'http://' + host return host def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]: """Generate a response synchronously""" try: return self._retry_with_backoff(self._generate_impl, prompt, conversation_history) except Exception as e: logger.error(f"Ollama generation failed: {e}") return None def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]: """Generate a response with streaming support""" try: return self._retry_with_backoff(self._stream_generate_impl, prompt, conversation_history) except Exception as e: logger.error(f"Ollama stream generation failed: {e}") return None def validate_model(self) -> bool: """Validate if the model is available""" try: response = requests.get( f"{self.host}/api/tags", headers=self.headers, timeout=self.timeout ) if response.status_code == 200: models = response.json().get("models", []) model_names = [model.get("name") for model in models] return self.model_name in model_names elif response.status_code == 404: # Try alternative endpoint response2 = requests.get( f"{self.host}", headers=self.headers, timeout=self.timeout ) return response2.status_code == 200 return False except Exception as e: logger.error(f"Model validation failed: {e}") return False def _generate_impl(self, prompt: str, conversation_history: List[Dict]) -> str: """Implementation of synchronous generation""" url = f"{self.host}/api/chat" messages = conversation_history.copy() # Add the current prompt if not already in history if not messages or messages[-1].get("content") != prompt: messages.append({"role": "user", "content": prompt}) payload = { "model": self.model_name, "messages": messages, "stream": False } response = requests.post( url, json=payload, headers=self.headers, timeout=self.timeout ) response.raise_for_status() result = response.json() return result["message"]["content"] def _stream_generate_impl(self, prompt: str, conversation_history: List[Dict]) -> List[str]: """Implementation of streaming generation""" url = f"{self.host}/api/chat" messages = conversation_history.copy() # Add the current prompt if not already in history if not messages or messages[-1].get("content") != prompt: messages.append({"role": "user", "content": prompt}) payload = { "model": self.model_name, "messages": messages, "stream": True } response = requests.post( url, json=payload, headers=self.headers, timeout=self.timeout, stream=True ) response.raise_for_status() chunks = [] for line in response.iter_lines(): if line: chunk = line.decode('utf-8') try: data = eval(chunk) # Simplified JSON parsing content = data.get("message", {}).get("content", "") if content: chunks.append(content) except: continue return chunks