enhanced-rag-demo / shared_utils /generation /ollama_answer_generator.py
Arthur Passuello
Added missing sources
b5246f1
#!/usr/bin/env python3
"""
Ollama-based answer generator for local inference.
Provides the same interface as HuggingFaceAnswerGenerator but uses
local Ollama server for model inference.
"""
import time
import requests
import json
import re
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
# Import shared components
from .hf_answer_generator import Citation, GeneratedAnswer
from .prompt_templates import TechnicalPromptTemplates
# Import standard interfaces (add this for the adapter)
try:
from pathlib import Path
import sys
project_root = Path(__file__).parent.parent.parent.parent.parent
sys.path.append(str(project_root))
from src.core.interfaces import Document, Answer, AnswerGenerator
except ImportError:
# Fallback for standalone usage
Document = None
Answer = None
AnswerGenerator = object
class OllamaAnswerGenerator(AnswerGenerator if AnswerGenerator != object else object):
"""
Generates answers using local Ollama server.
Perfect for:
- Local development
- Privacy-sensitive applications
- No API rate limits
- Consistent performance
- Offline operation
"""
def __init__(
self,
model_name: str = "llama3.2:3b",
base_url: str = "http://localhost:11434",
temperature: float = 0.3,
max_tokens: int = 512,
):
"""
Initialize Ollama answer generator.
Args:
model_name: Ollama model to use (e.g., "llama3.2:3b", "mistral")
base_url: Ollama server URL
temperature: Generation temperature
max_tokens: Maximum tokens to generate
"""
self.model_name = model_name
self.base_url = base_url.rstrip("/")
self.temperature = temperature
self.max_tokens = max_tokens
# Test connection
self._test_connection()
def _test_connection(self):
"""Test if Ollama server is accessible."""
# Reduce retries for faster initialization - container should be ready quickly
max_retries = 12 # Wait up to 60 seconds for Ollama to start
retry_delay = 5
print(
f"🔧 Testing connection to {self.base_url}/api/tags...",
file=sys.stderr,
flush=True,
)
for attempt in range(max_retries):
try:
response = requests.get(f"{self.base_url}/api/tags", timeout=8)
if response.status_code == 200:
print(
f"✅ Connected to Ollama at {self.base_url}",
file=sys.stderr,
flush=True,
)
# Check if our model is available
models = response.json().get("models", [])
model_names = [m["name"] for m in models]
if self.model_name in model_names:
print(
f"✅ Model {self.model_name} is available",
file=sys.stderr,
flush=True,
)
return # Success!
else:
print(
f"⚠️ Model {self.model_name} not found. Available: {model_names}",
file=sys.stderr,
flush=True,
)
if models: # If any models are available, use the first one
fallback_model = model_names[0]
print(
f"🔄 Using fallback model: {fallback_model}",
file=sys.stderr,
flush=True,
)
self.model_name = fallback_model
return
else:
print(
f"📥 No models found, will try to pull {self.model_name}",
file=sys.stderr,
flush=True,
)
# Try to pull the model
self._pull_model(self.model_name)
return
else:
print(f"⚠️ Ollama server returned status {response.status_code}")
if attempt < max_retries - 1:
print(
f"🔄 Retry {attempt + 1}/{max_retries} in {retry_delay} seconds..."
)
time.sleep(retry_delay)
continue
except requests.exceptions.ConnectionError:
if attempt < max_retries - 1:
print(
f"⏳ Ollama not ready yet, retry {attempt + 1}/{max_retries} in {retry_delay} seconds..."
)
time.sleep(retry_delay)
continue
else:
raise Exception(
f"Cannot connect to Ollama server at {self.base_url} after 60 seconds. Check if it's running."
)
except requests.exceptions.Timeout:
if attempt < max_retries - 1:
print(f"⏳ Ollama timeout, retry {attempt + 1}/{max_retries}...")
time.sleep(retry_delay)
continue
else:
raise Exception("Ollama server timeout after multiple retries.")
except Exception as e:
if attempt < max_retries - 1:
print(f"⚠️ Ollama error: {e}, retry {attempt + 1}/{max_retries}...")
time.sleep(retry_delay)
continue
else:
raise Exception(
f"Ollama connection failed after {max_retries} attempts: {e}"
)
raise Exception("Failed to connect to Ollama after all retries")
def _pull_model(self, model_name: str):
"""Pull a model if it's not available."""
try:
print(f"📥 Pulling model {model_name}...")
pull_response = requests.post(
f"{self.base_url}/api/pull",
json={"name": model_name},
timeout=300, # 5 minutes for model download
)
if pull_response.status_code == 200:
print(f"✅ Successfully pulled {model_name}")
else:
print(f"⚠️ Failed to pull {model_name}: {pull_response.status_code}")
# Try smaller models as fallback
fallback_models = ["llama3.2:1b", "llama2:latest", "mistral:latest"]
for fallback in fallback_models:
try:
print(f"🔄 Trying fallback model: {fallback}")
fallback_response = requests.post(
f"{self.base_url}/api/pull",
json={"name": fallback},
timeout=300,
)
if fallback_response.status_code == 200:
print(f"✅ Successfully pulled fallback {fallback}")
self.model_name = fallback
return
except:
continue
raise Exception(f"Failed to pull {model_name} or any fallback models")
except Exception as e:
print(f"❌ Model pull failed: {e}")
raise
def _format_context(self, chunks: List[Dict[str, Any]]) -> str:
"""Format retrieved chunks into context."""
context_parts = []
for i, chunk in enumerate(chunks):
chunk_text = chunk.get("content", chunk.get("text", ""))
page_num = chunk.get("metadata", {}).get("page_number", "unknown")
source = chunk.get("metadata", {}).get("source", "unknown")
context_parts.append(
f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n"
)
return "\n---\n".join(context_parts)
def _create_prompt(self, query: str, context: str, chunks: List[Dict[str, Any]]) -> str:
"""Create optimized prompt with dynamic length constraints and citation instructions."""
# Get the appropriate template based on query type
prompt_data = TechnicalPromptTemplates.format_prompt_with_template(
query=query, context=context
)
# Create dynamic citation instructions based on available chunks
num_chunks = len(chunks)
available_chunks = ", ".join([f"[chunk_{i+1}]" for i in range(min(num_chunks, 5))]) # Show max 5 examples
# Create appropriate example based on actual chunks
if num_chunks == 1:
citation_example = "RISC-V is an open-source ISA [chunk_1]."
elif num_chunks == 2:
citation_example = "RISC-V is an open-source ISA [chunk_1] that supports multiple data widths [chunk_2]."
else:
citation_example = "RISC-V is an open-source ISA [chunk_1] that supports multiple data widths [chunk_2] and provides extensions [chunk_3]."
# Determine optimal answer length based on query complexity
target_length = self._determine_target_length(query, chunks)
length_instruction = self._create_length_instruction(target_length)
# Format for different model types
if "llama" in self.model_name.lower():
# Llama-3.2 format with technical prompt templates
return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{prompt_data['system']}
MANDATORY CITATION RULES:
- ONLY use available chunks: {available_chunks}
- You have {num_chunks} chunks available - DO NOT cite chunk numbers higher than {num_chunks}
- Every technical claim MUST have a citation from available chunks
- Example: "{citation_example}"
{length_instruction}
<|eot_id|><|start_header_id|>user<|end_header_id|>
{prompt_data['user']}
CRITICAL: You MUST cite sources ONLY from available chunks: {available_chunks}. DO NOT use chunk numbers > {num_chunks}.
{length_instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
elif "mistral" in self.model_name.lower():
# Mistral format with technical templates
return f"""[INST] {prompt_data['system']}
Context:
{context}
Question: {query}
MANDATORY: ONLY use available chunks: {available_chunks}. DO NOT cite chunk numbers > {num_chunks}.
{length_instruction} [/INST]"""
else:
# Generic format with technical templates
return f"""{prompt_data['system']}
Context:
{context}
Question: {query}
MANDATORY CITATIONS: ONLY use available chunks: {available_chunks}. DO NOT cite chunk numbers > {num_chunks}.
{length_instruction}
Answer:"""
def _determine_target_length(self, query: str, chunks: List[Dict[str, Any]]) -> int:
"""
Determine optimal answer length based on query complexity.
Target range: 150-400 characters (down from 1000-2600)
"""
# Analyze query complexity
query_length = len(query)
query_words = len(query.split())
# Check for complexity indicators
complex_words = [
"explain", "describe", "analyze", "compare", "contrast",
"evaluate", "discuss", "detail", "elaborate", "comprehensive"
]
simple_words = [
"what", "define", "list", "name", "identify", "is", "are"
]
query_lower = query.lower()
is_complex = any(word in query_lower for word in complex_words)
is_simple = any(word in query_lower for word in simple_words)
# Base length from query type
if is_complex:
base_length = 350 # Complex queries get longer answers
elif is_simple:
base_length = 200 # Simple queries get shorter answers
else:
base_length = 275 # Default middle ground
# Adjust based on available context
context_factor = min(len(chunks) * 25, 75) # More context allows longer answers
# Adjust based on query length
query_factor = min(query_words * 5, 50) # Longer queries allow longer answers
target_length = base_length + context_factor + query_factor
# Constrain to target range
return max(150, min(target_length, 400))
def _create_length_instruction(self, target_length: int) -> str:
"""Create length instruction based on target length."""
if target_length <= 200:
return f"ANSWER LENGTH: Keep your answer concise and focused, approximately {target_length} characters. Be direct and to the point."
elif target_length <= 300:
return f"ANSWER LENGTH: Provide a clear and informative answer, approximately {target_length} characters. Include key details but avoid unnecessary elaboration."
else:
return f"ANSWER LENGTH: Provide a comprehensive but concise answer, approximately {target_length} characters. Include important details while maintaining clarity."
def _call_ollama(self, prompt: str) -> str:
"""Call Ollama API for generation."""
payload = {
"model": self.model_name,
"prompt": prompt,
"stream": False,
"options": {
"temperature": self.temperature,
"num_predict": self.max_tokens,
"top_p": 0.9,
"repeat_penalty": 1.1,
},
}
try:
response = requests.post(
f"{self.base_url}/api/generate", json=payload, timeout=300
)
response.raise_for_status()
result = response.json()
return result.get("response", "").strip()
except requests.exceptions.RequestException as e:
print(f"❌ Ollama API error: {e}")
return f"Error communicating with Ollama: {str(e)}"
except Exception as e:
print(f"❌ Unexpected error: {e}")
return f"Unexpected error: {str(e)}"
def _extract_citations(
self, answer: str, chunks: List[Dict[str, Any]]
) -> Tuple[str, List[Citation]]:
"""Extract citations from the generated answer."""
citations = []
citation_pattern = r"\[chunk_(\d+)\]"
cited_chunks = set()
# Find [chunk_X] citations
matches = re.finditer(citation_pattern, answer)
for match in matches:
chunk_idx = int(match.group(1)) - 1 # Convert to 0-based index
if 0 <= chunk_idx < len(chunks):
cited_chunks.add(chunk_idx)
# FALLBACK: If no explicit citations found but we have an answer and chunks,
# create citations for the top chunks that were likely used
if not cited_chunks and chunks and len(answer.strip()) > 50:
# Use the top chunks that were provided as likely sources
num_fallback_citations = min(3, len(chunks)) # Use top 3 chunks max
cited_chunks = set(range(num_fallback_citations))
print(
f"🔧 Fallback: Creating {num_fallback_citations} citations for answer without explicit [chunk_X] references",
file=sys.stderr,
flush=True,
)
# Create Citation objects
chunk_to_source = {}
for idx in cited_chunks:
chunk = chunks[idx]
citation = Citation(
chunk_id=chunk.get("id", f"chunk_{idx}"),
page_number=chunk.get("metadata", {}).get("page_number", 0),
source_file=chunk.get("metadata", {}).get("source", "unknown"),
relevance_score=chunk.get("score", 0.0),
text_snippet=chunk.get("content", chunk.get("text", ""))[:200] + "...",
)
citations.append(citation)
# Don't replace chunk references - keep them as proper citations
# The issue was that replacing [chunk_X] with "the documentation" creates repetitive text
# Instead, we should keep the proper citation format
pass
# Keep the answer as-is with proper [chunk_X] citations
# Don't replace citations with repetitive text
natural_answer = re.sub(r"\s+", " ", answer).strip()
return natural_answer, citations
def _calculate_confidence(
self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]
) -> float:
"""
Calculate confidence score with expanded multi-factor assessment.
Enhanced algorithm expands range from 0.75-0.95 to 0.3-0.9 with:
- Context quality assessment
- Citation quality evaluation
- Semantic relevance scoring
- Off-topic detection
- Answer completeness analysis
"""
if not answer or len(answer.strip()) < 10:
return 0.1
# 1. Context Quality Assessment (0.3-0.6 base range)
context_quality = self._assess_context_quality(chunks)
# 2. Citation Quality Evaluation (0.0-0.2 boost)
citation_quality = self._assess_citation_quality(citations, chunks)
# 3. Semantic Relevance Scoring (0.0-0.15 boost)
semantic_relevance = self._assess_semantic_relevance(answer, chunks)
# 4. Off-topic Detection (-0.4 penalty if off-topic)
off_topic_penalty = self._detect_off_topic(answer, chunks)
# 5. Answer Completeness Analysis (0.0-0.1 boost)
completeness_bonus = self._assess_answer_completeness(answer, len(chunks))
# Combine all factors
confidence = (
context_quality +
citation_quality +
semantic_relevance +
completeness_bonus +
off_topic_penalty
)
# Apply uncertainty penalty
uncertainty_phrases = [
"insufficient information",
"cannot determine",
"not available in the provided documents",
"I don't have enough context",
"the context doesn't seem to provide"
]
if any(phrase in answer.lower() for phrase in uncertainty_phrases):
confidence *= 0.4 # Stronger penalty for uncertainty
# Constrain to target range 0.3-0.9
return max(0.3, min(confidence, 0.9))
def _assess_context_quality(self, chunks: List[Dict[str, Any]]) -> float:
"""Assess quality of context chunks (0.3-0.6 range)."""
if not chunks:
return 0.3
# Base score from chunk count
if len(chunks) >= 3:
base_score = 0.6
elif len(chunks) >= 2:
base_score = 0.5
else:
base_score = 0.4
# Quality adjustments based on chunk content
avg_chunk_length = sum(len(chunk.get("content", chunk.get("text", ""))) for chunk in chunks) / len(chunks)
if avg_chunk_length > 500: # Rich content
base_score += 0.05
elif avg_chunk_length < 100: # Sparse content
base_score -= 0.05
return max(0.3, min(base_score, 0.6))
def _assess_citation_quality(self, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float:
"""Assess citation quality (0.0-0.2 range)."""
if not citations or not chunks:
return 0.0
# Citation coverage bonus
citation_ratio = len(citations) / min(len(chunks), 3)
coverage_bonus = 0.1 * citation_ratio
# Citation diversity bonus (multiple sources)
unique_sources = len(set(cit.source_file for cit in citations))
diversity_bonus = 0.05 * min(unique_sources / max(len(chunks), 1), 1.0)
return min(coverage_bonus + diversity_bonus, 0.2)
def _assess_semantic_relevance(self, answer: str, chunks: List[Dict[str, Any]]) -> float:
"""Assess semantic relevance between answer and context (0.0-0.15 range)."""
if not answer or not chunks:
return 0.0
# Simple keyword overlap assessment
answer_words = set(answer.lower().split())
context_words = set()
for chunk in chunks:
chunk_text = chunk.get("content", chunk.get("text", ""))
context_words.update(chunk_text.lower().split())
if not context_words:
return 0.0
# Calculate overlap ratio
overlap = len(answer_words & context_words)
total_unique = len(answer_words | context_words)
if total_unique == 0:
return 0.0
overlap_ratio = overlap / total_unique
return min(0.15 * overlap_ratio, 0.15)
def _detect_off_topic(self, answer: str, chunks: List[Dict[str, Any]]) -> float:
"""Detect if answer is off-topic (-0.4 penalty if off-topic)."""
if not answer or not chunks:
return 0.0
# Check for off-topic indicators
off_topic_phrases = [
"but I have to say that the context doesn't seem to provide",
"these documents appear to be focused on",
"but they don't seem to cover",
"I'd recommend consulting a different type of documentation",
"without more context or information"
]
answer_lower = answer.lower()
for phrase in off_topic_phrases:
if phrase in answer_lower:
return -0.4 # Strong penalty for off-topic responses
return 0.0
def _assess_answer_completeness(self, answer: str, chunk_count: int) -> float:
"""Assess answer completeness (0.0-0.1 range)."""
if not answer:
return 0.0
# Length-based completeness assessment
answer_length = len(answer)
if answer_length > 500: # Comprehensive answer
return 0.1
elif answer_length > 200: # Adequate answer
return 0.05
else: # Brief answer
return 0.0
def generate(self, query: str, context: List[Document]) -> Answer:
"""
Generate an answer from query and context documents (standard interface).
This is the public interface that conforms to the AnswerGenerator protocol.
It handles the conversion between standard Document objects and Ollama's
internal chunk format.
Args:
query: User's question
context: List of relevant Document objects
Returns:
Answer object conforming to standard interface
Raises:
ValueError: If query is empty or context is None
"""
if not query.strip():
raise ValueError("Query cannot be empty")
if context is None:
raise ValueError("Context cannot be None")
# Internal adapter: Convert Documents to Ollama chunk format
ollama_chunks = self._documents_to_ollama_chunks(context)
# Use existing Ollama-specific generation logic
ollama_result = self._generate_internal(query, ollama_chunks)
# Internal adapter: Convert Ollama result to standard Answer
return self._ollama_result_to_answer(ollama_result, context)
def _generate_internal(self, query: str, chunks: List[Dict[str, Any]]) -> GeneratedAnswer:
"""
Generate an answer based on the query and retrieved chunks.
Args:
query: User's question
chunks: Retrieved document chunks
Returns:
GeneratedAnswer object with answer, citations, and metadata
"""
start_time = datetime.now()
# Check for no-context situation
if not chunks or all(
len(chunk.get("content", chunk.get("text", ""))) < 20 for chunk in chunks
):
return GeneratedAnswer(
answer="This information isn't available in the provided documents.",
citations=[],
confidence_score=0.05,
generation_time=0.1,
model_used=self.model_name,
context_used=chunks,
)
# Format context
context = self._format_context(chunks)
# Create prompt with chunks parameter for dynamic citation instructions
prompt = self._create_prompt(query, context, chunks)
# Generate answer
print(
f"🤖 Calling Ollama with {self.model_name}...", file=sys.stderr, flush=True
)
answer_with_citations = self._call_ollama(prompt)
generation_time = (datetime.now() - start_time).total_seconds()
# Extract citations and create natural answer
natural_answer, citations = self._extract_citations(
answer_with_citations, chunks
)
# Calculate confidence
confidence = self._calculate_confidence(natural_answer, citations, chunks)
return GeneratedAnswer(
answer=natural_answer,
citations=citations,
confidence_score=confidence,
generation_time=generation_time,
model_used=self.model_name,
context_used=chunks,
)
def generate_with_custom_prompt(
self,
query: str,
chunks: List[Dict[str, Any]],
custom_prompt: Dict[str, str]
) -> GeneratedAnswer:
"""
Generate answer using a custom prompt (for adaptive prompting).
Args:
query: User's question
chunks: Retrieved context chunks
custom_prompt: Dict with 'system' and 'user' prompts
Returns:
GeneratedAnswer with custom prompt enhancement
"""
start_time = datetime.now()
if not chunks:
return GeneratedAnswer(
answer="I don't have enough context to answer your question.",
citations=[],
confidence_score=0.0,
generation_time=0.1,
model_used=self.model_name,
context_used=chunks,
)
# Build custom prompt based on model type
if "llama" in self.model_name.lower():
prompt = f"""[INST] {custom_prompt['system']}
{custom_prompt['user']}
MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
elif "mistral" in self.model_name.lower():
prompt = f"""[INST] {custom_prompt['system']}
{custom_prompt['user']}
MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
else:
# Generic format for other models
prompt = f"""{custom_prompt['system']}
{custom_prompt['user']}
MANDATORY: Use [chunk_1], [chunk_2] etc. for all factual statements.
Answer:"""
# Generate answer
print(f"🤖 Calling Ollama with custom prompt using {self.model_name}...", file=sys.stderr, flush=True)
answer_with_citations = self._call_ollama(prompt)
generation_time = (datetime.now() - start_time).total_seconds()
# Extract citations and create natural answer
natural_answer, citations = self._extract_citations(answer_with_citations, chunks)
# Calculate confidence
confidence = self._calculate_confidence(natural_answer, citations, chunks)
return GeneratedAnswer(
answer=natural_answer,
citations=citations,
confidence_score=confidence,
generation_time=generation_time,
model_used=self.model_name,
context_used=chunks,
)
def _documents_to_ollama_chunks(self, documents: List[Document]) -> List[Dict[str, Any]]:
"""
Convert Document objects to Ollama's internal chunk format.
This internal adapter ensures that Document objects are properly formatted
for Ollama's processing pipeline while keeping the format requirements
encapsulated within this class.
Args:
documents: List of Document objects from the standard interface
Returns:
List of chunk dictionaries in Ollama's expected format
"""
if not documents:
return []
chunks = []
for i, doc in enumerate(documents):
chunk = {
"id": f"chunk_{i+1}",
"content": doc.content, # Ollama expects "content" field
"text": doc.content, # Fallback field for compatibility
"score": 1.0, # Default relevance score
"metadata": {
"source": doc.metadata.get("source", "unknown"),
"page_number": doc.metadata.get("start_page", 1),
**doc.metadata # Include all original metadata
}
}
chunks.append(chunk)
return chunks
def _ollama_result_to_answer(self, ollama_result: GeneratedAnswer, original_context: List[Document]) -> Answer:
"""
Convert Ollama's GeneratedAnswer to the standard Answer format.
This internal adapter converts Ollama's result format back to the
standard interface format expected by the rest of the system.
Args:
ollama_result: Result from Ollama's internal generation
original_context: Original Document objects for sources
Returns:
Answer object conforming to standard interface
"""
if not Answer:
# Fallback if standard interface not available
return ollama_result
# Convert to standard Answer format
return Answer(
text=ollama_result.answer,
sources=original_context, # Use original Document objects
confidence=ollama_result.confidence_score,
metadata={
"model_used": ollama_result.model_used,
"generation_time": ollama_result.generation_time,
"citations": [
{
"chunk_id": cit.chunk_id,
"page_number": cit.page_number,
"source_file": cit.source_file,
"relevance_score": cit.relevance_score,
"text_snippet": cit.text_snippet
}
for cit in ollama_result.citations
],
"provider": "ollama",
"temperature": self.temperature,
"max_tokens": self.max_tokens
}
)
# Example usage
if __name__ == "__main__":
# Test Ollama connection
generator = OllamaAnswerGenerator(model_name="llama3.2:3b")
# Mock chunks for testing
test_chunks = [
{
"content": "RISC-V is a free and open-source ISA.",
"metadata": {"page_number": 1, "source": "riscv-spec.pdf"},
"score": 0.9,
}
]
# Test generation
result = generator.generate("What is RISC-V?", test_chunks)
print(f"Answer: {result.answer}")
print(f"Confidence: {result.confidence_score:.2%}")