|
import requests |
|
import logging |
|
import re |
|
from typing import List, Dict, Optional, Union |
|
from core.providers.base import LLMProvider |
|
from utils.config import config |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class OllamaProvider(LLMProvider): |
|
"""Ollama LLM provider implementation""" |
|
|
|
def __init__(self, model_name: str, timeout: int = 60, max_retries: int = 3): |
|
super().__init__(model_name, timeout, max_retries) |
|
self.host = self._sanitize_host(config.ollama_host or "http://localhost:11434") |
|
|
|
self.headers = { |
|
"ngrok-skip-browser-warning": "true", |
|
"User-Agent": "AI-Life-Coach-Ollama" |
|
} |
|
|
|
def _sanitize_host(self, host: str) -> str: |
|
"""Sanitize host URL by removing whitespace and control characters""" |
|
if not host: |
|
return "http://localhost:11434" |
|
|
|
host = host.strip() |
|
|
|
host = re.sub(r'[\r\n\t\0]+', '', host) |
|
|
|
if not host.startswith(('http://', 'https://')): |
|
host = 'http://' + host |
|
return host |
|
|
|
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]: |
|
"""Generate a response synchronously""" |
|
try: |
|
return self._retry_with_backoff(self._generate_impl, prompt, conversation_history) |
|
except Exception as e: |
|
logger.error(f"Ollama generation failed: {e}") |
|
return None |
|
|
|
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]: |
|
"""Generate a response with streaming support""" |
|
try: |
|
return self._retry_with_backoff(self._stream_generate_impl, prompt, conversation_history) |
|
except Exception as e: |
|
logger.error(f"Ollama stream generation failed: {e}") |
|
return None |
|
|
|
def validate_model(self) -> bool: |
|
"""Validate if the model is available""" |
|
try: |
|
response = requests.get( |
|
f"{self.host}/api/tags", |
|
headers=self.headers, |
|
timeout=self.timeout |
|
) |
|
if response.status_code == 200: |
|
models = response.json().get("models", []) |
|
model_names = [model.get("name") for model in models] |
|
return self.model_name in model_names |
|
elif response.status_code == 404: |
|
|
|
response2 = requests.get( |
|
f"{self.host}", |
|
headers=self.headers, |
|
timeout=self.timeout |
|
) |
|
return response2.status_code == 200 |
|
return False |
|
except Exception as e: |
|
logger.error(f"Model validation failed: {e}") |
|
return False |
|
|
|
def _generate_impl(self, prompt: str, conversation_history: List[Dict]) -> str: |
|
"""Implementation of synchronous generation""" |
|
url = f"{self.host}/api/chat" |
|
messages = conversation_history.copy() |
|
|
|
if not messages or messages[-1].get("content") != prompt: |
|
messages.append({"role": "user", "content": prompt}) |
|
payload = { |
|
"model": self.model_name, |
|
"messages": messages, |
|
"stream": False |
|
} |
|
response = requests.post( |
|
url, |
|
json=payload, |
|
headers=self.headers, |
|
timeout=self.timeout |
|
) |
|
response.raise_for_status() |
|
result = response.json() |
|
return result["message"]["content"] |
|
|
|
def _stream_generate_impl(self, prompt: str, conversation_history: List[Dict]) -> List[str]: |
|
"""Implementation of streaming generation""" |
|
url = f"{self.host}/api/chat" |
|
messages = conversation_history.copy() |
|
|
|
if not messages or messages[-1].get("content") != prompt: |
|
messages.append({"role": "user", "content": prompt}) |
|
payload = { |
|
"model": self.model_name, |
|
"messages": messages, |
|
"stream": True |
|
} |
|
response = requests.post( |
|
url, |
|
json=payload, |
|
headers=self.headers, |
|
timeout=self.timeout, |
|
stream=True |
|
) |
|
response.raise_for_status() |
|
chunks = [] |
|
for line in response.iter_lines(): |
|
if line: |
|
chunk = line.decode('utf-8') |
|
try: |
|
data = eval(chunk) |
|
content = data.get("message", {}).get("content", "") |
|
if content: |
|
chunks.append(content) |
|
except: |
|
continue |
|
return chunks |
|
|