File size: 11,734 Bytes
5e1a30c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
"""
Ollama LLM adapter implementation.

This adapter provides integration with local Ollama models, handling
the specific API format and response structure of Ollama.

Architecture Notes:
- Converts between unified interface and Ollama API format
- Handles Ollama-specific parameters and responses
- Supports both regular and streaming generation
- Maps Ollama errors to standard LLMError types
"""

import json
import requests
import logging
from typing import Dict, Any, Optional, Iterator
from urllib.parse import urljoin

from .base_adapter import BaseLLMAdapter, LLMError, ModelNotFoundError
from ..base import GenerationParams

logger = logging.getLogger(__name__)


class OllamaAdapter(BaseLLMAdapter):
    """
    Adapter for Ollama local LLM integration.
    
    Features:
    - Support for all Ollama models
    - Streaming response support
    - Automatic model pulling if not available
    - Context window management
    - Format conversion for Ollama API
    
    Configuration:
    - base_url: Ollama server URL (default: http://localhost:11434)
    - timeout: Request timeout in seconds (default: 120)
    - auto_pull: Automatically pull models if not found (default: False)
    """
    
    def __init__(self,
                 model_name: str = "llama3.2",
                 base_url: str = "http://localhost:11434",
                 timeout: int = 120,
                 auto_pull: bool = False,
                 config: Optional[Dict[str, Any]] = None):
        """
        Initialize Ollama adapter.
        
        Args:
            model_name: Ollama model name (e.g., "llama3.2", "mistral")
            base_url: Ollama server URL
            timeout: Request timeout in seconds
            auto_pull: Automatically pull models if not found
            config: Additional configuration
        """
        # Merge config
        adapter_config = {
            'base_url': base_url,
            'timeout': timeout,
            'auto_pull': auto_pull,
            **(config or {})
        }
        
        super().__init__(model_name, adapter_config)
        
        self.base_url = adapter_config['base_url'].rstrip('/')
        self.timeout = adapter_config['timeout']
        self.auto_pull = adapter_config['auto_pull']
        
        # API endpoints
        self.generate_url = urljoin(self.base_url + '/', 'api/generate')
        self.chat_url = urljoin(self.base_url + '/', 'api/chat')
        self.tags_url = urljoin(self.base_url + '/', 'api/tags')
        self.pull_url = urljoin(self.base_url + '/', 'api/pull')
        
        logger.info(f"Initialized Ollama adapter for model '{model_name}' at {base_url}")
    
    def _make_request(self, prompt: str, params: GenerationParams) -> Dict[str, Any]:
        """
        Make a request to Ollama API.
        
        Args:
            prompt: The prompt to send
            params: Generation parameters
            
        Returns:
            Ollama API response
            
        Raises:
            Various request exceptions
        """
        # Convert to Ollama format
        ollama_params = self._convert_params(params)
        
        # Prepare request payload
        payload = {
            "model": self.model_name,
            "prompt": prompt,
            "stream": False,
            "options": ollama_params
        }
        
        try:
            # Make request
            response = requests.post(
                self.generate_url,
                json=payload,
                timeout=self.timeout
            )
            
            # Check for errors
            if response.status_code == 404:
                # Model not found
                if self.auto_pull:
                    logger.info(f"Model '{self.model_name}' not found, attempting to pull...")
                    self._pull_model()
                    # Retry request
                    response = requests.post(
                        self.generate_url,
                        json=payload,
                        timeout=self.timeout
                    )
                else:
                    raise ModelNotFoundError(f"Model '{self.model_name}' not found. Set auto_pull=True to download it.")
            
            response.raise_for_status()
            
            return response.json()
            
        except requests.exceptions.Timeout:
            raise LLMError(f"Ollama request timed out after {self.timeout}s")
        except requests.exceptions.ConnectionError:
            raise LLMError(f"Failed to connect to Ollama at {self.base_url}. Is Ollama running?")
        except requests.exceptions.HTTPError as e:
            self._handle_http_error(e)
        except Exception as e:
            self._handle_provider_error(e)
    
    def _parse_response(self, response: Dict[str, Any]) -> str:
        """
        Parse Ollama response to extract generated text.
        
        Args:
            response: Ollama API response
            
        Returns:
            Generated text
        """
        # Ollama returns response in 'response' field
        text = response.get('response', '')
        
        # Log token usage if available
        if 'eval_count' in response:
            logger.debug(f"Ollama used {response['eval_count']} tokens for generation")
        
        return text
    
    def generate_streaming(self, prompt: str, params: GenerationParams) -> Iterator[str]:
        """
        Generate a streaming response from Ollama.
        
        Args:
            prompt: The prompt to send
            params: Generation parameters
            
        Yields:
            Generated text chunks
        """
        # Convert parameters
        ollama_params = self._convert_params(params)
        
        # Prepare streaming request
        payload = {
            "model": self.model_name,
            "prompt": prompt,
            "stream": True,
            "options": ollama_params
        }
        
        try:
            # Make streaming request
            response = requests.post(
                self.generate_url,
                json=payload,
                stream=True,
                timeout=self.timeout
            )
            
            response.raise_for_status()
            
            # Process streaming response
            for line in response.iter_lines():
                if line:
                    try:
                        chunk = json.loads(line)
                        if 'response' in chunk:
                            yield chunk['response']
                        
                        # Check if generation is done
                        if chunk.get('done', False):
                            break
                            
                    except json.JSONDecodeError:
                        logger.warning(f"Failed to parse streaming chunk: {line}")
                        continue
                        
        except Exception as e:
            self._handle_provider_error(e)
    
    def _get_provider_name(self) -> str:
        """Return the provider name."""
        return "Ollama"
    
    def _validate_model(self) -> bool:
        """Check if the model exists in Ollama."""
        try:
            response = requests.get(self.tags_url, timeout=10)
            response.raise_for_status()
            
            models = response.json().get('models', [])
            model_names = [model['name'] for model in models]
            
            # Check exact match or partial match (e.g., "llama3.2" matches "llama3.2:latest")
            for name in model_names:
                if self.model_name in name or name in self.model_name:
                    return True
                    
            return False
            
        except Exception as e:
            logger.warning(f"Failed to validate model: {str(e)}")
            # Assume model exists if we can't check
            return True
    
    def _supports_streaming(self) -> bool:
        """Ollama supports streaming."""
        return True
    
    def _get_max_tokens(self) -> Optional[int]:
        """Get max tokens for current model."""
        # Model-specific limits
        model_limits = {
            'llama3.2': 4096,
            'llama3.1': 128000,
            'llama3': 8192,
            'llama2': 4096,
            'mistral': 8192,
            'mixtral': 32768,
            'gemma': 8192,
            'gemma2': 8192,
            'phi3': 4096,
            'qwen2.5': 32768,
        }
        
        # Check for exact match or prefix
        for model, limit in model_limits.items():
            if model in self.model_name.lower():
                return limit
                
        # Default for unknown models
        return 4096
    
    def _convert_params(self, params: GenerationParams) -> Dict[str, Any]:
        """
        Convert unified parameters to Ollama format.
        
        Args:
            params: Unified generation parameters
            
        Returns:
            Ollama-specific parameters
        """
        ollama_params = {}
        
        # Map common parameters
        if params.temperature is not None:
            ollama_params['temperature'] = params.temperature
        if params.max_tokens is not None:
            ollama_params['num_predict'] = params.max_tokens
        if params.top_p is not None:
            ollama_params['top_p'] = params.top_p
        if params.frequency_penalty is not None:
            ollama_params['repeat_penalty'] = 1.0 + params.frequency_penalty
        if params.stop_sequences:
            ollama_params['stop'] = params.stop_sequences
            
        # Add Ollama-specific defaults
        ollama_params.setdefault('seed', -1)  # Random seed
        ollama_params.setdefault('num_ctx', 2048)  # Context window
        
        return ollama_params
    
    def _pull_model(self) -> None:
        """Pull a model from Ollama registry."""
        logger.info(f"Pulling model '{self.model_name}'...")
        
        payload = {"name": self.model_name, "stream": False}
        
        try:
            response = requests.post(
                self.pull_url,
                json=payload,
                timeout=600  # 10 minutes for model download
            )
            response.raise_for_status()
            
            logger.info(f"Successfully pulled model '{self.model_name}'")
            
        except Exception as e:
            raise LLMError(f"Failed to pull model '{self.model_name}': {str(e)}")
    
    def _handle_http_error(self, error: requests.exceptions.HTTPError) -> None:
        """Handle HTTP errors from Ollama."""
        if error.response.status_code == 404:
            raise ModelNotFoundError(f"Model '{self.model_name}' not found")
        elif error.response.status_code == 400:
            raise LLMError(f"Bad request to Ollama: {error.response.text}")
        elif error.response.status_code == 500:
            raise LLMError(f"Ollama server error: {error.response.text}")
        else:
            raise LLMError(f"HTTP error {error.response.status_code}: {error.response.text}")
    
    def _handle_provider_error(self, error: Exception) -> None:
        """Map Ollama-specific errors to standard errors."""
        error_msg = str(error).lower()
        
        if 'connection' in error_msg:
            raise LLMError(f"Cannot connect to Ollama at {self.base_url}. Is Ollama running?")
        elif 'timeout' in error_msg:
            raise LLMError(f"Request to Ollama timed out")
        elif 'model' in error_msg and 'not found' in error_msg:
            raise ModelNotFoundError(f"Model '{self.model_name}' not found")
        else:
            super()._handle_provider_error(error)