""" Ollama LLM adapter implementation. This adapter provides integration with local Ollama models, handling the specific API format and response structure of Ollama. Architecture Notes: - Converts between unified interface and Ollama API format - Handles Ollama-specific parameters and responses - Supports both regular and streaming generation - Maps Ollama errors to standard LLMError types """ import json import requests import logging from typing import Dict, Any, Optional, Iterator from urllib.parse import urljoin from .base_adapter import BaseLLMAdapter, LLMError, ModelNotFoundError from ..base import GenerationParams logger = logging.getLogger(__name__) class OllamaAdapter(BaseLLMAdapter): """ Adapter for Ollama local LLM integration. Features: - Support for all Ollama models - Streaming response support - Automatic model pulling if not available - Context window management - Format conversion for Ollama API Configuration: - base_url: Ollama server URL (default: http://localhost:11434) - timeout: Request timeout in seconds (default: 120) - auto_pull: Automatically pull models if not found (default: False) """ def __init__(self, model_name: str = "llama3.2", base_url: str = "http://localhost:11434", timeout: int = 120, auto_pull: bool = False, config: Optional[Dict[str, Any]] = None): """ Initialize Ollama adapter. Args: model_name: Ollama model name (e.g., "llama3.2", "mistral") base_url: Ollama server URL timeout: Request timeout in seconds auto_pull: Automatically pull models if not found config: Additional configuration """ # Merge config adapter_config = { 'base_url': base_url, 'timeout': timeout, 'auto_pull': auto_pull, **(config or {}) } super().__init__(model_name, adapter_config) self.base_url = adapter_config['base_url'].rstrip('/') self.timeout = adapter_config['timeout'] self.auto_pull = adapter_config['auto_pull'] # API endpoints self.generate_url = urljoin(self.base_url + '/', 'api/generate') self.chat_url = urljoin(self.base_url + '/', 'api/chat') self.tags_url = urljoin(self.base_url + '/', 'api/tags') self.pull_url = urljoin(self.base_url + '/', 'api/pull') logger.info(f"Initialized Ollama adapter for model '{model_name}' at {base_url}") def _make_request(self, prompt: str, params: GenerationParams) -> Dict[str, Any]: """ Make a request to Ollama API. Args: prompt: The prompt to send params: Generation parameters Returns: Ollama API response Raises: Various request exceptions """ # Convert to Ollama format ollama_params = self._convert_params(params) # Prepare request payload payload = { "model": self.model_name, "prompt": prompt, "stream": False, "options": ollama_params } try: # Make request response = requests.post( self.generate_url, json=payload, timeout=self.timeout ) # Check for errors if response.status_code == 404: # Model not found if self.auto_pull: logger.info(f"Model '{self.model_name}' not found, attempting to pull...") self._pull_model() # Retry request response = requests.post( self.generate_url, json=payload, timeout=self.timeout ) else: raise ModelNotFoundError(f"Model '{self.model_name}' not found. Set auto_pull=True to download it.") response.raise_for_status() return response.json() except requests.exceptions.Timeout: raise LLMError(f"Ollama request timed out after {self.timeout}s") except requests.exceptions.ConnectionError: raise LLMError(f"Failed to connect to Ollama at {self.base_url}. Is Ollama running?") except requests.exceptions.HTTPError as e: self._handle_http_error(e) except Exception as e: self._handle_provider_error(e) def _parse_response(self, response: Dict[str, Any]) -> str: """ Parse Ollama response to extract generated text. Args: response: Ollama API response Returns: Generated text """ # Ollama returns response in 'response' field text = response.get('response', '') # Log token usage if available if 'eval_count' in response: logger.debug(f"Ollama used {response['eval_count']} tokens for generation") return text def generate_streaming(self, prompt: str, params: GenerationParams) -> Iterator[str]: """ Generate a streaming response from Ollama. Args: prompt: The prompt to send params: Generation parameters Yields: Generated text chunks """ # Convert parameters ollama_params = self._convert_params(params) # Prepare streaming request payload = { "model": self.model_name, "prompt": prompt, "stream": True, "options": ollama_params } try: # Make streaming request response = requests.post( self.generate_url, json=payload, stream=True, timeout=self.timeout ) response.raise_for_status() # Process streaming response for line in response.iter_lines(): if line: try: chunk = json.loads(line) if 'response' in chunk: yield chunk['response'] # Check if generation is done if chunk.get('done', False): break except json.JSONDecodeError: logger.warning(f"Failed to parse streaming chunk: {line}") continue except Exception as e: self._handle_provider_error(e) def _get_provider_name(self) -> str: """Return the provider name.""" return "Ollama" def _validate_model(self) -> bool: """Check if the model exists in Ollama.""" try: response = requests.get(self.tags_url, timeout=10) response.raise_for_status() models = response.json().get('models', []) model_names = [model['name'] for model in models] # Check exact match or partial match (e.g., "llama3.2" matches "llama3.2:latest") for name in model_names: if self.model_name in name or name in self.model_name: return True return False except Exception as e: logger.warning(f"Failed to validate model: {str(e)}") # Assume model exists if we can't check return True def _supports_streaming(self) -> bool: """Ollama supports streaming.""" return True def _get_max_tokens(self) -> Optional[int]: """Get max tokens for current model.""" # Model-specific limits model_limits = { 'llama3.2': 4096, 'llama3.1': 128000, 'llama3': 8192, 'llama2': 4096, 'mistral': 8192, 'mixtral': 32768, 'gemma': 8192, 'gemma2': 8192, 'phi3': 4096, 'qwen2.5': 32768, } # Check for exact match or prefix for model, limit in model_limits.items(): if model in self.model_name.lower(): return limit # Default for unknown models return 4096 def _convert_params(self, params: GenerationParams) -> Dict[str, Any]: """ Convert unified parameters to Ollama format. Args: params: Unified generation parameters Returns: Ollama-specific parameters """ ollama_params = {} # Map common parameters if params.temperature is not None: ollama_params['temperature'] = params.temperature if params.max_tokens is not None: ollama_params['num_predict'] = params.max_tokens if params.top_p is not None: ollama_params['top_p'] = params.top_p if params.frequency_penalty is not None: ollama_params['repeat_penalty'] = 1.0 + params.frequency_penalty if params.stop_sequences: ollama_params['stop'] = params.stop_sequences # Add Ollama-specific defaults ollama_params.setdefault('seed', -1) # Random seed ollama_params.setdefault('num_ctx', 2048) # Context window return ollama_params def _pull_model(self) -> None: """Pull a model from Ollama registry.""" logger.info(f"Pulling model '{self.model_name}'...") payload = {"name": self.model_name, "stream": False} try: response = requests.post( self.pull_url, json=payload, timeout=600 # 10 minutes for model download ) response.raise_for_status() logger.info(f"Successfully pulled model '{self.model_name}'") except Exception as e: raise LLMError(f"Failed to pull model '{self.model_name}': {str(e)}") def _handle_http_error(self, error: requests.exceptions.HTTPError) -> None: """Handle HTTP errors from Ollama.""" if error.response.status_code == 404: raise ModelNotFoundError(f"Model '{self.model_name}' not found") elif error.response.status_code == 400: raise LLMError(f"Bad request to Ollama: {error.response.text}") elif error.response.status_code == 500: raise LLMError(f"Ollama server error: {error.response.text}") else: raise LLMError(f"HTTP error {error.response.status_code}: {error.response.text}") def _handle_provider_error(self, error: Exception) -> None: """Map Ollama-specific errors to standard errors.""" error_msg = str(error).lower() if 'connection' in error_msg: raise LLMError(f"Cannot connect to Ollama at {self.base_url}. Is Ollama running?") elif 'timeout' in error_msg: raise LLMError(f"Request to Ollama timed out") elif 'model' in error_msg and 'not found' in error_msg: raise ModelNotFoundError(f"Model '{self.model_name}' not found") else: super()._handle_provider_error(error)