Spaces:
Sleeping
Sleeping
""" | |
Groq Medical RAG System v2.0 | |
FREE Groq Cloud API integration for advanced medical reasoning | |
""" | |
import os | |
import time | |
import logging | |
import numpy as np | |
from typing import List, Dict, Any, Optional, Tuple | |
from dataclasses import dataclass | |
from dotenv import load_dotenv | |
from pathlib import Path | |
import argparse | |
import shutil | |
import re | |
# Langchain for document loading and splitting | |
from langchain_community.document_loaders import UnstructuredMarkdownLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# Sentence Transformers for re-ranking | |
from sentence_transformers import CrossEncoder | |
# Groq API integration | |
from groq import Groq | |
from tenacity import retry, stop_after_attempt, wait_fixed, before_sleep_log | |
# Load environment variables from .env file | |
load_dotenv() | |
# Import our simplified components | |
from .simple_vector_store import SimpleVectorStore, SearchResult | |
class MedicalResponse: | |
"""Enhanced medical response structure""" | |
answer: str | |
confidence: float | |
sources: List[str] | |
query_time: float | |
class GroqMedicalRAG: | |
"""Groq-powered Medical RAG System v2.0 - FREE LLM integration""" | |
def __init__(self, | |
vector_store_dir: str = "simple_vector_store", | |
processed_docs_dir: str = "src/processed_markdown", | |
groq_api_key: Optional[str] = None): | |
"""Initialize the Groq medical RAG system""" | |
# Get the absolute path to the project root directory | |
project_root = Path(__file__).parent.parent.resolve() | |
self.vector_store_dir = project_root / vector_store_dir | |
self.processed_docs_dir = project_root / processed_docs_dir | |
# Initialize Groq client | |
self.groq_api_key = groq_api_key or os.getenv("GROQ_API_KEY") | |
if not self.groq_api_key: | |
raise ValueError("GROQ_API_KEY environment variable not set. Get your free API key from https://console.groq.com/keys") | |
self.groq_client = Groq(api_key=self.groq_api_key) | |
self.model_name = "llama3-70b-8192" | |
# Initialize Cross-Encoder for re-ranking | |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
# Initialize components | |
self.vector_store = None | |
self.setup_logging() | |
self._initialize_system() | |
def setup_logging(self): | |
"""Setup logging for the RAG system""" | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
self.logger = logging.getLogger(__name__) | |
def _initialize_system(self, force_recreate: bool = False): | |
"""Initialize the RAG system components""" | |
try: | |
# If forcing recreation, delete the old vector store | |
if force_recreate and self.vector_store_dir.exists(): | |
self.logger.warning(f"Recreating index as requested. Deleting {self.vector_store_dir}...") | |
shutil.rmtree(self.vector_store_dir) | |
# Initialize vector store | |
self.vector_store = SimpleVectorStore(vector_store_dir=self.vector_store_dir) | |
# Try to load existing vector store | |
if not self.vector_store.load_vector_store(): | |
self.logger.info("Creating new vector store from documents...") | |
self._create_vector_store() | |
else: | |
self.logger.info("Loaded existing vector store") | |
# Test Groq connection | |
self._test_groq_connection() | |
self.logger.info("Groq Medical RAG system initialized successfully") | |
except Exception as e: | |
self.logger.error(f"Error initializing RAG system: {e}") | |
raise | |
def _test_groq_connection(self): | |
"""Test Groq API connection with retry logic.""" | |
try: | |
self.groq_client.chat.completions.create( | |
model=self.model_name, | |
messages=[{"role": "user", "content": "Test"}], | |
max_tokens=10, | |
) | |
self.logger.info("✅ Groq API connection successful") | |
except Exception as e: | |
self.logger.error(f"❌ Groq API connection failed: {e}") | |
raise | |
def _create_vector_store(self): | |
"""Create vector store from processed markdown documents.""" | |
self.logger.info(f"Checking for documents in {self.processed_docs_dir}...") | |
doc_files = list(self.processed_docs_dir.glob("**/*.md")) | |
if not doc_files: | |
self.logger.error(f"No markdown files found in {self.processed_docs_dir}. Please run the enhanced_pdf_processor.py script first.") | |
raise FileNotFoundError(f"No markdown files found in {self.processed_docs_dir}") | |
self.logger.info(f"Found {len(doc_files)} markdown documents to process.") | |
# Load documents using UnstructuredMarkdownLoader | |
all_docs = [] | |
for doc_path in doc_files: | |
try: | |
loader = UnstructuredMarkdownLoader(str(doc_path)) | |
loaded_docs = loader.load() | |
# We still need to ensure the 'source' is present for our context string. | |
for doc in loaded_docs: | |
if 'source' not in doc.metadata: | |
doc.metadata['source'] = str(doc_path) | |
all_docs.extend(loaded_docs) | |
except Exception as e: | |
self.logger.error(f"Error loading {doc_path}: {e}") | |
if not all_docs: | |
self.logger.error("Failed to load any documents. Vector store not created.") | |
return | |
# Split documents into chunks with smaller size and overlap | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1024, # Reduced from 2048 | |
chunk_overlap=128, # Reduced from 256 | |
separators=["\n\n", "\n", " ", ""] | |
) | |
chunks = text_splitter.split_documents(all_docs) | |
self.logger.info(f"Created {len(chunks)} chunks from {len(all_docs)} documents.") | |
# Create embeddings and build index | |
embeddings, count = self.vector_store.create_embeddings(chunks) | |
self.vector_store.build_index(embeddings) | |
self.vector_store.save_vector_store() | |
self.logger.info(f"Created vector store with {count} embeddings.") | |
def query(self, | |
query: str, | |
history: Optional[List[Dict[str, str]]] = None, | |
k: int = 15, # Reduced from 30 | |
top_n_rerank: int = 3, # Reduced from 5 | |
use_llm: bool = True) -> MedicalResponse: | |
"""Query the Groq medical RAG system with re-ranking.""" | |
start_time = time.time() | |
# Stage 1: Initial retrieval from vector store | |
docs = self.vector_store.search(query=query, k=k) | |
if not docs: | |
return self._create_no_results_response(query) | |
# Stage 2: Re-ranking with Cross-Encoder | |
sentence_pairs = [[query, doc.content] for doc in docs] | |
scores = self.reranker.predict(sentence_pairs) | |
# Combine docs with scores and sort | |
doc_score_pairs = list(zip(docs, scores)) | |
doc_score_pairs.sort(key=lambda x: x[1], reverse=True) | |
# Select top N results after re-ranking | |
reranked_docs = [pair[0] for pair in doc_score_pairs[:top_n_rerank]] | |
reranked_scores = [pair[1] for pair in doc_score_pairs[:top_n_rerank]] | |
# Prepare context with rich metadata for the LLM | |
context_parts = [] | |
for i, doc in enumerate(reranked_docs, 1): | |
citation = doc.metadata.get('citation') | |
if not citation: | |
source_path = doc.metadata.get('source', 'Unknown') | |
citation = Path(source_path).parent.name | |
# Add reference number to citation | |
context_parts.append(f"[{i}] Citation: {citation}\\n\\nContent: {doc.content}") | |
context = "\\n\\n---\\n\\n".join(context_parts) | |
confidence = self._calculate_confidence(reranked_scores, use_llm) | |
# Use a set to get unique citations for display | |
sources = list(set([ | |
doc.metadata.get('citation', Path(doc.metadata.get('source', 'Unknown')).parent.name) | |
for doc in reranked_docs | |
])) | |
if use_llm: | |
# Phase 4: Persona-driven, structured response generation | |
system_prompt = ( | |
"You are 'VedaMD', a world-class medical expert and a compassionate assistant for healthcare professionals in Sri Lanka. " | |
"Your primary goal is to provide accurate, evidence-based clinical information based ONLY on the provided context, which is sourced from official Sri Lankan maternal health guidelines. " | |
"Your tone should be professional, clear, and supportive.\\n\\n" | |
"**CRITICAL INSTRUCTIONS:**\\n" | |
"1. **Strictly Context-Bound:** Your answer MUST be based exclusively on the 'Content' provided for each source. Do not use any external knowledge or provide information not present in the context.\\n" | |
"2. **Markdown Formatting:** Structure your answers for maximum clarity. Use markdown for formatting:\\n" | |
" - Use headings (`##`) for main topics.\\n" | |
" - Use bullet points (`-` or `*`) for lists of symptoms, recommendations, or steps.\\n" | |
" - Use bold (`**text**`) to emphasize key terms, dosages, or critical warnings.\\n" | |
"3. **Synthesize, Don't Just Copy:** Read all context pieces, synthesize the information, and provide a comprehensive answer. Do not repeat information.\\n" | |
"4. **Scientific Citations:** Use numbered citations [1], [2], etc. in your answer text to reference specific information. At the end, list all sources under a 'References:' heading in scientific format:\\n" | |
" [1] Title of Guideline/Document\\n" | |
" [2] Title of Another Guideline/Document\\n" | |
"5. **Disclaimer:** At the end of EVERY response, include the following disclaimer: '_This information is for clinical reference based on Sri Lankan guidelines and does not replace professional medical judgment._'" | |
) | |
return self._create_llm_response(system_prompt, context, query, confidence, sources, start_time, history) | |
else: | |
# If not using LLM, return context directly | |
return MedicalResponse( | |
answer=context, | |
confidence=confidence, | |
sources=sources, | |
query_time=time.time() - start_time | |
) | |
def _create_llm_response(self, system_prompt: str, context: str, query: str, confidence: float, sources: List[str], start_time: float, history: Optional[List[Dict[str, str]]] = None) -> MedicalResponse: | |
"""Helper to generate response from LLM.""" | |
try: | |
messages = [ | |
{ | |
"role": "system", | |
"content": system_prompt, | |
} | |
] | |
# Add conversation history to the messages | |
if history: | |
messages.extend(history) | |
# Add the current query | |
messages.append({"role": "user", "content": f"Context:\\n{context}\\n\\nQuestion: {query}"}) | |
chat_completion = self.groq_client.chat.completions.create( | |
messages=messages, | |
model=self.model_name, | |
temperature=0.7, | |
max_tokens=2048, | |
top_p=1, | |
stream=False | |
) | |
response_content = chat_completion.choices[0].message.content | |
return MedicalResponse( | |
answer=response_content, | |
confidence=confidence, | |
sources=sources, | |
query_time=time.time() - start_time, | |
) | |
except Exception as e: | |
self.logger.error(f"Error during Groq API call: {e}") | |
return MedicalResponse( | |
answer=f"Sorry, I encountered an error while generating the response: {e}", | |
confidence=0, | |
sources=sources, | |
query_time=time.time() - start_time | |
) | |
def _calculate_confidence(self, scores: List[float], use_llm: bool) -> float: | |
""" | |
Calculate confidence score based on re-ranked results. | |
For LLM responses, we can be more optimistic. | |
""" | |
if not scores: | |
return 0.0 | |
# Simple average of scores, scaled | |
avg_score = sum(scores) / len(scores) | |
# Sigmoid-like scaling for better confidence representation | |
confidence = 1 / (1 + np.exp(-avg_score)) | |
if use_llm: | |
return min(confidence * 1.2, 1.0) # Boost confidence for LLM | |
return confidence | |
def _create_no_results_response(self, query: str) -> MedicalResponse: | |
"""Helper for no results response""" | |
return MedicalResponse( | |
answer="No relevant documents found for your query. Please try rephrasing your question.", | |
confidence=0, | |
sources=[], | |
query_time=0 | |
) | |
def main(recreate_index: bool = False): | |
"""Main function to initialize and test the RAG system.""" | |
print("Initializing Groq Medical RAG system...") | |
try: | |
rag_system = GroqMedicalRAG() | |
if recreate_index: | |
print("Recreating index as requested...") | |
# Re-initialize with force_recreate=True | |
rag_system._initialize_system(force_recreate=True) | |
print("✅ Index recreated successfully.") | |
return # Exit after recreating index | |
print("✅ System initialized successfully.") | |
# Example query for testing | |
print("\\n--- Testing with an example query ---") | |
query = "What is the management for puerperal sepsis?" | |
print(f"Query: {query}") | |
response = rag_system.query(query) | |
print("\\n--- Response ---") | |
print(f"Answer: {response.answer}") | |
print(f"Confidence: {response.confidence:.2f}") | |
print(f"Sources: {response.sources}") | |
print(f"Query Time: {response.query_time:.2f}s") | |
print("--------------------\\n") | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Groq Medical RAG System CLI") | |
parser.add_argument( | |
"--recreate-index", | |
action="store_true", | |
help="If set, deletes the existing vector store and creates a new one." | |
) | |
args = parser.parse_args() | |
main(recreate_index=args.recreate_index) | |
async def main_async(recreate_index: bool = False): | |
# This function seems to be unused in the current context, but I'll add a pass to avoid syntax errors. | |
pass |