Spaces:
Running
Running
""" | |
Enhanced RAG system with integrated answer generation. | |
This module extends BasicRAG to include answer generation capabilities | |
using local LLMs via Ollama, with specialized prompt templates for | |
technical documentation. | |
""" | |
import time | |
from pathlib import Path | |
from typing import Dict, List, Optional, Union, Generator | |
import sys | |
# Import from same directory | |
from src.basic_rag import BasicRAG | |
# Import from shared utils - Support HF API, Ollama, and Inference Providers | |
from src.shared_utils.generation.hf_answer_generator import HuggingFaceAnswerGenerator, GeneratedAnswer | |
from src.shared_utils.generation.ollama_answer_generator import OllamaAnswerGenerator | |
from src.shared_utils.generation.inference_providers_generator import InferenceProvidersGenerator | |
from src.shared_utils.generation.prompt_templates import TechnicalPromptTemplates | |
class RAGWithGeneration(BasicRAG): | |
""" | |
Extended RAG system with answer generation capabilities. | |
Combines hybrid search with LLM-based answer generation, | |
optimized for technical documentation Q&A. | |
""" | |
def __init__( | |
self, | |
model_name: str = "sshleifer/distilbart-cnn-12-6", | |
api_token: str = None, | |
temperature: float = 0.3, | |
max_tokens: int = 512, | |
use_ollama: bool = False, | |
ollama_url: str = "http://localhost:11434", | |
use_inference_providers: bool = False | |
): | |
""" | |
Initialize RAG with generation capabilities. | |
Args: | |
model_name: Model for generation (HF model or Ollama model) | |
api_token: HF API token (for HF models only) | |
temperature: Generation temperature | |
max_tokens: Maximum tokens to generate | |
use_ollama: If True, use local Ollama instead of HuggingFace API | |
ollama_url: Ollama server URL (if using Ollama) | |
use_inference_providers: If True, use new Inference Providers API | |
""" | |
super().__init__() | |
# Choose generator based on configuration with fallback chain | |
if use_inference_providers: | |
# Try new Inference Providers API first | |
print(f"🚀 Trying HuggingFace Inference Providers API...", file=sys.stderr, flush=True) | |
try: | |
self.answer_generator = InferenceProvidersGenerator( | |
model_name=None, # Let it auto-select best available model | |
api_token=api_token, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
print(f"✅ Inference Providers API connected successfully", file=sys.stderr, flush=True) | |
self._using_ollama = False | |
self._using_inference_providers = True | |
except Exception as e: | |
print(f"❌ Inference Providers failed: {e}", file=sys.stderr, flush=True) | |
print(f"🔄 Falling back to classic HuggingFace API...", file=sys.stderr, flush=True) | |
# Fallback to classic HF API | |
self.answer_generator = HuggingFaceAnswerGenerator( | |
model_name=model_name, | |
api_token=api_token, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
print(f"✅ HuggingFace classic API ready", file=sys.stderr, flush=True) | |
self._using_ollama = False | |
self._using_inference_providers = False | |
elif use_ollama: | |
print(f"🦙 Trying local Ollama server at {ollama_url}...", file=sys.stderr, flush=True) | |
try: | |
self.answer_generator = OllamaAnswerGenerator( | |
model_name=model_name, | |
base_url=ollama_url, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
print(f"✅ Ollama connected successfully with {model_name}", file=sys.stderr, flush=True) | |
self._using_ollama = True | |
self._using_inference_providers = False | |
except Exception as e: | |
print(f"❌ Ollama failed: {e}", file=sys.stderr, flush=True) | |
print(f"🔄 Falling back to HuggingFace API...", file=sys.stderr, flush=True) | |
# Fallback to HuggingFace | |
hf_model = "sshleifer/distilbart-cnn-12-6" | |
self.answer_generator = HuggingFaceAnswerGenerator( | |
model_name=hf_model, | |
api_token=api_token, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
print(f"✅ HuggingFace fallback ready with {hf_model}", file=sys.stderr, flush=True) | |
self._using_ollama = False | |
self._using_inference_providers = False | |
else: | |
print("🤗 Using HuggingFace classic API...", file=sys.stderr, flush=True) | |
self.answer_generator = HuggingFaceAnswerGenerator( | |
model_name=model_name, | |
api_token=api_token, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
self._using_ollama = False | |
self._using_inference_providers = False | |
self.prompt_templates = TechnicalPromptTemplates() | |
self.enable_streaming = False # HF API doesn't support streaming in this implementation | |
def get_generator_info(self) -> Dict[str, str]: | |
"""Get information about the current answer generator.""" | |
return { | |
"using_ollama": getattr(self, '_using_ollama', False), | |
"using_inference_providers": getattr(self, '_using_inference_providers', False), | |
"generator_type": type(self.answer_generator).__name__, | |
"model_name": getattr(self.answer_generator, 'model_name', 'unknown'), | |
"base_url": getattr(self.answer_generator, 'base_url', None) | |
} | |
def query_with_answer( | |
self, | |
question: str, | |
top_k: int = 5, | |
use_hybrid: bool = True, | |
dense_weight: float = 0.7, | |
use_fallback_llm: bool = False, | |
return_context: bool = False, | |
similarity_threshold: float = 0.3 | |
) -> Dict: | |
""" | |
Query the system and generate a complete answer. | |
Args: | |
question: User's question | |
top_k: Number of chunks to retrieve | |
use_hybrid: Whether to use hybrid search (vs basic semantic) | |
dense_weight: Weight for dense retrieval in hybrid search | |
use_fallback_llm: Whether to use fallback LLM model | |
return_context: Whether to include retrieved chunks in response | |
similarity_threshold: Minimum similarity score to include results (0.3 = 30%) | |
Returns: | |
Dict containing: | |
- answer: Generated answer text | |
- citations: List of citations with sources | |
- confidence: Confidence score | |
- sources: List of unique source documents | |
- retrieval_stats: Statistics from retrieval | |
- generation_stats: Statistics from generation | |
- context (optional): Retrieved chunks if requested | |
""" | |
start_time = time.time() | |
# Debug: Show which generator is being used | |
generator_info = self.get_generator_info() | |
print(f"🔧 Debug: Using {generator_info['generator_type']} (Ollama: {generator_info['using_ollama']}) with model {generator_info['model_name']}", file=sys.stderr, flush=True) | |
# Step 1: Retrieve relevant chunks | |
if use_hybrid and self.hybrid_retriever is not None: | |
retrieval_result = self.hybrid_query(question, top_k, dense_weight, similarity_threshold) | |
else: | |
retrieval_result = self.query(question, top_k, similarity_threshold) | |
retrieval_time = time.time() - start_time | |
# Step 2: Generate answer using retrieved chunks | |
chunks = retrieval_result.get("chunks", []) | |
if not chunks: | |
return { | |
"answer": "I couldn't find relevant information in the documentation to answer your question.", | |
"citations": [], | |
"confidence": 0.0, | |
"sources": [], | |
"retrieval_stats": { | |
"method": retrieval_result.get("retrieval_method", "none"), | |
"chunks_retrieved": 0, | |
"retrieval_time": retrieval_time | |
}, | |
"generation_stats": { | |
"model": "none", | |
"generation_time": 0.0 | |
} | |
} | |
# Prepare chunks for answer generator | |
formatted_chunks = [] | |
for chunk in chunks: | |
formatted_chunk = { | |
"id": f"chunk_{chunk.get('chunk_id', 0)}", | |
"content": chunk.get("text", ""), | |
"metadata": { | |
"page_number": chunk.get("page", 0), | |
"source": Path(chunk.get("source", "unknown")).name, | |
"quality_score": chunk.get("quality_score", 0.0) | |
}, | |
"score": chunk.get("hybrid_score", chunk.get("similarity_score", 0.0)) | |
} | |
formatted_chunks.append(formatted_chunk) | |
# Generate answer | |
generation_start = time.time() | |
generated_answer = self.answer_generator.generate( | |
query=question, | |
chunks=formatted_chunks | |
) | |
generation_time = time.time() - generation_start | |
# Prepare response | |
response = { | |
"answer": generated_answer.answer, | |
"citations": [ | |
{ | |
"source": citation.source_file, | |
"page": citation.page_number, | |
"relevance": citation.relevance_score, | |
"snippet": citation.text_snippet | |
} | |
for citation in generated_answer.citations | |
], | |
"confidence": generated_answer.confidence_score, | |
"sources": list(set(chunk.get("source", "unknown") for chunk in chunks)), | |
"retrieval_stats": { | |
"method": retrieval_result.get("retrieval_method", "semantic"), | |
"chunks_retrieved": len(chunks), | |
"retrieval_time": retrieval_time, | |
"dense_weight": retrieval_result.get("dense_weight", 1.0), | |
"sparse_weight": retrieval_result.get("sparse_weight", 0.0) | |
}, | |
"generation_stats": { | |
"model": generated_answer.model_used, | |
"generation_time": generation_time, | |
"total_time": time.time() - start_time | |
} | |
} | |
# Optionally include context | |
if return_context: | |
response["context"] = chunks | |
return response | |
def query_with_answer_stream( | |
self, | |
question: str, | |
top_k: int = 5, | |
use_hybrid: bool = True, | |
dense_weight: float = 0.7, | |
use_fallback_llm: bool = False | |
) -> Generator[Union[str, Dict], None, None]: | |
""" | |
Query the system and stream the answer generation. | |
Args: | |
question: User's question | |
top_k: Number of chunks to retrieve | |
use_hybrid: Whether to use hybrid search | |
dense_weight: Weight for dense retrieval | |
use_fallback_llm: Whether to use fallback LLM | |
Yields: | |
Partial answer strings during generation | |
Returns: | |
Final complete response dict (after generator exhaustion) | |
""" | |
if not self.enable_streaming: | |
# Fall back to non-streaming if disabled | |
result = self.query_with_answer( | |
question, top_k, use_hybrid, dense_weight, use_fallback_llm | |
) | |
yield result["answer"] | |
yield result | |
return | |
start_time = time.time() | |
# Step 1: Retrieve relevant chunks | |
if use_hybrid and self.hybrid_retriever is not None: | |
retrieval_result = self.hybrid_query(question, top_k, dense_weight, similarity_threshold) | |
else: | |
retrieval_result = self.query(question, top_k, similarity_threshold) | |
retrieval_time = time.time() - start_time | |
# Step 2: Stream answer generation | |
chunks = retrieval_result.get("chunks", []) | |
if not chunks: | |
yield "I couldn't find relevant information in the documentation to answer your question." | |
yield { | |
"answer": "I couldn't find relevant information in the documentation to answer your question.", | |
"citations": [], | |
"confidence": 0.0, | |
"sources": [], | |
"retrieval_stats": {"chunks_retrieved": 0, "retrieval_time": retrieval_time} | |
} | |
return | |
# Prepare chunks | |
formatted_chunks = [] | |
for chunk in chunks: | |
formatted_chunk = { | |
"id": f"chunk_{chunk.get('chunk_id', 0)}", | |
"content": chunk.get("text", ""), | |
"metadata": { | |
"page_number": chunk.get("page", 0), | |
"source": Path(chunk.get("source", "unknown")).name, | |
"quality_score": chunk.get("quality_score", 0.0) | |
}, | |
"score": chunk.get("hybrid_score", chunk.get("similarity_score", 0.0)) | |
} | |
formatted_chunks.append(formatted_chunk) | |
# Stream generation | |
generation_start = time.time() | |
stream_generator = self.answer_generator.generate_stream( | |
query=question, | |
chunks=formatted_chunks, | |
use_fallback=use_fallback_llm | |
) | |
# Stream partial results | |
for partial in stream_generator: | |
if isinstance(partial, str): | |
yield partial | |
elif isinstance(partial, GeneratedAnswer): | |
# Final result | |
generation_time = time.time() - generation_start | |
final_response = { | |
"answer": partial.answer, | |
"citations": [ | |
{ | |
"source": citation.source_file, | |
"page": citation.page_number, | |
"relevance": citation.relevance_score, | |
"snippet": citation.text_snippet | |
} | |
for citation in partial.citations | |
], | |
"confidence": partial.confidence_score, | |
"sources": list(set(chunk.get("source", "unknown") for chunk in chunks)), | |
"retrieval_stats": { | |
"method": retrieval_result.get("retrieval_method", "semantic"), | |
"chunks_retrieved": len(chunks), | |
"retrieval_time": retrieval_time | |
}, | |
"generation_stats": { | |
"model": partial.model_used, | |
"generation_time": generation_time, | |
"total_time": time.time() - start_time | |
} | |
} | |
yield final_response | |
def get_formatted_answer(self, response: Dict) -> str: | |
""" | |
Format a query response for display. | |
Args: | |
response: Response dict from query_with_answer | |
Returns: | |
Formatted string for display | |
""" | |
formatted = f"**Answer:**\n{response['answer']}\n\n" | |
if response['citations']: | |
formatted += "**Sources:**\n" | |
for i, citation in enumerate(response['citations'], 1): | |
formatted += f"{i}. {citation['source']} (Page {citation['page']})\n" | |
formatted += f"\n*Confidence: {response['confidence']:.1%} | " | |
formatted += f"Model: {response['generation_stats']['model']} | " | |
formatted += f"Time: {response['generation_stats']['total_time']:.2f}s*" | |
return formatted | |
# Example usage | |
if __name__ == "__main__": | |
# Initialize RAG with generation | |
rag = RAGWithGeneration() | |
# Example query (would need indexed documents first) | |
question = "What is RISC-V and what are its main features?" | |
print("Initializing system...") | |
print(f"Primary model: llama3.2:3b") | |
print(f"Fallback model: mistral:latest") | |
# Note: This would only work after indexing documents | |
# Example of how to use: | |
# rag.index_document(Path("path/to/document.pdf")) | |
# result = rag.query_with_answer(question) | |
# print(rag.get_formatted_answer(result)) |