#!/usr/bin/env python3 """ Ollama-based answer generator for local inference. Provides the same interface as HuggingFaceAnswerGenerator but uses local Ollama server for model inference. """ import time import requests import json import re import sys from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass # Import shared components from .hf_answer_generator import Citation, GeneratedAnswer from .prompt_templates import TechnicalPromptTemplates # Import standard interfaces (add this for the adapter) try: from pathlib import Path import sys project_root = Path(__file__).parent.parent.parent.parent.parent sys.path.append(str(project_root)) from src.core.interfaces import Document, Answer, AnswerGenerator except ImportError: # Fallback for standalone usage Document = None Answer = None AnswerGenerator = object class OllamaAnswerGenerator(AnswerGenerator if AnswerGenerator != object else object): """ Generates answers using local Ollama server. Perfect for: - Local development - Privacy-sensitive applications - No API rate limits - Consistent performance - Offline operation """ def __init__( self, model_name: str = "llama3.2:3b", base_url: str = "http://localhost:11434", temperature: float = 0.3, max_tokens: int = 512, ): """ Initialize Ollama answer generator. Args: model_name: Ollama model to use (e.g., "llama3.2:3b", "mistral") base_url: Ollama server URL temperature: Generation temperature max_tokens: Maximum tokens to generate """ self.model_name = model_name self.base_url = base_url.rstrip("/") self.temperature = temperature self.max_tokens = max_tokens # Test connection self._test_connection() def _test_connection(self): """Test if Ollama server is accessible.""" # Reduce retries for faster initialization - container should be ready quickly max_retries = 12 # Wait up to 60 seconds for Ollama to start retry_delay = 5 print( f"🔧 Testing connection to {self.base_url}/api/tags...", file=sys.stderr, flush=True, ) for attempt in range(max_retries): try: response = requests.get(f"{self.base_url}/api/tags", timeout=8) if response.status_code == 200: print( f"✅ Connected to Ollama at {self.base_url}", file=sys.stderr, flush=True, ) # Check if our model is available models = response.json().get("models", []) model_names = [m["name"] for m in models] if self.model_name in model_names: print( f"✅ Model {self.model_name} is available", file=sys.stderr, flush=True, ) return # Success! else: print( f"⚠️ Model {self.model_name} not found. Available: {model_names}", file=sys.stderr, flush=True, ) if models: # If any models are available, use the first one fallback_model = model_names[0] print( f"🔄 Using fallback model: {fallback_model}", file=sys.stderr, flush=True, ) self.model_name = fallback_model return else: print( f"📥 No models found, will try to pull {self.model_name}", file=sys.stderr, flush=True, ) # Try to pull the model self._pull_model(self.model_name) return else: print(f"⚠️ Ollama server returned status {response.status_code}") if attempt < max_retries - 1: print( f"🔄 Retry {attempt + 1}/{max_retries} in {retry_delay} seconds..." ) time.sleep(retry_delay) continue except requests.exceptions.ConnectionError: if attempt < max_retries - 1: print( f"⏳ Ollama not ready yet, retry {attempt + 1}/{max_retries} in {retry_delay} seconds..." ) time.sleep(retry_delay) continue else: raise Exception( f"Cannot connect to Ollama server at {self.base_url} after 60 seconds. Check if it's running." ) except requests.exceptions.Timeout: if attempt < max_retries - 1: print(f"⏳ Ollama timeout, retry {attempt + 1}/{max_retries}...") time.sleep(retry_delay) continue else: raise Exception("Ollama server timeout after multiple retries.") except Exception as e: if attempt < max_retries - 1: print(f"⚠️ Ollama error: {e}, retry {attempt + 1}/{max_retries}...") time.sleep(retry_delay) continue else: raise Exception( f"Ollama connection failed after {max_retries} attempts: {e}" ) raise Exception("Failed to connect to Ollama after all retries") def _pull_model(self, model_name: str): """Pull a model if it's not available.""" try: print(f"📥 Pulling model {model_name}...") pull_response = requests.post( f"{self.base_url}/api/pull", json={"name": model_name}, timeout=300, # 5 minutes for model download ) if pull_response.status_code == 200: print(f"✅ Successfully pulled {model_name}") else: print(f"⚠️ Failed to pull {model_name}: {pull_response.status_code}") # Try smaller models as fallback fallback_models = ["llama3.2:1b", "llama2:latest", "mistral:latest"] for fallback in fallback_models: try: print(f"🔄 Trying fallback model: {fallback}") fallback_response = requests.post( f"{self.base_url}/api/pull", json={"name": fallback}, timeout=300, ) if fallback_response.status_code == 200: print(f"✅ Successfully pulled fallback {fallback}") self.model_name = fallback return except: continue raise Exception(f"Failed to pull {model_name} or any fallback models") except Exception as e: print(f"❌ Model pull failed: {e}") raise def _format_context(self, chunks: List[Dict[str, Any]]) -> str: """Format retrieved chunks into context.""" context_parts = [] for i, chunk in enumerate(chunks): chunk_text = chunk.get("content", chunk.get("text", "")) page_num = chunk.get("metadata", {}).get("page_number", "unknown") source = chunk.get("metadata", {}).get("source", "unknown") context_parts.append( f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n" ) return "\n---\n".join(context_parts) def _create_prompt(self, query: str, context: str, chunks: List[Dict[str, Any]]) -> str: """Create optimized prompt with dynamic length constraints and citation instructions.""" # Get the appropriate template based on query type prompt_data = TechnicalPromptTemplates.format_prompt_with_template( query=query, context=context ) # Create dynamic citation instructions based on available chunks num_chunks = len(chunks) available_chunks = ", ".join([f"[chunk_{i+1}]" for i in range(min(num_chunks, 5))]) # Show max 5 examples # Create appropriate example based on actual chunks if num_chunks == 1: citation_example = "RISC-V is an open-source ISA [chunk_1]." elif num_chunks == 2: citation_example = "RISC-V is an open-source ISA [chunk_1] that supports multiple data widths [chunk_2]." else: citation_example = "RISC-V is an open-source ISA [chunk_1] that supports multiple data widths [chunk_2] and provides extensions [chunk_3]." # Determine optimal answer length based on query complexity target_length = self._determine_target_length(query, chunks) length_instruction = self._create_length_instruction(target_length) # Format for different model types if "llama" in self.model_name.lower(): # Llama-3.2 format with technical prompt templates return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> {prompt_data['system']} MANDATORY CITATION RULES: - ONLY use available chunks: {available_chunks} - You have {num_chunks} chunks available - DO NOT cite chunk numbers higher than {num_chunks} - Every technical claim MUST have a citation from available chunks - Example: "{citation_example}" {length_instruction} <|eot_id|><|start_header_id|>user<|end_header_id|> {prompt_data['user']} CRITICAL: You MUST cite sources ONLY from available chunks: {available_chunks}. DO NOT use chunk numbers > {num_chunks}. {length_instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>""" elif "mistral" in self.model_name.lower(): # Mistral format with technical templates return f"""[INST] {prompt_data['system']} Context: {context} Question: {query} MANDATORY: ONLY use available chunks: {available_chunks}. DO NOT cite chunk numbers > {num_chunks}. {length_instruction} [/INST]""" else: # Generic format with technical templates return f"""{prompt_data['system']} Context: {context} Question: {query} MANDATORY CITATIONS: ONLY use available chunks: {available_chunks}. DO NOT cite chunk numbers > {num_chunks}. {length_instruction} Answer:""" def _determine_target_length(self, query: str, chunks: List[Dict[str, Any]]) -> int: """ Determine optimal answer length based on query complexity. Target range: 150-400 characters (down from 1000-2600) """ # Analyze query complexity query_length = len(query) query_words = len(query.split()) # Check for complexity indicators complex_words = [ "explain", "describe", "analyze", "compare", "contrast", "evaluate", "discuss", "detail", "elaborate", "comprehensive" ] simple_words = [ "what", "define", "list", "name", "identify", "is", "are" ] query_lower = query.lower() is_complex = any(word in query_lower for word in complex_words) is_simple = any(word in query_lower for word in simple_words) # Base length from query type if is_complex: base_length = 350 # Complex queries get longer answers elif is_simple: base_length = 200 # Simple queries get shorter answers else: base_length = 275 # Default middle ground # Adjust based on available context context_factor = min(len(chunks) * 25, 75) # More context allows longer answers # Adjust based on query length query_factor = min(query_words * 5, 50) # Longer queries allow longer answers target_length = base_length + context_factor + query_factor # Constrain to target range return max(150, min(target_length, 400)) def _create_length_instruction(self, target_length: int) -> str: """Create length instruction based on target length.""" if target_length <= 200: return f"ANSWER LENGTH: Keep your answer concise and focused, approximately {target_length} characters. Be direct and to the point." elif target_length <= 300: return f"ANSWER LENGTH: Provide a clear and informative answer, approximately {target_length} characters. Include key details but avoid unnecessary elaboration." else: return f"ANSWER LENGTH: Provide a comprehensive but concise answer, approximately {target_length} characters. Include important details while maintaining clarity." def _call_ollama(self, prompt: str) -> str: """Call Ollama API for generation.""" payload = { "model": self.model_name, "prompt": prompt, "stream": False, "options": { "temperature": self.temperature, "num_predict": self.max_tokens, "top_p": 0.9, "repeat_penalty": 1.1, }, } try: response = requests.post( f"{self.base_url}/api/generate", json=payload, timeout=300 ) response.raise_for_status() result = response.json() return result.get("response", "").strip() except requests.exceptions.RequestException as e: print(f"❌ Ollama API error: {e}") return f"Error communicating with Ollama: {str(e)}" except Exception as e: print(f"❌ Unexpected error: {e}") return f"Unexpected error: {str(e)}" def _extract_citations( self, answer: str, chunks: List[Dict[str, Any]] ) -> Tuple[str, List[Citation]]: """Extract citations from the generated answer.""" citations = [] citation_pattern = r"\[chunk_(\d+)\]" cited_chunks = set() # Find [chunk_X] citations matches = re.finditer(citation_pattern, answer) for match in matches: chunk_idx = int(match.group(1)) - 1 # Convert to 0-based index if 0 <= chunk_idx < len(chunks): cited_chunks.add(chunk_idx) # FALLBACK: If no explicit citations found but we have an answer and chunks, # create citations for the top chunks that were likely used if not cited_chunks and chunks and len(answer.strip()) > 50: # Use the top chunks that were provided as likely sources num_fallback_citations = min(3, len(chunks)) # Use top 3 chunks max cited_chunks = set(range(num_fallback_citations)) print( f"🔧 Fallback: Creating {num_fallback_citations} citations for answer without explicit [chunk_X] references", file=sys.stderr, flush=True, ) # Create Citation objects chunk_to_source = {} for idx in cited_chunks: chunk = chunks[idx] citation = Citation( chunk_id=chunk.get("id", f"chunk_{idx}"), page_number=chunk.get("metadata", {}).get("page_number", 0), source_file=chunk.get("metadata", {}).get("source", "unknown"), relevance_score=chunk.get("score", 0.0), text_snippet=chunk.get("content", chunk.get("text", ""))[:200] + "...", ) citations.append(citation) # Don't replace chunk references - keep them as proper citations # The issue was that replacing [chunk_X] with "the documentation" creates repetitive text # Instead, we should keep the proper citation format pass # Keep the answer as-is with proper [chunk_X] citations # Don't replace citations with repetitive text natural_answer = re.sub(r"\s+", " ", answer).strip() return natural_answer, citations def _calculate_confidence( self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]] ) -> float: """ Calculate confidence score with expanded multi-factor assessment. Enhanced algorithm expands range from 0.75-0.95 to 0.3-0.9 with: - Context quality assessment - Citation quality evaluation - Semantic relevance scoring - Off-topic detection - Answer completeness analysis """ if not answer or len(answer.strip()) < 10: return 0.1 # 1. Context Quality Assessment (0.3-0.6 base range) context_quality = self._assess_context_quality(chunks) # 2. Citation Quality Evaluation (0.0-0.2 boost) citation_quality = self._assess_citation_quality(citations, chunks) # 3. Semantic Relevance Scoring (0.0-0.15 boost) semantic_relevance = self._assess_semantic_relevance(answer, chunks) # 4. Off-topic Detection (-0.4 penalty if off-topic) off_topic_penalty = self._detect_off_topic(answer, chunks) # 5. Answer Completeness Analysis (0.0-0.1 boost) completeness_bonus = self._assess_answer_completeness(answer, len(chunks)) # Combine all factors confidence = ( context_quality + citation_quality + semantic_relevance + completeness_bonus + off_topic_penalty ) # Apply uncertainty penalty uncertainty_phrases = [ "insufficient information", "cannot determine", "not available in the provided documents", "I don't have enough context", "the context doesn't seem to provide" ] if any(phrase in answer.lower() for phrase in uncertainty_phrases): confidence *= 0.4 # Stronger penalty for uncertainty # Constrain to target range 0.3-0.9 return max(0.3, min(confidence, 0.9)) def _assess_context_quality(self, chunks: List[Dict[str, Any]]) -> float: """Assess quality of context chunks (0.3-0.6 range).""" if not chunks: return 0.3 # Base score from chunk count if len(chunks) >= 3: base_score = 0.6 elif len(chunks) >= 2: base_score = 0.5 else: base_score = 0.4 # Quality adjustments based on chunk content avg_chunk_length = sum(len(chunk.get("content", chunk.get("text", ""))) for chunk in chunks) / len(chunks) if avg_chunk_length > 500: # Rich content base_score += 0.05 elif avg_chunk_length < 100: # Sparse content base_score -= 0.05 return max(0.3, min(base_score, 0.6)) def _assess_citation_quality(self, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float: """Assess citation quality (0.0-0.2 range).""" if not citations or not chunks: return 0.0 # Citation coverage bonus citation_ratio = len(citations) / min(len(chunks), 3) coverage_bonus = 0.1 * citation_ratio # Citation diversity bonus (multiple sources) unique_sources = len(set(cit.source_file for cit in citations)) diversity_bonus = 0.05 * min(unique_sources / max(len(chunks), 1), 1.0) return min(coverage_bonus + diversity_bonus, 0.2) def _assess_semantic_relevance(self, answer: str, chunks: List[Dict[str, Any]]) -> float: """Assess semantic relevance between answer and context (0.0-0.15 range).""" if not answer or not chunks: return 0.0 # Simple keyword overlap assessment answer_words = set(answer.lower().split()) context_words = set() for chunk in chunks: chunk_text = chunk.get("content", chunk.get("text", "")) context_words.update(chunk_text.lower().split()) if not context_words: return 0.0 # Calculate overlap ratio overlap = len(answer_words & context_words) total_unique = len(answer_words | context_words) if total_unique == 0: return 0.0 overlap_ratio = overlap / total_unique return min(0.15 * overlap_ratio, 0.15) def _detect_off_topic(self, answer: str, chunks: List[Dict[str, Any]]) -> float: """Detect if answer is off-topic (-0.4 penalty if off-topic).""" if not answer or not chunks: return 0.0 # Check for off-topic indicators off_topic_phrases = [ "but I have to say that the context doesn't seem to provide", "these documents appear to be focused on", "but they don't seem to cover", "I'd recommend consulting a different type of documentation", "without more context or information" ] answer_lower = answer.lower() for phrase in off_topic_phrases: if phrase in answer_lower: return -0.4 # Strong penalty for off-topic responses return 0.0 def _assess_answer_completeness(self, answer: str, chunk_count: int) -> float: """Assess answer completeness (0.0-0.1 range).""" if not answer: return 0.0 # Length-based completeness assessment answer_length = len(answer) if answer_length > 500: # Comprehensive answer return 0.1 elif answer_length > 200: # Adequate answer return 0.05 else: # Brief answer return 0.0 def generate(self, query: str, context: List[Document]) -> Answer: """ Generate an answer from query and context documents (standard interface). This is the public interface that conforms to the AnswerGenerator protocol. It handles the conversion between standard Document objects and Ollama's internal chunk format. Args: query: User's question context: List of relevant Document objects Returns: Answer object conforming to standard interface Raises: ValueError: If query is empty or context is None """ if not query.strip(): raise ValueError("Query cannot be empty") if context is None: raise ValueError("Context cannot be None") # Internal adapter: Convert Documents to Ollama chunk format ollama_chunks = self._documents_to_ollama_chunks(context) # Use existing Ollama-specific generation logic ollama_result = self._generate_internal(query, ollama_chunks) # Internal adapter: Convert Ollama result to standard Answer return self._ollama_result_to_answer(ollama_result, context) def _generate_internal(self, query: str, chunks: List[Dict[str, Any]]) -> GeneratedAnswer: """ Generate an answer based on the query and retrieved chunks. Args: query: User's question chunks: Retrieved document chunks Returns: GeneratedAnswer object with answer, citations, and metadata """ start_time = datetime.now() # Check for no-context situation if not chunks or all( len(chunk.get("content", chunk.get("text", ""))) < 20 for chunk in chunks ): return GeneratedAnswer( answer="This information isn't available in the provided documents.", citations=[], confidence_score=0.05, generation_time=0.1, model_used=self.model_name, context_used=chunks, ) # Format context context = self._format_context(chunks) # Create prompt with chunks parameter for dynamic citation instructions prompt = self._create_prompt(query, context, chunks) # Generate answer print( f"🤖 Calling Ollama with {self.model_name}...", file=sys.stderr, flush=True ) answer_with_citations = self._call_ollama(prompt) generation_time = (datetime.now() - start_time).total_seconds() # Extract citations and create natural answer natural_answer, citations = self._extract_citations( answer_with_citations, chunks ) # Calculate confidence confidence = self._calculate_confidence(natural_answer, citations, chunks) return GeneratedAnswer( answer=natural_answer, citations=citations, confidence_score=confidence, generation_time=generation_time, model_used=self.model_name, context_used=chunks, ) def generate_with_custom_prompt( self, query: str, chunks: List[Dict[str, Any]], custom_prompt: Dict[str, str] ) -> GeneratedAnswer: """ Generate answer using a custom prompt (for adaptive prompting). Args: query: User's question chunks: Retrieved context chunks custom_prompt: Dict with 'system' and 'user' prompts Returns: GeneratedAnswer with custom prompt enhancement """ start_time = datetime.now() if not chunks: return GeneratedAnswer( answer="I don't have enough context to answer your question.", citations=[], confidence_score=0.0, generation_time=0.1, model_used=self.model_name, context_used=chunks, ) # Build custom prompt based on model type if "llama" in self.model_name.lower(): prompt = f"""[INST] {custom_prompt['system']} {custom_prompt['user']} MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]""" elif "mistral" in self.model_name.lower(): prompt = f"""[INST] {custom_prompt['system']} {custom_prompt['user']} MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]""" else: # Generic format for other models prompt = f"""{custom_prompt['system']} {custom_prompt['user']} MANDATORY: Use [chunk_1], [chunk_2] etc. for all factual statements. Answer:""" # Generate answer print(f"🤖 Calling Ollama with custom prompt using {self.model_name}...", file=sys.stderr, flush=True) answer_with_citations = self._call_ollama(prompt) generation_time = (datetime.now() - start_time).total_seconds() # Extract citations and create natural answer natural_answer, citations = self._extract_citations(answer_with_citations, chunks) # Calculate confidence confidence = self._calculate_confidence(natural_answer, citations, chunks) return GeneratedAnswer( answer=natural_answer, citations=citations, confidence_score=confidence, generation_time=generation_time, model_used=self.model_name, context_used=chunks, ) def _documents_to_ollama_chunks(self, documents: List[Document]) -> List[Dict[str, Any]]: """ Convert Document objects to Ollama's internal chunk format. This internal adapter ensures that Document objects are properly formatted for Ollama's processing pipeline while keeping the format requirements encapsulated within this class. Args: documents: List of Document objects from the standard interface Returns: List of chunk dictionaries in Ollama's expected format """ if not documents: return [] chunks = [] for i, doc in enumerate(documents): chunk = { "id": f"chunk_{i+1}", "content": doc.content, # Ollama expects "content" field "text": doc.content, # Fallback field for compatibility "score": 1.0, # Default relevance score "metadata": { "source": doc.metadata.get("source", "unknown"), "page_number": doc.metadata.get("start_page", 1), **doc.metadata # Include all original metadata } } chunks.append(chunk) return chunks def _ollama_result_to_answer(self, ollama_result: GeneratedAnswer, original_context: List[Document]) -> Answer: """ Convert Ollama's GeneratedAnswer to the standard Answer format. This internal adapter converts Ollama's result format back to the standard interface format expected by the rest of the system. Args: ollama_result: Result from Ollama's internal generation original_context: Original Document objects for sources Returns: Answer object conforming to standard interface """ if not Answer: # Fallback if standard interface not available return ollama_result # Convert to standard Answer format return Answer( text=ollama_result.answer, sources=original_context, # Use original Document objects confidence=ollama_result.confidence_score, metadata={ "model_used": ollama_result.model_used, "generation_time": ollama_result.generation_time, "citations": [ { "chunk_id": cit.chunk_id, "page_number": cit.page_number, "source_file": cit.source_file, "relevance_score": cit.relevance_score, "text_snippet": cit.text_snippet } for cit in ollama_result.citations ], "provider": "ollama", "temperature": self.temperature, "max_tokens": self.max_tokens } ) # Example usage if __name__ == "__main__": # Test Ollama connection generator = OllamaAnswerGenerator(model_name="llama3.2:3b") # Mock chunks for testing test_chunks = [ { "content": "RISC-V is a free and open-source ISA.", "metadata": {"page_number": 1, "source": "riscv-spec.pdf"}, "score": 0.9, } ] # Test generation result = generator.generate("What is RISC-V?", test_chunks) print(f"Answer: {result.answer}") print(f"Confidence: {result.confidence_score:.2%}")