technical-rag-assistant / src /rag_with_generation.py
Arthur Passuello
Fixed text display and set relevance threshold
1f4f2f0
"""
Enhanced RAG system with integrated answer generation.
This module extends BasicRAG to include answer generation capabilities
using local LLMs via Ollama, with specialized prompt templates for
technical documentation.
"""
import time
from pathlib import Path
from typing import Dict, List, Optional, Union, Generator
import sys
# Import from same directory
from src.basic_rag import BasicRAG
# Import from shared utils - Support HF API, Ollama, and Inference Providers
from src.shared_utils.generation.hf_answer_generator import HuggingFaceAnswerGenerator, GeneratedAnswer
from src.shared_utils.generation.ollama_answer_generator import OllamaAnswerGenerator
from src.shared_utils.generation.inference_providers_generator import InferenceProvidersGenerator
from src.shared_utils.generation.prompt_templates import TechnicalPromptTemplates
class RAGWithGeneration(BasicRAG):
"""
Extended RAG system with answer generation capabilities.
Combines hybrid search with LLM-based answer generation,
optimized for technical documentation Q&A.
"""
def __init__(
self,
model_name: str = "sshleifer/distilbart-cnn-12-6",
api_token: str = None,
temperature: float = 0.3,
max_tokens: int = 512,
use_ollama: bool = False,
ollama_url: str = "http://localhost:11434",
use_inference_providers: bool = False
):
"""
Initialize RAG with generation capabilities.
Args:
model_name: Model for generation (HF model or Ollama model)
api_token: HF API token (for HF models only)
temperature: Generation temperature
max_tokens: Maximum tokens to generate
use_ollama: If True, use local Ollama instead of HuggingFace API
ollama_url: Ollama server URL (if using Ollama)
use_inference_providers: If True, use new Inference Providers API
"""
super().__init__()
# Choose generator based on configuration with fallback chain
if use_inference_providers:
# Try new Inference Providers API first
print(f"🚀 Trying HuggingFace Inference Providers API...", file=sys.stderr, flush=True)
try:
self.answer_generator = InferenceProvidersGenerator(
model_name=None, # Let it auto-select best available model
api_token=api_token,
temperature=temperature,
max_tokens=max_tokens
)
print(f"✅ Inference Providers API connected successfully", file=sys.stderr, flush=True)
self._using_ollama = False
self._using_inference_providers = True
except Exception as e:
print(f"❌ Inference Providers failed: {e}", file=sys.stderr, flush=True)
print(f"🔄 Falling back to classic HuggingFace API...", file=sys.stderr, flush=True)
# Fallback to classic HF API
self.answer_generator = HuggingFaceAnswerGenerator(
model_name=model_name,
api_token=api_token,
temperature=temperature,
max_tokens=max_tokens
)
print(f"✅ HuggingFace classic API ready", file=sys.stderr, flush=True)
self._using_ollama = False
self._using_inference_providers = False
elif use_ollama:
print(f"🦙 Trying local Ollama server at {ollama_url}...", file=sys.stderr, flush=True)
try:
self.answer_generator = OllamaAnswerGenerator(
model_name=model_name,
base_url=ollama_url,
temperature=temperature,
max_tokens=max_tokens
)
print(f"✅ Ollama connected successfully with {model_name}", file=sys.stderr, flush=True)
self._using_ollama = True
self._using_inference_providers = False
except Exception as e:
print(f"❌ Ollama failed: {e}", file=sys.stderr, flush=True)
print(f"🔄 Falling back to HuggingFace API...", file=sys.stderr, flush=True)
# Fallback to HuggingFace
hf_model = "sshleifer/distilbart-cnn-12-6"
self.answer_generator = HuggingFaceAnswerGenerator(
model_name=hf_model,
api_token=api_token,
temperature=temperature,
max_tokens=max_tokens
)
print(f"✅ HuggingFace fallback ready with {hf_model}", file=sys.stderr, flush=True)
self._using_ollama = False
self._using_inference_providers = False
else:
print("🤗 Using HuggingFace classic API...", file=sys.stderr, flush=True)
self.answer_generator = HuggingFaceAnswerGenerator(
model_name=model_name,
api_token=api_token,
temperature=temperature,
max_tokens=max_tokens
)
self._using_ollama = False
self._using_inference_providers = False
self.prompt_templates = TechnicalPromptTemplates()
self.enable_streaming = False # HF API doesn't support streaming in this implementation
def get_generator_info(self) -> Dict[str, str]:
"""Get information about the current answer generator."""
return {
"using_ollama": getattr(self, '_using_ollama', False),
"using_inference_providers": getattr(self, '_using_inference_providers', False),
"generator_type": type(self.answer_generator).__name__,
"model_name": getattr(self.answer_generator, 'model_name', 'unknown'),
"base_url": getattr(self.answer_generator, 'base_url', None)
}
def query_with_answer(
self,
question: str,
top_k: int = 5,
use_hybrid: bool = True,
dense_weight: float = 0.7,
use_fallback_llm: bool = False,
return_context: bool = False,
similarity_threshold: float = 0.3
) -> Dict:
"""
Query the system and generate a complete answer.
Args:
question: User's question
top_k: Number of chunks to retrieve
use_hybrid: Whether to use hybrid search (vs basic semantic)
dense_weight: Weight for dense retrieval in hybrid search
use_fallback_llm: Whether to use fallback LLM model
return_context: Whether to include retrieved chunks in response
similarity_threshold: Minimum similarity score to include results (0.3 = 30%)
Returns:
Dict containing:
- answer: Generated answer text
- citations: List of citations with sources
- confidence: Confidence score
- sources: List of unique source documents
- retrieval_stats: Statistics from retrieval
- generation_stats: Statistics from generation
- context (optional): Retrieved chunks if requested
"""
start_time = time.time()
# Debug: Show which generator is being used
generator_info = self.get_generator_info()
print(f"🔧 Debug: Using {generator_info['generator_type']} (Ollama: {generator_info['using_ollama']}) with model {generator_info['model_name']}", file=sys.stderr, flush=True)
# Step 1: Retrieve relevant chunks
if use_hybrid and self.hybrid_retriever is not None:
retrieval_result = self.hybrid_query(question, top_k, dense_weight, similarity_threshold)
else:
retrieval_result = self.query(question, top_k, similarity_threshold)
retrieval_time = time.time() - start_time
# Step 2: Generate answer using retrieved chunks
chunks = retrieval_result.get("chunks", [])
if not chunks:
return {
"answer": "I couldn't find relevant information in the documentation to answer your question.",
"citations": [],
"confidence": 0.0,
"sources": [],
"retrieval_stats": {
"method": retrieval_result.get("retrieval_method", "none"),
"chunks_retrieved": 0,
"retrieval_time": retrieval_time
},
"generation_stats": {
"model": "none",
"generation_time": 0.0
}
}
# Prepare chunks for answer generator
formatted_chunks = []
for chunk in chunks:
formatted_chunk = {
"id": f"chunk_{chunk.get('chunk_id', 0)}",
"content": chunk.get("text", ""),
"metadata": {
"page_number": chunk.get("page", 0),
"source": Path(chunk.get("source", "unknown")).name,
"quality_score": chunk.get("quality_score", 0.0)
},
"score": chunk.get("hybrid_score", chunk.get("similarity_score", 0.0))
}
formatted_chunks.append(formatted_chunk)
# Generate answer
generation_start = time.time()
generated_answer = self.answer_generator.generate(
query=question,
chunks=formatted_chunks
)
generation_time = time.time() - generation_start
# Prepare response
response = {
"answer": generated_answer.answer,
"citations": [
{
"source": citation.source_file,
"page": citation.page_number,
"relevance": citation.relevance_score,
"snippet": citation.text_snippet
}
for citation in generated_answer.citations
],
"confidence": generated_answer.confidence_score,
"sources": list(set(chunk.get("source", "unknown") for chunk in chunks)),
"retrieval_stats": {
"method": retrieval_result.get("retrieval_method", "semantic"),
"chunks_retrieved": len(chunks),
"retrieval_time": retrieval_time,
"dense_weight": retrieval_result.get("dense_weight", 1.0),
"sparse_weight": retrieval_result.get("sparse_weight", 0.0)
},
"generation_stats": {
"model": generated_answer.model_used,
"generation_time": generation_time,
"total_time": time.time() - start_time
}
}
# Optionally include context
if return_context:
response["context"] = chunks
return response
def query_with_answer_stream(
self,
question: str,
top_k: int = 5,
use_hybrid: bool = True,
dense_weight: float = 0.7,
use_fallback_llm: bool = False
) -> Generator[Union[str, Dict], None, None]:
"""
Query the system and stream the answer generation.
Args:
question: User's question
top_k: Number of chunks to retrieve
use_hybrid: Whether to use hybrid search
dense_weight: Weight for dense retrieval
use_fallback_llm: Whether to use fallback LLM
Yields:
Partial answer strings during generation
Returns:
Final complete response dict (after generator exhaustion)
"""
if not self.enable_streaming:
# Fall back to non-streaming if disabled
result = self.query_with_answer(
question, top_k, use_hybrid, dense_weight, use_fallback_llm
)
yield result["answer"]
yield result
return
start_time = time.time()
# Step 1: Retrieve relevant chunks
if use_hybrid and self.hybrid_retriever is not None:
retrieval_result = self.hybrid_query(question, top_k, dense_weight, similarity_threshold)
else:
retrieval_result = self.query(question, top_k, similarity_threshold)
retrieval_time = time.time() - start_time
# Step 2: Stream answer generation
chunks = retrieval_result.get("chunks", [])
if not chunks:
yield "I couldn't find relevant information in the documentation to answer your question."
yield {
"answer": "I couldn't find relevant information in the documentation to answer your question.",
"citations": [],
"confidence": 0.0,
"sources": [],
"retrieval_stats": {"chunks_retrieved": 0, "retrieval_time": retrieval_time}
}
return
# Prepare chunks
formatted_chunks = []
for chunk in chunks:
formatted_chunk = {
"id": f"chunk_{chunk.get('chunk_id', 0)}",
"content": chunk.get("text", ""),
"metadata": {
"page_number": chunk.get("page", 0),
"source": Path(chunk.get("source", "unknown")).name,
"quality_score": chunk.get("quality_score", 0.0)
},
"score": chunk.get("hybrid_score", chunk.get("similarity_score", 0.0))
}
formatted_chunks.append(formatted_chunk)
# Stream generation
generation_start = time.time()
stream_generator = self.answer_generator.generate_stream(
query=question,
chunks=formatted_chunks,
use_fallback=use_fallback_llm
)
# Stream partial results
for partial in stream_generator:
if isinstance(partial, str):
yield partial
elif isinstance(partial, GeneratedAnswer):
# Final result
generation_time = time.time() - generation_start
final_response = {
"answer": partial.answer,
"citations": [
{
"source": citation.source_file,
"page": citation.page_number,
"relevance": citation.relevance_score,
"snippet": citation.text_snippet
}
for citation in partial.citations
],
"confidence": partial.confidence_score,
"sources": list(set(chunk.get("source", "unknown") for chunk in chunks)),
"retrieval_stats": {
"method": retrieval_result.get("retrieval_method", "semantic"),
"chunks_retrieved": len(chunks),
"retrieval_time": retrieval_time
},
"generation_stats": {
"model": partial.model_used,
"generation_time": generation_time,
"total_time": time.time() - start_time
}
}
yield final_response
def get_formatted_answer(self, response: Dict) -> str:
"""
Format a query response for display.
Args:
response: Response dict from query_with_answer
Returns:
Formatted string for display
"""
formatted = f"**Answer:**\n{response['answer']}\n\n"
if response['citations']:
formatted += "**Sources:**\n"
for i, citation in enumerate(response['citations'], 1):
formatted += f"{i}. {citation['source']} (Page {citation['page']})\n"
formatted += f"\n*Confidence: {response['confidence']:.1%} | "
formatted += f"Model: {response['generation_stats']['model']} | "
formatted += f"Time: {response['generation_stats']['total_time']:.2f}s*"
return formatted
# Example usage
if __name__ == "__main__":
# Initialize RAG with generation
rag = RAGWithGeneration()
# Example query (would need indexed documents first)
question = "What is RISC-V and what are its main features?"
print("Initializing system...")
print(f"Primary model: llama3.2:3b")
print(f"Fallback model: mistral:latest")
# Note: This would only work after indexing documents
# Example of how to use:
# rag.index_document(Path("path/to/document.pdf"))
# result = rag.query_with_answer(question)
# print(rag.get_formatted_answer(result))